Commit 1f2c39ba authored by kurumuz's avatar kurumuz

update

parent f7c8c4fe
......@@ -11,6 +11,9 @@ import requests
import hashlib
import io
import os
from simplejpeg import decode_jpeg
import simplejpeg
from PIL import Image
# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
......@@ -48,12 +51,8 @@ class ShardedDataset(data.Dataset):
return (data[:-1], data[1:])
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, index_path: str, metadata_path=None, threads=None, inner_transform=None,
def __init__(self, dataset_path: str, name:str, shuffle=False, metadata_path=None, threads=None, inner_transform=None,
outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):
from simplejpeg import decode_jpeg
import simplejpeg
from PIL import Image
self.skip = skip
self.threads = threads
self.bsz = bsz
......@@ -61,12 +60,15 @@ class ShardedImageDataset(data.Dataset):
self.inner_transform = inner_transform
#for batched transforms after images become batchable
self.outer_transform = outer_transform
self.dataset_path = dataset_path
self.dataset_path = Path(dataset_path)
self.index_path = self.dataset_path / f"{name}.index"
self.pointer_path = self.dataset_path / f"{name}.pointer"
self.dataset_path = self.dataset_path / f"{name}.ds"
self.world_size = world_size
self.local_rank = local_rank
self.global_rank = global_rank
self.device = device
with open(index_path, 'rb') as f:
with open(self.index_path, 'rb') as f:
self.index = pickle.load(f)
if metadata_path:
......@@ -77,17 +79,22 @@ class ShardedImageDataset(data.Dataset):
self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)
#precompute pointer lookup dict for faster random read
self.pointer_lookup = {}
for t in self.index:
offset, length, id = t
self.pointer_lookup[id] = (offset, length)
if not self.pointer_path.is_file():
self.pointer_lookup = {}
for t in tqdm(self.index):
offset, length, id = t
self.pointer_lookup[id] = (offset, length)
with open(self.pointer_path, 'wb') as f:
pickle.dump(self.pointer_lookup, f)
else:
with open(self.pointer_path, 'rb') as f:
self.pointer_lookup = pickle.load(f)
#make so metadata is shardable by world_size(num_gpus)
#and batch_size
self.index = self.index[:len(self.index) - (len(self.index) % (bsz * world_size))]
#shard the dataset according to the rank
self.index = self.index[global_rank::world_size]
self.original_index = self.index
self.index = self.shard(shuffle=shuffle)
#override possible gil locks by making the index map an nparray
self.index = np.array(self.index)
self.ids = self.index.transpose(1, 0)[2]
......@@ -95,6 +102,16 @@ class ShardedImageDataset(data.Dataset):
def __len__(self):
return len(self.index) // self.bsz
def shard(self, shuffle=False):
#use this function to shuffle every new epoch as well.
self.index = self.original_index
if shuffle:
self.index = np.random.permutation(self.index)
self.index = self.index[:len(self.index) - (len(self.index) % (self.bsz * self.world_size))]
self.index = self.index[self.global_rank::self.world_size]
return self.index
def __getitem__(self, key):
key = self.skip + key
......@@ -107,7 +124,8 @@ class ShardedImageDataset(data.Dataset):
if self.device == "cuda":
tensors = tensors.to(self.local_rank)
tensors = tensors.permute(0, 3, 1, 2).float()
tensors = tensors.float()#permute#(0, 3, 1, 2).float() / 255.0
tensors = tensors / 127.5 - 1
#####################################
if self.outer_transform:
tensors = self.outer_transform(tensors)
......@@ -118,10 +136,10 @@ class ShardedImageDataset(data.Dataset):
offset, size, id = self.index[key]
data = self.mmap[offset:offset+size]
data = decode_jpeg(data)
data = torch.from_numpy(data)#.permute(2, 0, 1)
if self.inner_transform:
data = self.inner_transform(data)
data = torch.from_numpy(data)#.permute(2, 0, 1)
return data, id
def read_from_id(self, id, decode=True):
......@@ -134,6 +152,96 @@ class ShardedImageDataset(data.Dataset):
def get_metadata(self, id):
return self.metadata[id]
class CPUTransforms():
def __init__(self, threads=None):
self.threads=None
@staticmethod
def scale(data, res, pil=True):
#scale can be an int or a tuple(h, w)
#if it's int preserve aspect ratio
#use opencv2
#data.shape = (h, w, c)
h, w = data.shape[:2]
#w, h = data.size
if isinstance(res, int):
if h > w:
#get the scale needed to make the width match the target
scale = res / w
hw = (res, int(h*scale))
elif h == w:
hw = (res, res)
else:
#get the scale needed to make the height match the target
scale = res / h
hw = (int(w*scale), res)
if pil:
data = Image.fromarray(data)
data = data.resize(hw, Image.LANCZOS)
data = np.asarray(data)
else:
data = cv2.resize(data, hw, interpolation=cv2.INTER_AREA)
return data
@staticmethod
def centercrop(data, res: int):
h_offset = (data.shape[0] - res) // 2
w_offset = (data.shape[1] - res) // 2
data = data[h_offset:h_offset+res, w_offset:w_offset+res]
return data
@staticmethod
def cast_to_rgb(data, pil=False):
if len(data.shape) < 3:
data = np.expand_dims(data, axis=2)
data = np.repeat(data, 3, axis=2)
return data
if data.shape[2] == 1:
data = np.repeat(data, 3, axis=2)
return data
if data.shape[2] == 3:
return data
if data.shape[2] == 4:
#Alpha blending, remove alpha channel and blend in with white
png = Image.fromarray(data) # ->Fails here because image is uint16??
background = Image.new('RGBA', png.size, (255,255,255))
alpha_composite = Image.alpha_composite(background, png)
data = np.asarray(alpha_composite)
'''
data = data.astype(np.float32)
data = data / 255.0
alpha = data[:,:,[3,3,3]]
data = data[:,:,:3]
ones = np.ones_like(data)
data = (data * alpha) + (ones * (1-alpha))
data = data * 255.0
data = np.clip(data, 0, 255)
data = data.astype(np.uint8)
'''
return data
else:
return data
@staticmethod
def randomcrop(data, res):
h, w = data.shape[:2]
if h - res > 0:
h_offset = np.random.randint(0, h - res)
else:
h_offset = 0
if w - res > 0:
w_offset = np.random.randint(0, w - res)
else:
w_offset = 0
data = data[h_offset:h_offset+res, w_offset:w_offset+res]
return data
class ImageDatasetBuilder():
def __init__(self, folder_path, name, dataset=True, index=True, metadata=False, threads=None, block_size=4096, align_fs_blocks=True):
self.folder_path = Path(folder_path)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment