Commit 176b21d8 authored by Eren Doğan's avatar Eren Doğan Committed by GitHub

add transform support and stacking

parent 7f9ceee2
......@@ -42,10 +42,16 @@ class ShardedDataset(data.Dataset):
return (data[:-1], data[1:])
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, metadata_path: str, threads=None, skip=0, bsz=256, world_size=1, rank=0):
def __init__(self, dataset_path: str, metadata_path: str, threads=None, inner_transform=None,
outer_transform=None, skip=0, bsz=256, world_size=1, rank=0):
self.skip = skip
self.threads = threads
self.bsz = bsz
#for one by one transforms because images can't be batched
self.inner_transform = inner_transform
#for batched transforms after images become batchable
self.outer_transform = outer_transform
self.dataset_path = dataset_path
self.world_size = world_size
self.rank = rank
......@@ -63,8 +69,6 @@ class ShardedImageDataset(data.Dataset):
self.metadata = self.metadata[rank::world_size]
#override possible gil locks by making the metadata map an nparray
self.metadata = np.array(self.metadata)
#getting the threadpoolexecutor to __init__ instead of __getitem__
#made it 10x faster lol
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
def __len__(self):
......@@ -75,6 +79,10 @@ class ShardedImageDataset(data.Dataset):
keys = [*range(key, key+self.bsz)]
tensors = self.executor.map(self.read_from_metadata_key, keys)
tensors = list(tensors)
tensors = torch.stack(tensors)
if self.outer_transform:
tensors = self.outer_transform(tensors)
return tensors
def read_from_metadata_key(self, key):
......@@ -82,4 +90,7 @@ class ShardedImageDataset(data.Dataset):
data = self.mmap[offset:offset+size]
data = decode_jpeg(data)
data = torch.from_numpy(data).permute(2, 0, 1)
if self.inner_transform:
data = self.inner_transform(data)
return data
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment