Commit a0b66a4b authored by novelailab's avatar novelailab

add more stuff

parent 2b10353c
...@@ -7,6 +7,7 @@ from torch.utils import data ...@@ -7,6 +7,7 @@ from torch.utils import data
from simplejpeg import decode_jpeg from simplejpeg import decode_jpeg
import simplejpeg import simplejpeg
import pickle import pickle
from pathlib import Path
# Does this work with other block_sizes? doesn't seem to. # Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset): class FbDataset(data.Dataset):
...@@ -98,10 +99,12 @@ class ShardedImageDataset(data.Dataset): ...@@ -98,10 +99,12 @@ class ShardedImageDataset(data.Dataset):
return data return data
class ImageDatasetBuilder(): class ImageDatasetBuilder():
def __init__(self, folder, name, threads=None): def __init__(self, folder_path, name, threads=None):
self.folder = folder self.folder_path = Path(folder_path)
self.name = name self.dataset_name = name + ".ds"
self.index_name = self.name + ".index" self.index_name = name + ".index"
self.dataset_path = self.folder_path / self.dataset_name
self.index_path = self.folder_path / self.index_name
self.dataset = None self.dataset = None
self.index = None self.index = None
self.threads = threads self.threads = threads
...@@ -119,6 +122,7 @@ class ImageDatasetBuilder(): ...@@ -119,6 +122,7 @@ class ImageDatasetBuilder():
if self.is_open: if self.is_open:
raise Exception("Dataset already built") raise Exception("Dataset already built")
self.folder_path.mkdir(parents=True, exist_ok=True)
dataset = open(self.dataset_path, mode="ab+") dataset = open(self.dataset_path, mode="ab+")
dataset.flush() dataset.flush()
self.index = [] self.index = []
...@@ -131,6 +135,9 @@ class ImageDatasetBuilder(): ...@@ -131,6 +135,9 @@ class ImageDatasetBuilder():
self.close() self.close()
print("Dataset closed and flushed.") print("Dataset closed and flushed.")
if not self.dataset_path.is_file() or not self.index_path.is_file():
raise Exception("Dataset files not found")
self.dataset = open(self.dataset_path, mode="ab+") self.dataset = open(self.dataset_path, mode="ab+")
with open(self.index_name, 'rb') as f: with open(self.index_name, 'rb') as f:
self.index = pickle.load(f) self.index = pickle.load(f)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment