Commit 02d93202 authored by novelailab's avatar novelailab

remove lm_train, merge utils, move optimizer

parent 773e886f
......@@ -2,22 +2,19 @@ from re import A
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from lm_train import optimizer, utils
from lm_train import utils
from torch.utils import data
from main import *
import yaml
import math
import sys
from tqdm import tqdm
import time
import wandb
from lm_arch.gpt2 import GPT2Model
import numpy as np
from transformers import AutoTokenizer
from torch.utils.checkpoint import checkpoint as ck
from math import log2, ceil
from lm_arch import gptj, lm_base, optimizer
from lm_arch import util
def _init_weights(module):
"""Initialize the weights."""
......
......@@ -11,9 +11,6 @@ import os
from pathlib import Path
import math
def gelu_new(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None:
x = torch.empty(0)
......
......@@ -5,6 +5,29 @@ except ImportError:
from collections import MutableMapping
from pathlib import Path
import os
import math
from torch.utils import data
import numpy as np
import torch
# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
def __init__(self, block_size, map_file, max_samples=None):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
self.samples = self.npz.shape[0]
if max_samples is not None:
self.samples = min(self.samples, int(max_samples))
self.skip = 0
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))
def no_init(loading_code):
def dummy(self):
......@@ -22,6 +45,11 @@ def no_init(loading_code):
return result
# Count the parameters of a given pytorch model.
def count_parameters(model, only_trainable=False):
return sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)
SPLIT_WEIGHTS_NAME = "m.pt"
class SplitCheckpoint(MutableMapping):
def __init__(self, name_or_path, device="cpu", subfolder=None):
......@@ -60,4 +88,7 @@ class SplitCheckpoint(MutableMapping):
def __copy__(self):
return SplitCheckpoint(self.chkpt_dir, device=self.device)
def copy(self):
return SplitCheckpoint(self.chkpt_dir, device=self.device)
\ No newline at end of file
return SplitCheckpoint(self.chkpt_dir, device=self.device)
def gelu_new(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
\ No newline at end of file
from transformers import GPTNeoForCausalLM, AutoConfig
import torch
from lm_train.utils import *
import math
class GPT:
def __init__(self, model_dtype="bf16", model_device="cuda"):
self.config = self.get_config(model_dtype, model_device)
self.checkpoint = self.get_checkpoint()
self.model = None
return
def get_config(self, model_dtype="bf16", model_device="cuda"):
print("Using device:", model_device)
config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B")
config.num_layers = 28
config.attention_layers = ["global"] * config.num_layers
config.attention_types = [["global"], config.num_layers]
config.num_heads = 16
config.hidden_size = 256 * config.num_heads
config.vocab_size = 50400
config.rotary = True
config.rotary_dim = 64
config.jax = True
config.model_dtype = model_dtype
config.model_device = model_device
if model_dtype == "bf16":
config.full_bf16 = True
return config
def get_checkpoint(self):
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
from pathlib import Path
class Checkpoint(MutableMapping):
def __init__(self, chkpt_dir, device="cpu"):
self.device = device
self.chkpt_dir = Path(chkpt_dir)
self.checkpoint = torch.load(str(chkpt_dir / Path("m.pt")))
def __len__(self):
return len(self.checkpoint)
def __getitem__(self, key):
path = self.chkpt_dir / Path(self.checkpoint[key]).name
return torch.load(str(path), map_location=self.device)
def __setitem__(self, key, value):
return
def __delitem__(self, key, value):
return
def keys(self):
return self.checkpoint.keys()
def __iter__(self):
for key in self.checkpoint:
yield (key, self.__getitem__(key))
def __copy__(self):
return Checkpoint(self.chkpt_dir, device=self.device)
def copy(self):
return Checkpoint(self.chkpt_dir, device=self.device)
return Checkpoint
def load_model(self, model_path=None, model_name=None, config=None, checkpoint=None):
if config == None:
config = self.config
if checkpoint == None:
Checkpoint = self.checkpoint
if model_name != None:
model_path = self.assign_path(model_name)
print("Loading model from: " + model_path)
model = no_init(lambda: GPTNeoForCausalLM.from_pretrained(pretrained_model_name_or_path=None, config=self.config, state_dict=Checkpoint(model_path)))
self.model = model
return model
def assign_path(self, model_name):
if model_name == "gptj":
return "/home/xuser/models/j6b_ckpt_14001"
# Raise error if model name not recognized
else:
raise ValueError("Model name not recognized")
def init_model(self, config=None, method='wang'):
neox_init = True
if config == None:
config = self.config
model = no_init(lambda: GPTNeoForCausalLM(config))
if neox_init:
modules = [*model.transformer.h[:-1], model.transformer.wte, model.transformer.ln_f]
init = small_init_method(self.config.hidden_size)
for module in modules:
for param in module.parameters():
init(param)
last_layer = model.transformer.h[-1]
last_layer_init = wang_init_method(self.config.num_layers, self.config.hidden_size)
for param in last_layer.parameters():
last_layer_init(param)
self.model = model
return model
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embs=None,
):
if isinstance(self.model, GPTNeoForCausalLM):
outputs = self.model(input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, embs)
# outputs: dict(loss, logits, past_key_values, hidden_states, attentions)
return outputs
#def init_module()
def wang_init_method(n_layers, dim):
std = 2 / n_layers / math.sqrt(dim)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
# Stolen from NeoX. For the 20B run wang_init used on the output layer and small_init on rest of the layers.
def small_init_method(dim):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
std = math.sqrt(2 / (5 * dim))
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
from torch.utils import data
from transformers.modeling_utils import no_init_weights
import numpy as np
import torch
# Does this work with other block_sizes? doesn't seem to.
class FbDataset(data.Dataset):
def __init__(self, block_size, map_file, max_samples=None):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
self.samples = self.npz.shape[0]
if max_samples is not None:
self.samples = min(self.samples, int(max_samples))
self.skip = 0
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
with no_init_weights():
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
# Count the parameters of a given pytorch model.
def count_parameters(model, only_trainable=False):
return sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)
\ No newline at end of file
......@@ -4,15 +4,14 @@ import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from lm_train import optimizer, utils
from lm_train import utils
from torch.utils import data
from lm_arch import lm_base
from lm_arch import lm_base, optimizer
import yaml
import sys
from tqdm import tqdm
import time
import wandb
from lm_arch.gpt2 import GPT2Model
import numpy as np
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment