Commit e00841d8 authored by Eren Doğan's avatar Eren Doğan Committed by GitHub

10x faster

parent 63f92eba
......@@ -55,9 +55,9 @@ class ShardedDataset(data.Dataset):
return (data[:-1], data[1:])
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, metadata_path: str, skip=0, bsz=256, world_size=1, rank=0):
self.skip = skip # not used for now
self.threads = 16 # it seems 16 is the ideal thread count for this machine
def __init__(self, dataset_path: str, metadata_path: str, threads=None, skip=0, bsz=256, world_size=1, rank=0):
self.skip = skip
self.threads = threads
self.bsz = bsz
self.dataset_path = dataset_path
self.world_size = world_size
......@@ -74,18 +74,19 @@ class ShardedImageDataset(data.Dataset):
#shard the dataset according to the rank
self.metadata = self.metadata[rank::world_size]
self.samples = len(self.metadata)
#override possible gil locks by making the metadata map an nparray
self.metadata = np.array(self.metadata)
#getting the threadpoolexecutor to __init__ instead of __getitem__
#made it 10x faster lol
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
def __len__(self):
return self.samples // (self.bsz * self.world_size)
return len(self.metadata) // (self.bsz * self.world_size)
def __getitem__(self, key):
key = self.skip + key
keys = [*range(key, key+self.bsz)]
# We can use a with statement to ensure threads are cleaned up promptly
with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
tensors = list(executor.map(self.read_from_metadata_key, keys))
tensors = self.executor.map(self.read_from_metadata_key, keys)
return tensors
def read_from_metadata_key(self, key):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment