Commit ebb032d8 authored by kurumuz's avatar kurumuz

sharding and epoch support with global shuffle

parent 1f2c39ba
......@@ -94,24 +94,33 @@ class ShardedImageDataset(data.Dataset):
#make so metadata is shardable by world_size(num_gpus)
#and batch_size
self.original_index = self.index
self.index = self.shard(shuffle=shuffle)
self.shard(shuffle=shuffle)
#override possible gil locks by making the index map an nparray
self.index = np.array(self.index)
self.ids = self.index.transpose(1, 0)[2]
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
def __len__(self):
return len(self.index) // self.bsz
def shard(self, shuffle=False):
def shard(self, shuffle=False, epoch=1, seed=69):
#get numpy random state
state = np.random.get_state()
#set numpy seed
np.random.seed(seed)
#use this function to shuffle every new epoch as well.
self.index = self.original_index
if shuffle:
#repeat index n_epoch times
self.index = np.repeat(self.index, epoch, axis=0)
#shuffle the index
self.index = np.random.permutation(self.index)
self.index = self.index[:len(self.index) - (len(self.index) % (self.bsz * self.world_size))]
self.index = self.index[self.global_rank::self.world_size]
return self.index
#reset numpy random state
np.random.set_state(state)
def __getitem__(self, key):
key = self.skip + key
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment