Commit 5c846daf authored by novelailab's avatar novelailab

add metadata and better error handling

parent 1a66f013
...@@ -103,19 +103,25 @@ class ShardedImageDataset(data.Dataset): ...@@ -103,19 +103,25 @@ class ShardedImageDataset(data.Dataset):
return data, id return data, id
class ImageDatasetBuilder(): class ImageDatasetBuilder():
def __init__(self, folder_path, name, threads=None): def __init__(self, folder_path, name, dataset=True, index=True, metadata=True, threads=None):
self.folder_path = Path(folder_path) self.folder_path = Path(folder_path)
self.dataset_name = name + ".ds" self.dataset_name = name + ".ds"
self.index_name = name + ".index" self.index_name = name + ".index"
self.metadata_name = name + ".metadata"
self.dataset_path = self.folder_path / self.dataset_name self.dataset_path = self.folder_path / self.dataset_name
self.index_path = self.folder_path / self.index_name self.index_path = self.folder_path / self.index_name
self.metadata_path = self.folder_path / self.metadata_name
self.open_dataset = dataset
self.open_index = index
self.open_metadata = metadata
self.dataset = None self.dataset = None
self.index = None self.index = None
self.metadata = None
self.threads = threads self.threads = threads
@property @property
def is_open(self): def is_open(self):
self.dataset is not None or self.index is not None self.dataset is not None or self.index is not None or self.metadata is not None
@property @property
def is_close(self): def is_close(self):
...@@ -126,15 +132,15 @@ class ImageDatasetBuilder(): ...@@ -126,15 +132,15 @@ class ImageDatasetBuilder():
if self.index is None: if self.index is None:
return -1 return -1
else: else:
return np.max(self.index[:, 2]) return np.max(self.np_index[:, 2])
@property @property
def biggest_item(self): def biggest_item(self):
if self.index is None: if self.index is None:
return -1 return -1
else: else:
return np.max(self.index[:, 1]) return np.max(self.np_index[:, 1])
@property @property
def np_index(self): def np_index(self):
return np.array(self.index) return np.array(self.index)
...@@ -145,26 +151,44 @@ class ImageDatasetBuilder(): ...@@ -145,26 +151,44 @@ class ImageDatasetBuilder():
raise Exception("Dataset already built") raise Exception("Dataset already built")
self.folder_path.mkdir(parents=True, exist_ok=True) self.folder_path.mkdir(parents=True, exist_ok=True)
dataset = open(self.dataset_path, mode="ab+") if self.open_dataset:
dataset.flush() dataset = open(self.dataset_path, mode="ab+")
self.index = [] dataset.flush()
if self.open_index:
self.index = []
if self.open_metadata:
self.metadata = {}
def open(self, overwrite=False): def open(self, overwrite=False):
if overwrite is False and self.is_open: if overwrite is False and self.is_open:
raise Exception("A dataset is already open! If you wish to continue set overwrite to True.") raise Exception("A dataset is already open! If you wish to continue set overwrite to True.")
if overwrite is True and self.is_open: if overwrite is True:
self.close() self.close(silent=True)
self.flush_index(silent=True)
self.flush_metadata(silent=True)
print("Dataset closed and flushed.") print("Dataset closed and flushed.")
if self.open_dataset and self.dataset_path.is_file():
self.dataset = open(self.dataset_path, mode="ab+")
else:
raise Exception("Dataset file not found at {}".format(self.dataset_path))
if self.open_index and self.index_path.is_file():
with open(self.index_path, 'rb') as f:
self.index = pickle.load(f)
else:
raise Exception("Index file not found at {}".format(self.index_path))
if self.open_metadata and self.metadata_path.is_file():
with open(self.metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
else:
raise Exception("Metadata file not found at {}".format(self.metadata_path))
if not self.dataset_path.is_file() or not self.index_path.is_file(): def operate(self, operation, data_batch, identities, metadata=None):
raise Exception("Dataset files not found")
self.dataset = open(self.dataset_path, mode="ab+")
with open(self.index_name, 'rb') as f:
self.index = pickle.load(f)
def operate(self, operation, data_batch, identities):
executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
futures = executor.map(operation, data_batch) futures = executor.map(operation, data_batch)
futures = list(futures) futures = list(futures)
...@@ -180,32 +204,49 @@ class ImageDatasetBuilder(): ...@@ -180,32 +204,49 @@ class ImageDatasetBuilder():
return data return data
def write(self, data, identity, flush=False): def write(self, data, identity, metadata=None, flush=False):
if self.is_close: if self.is_close:
raise Exception("Dataset not built") raise Exception("Dataset not built")
self.dataset.write(data) self.dataset.write(data)
self.index.append([self.dataset.tell(), len(data), identity]) self.index.append([self.dataset.tell(), len(data), identity])
if self.metadata and metadata:
self.metadata[identity] = metadata
if flush: if flush:
self.dataset.flush() self.flush()
def flush_index(self): def write_metadata(self, id, metadata):
if self.is_close: self.metadata[id] = metadata
raise Exception("Dataset not built")
with open(self.index_name, 'wb') as f: def flush_index(self, silent=False):
if not self.index and not silent:
print("Warning: Index not built, couldn't flush")
return
with open(self.index_path, 'wb') as f:
pickle.dump(self.index, f) pickle.dump(self.index, f)
def flush_metadata(self, silent=False):
if not self.metadata and not silent:
print("Warning: Metadata not built, couldn't flush")
return
def flush(self): with open(self.metadata_path, 'wb') as f:
if self.is_close: pickle.dump(self.metadata, f)
raise Exception("Dataset not built")
def flush(self, silent=False):
if not self.dataset and not silent:
print("Warning: Dataset not built, couldn't flush")
return
self.dataset.flush() self.dataset.flush()
def close(self): def close(self, silent=False):
if self.is_close: if not self.dataset and not silent:
raise Exception("Dataset not built") print("Warning: Dataset not built, couldn't flush")
return
#close the dataset filehandle and dump the pickle index #close the dataset filehandle and dump the pickle index
self.flush() self.flush()
self.dataset.close() self.dataset.close()
self.flush_index() \ No newline at end of file
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment