Commit cb67abcb authored by novelailab's avatar novelailab

make so can transfer to cuda after reduced

parent 5c846daf
......@@ -46,7 +46,7 @@ class ShardedDataset(data.Dataset):
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, metadata_path: str, threads=None, inner_transform=None,
outer_transform=None, skip=0, bsz=256, world_size=1, rank=0):
outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):
self.skip = skip
self.threads = threads
......@@ -57,7 +57,9 @@ class ShardedImageDataset(data.Dataset):
self.outer_transform = outer_transform
self.dataset_path = dataset_path
self.world_size = world_size
self.rank = rank
self.local_rank = local_rank
self.global_rank = global_rank
self.device = device
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
......@@ -69,7 +71,7 @@ class ShardedImageDataset(data.Dataset):
self.metadata = self.metadata[:len(self.metadata) - (len(self.metadata) % (bsz * world_size))]
#shard the dataset according to the rank
self.metadata = self.metadata[rank::world_size]
self.metadata = self.metadata[global_rank::world_size]
#override possible gil locks by making the metadata map an nparray
self.metadata = np.array(self.metadata)
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
......@@ -85,6 +87,9 @@ class ShardedImageDataset(data.Dataset):
#make sure these operations are fast!
ids = [t[1] for t in tensors]
tensors = torch.stack([t[0] for t in tensors])
if self.device == "cuda":
tensors = tensors.to(self.local_rank)
tensors = tensors.permute(0, 3, 1, 2).float()
#####################################
if self.outer_transform:
......@@ -246,7 +251,7 @@ class ImageDatasetBuilder():
if not self.dataset and not silent:
print("Warning: Dataset not built, couldn't flush")
return
#close the dataset filehandle and dump the pickle index
self.flush()
self.dataset.close()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment