Commit 210cb3f9 authored by kurumuz's avatar kurumuz

dataset changes

parent ebb032d8
......@@ -51,7 +51,7 @@ class ShardedDataset(data.Dataset):
return (data[:-1], data[1:])
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, name:str, shuffle=False, metadata_path=None, threads=None, inner_transform=None,
def __init__(self, dataset_path: str, name:str, index_path:str=None, shuffle=False, metadata_path=None, threads=None, inner_transform=None,
outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):
self.skip = skip
self.threads = threads
......@@ -61,7 +61,11 @@ class ShardedImageDataset(data.Dataset):
#for batched transforms after images become batchable
self.outer_transform = outer_transform
self.dataset_path = Path(dataset_path)
self.index_path = self.dataset_path / f"{name}.index"
if index_path is None:
self.index_path = self.dataset_path / f"{name}.index"
else:
self.index_pth = Path(index_path)
self.pointer_path = self.dataset_path / f"{name}.pointer"
self.dataset_path = self.dataset_path / f"{name}.ds"
self.world_size = world_size
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment