Commit 8056efa1 authored by novelailab's avatar novelailab

make the permute faster by doing it on batch

parent df7a70ab
...@@ -85,6 +85,7 @@ class ShardedImageDataset(data.Dataset): ...@@ -85,6 +85,7 @@ class ShardedImageDataset(data.Dataset):
#make sure these operations are fast! #make sure these operations are fast!
ids = [t[1] for t in tensors] ids = [t[1] for t in tensors]
tensors = torch.stack([t[0] for t in tensors]) tensors = torch.stack([t[0] for t in tensors])
tensors = tensors.permute(0, 3, 1, 2).float()
##################################### #####################################
if self.outer_transform: if self.outer_transform:
tensors = self.outer_transform(tensors) tensors = self.outer_transform(tensors)
...@@ -95,7 +96,7 @@ class ShardedImageDataset(data.Dataset): ...@@ -95,7 +96,7 @@ class ShardedImageDataset(data.Dataset):
offset, size, id = self.metadata[key] offset, size, id = self.metadata[key]
data = self.mmap[offset:offset+size] data = self.mmap[offset:offset+size]
data = decode_jpeg(data) data = decode_jpeg(data)
data = torch.from_numpy(data).permute(2, 0, 1) data = torch.from_numpy(data)#.permute(2, 0, 1)
if self.inner_transform: if self.inner_transform:
data = self.inner_transform(data) data = self.inner_transform(data)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment