Commit 87e5a1f8 authored by kurumuz's avatar kurumuz

bug fix to the dataset

parent 210cb3f9
Pipeline #17100 failed with stages
......@@ -64,7 +64,7 @@ class ShardedImageDataset(data.Dataset):
if index_path is None:
self.index_path = self.dataset_path / f"{name}.index"
else:
self.index_pth = Path(index_path)
self.index_path = Path(index_path)
self.pointer_path = self.dataset_path / f"{name}.pointer"
self.dataset_path = self.dataset_path / f"{name}.ds"
......@@ -98,11 +98,10 @@ class ShardedImageDataset(data.Dataset):
#make so metadata is shardable by world_size(num_gpus)
#and batch_size
self.original_index = self.index
self.shard(shuffle=shuffle)
#self.shard(shuffle=shuffle)
#override possible gil locks by making the index map an nparray
self.index = np.array(self.index)
self.ids = self.index.transpose(1, 0)[2]
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
def __len__(self):
......@@ -114,12 +113,18 @@ class ShardedImageDataset(data.Dataset):
#set numpy seed
np.random.seed(seed)
#use this function to shuffle every new epoch as well.
self.index = self.original_index
shuffled_indexes = []
if shuffle:
#repeat index n_epoch times
self.index = np.repeat(self.index, epoch, axis=0)
#shuffle the index
self.index = np.random.permutation(self.index)
#shuffle on the epoch boundries
for _ in range(epoch):
#shuffle the index
shuffled = np.random.permutation(self.index)
#append the index
shuffled_indexes.append(shuffled)
#del shuffled
#concatenate the indexes
self.index = np.concatenate(shuffled_indexes)
#del shuffled_indexes
self.index = self.index[:len(self.index) - (len(self.index) % (self.bsz * self.world_size))]
self.index = self.index[self.global_rank::self.world_size]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment