Commit 2b10353c authored by novelailab's avatar novelailab

ImageDatasetBuilder

parent 176b21d8
......@@ -5,6 +5,8 @@ import pickle
import concurrent
from torch.utils import data
from simplejpeg import decode_jpeg
import simplejpeg
import pickle
# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
......@@ -58,7 +60,7 @@ class ShardedImageDataset(data.Dataset):
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
with open(self.dataset_path, mode="r", encoding="utf8") as file_obj:
with open(self.dataset_path, mode="r") as file_obj:
self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)
#make so metadata is shardable by world_size(num_gpus)
......@@ -93,4 +95,88 @@ class ShardedImageDataset(data.Dataset):
if self.inner_transform:
data = self.inner_transform(data)
return data
\ No newline at end of file
return data
class ImageDatasetBuilder():
def __init__(self, folder, name, threads=None):
self.folder = folder
self.name = name
self.index_name = self.name + ".index"
self.dataset = None
self.index = None
self.threads = threads
@property
def is_open(self):
self.dataset is not None or self.index is not None
@property
def is_close(self):
self.dataset is None or self.index is None
def build(self):
#be careful with not nuking the files if they exist
if self.is_open:
raise Exception("Dataset already built")
dataset = open(self.dataset_path, mode="ab+")
dataset.flush()
self.index = []
def open(self, overwrite=False):
if overwrite is False and self.is_open:
raise Exception("A dataset is already open! If you wish to continue set overwrite to True.")
if overwrite is True and self.is_open:
self.close()
print("Dataset closed and flushed.")
self.dataset = open(self.dataset_path, mode="ab+")
with open(self.index_name, 'rb') as f:
self.index = pickle.load(f)
def operate(self, operation, data_batch, identities):
executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
futures = executor.map(operation, data_batch)
futures = list(futures)
for data, identity in zip(futures, identities):
self.write(data, identity)
def encode_op(self, data):
if simplejpeg.is_jpeg(data):
pass
else:
data = simplejpeg.encode_jpeg(data, quality=91)
return data
def write(self, data, identity, flush=False):
if self.is_close:
raise Exception("Dataset not built")
self.dataset.write(data)
self.index.append([self.dataset.tell(), len(data), identity])
if flush:
self.dataset.flush()
def flush_index(self):
if self.is_close:
raise Exception("Dataset not built")
with open(self.index_name, 'wb') as f:
pickle.dump(self.index, f)
def flush(self):
if self.is_close:
raise Exception("Dataset not built")
self.dataset.flush()
def close(self):
if self.is_close:
raise Exception("Dataset not built")
#close the dataset filehandle and dump the pickle index
self.flush()
self.dataset.close()
self.flush_index()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment