Commit 8056efa1 authored by novelailab's avatar novelailab

make the permute faster by doing it on batch

parent df7a70ab
......@@ -85,6 +85,7 @@ class ShardedImageDataset(data.Dataset):
#make sure these operations are fast!
ids = [t[1] for t in tensors]
tensors = torch.stack([t[0] for t in tensors])
tensors = tensors.permute(0, 3, 1, 2).float()
#####################################
if self.outer_transform:
tensors = self.outer_transform(tensors)
......@@ -95,7 +96,7 @@ class ShardedImageDataset(data.Dataset):
offset, size, id = self.metadata[key]
data = self.mmap[offset:offset+size]
data = decode_jpeg(data)
data = torch.from_numpy(data).permute(2, 0, 1)
data = torch.from_numpy(data)#.permute(2, 0, 1)
if self.inner_transform:
data = self.inner_transform(data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment