Commit a5d9beec authored by novelailab's avatar novelailab

remove mmaptest.py

parent f79775ee
import mmap
import concurrent
import pickle
from timeit import default_timer as timer
from itertools import repeat
import numpy as np
import torch
import torchvision.transforms as transforms
from simplejpeg import decode_jpeg
from tqdm import tqdm
dataset_path = "/home/xuser/hugessd/danbooru/danbooru.fia"
metadata_path = "/home/xuser/diffusionstorage/danbooru_meta_fast.pkl"
d_id_ptr_path = "/home/xuser/diffusionstorage/danbooru_db.pkl"
class FIAReader():
def __init__(self, dataset_path: str, metadata_path: str, transform=None, local_transform=None,
skip=0, batch_size=8500, image_cnt=100000):
self.skip = skip # not used for now
self.threads = 16 # it seems 16 is the ideal thread count for this machine
self.image_cnt = image_cnt # The image count to be read at each run of FIAReader[x]
self.batch_size = batch_size
self.transform = transform
self.local_transform = local_transform
self.dataset_path = dataset_path
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
def __len__(self):
return len(self.metadata)
def __getitem__(self, key):
# Currently, we're just iterating over the dataset, decoding each JPEGs into a tensor, and doing nothing with a tensor
# this code is currently only used for benchmarks. See the tensors object declaration below
start_time = timer()
keys = [*range(key, key+self.image_cnt)]
for i in tqdm(range(self.image_cnt // self.batch_size)):
start_val = self.metadata[key + (i * self.batch_size)]
end_val = self.metadata[key + ((i + 1) * self.batch_size)]
start_ptr = start_val[0]
end_ptr = end_val[0] + end_val[1]
# At this part, we're reading the file using mmap for all pictures at the current batch
with open(self.dataset_path, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
mmap_obj.seek(start_ptr)
curr_mmap = mmap_obj.read(end_ptr - start_ptr)
# We can use a with statement to ensure threads are cleaned up promptly
with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
# tensors object is not saved to anywhere due to memory constaints.
tensors = list(executor.map(self.read_from_metadata_key, repeat(curr_mmap), repeat(start_ptr), keys[i*self.batch_size:(i+1)*self.batch_size - 1]))
mmap_obj.close()
end_time = timer()
print('image reading time: ', end_time - start_time)
# The code below the return expression has not been tested yet
return
if self.local_transform:
globo1_list = []
globo2_list = []
local_list = []
for i, t in enumerate(tensors):
globo1, globo2, local = self.local_transform(t.cuda())
globo1_list.append(globo1)
globo2_list.append(globo2)
local_list.append(local)
globo1 = torch.stack(globo1_list).cuda()
globo2 = torch.stack(globo2_list).cuda()
local = torch.cat(local_list, dim=0).cuda()
if self.transform:
globo1, globo2, local = self.transform(globo1, globo2, local)
imagelist = []
imagelist.append(globo1)
imagelist.append(globo2)
imagelist = [*imagelist, *local.split(self.image_cnt)]
return imagelist
def read_from_metadata_key(self, dataset_mmap, start_ptr, key):
val = self.metadata[key]
data = dataset_mmap[val[0]-start_ptr: val[0]+val[1]-start_ptr]
#data = torch.frombuffer(data, dtype=torch.uint8)
#data = torchvision_decode_jpeg(data, device="cpu")
#data = np.frombuffer(data, dtype=np.uint8)
data = decode_jpeg(data)
data = torch.from_numpy(data).permute(2, 0, 1)
return data
if __name__ == "__main__":
reader = FIAReader(dataset_path, metadata_path)
#edge case tests
reader[0]
reader[len(reader) - 1 - 100000]
print(len(reader))
print("success!")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment