Commit cf0224e4 authored by Eren Doğan's avatar Eren Doğan Committed by GitHub

Merge pull request #7 from NovelAI/fiareader

mmap based FIAReader
parents b47ef0e9 da4bc5ad
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
import numpy as np
import torch
import mmap
import pickle
import concurrent
from torch.utils import data
from simplejpeg import decode_jpeg
import simplejpeg
import pickle
from pathlib import Path
# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
def __init__(self, block_size, map_file, max_samples=None, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
self.samples = self.npz.shape[0]
if max_samples is not None:
self.samples = min(self.samples, int(max_samples))
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
class ShardedDataset(data.Dataset):
def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
#might want to pad later
self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
#shard
self.npz = self.npz[rank::world_size]
self.samples = self.npz.shape[0]
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, index_path: str, metadata_path=None, threads=None, inner_transform=None,
outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):
self.skip = skip
self.threads = threads
self.bsz = bsz
#for one by one transforms because images can't be batched
self.inner_transform = inner_transform
#for batched transforms after images become batchable
self.outer_transform = outer_transform
self.dataset_path = dataset_path
self.world_size = world_size
self.local_rank = local_rank
self.global_rank = global_rank
self.device = device
with open(index_path, 'rb') as f:
self.index = pickle.load(f)
if metadata_path:
with open(metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
with open(self.dataset_path, mode="r") as file_obj:
self.mmap = mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ)
#make so metadata is shardable by world_size(num_gpus)
#and batch_size
self.index = self.index[:len(self.index) - (len(self.index) % (bsz * world_size))]
#shard the dataset according to the rank
self.index = self.index[global_rank::world_size]
#override possible gil locks by making the index map an nparray
self.index = np.array(self.index)
self.ids = self.index.transpose(1, 0)[2]
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
def __len__(self):
return len(self.index) // self.bsz
def __getitem__(self, key):
key = self.skip + key
keys = [*range(key, key+self.bsz)]
tensors = self.executor.map(self.read_from_index_key, keys)
tensors = list(tensors)
#make sure these operations are fast!
ids = [t[1] for t in tensors]
tensors = torch.stack([t[0] for t in tensors])
if self.device == "cuda":
tensors = tensors.to(self.local_rank)
tensors = tensors.permute(0, 3, 1, 2).float()
#####################################
if self.outer_transform:
tensors = self.outer_transform(tensors)
return tensors, ids
def read_from_index_key(self, key):
offset, size, id = self.index[key]
data = self.mmap[offset:offset+size]
data = decode_jpeg(data)
data = torch.from_numpy(data)#.permute(2, 0, 1)
if self.inner_transform:
data = self.inner_transform(data)
return data, id
def read_from_id(self, id):
#to be used standalone
offset, size, _ = self.index[self.ids == id][0]
data = self.mmap[offset:offset+size]
data = decode_jpeg(data)
return data
def get_metadata(self, id):
return self.metadata[id]
class ImageDatasetBuilder():
def __init__(self, folder_path, name, dataset=True, index=True, metadata=False, threads=None):
self.folder_path = Path(folder_path)
self.dataset_name = name + ".ds"
self.index_name = name + ".index"
self.metadata_name = name + ".metadata"
self.dataset_path = self.folder_path / self.dataset_name
self.index_path = self.folder_path / self.index_name
self.metadata_path = self.folder_path / self.metadata_name
self.open_dataset = dataset
self.open_index = index
self.open_metadata = metadata
self.dataset = None
self.index = None
self.metadata = None
self.threads = threads
@property
def is_open(self):
self.dataset is not None or self.index is not None or self.metadata is not None
@property
def is_close(self):
self.dataset is None or self.index is None
@property
def biggest_id(self):
if self.index is None:
return -1
else:
return np.max(self.np_index[:, 2])
@property
def biggest_item(self):
if self.index is None:
return -1
else:
return np.max(self.np_index[:, 1])
@property
def np_index(self):
return np.array(self.index)
def build(self):
#be careful with not nuking the files if they exist
if self.is_open:
raise Exception("Dataset already built")
self.folder_path.mkdir(parents=True, exist_ok=True)
if self.open_dataset:
dataset = open(self.dataset_path, mode="ab+")
dataset.flush()
if self.open_index:
self.index = []
if self.open_metadata:
self.metadata = {}
def open(self, overwrite=False):
if overwrite is False and self.is_open:
raise Exception("A dataset is already open! If you wish to continue set overwrite to True.")
if overwrite is True:
self.close(silent=True)
self.flush_index(silent=True)
self.flush_metadata(silent=True)
print("Dataset closed and flushed.")
if self.open_dataset and self.dataset_path.is_file():
self.dataset = open(self.dataset_path, mode="ab+")
else:
raise Exception("Dataset file not found at {}".format(self.dataset_path))
if self.open_index and self.index_path.is_file():
with open(self.index_path, 'rb') as f:
self.index = pickle.load(f)
else:
raise Exception("Index file not found at {}".format(self.index_path))
if self.open_metadata and self.metadata_path.is_file():
with open(self.metadata_path, 'rb') as f:
self.metadata = pickle.load(f)
else:
raise Exception("Metadata file not found at {}".format(self.metadata_path))
def operate(self, operation, data_batch, identities, metadata=None):
executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.threads)
futures = executor.map(operation, data_batch)
futures = list(futures)
for data, identity in zip(futures, identities):
self.write(data, identity)
def encode_op(self, data):
if simplejpeg.is_jpeg(data):
pass
else:
data = simplejpeg.encode_jpeg(data, quality=91)
return data
def write(self, data, identity, metadata=None, flush=False):
if self.is_close:
raise Exception("Dataset not built")
self.dataset.write(data)
self.index.append([self.dataset.tell(), len(data), identity])
if self.metadata and metadata:
self.metadata[identity] = metadata
if flush:
self.flush()
def write_metadata(self, id, metadata):
self.metadata[id] = metadata
def flush_index(self, silent=False):
if not self.index and not silent:
print("Warning: Index not built, couldn't flush")
return
with open(self.index_path, 'wb') as f:
pickle.dump(self.index, f)
def flush_metadata(self, silent=False):
if not self.metadata and not silent:
print("Warning: Metadata not built, couldn't flush")
return
with open(self.metadata_path, 'wb') as f:
pickle.dump(self.metadata, f)
def flush(self, silent=False):
if not self.dataset and not silent:
print("Warning: Dataset not built, couldn't flush")
return
self.dataset.flush()
def close(self, silent=False):
if not self.dataset and not silent:
print("Warning: Dataset not built, couldn't flush")
return
#close the dataset filehandle and dump the pickle index
self.flush()
self.dataset.close()
\ No newline at end of file
......@@ -2,6 +2,7 @@ import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from dotmap import DotMap
import math
class BaseModel(nn.Module):
def __init__(self, user_config, **kwargs):
......@@ -27,7 +28,26 @@ class BaseModel(nn.Module):
config=config,
)
)
def init_weights(self):
n_layer = self.n_layer
for module in self.modules():
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if ("ff2" in name or "out_proj" in name) and "weight" in name:
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * n_layer)))
def configure_model(self):
full_config = {}
if not hasattr(self, 'default_config'):
......
import torch.nn as nn
import torch.nn.functional as F
from basedformer import utils
from dotmap import DotMap
from pathlib import Path
import torch
import json
class PretrainedModel(nn.Module):
def __init__(self, **kwargs):
nn.Module.__init__(self)
self.config = None
@classmethod
def no_init(cls, config):
model = utils.no_init(lambda: cls(config))
return model
@classmethod
def init(cls, config):
model = cls(config)
if hasattr(model, 'init_weights'):
model.init_weights()
else:
raise ValueError("No init_weights found, add one for init to function")
return model
def save(self, path, save_as=torch.float16):
original_dtype = model.dtype
model = self
if save_as:
model = model.to(save_as)
path = Path(path)
model_path = path / "model"
#make folder
model_path.mkdir(parents=True, exist_ok=True)
checkpoint = {}
for i, x in enumerate(model.state_dict().items()):
checkpoint[x[0]] = model_path / f"b{i}.pt"
torch.save(x[1], model_path / f"b{i}.pt")
torch.save(checkpoint, model_path / "m.pt")
#write model.config to config.json inside path
#with open(path / "config.json", "w") as f:
# json.dump(serialize_config(model.config), f)
if save_as:
model = model.to(original_dtype)
\ No newline at end of file
from cmath import exp
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -26,28 +27,29 @@ class ResBlock(nn.Module):
return F.relu(out)
class ResBlockBottleNeck(nn.Module):
def __init__(self, in_channels, out_channels) -> None:
def __init__(self, in_channels, out_channels, expansion, needs_downsample=False) -> None:
super().__init__()
downsample = True if in_channels != out_channels else False
self.residual = nn.Sequential()
if downsample:
self.expansion = expansion
if needs_downsample or in_channels != out_channels * self.expansion:
self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
nn.BatchNorm2d(out_channels)
nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=2 if needs_downsample else 1),
nn.BatchNorm2d(out_channels * self.expansion)
)
self.conv1 = nn.Conv2d(in_channels, out_channels//4, kernel_size=1, stride=1)
self.conv2 = nn.Conv2d(out_channels//4, out_channels//4, kernel_size=3, stride=2 if downsample else 1, padding=1)
self.conv3 = nn.Conv2d(out_channels//4, out_channels, kernel_size=1, stride=1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2 if needs_downsample else 1, padding=1)
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1)
self.bn1 = nn.BatchNorm2d(out_channels//4)
self.bn2 = nn.BatchNorm2d(out_channels//4)
self.bn3 = nn.BatchNorm2d(out_channels)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
def forward(self, x):
residual = self.residual(x)
out = F.relu((self.bn1(self.conv1(x))))
out = F.relu((self.bn2(self.conv2(out))))
out = F.relu((self.bn3(self.conv3(out)))) + self.residual(x)
return F.relu(out)
out = F.relu((self.bn3(self.conv3(out))))
return F.relu(out + residual)
class ResNet(base_image.BaseVisionModel):
......@@ -76,15 +78,26 @@ class ResNet(base_image.BaseVisionModel):
is_bottleneck = self.network_config[0]
curr_chan = 64
prev_chan = curr_chan
#dirty hack for downscaling at bottleneck layers
firstlayer = True
for i in self.network_config[1]:
for _ in range(i):
resblock = ResBlockBottleNeck(prev_chan, curr_chan) if is_bottleneck else ResBlock(prev_chan, curr_chan)
needs_downsample = True
if is_bottleneck:
if firstlayer:
resblock = ResBlockBottleNeck(prev_chan, curr_chan, 4)
firstlayer = False
else:
resblock = ResBlockBottleNeck(prev_chan * 4, curr_chan, 4, needs_downsample)
needs_downsample = False
else:
resblock = ResBlock(prev_chan, curr_chan)
self.resblocks.append(resblock)
prev_chan = curr_chan
curr_chan *= 2
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(prev_chan, self.config.n_class)
self.fc = nn.Linear(prev_chan * 4 if is_bottleneck else prev_chan, self.config.n_class)
def forward(self, x):
out = self.layerin(x)
......
from typing import Callable, KeysView
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
from basedformer.models import base_lm
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None:
x = torch.empty(0)
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2) / dim)).to(x.dtype).to(x.device)
sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(seq_len).to(x.device), inv_freq).float()
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)')
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
return (x * cos) + (rotate_every_two(x) * sin)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class Attention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config, causal=True, null_kv=False):
nn.Module.__init__(self)
max_positions = 2049
self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim
self.n_head = config.n_head
self.q_only = config.q_only
self.causal = causal
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
if self.causal:
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
if config.q_only:
self.k_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.head_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
else:
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=config.device, dtype=config.dtype)
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
# allowing for attending to nothing (null function)
# and to save attention from breaking if all retrieved chunks are padded out
self.null_k = nn.Parameter(torch.randn(self.hidden_dim)) if null_kv else None
self.null_v = nn.Parameter(torch.randn(self.hidden_dim)) if null_kv else None
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim)
if self.q_only:
key = self.k_proj(x).view(B, S, 1, self.head_dim)
value = self.v_proj(x).view(B, S, 1, self.head_dim).transpose(1, 2)
else:
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
offset = kv[0].shape[-2]
else:
offset = 0
if self.rotary_dim < self.head_dim:
k_rot = key[:, :, :, :self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
q_rot = query[:, :, :, :self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim:]
k_rot = apply_rotary_pos_emb(k_rot, (self.sin, self.cos), offset=offset).to(k_rot.dtype)
q_rot = apply_rotary_pos_emb(q_rot, (self.sin, self.cos), offset=offset).to(q_rot.dtype)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
query = query.transpose(1, 2)
key = key.transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key = torch.cat([k, key], dim=-2) # cat key
value = torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2)
#causal mask with generation in mind
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, [key, value]
else:
return x, None
class FeedForward(nn.Module):
def __init__(self, config):
nn.Module.__init__(self)
self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim*4, device=config.device, dtype=config.dtype)
self.ff2 = nn.Linear(config.hidden_dim*4, config.hidden_dim, device=config.device, dtype=config.dtype)
self.activation = config.activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
x = ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class GPTJLayer(nn.Module):
def __init__(self, attn, ff, config):
nn.Module.__init__(self)
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
#self.ln_preattn = nn.LogSoftmax(dim=-2)
self.ff = ff(config)
self.attn = attn(config)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5, cache=False, kv=None):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out, kv = ck(self.attn, x, kv, cache)
#attn_out, kv = self.attn(x, kv=kv, cache=cache)
else:
x = self.ln_preattn(x)
attn_out, kv = self.attn(x, kv=kv, cache=cache)
if hypernetwork:
if diff_hypernets:
if interleaving_layers and layer_id % every_n == 0:
if self.tick:
hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out
return x, kv
class GPTJModel(base_lm.BaseModel):
def __init__(self, user_config, **kwargs):
self.default_config = {
'n_layer': 6,
'n_head': 8,
'n_tokens': 2048,
'hidden_dim': 512,
'vocab_dim': 50400,
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
'Layer': GPTJLayer,
'activation': gelu_new,
'SelfAttention': SelfAttention,
'FeedForward': FeedForward,
}
base_lm.BaseModel.__init__(self, user_config, **kwargs)
......@@ -6,48 +6,10 @@ except ImportError:
from pathlib import Path
import os
import math
from torch.utils import data
import numpy as np
import torch
from tqdm import tqdm
import time
# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
def __init__(self, block_size, map_file, max_samples=None, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
self.samples = self.npz.shape[0]
if max_samples is not None:
self.samples = min(self.samples, int(max_samples))
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
class ShardedDataset(data.Dataset):
def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
#might want to pad later
self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
#shard
self.npz = self.npz[rank::world_size]
self.samples = self.npz.shape[0]
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment