Commit ec79f2ea authored by Eren Doğan's avatar Eren Doğan Committed by GitHub

Rewrite, around %70 faster than FIAReader thx to mmap

parent a5d9beec
......@@ -54,6 +54,47 @@ class ShardedDataset(data.Dataset):
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, metadata_path: str, skip=0, bsz=256, world_size=1, rank=0):
self.skip = skip # not used for now
self.threads = 16 # it seems 16 is the ideal thread count for this machine
self.bsz = bsz
self.dataset_path = dataset_path
self.world_size = world_size
self.rank = rank
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
with open(self.dataset_path, mode="r", encoding="utf8") as file_obj:
self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)
#make so metadata is shardable by world_size(num_gpus)
#and batch_size
self.metadata = self.metadata[:len(self.metadata) - (len(self.metadata) % (bsz * world_size))]
#shard the dataset according to the rank
self.metadata = self.metadata[rank::world_size]
self.samples = len(self.metadata)
def __len__(self):
return self.samples // (self.bsz * self.world_size)
def __getitem__(self, key):
key = self.skip + key
keys = [*range(key, key+self.bsz)]
# We can use a with statement to ensure threads are cleaned up promptly
with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
tensors = list(executor.map(self.read_from_metadata_key, keys))
return tensors
def read_from_metadata_key(self, key):
offset, size, d_id = self.metadata[key]
data = self.mmap[offset:offset+size]
data = decode_jpeg(data)
data = torch.from_numpy(data).permute(2, 0, 1)
return data
# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))
......@@ -178,78 +219,4 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
def gelu_new(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class FIAReader():
def __init__(self, dataset_path: str, metadata_path: str, transform=None, local_transform=None,
skip=0, batch_size=8500, image_cnt=100000):
self.skip = skip # not used for now
self.threads = 16 # it seems 16 is the ideal thread count for this machine
self.image_cnt = image_cnt # The image count to be read at each run of FIAReader[x]
self.batch_size = batch_size
self.transform = transform
self.local_transform = local_transform
self.dataset_path = dataset_path
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
def __len__(self):
return len(self.metadata)
def __getitem__(self, key):
# Currently, we're just iterating over the dataset, decoding each JPEGs into a tensor, and doing nothing with a tensor
# this code is currently only used for benchmarks. See the tensors object declaration below
start_time = timer()
keys = [*range(key, key+self.image_cnt)]
for i in tqdm(range(self.image_cnt // self.batch_size)):
start_val = self.metadata[key + (i * self.batch_size)]
end_val = self.metadata[key + ((i + 1) * self.batch_size)]
start_ptr = start_val[0]
end_ptr = end_val[0] + end_val[1]
# At this part, we're reading the file using mmap for all pictures at the current batch
with open(self.dataset_path, mode="r", encoding="utf8") as file_obj:
with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
mmap_obj.seek(start_ptr)
curr_mmap = mmap_obj.read(end_ptr - start_ptr)
# We can use a with statement to ensure threads are cleaned up promptly
with concurrent.futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
# tensors object is not saved to anywhere due to memory constaints.
tensors = list(executor.map(self.read_from_metadata_key, repeat(curr_mmap), repeat(start_ptr), keys[i*self.batch_size:(i+1)*self.batch_size - 1]))
mmap_obj.close()
end_time = timer()
print('image reading time: ', end_time - start_time)
# The code below the return expression has not been tested yet
return
if self.local_transform:
globo1_list = []
globo2_list = []
local_list = []
for i, t in enumerate(tensors):
globo1, globo2, local = self.local_transform(t.cuda())
globo1_list.append(globo1)
globo2_list.append(globo2)
local_list.append(local)
globo1 = torch.stack(globo1_list).cuda()
globo2 = torch.stack(globo2_list).cuda()
local = torch.cat(local_list, dim=0).cuda()
if self.transform:
globo1, globo2, local = self.transform(globo1, globo2, local)
imagelist = []
imagelist.append(globo1)
imagelist.append(globo2)
imagelist = [*imagelist, *local.split(self.image_cnt)]
return imagelist
def read_from_metadata_key(self, dataset_mmap, start_ptr, key):
val = self.metadata[key]
data = dataset_mmap[val[0]-start_ptr: val[0]+val[1]-start_ptr]
#data = torch.frombuffer(data, dtype=torch.uint8)
#data = torchvision_decode_jpeg(data, device="cpu")
#data = np.frombuffer(data, dtype=np.uint8)
data = decode_jpeg(data)
data = torch.from_numpy(data).permute(2, 0, 1)
return data
\ No newline at end of file
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment