Commit 62baf4ad authored by novelailab's avatar novelailab

move dataset code to dataset.py

parent a1d18aaa
import numpy as np
import torch
import mmap
import pickle
import concurrent
from torch.utils import data
from simplejpeg import decode_jpeg
# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
def __init__(self, block_size, map_file, max_samples=None, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
self.samples = self.npz.shape[0]
if max_samples is not None:
self.samples = min(self.samples, int(max_samples))
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
class ShardedDataset(data.Dataset):
def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
#might want to pad later
self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
#shard
self.npz = self.npz[rank::world_size]
self.samples = self.npz.shape[0]
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, metadata_path: str, threads=None, skip=0, bsz=256, world_size=1, rank=0):
self.skip = skip
self.threads = threads
self.bsz = bsz
self.dataset_path = dataset_path
self.world_size = world_size
self.rank = rank
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
with open(self.dataset_path, mode="r", encoding="utf8") as file_obj:
self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)
#make so metadata is shardable by world_size(num_gpus)
#and batch_size
self.metadata = self.metadata[:len(self.metadata) - (len(self.metadata) % (bsz * world_size))]
#shard the dataset according to the rank
self.metadata = self.metadata[rank::world_size]
#override possible gil locks by making the metadata map an nparray
self.metadata = np.array(self.metadata)
#getting the threadpoolexecutor to __init__ instead of __getitem__
#made it 10x faster lol
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
def __len__(self):
return len(self.metadata) // (self.bsz * self.world_size)
def __getitem__(self, key):
key = self.skip + key
keys = [*range(key, key+self.bsz)]
tensors = self.executor.map(self.read_from_metadata_key, keys)
return tensors
def read_from_metadata_key(self, key):
offset, size, d_id = self.metadata[key]
data = self.mmap[offset:offset+size]
data = decode_jpeg(data)
data = torch.from_numpy(data).permute(2, 0, 1)
return data
\ No newline at end of file
......@@ -6,96 +6,10 @@ except ImportError:
from pathlib import Path
import os
import math
from torch.utils import data
import numpy as np
import torch
from tqdm import tqdm
import time
from simplejpeg import decode_jpeg
import mmap
from timeit import default_timer as timer
import pickle
import concurrent
from itertools import repeat
# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
def __init__(self, block_size, map_file, max_samples=None, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
self.samples = self.npz.shape[0]
if max_samples is not None:
self.samples = min(self.samples, int(max_samples))
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
class ShardedDataset(data.Dataset):
def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
#might want to pad later
self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
#shard
self.npz = self.npz[rank::world_size]
self.samples = self.npz.shape[0]
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, metadata_path: str, threads=None, skip=0, bsz=256, world_size=1, rank=0):
self.skip = skip
self.threads = threads
self.bsz = bsz
self.dataset_path = dataset_path
self.world_size = world_size
self.rank = rank
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
with open(self.dataset_path, mode="r", encoding="utf8") as file_obj:
self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)
#make so metadata is shardable by world_size(num_gpus)
#and batch_size
self.metadata = self.metadata[:len(self.metadata) - (len(self.metadata) % (bsz * world_size))]
#shard the dataset according to the rank
self.metadata = self.metadata[rank::world_size]
#override possible gil locks by making the metadata map an nparray
self.metadata = np.array(self.metadata)
#getting the threadpoolexecutor to __init__ instead of __getitem__
#made it 10x faster lol
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
def __len__(self):
return len(self.metadata) // (self.bsz * self.world_size)
def __getitem__(self, key):
key = self.skip + key
keys = [*range(key, key+self.bsz)]
tensors = self.executor.map(self.read_from_metadata_key, keys)
return tensors
def read_from_metadata_key(self, key):
offset, size, d_id = self.metadata[key]
data = self.mmap[offset:offset+size]
data = decode_jpeg(data)
data = torch.from_numpy(data).permute(2, 0, 1)
return data
# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment