Commit e00841d8 authored by Eren Doğan's avatar Eren Doğan Committed by GitHub

10x faster

parent 63f92eba
...@@ -55,9 +55,9 @@ class ShardedDataset(data.Dataset): ...@@ -55,9 +55,9 @@ class ShardedDataset(data.Dataset):
return (data[:-1], data[1:]) return (data[:-1], data[1:])
class ShardedImageDataset(data.Dataset): class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, metadata_path: str, skip=0, bsz=256, world_size=1, rank=0): def __init__(self, dataset_path: str, metadata_path: str, threads=None, skip=0, bsz=256, world_size=1, rank=0):
self.skip = skip # not used for now self.skip = skip
self.threads = 16 # it seems 16 is the ideal thread count for this machine self.threads = threads
self.bsz = bsz self.bsz = bsz
self.dataset_path = dataset_path self.dataset_path = dataset_path
self.world_size = world_size self.world_size = world_size
...@@ -74,18 +74,19 @@ class ShardedImageDataset(data.Dataset): ...@@ -74,18 +74,19 @@ class ShardedImageDataset(data.Dataset):
#shard the dataset according to the rank #shard the dataset according to the rank
self.metadata = self.metadata[rank::world_size] self.metadata = self.metadata[rank::world_size]
self.samples = len(self.metadata) #override possible gil locks by making the metadata map an nparray
self.metadata = np.array(self.metadata)
#getting the threadpoolexecutor to __init__ instead of __getitem__
#made it 10x faster lol
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
def __len__(self): def __len__(self):
return self.samples // (self.bsz * self.world_size) return len(self.metadata) // (self.bsz * self.world_size)
def __getitem__(self, key): def __getitem__(self, key):
key = self.skip + key key = self.skip + key
keys = [*range(key, key+self.bsz)] keys = [*range(key, key+self.bsz)]
# We can use a with statement to ensure threads are cleaned up promptly tensors = self.executor.map(self.read_from_metadata_key, keys)
with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
tensors = list(executor.map(self.read_from_metadata_key, keys))
return tensors return tensors
def read_from_metadata_key(self, key): def read_from_metadata_key(self, key):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment