Commit 0f8be39c authored by novelailab's avatar novelailab

change metadata -> index in ShardedImageDataset

parent cb67abcb
......@@ -45,7 +45,7 @@ class ShardedDataset(data.Dataset):
return (data[:-1], data[1:])
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, metadata_path: str, threads=None, inner_transform=None,
def __init__(self, dataset_path: str, index_path: str, threads=None, inner_transform=None,
outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):
self.skip = skip
......@@ -60,36 +60,36 @@ class ShardedImageDataset(data.Dataset):
self.local_rank = local_rank
self.global_rank = global_rank
self.device = device
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
with open(index_path, 'rb') as f:
self.index = pickle.load(f)
with open(self.dataset_path, mode="r") as file_obj:
self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)
#make so metadata is shardable by world_size(num_gpus)
#and batch_size
self.metadata = self.metadata[:len(self.metadata) - (len(self.metadata) % (bsz * world_size))]
self.index = self.index[:len(self.index) - (len(self.index) % (bsz * world_size))]
#shard the dataset according to the rank
self.metadata = self.metadata[global_rank::world_size]
#override possible gil locks by making the metadata map an nparray
self.metadata = np.array(self.metadata)
self.index = self.index[global_rank::world_size]
#override possible gil locks by making the index map an nparray
self.index = np.array(self.index)
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
def __len__(self):
return len(self.metadata) // self.bsz
return len(self.index) // self.bsz
def __getitem__(self, key):
key = self.skip + key
keys = [*range(key, key+self.bsz)]
tensors = self.executor.map(self.read_from_metadata_key, keys)
tensors = self.executor.map(self.read_from_index_key, keys)
tensors = list(tensors)
#make sure these operations are fast!
ids = [t[1] for t in tensors]
tensors = torch.stack([t[0] for t in tensors])
if self.device == "cuda":
tensors = tensors.to(self.local_rank)
tensors = tensors.permute(0, 3, 1, 2).float()
#####################################
if self.outer_transform:
......@@ -97,8 +97,8 @@ class ShardedImageDataset(data.Dataset):
return tensors, ids
def read_from_metadata_key(self, key):
offset, size, id = self.metadata[key]
def read_from_index_key(self, key):
offset, size, id = self.index[key]
data = self.mmap[offset:offset+size]
data = decode_jpeg(data)
data = torch.from_numpy(data)#.permute(2, 0, 1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment