Commit ae114e18 authored by novelailab's avatar novelailab

get_metadata, from id

parent dff538a0
......@@ -45,7 +45,7 @@ class ShardedDataset(data.Dataset):
return (data[:-1], data[1:])
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, index_path: str, threads=None, inner_transform=None,
def __init__(self, dataset_path: str, index_path: str, metadata_path=None, threads=None, inner_transform=None,
outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):
self.skip = skip
......@@ -63,6 +63,10 @@ class ShardedImageDataset(data.Dataset):
with open(index_path, 'rb') as f:
self.index = pickle.load(f)
if metadata_path:
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
with open(self.dataset_path, mode="r") as file_obj:
self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)
......@@ -115,6 +119,9 @@ class ShardedImageDataset(data.Dataset):
data = decode_jpeg(data)
return data
def get_metadata(self, id):
return self.metadata[id]
class ImageDatasetBuilder():
def __init__(self, folder_path, name, dataset=True, index=True, metadata=True, threads=None):
self.folder_path = Path(folder_path)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment