Commit 4f87dce5 authored by novelailab's avatar novelailab

messy

parent db8c9ae8
import torch
import transformers
import sys
"""
Original:
ln_f.weight
ln_f.bias
wte.weight
wpe.weight
h.0.ln_1.weight
h.0.ln_1.bias
h.0.attn.bias
h.0.attn.c_attn.weight
h.0.attn.c_attn.bias
h.0.attn.c_proj.weight
h.0.attn.c_proj.bias
h.0.ln_2.weight
h.0.ln_2.bias
h.0.mlp.c_fc.weight
h.0.mlp.c_fc.bias
h.0.mlp.c_proj.weight
h.0.mlp.c_proj.bias
attn has biases unlike GPT-J. QKV Matrices are also merged instead of separate. what is the order though?
"""
x = torch.load("models/gpt2_vanilla/pytorch_model.bin")
print(x["h.0.attn.c_attn.weight"].reshape(-1, 768, 768).shape)
sys.exit(0)
new_state_dict = {}
module_map = {
"ln_1": "ln_preattn",
"mlp.c_proj": "ff.ff2",
"mlp.c_fc": "ff.ff1",
"attn.attention.out_proj": "attn.out_proj",
"attn.attention.k_proj": "attn.k_proj",
"attn.attention.v_proj": "attn.v_proj",
"attn.attention.q_proj": "attn.q_proj",
"wte": "vocab_embed",
'ln_f': 'ln_final',
'lm_head': 'lm_head',
}
print(type(state_dict))
for key in state_dict.keys():
dotlist = key.split('.')
if len(dotlist) > 3:
layer = dotlist[2]
for x in module_map:
if x in key:
new_state_dict[f"layers.{layer}.{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
print(f"{key} -> layers.{layer}.{module_map[x]}.{dotlist[-1]}")
else:
for x in module_map:
if x in key:
new_state_dict[f"{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
print(f"{key} -> {module_map[x]}.{dotlist[-1]}")
#print(new_state_dict)
def save(state_dict, path):
try: os.mkdir(path)
except: pass
checkpoint = {}
for i, x in enumerate(state_dict.items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
save(new_state_dict, "models/6b_vanilla")
\ No newline at end of file
This diff is collapsed.
from torch import nn
import torch
import os
import math
from lm_arch.utils import *
# Can access and change every module from here, as both Layer class and ff and attn classes are passed from GPTModel.
class GPTModel(nn.Module):
def __init__(self, hidden_dim=512, n_layer=12, n_head=4, vocab_dim=50400, eps=1e-4, activation=nn.GELU(), Layer=None, SelfAttention=None, FeedForward=None, device="cuda", dtype=torch.float16):
super(GPTModel, self).__init__()
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
#TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet.
#TODO: Do we want to pass a config object everywhere? I don't exactly like that but passing a lot of variables is a bit ugly too.
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer in self.layers:
x = layer(x, hypernetwork, act_ck)
x = self.ln_final(x)
return x
def get_logits(self, x, hypernetwork=None, act_ck=False):
x = self.forward(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
@classmethod
def load(cls, config, path=None, state_dict=None):
if path:
state_dict = SplitCheckpoint(path, device="cuda")
model = no_init(lambda: cls(**config))
model.load_state_dict(state_dict, strict=False)
return model
@classmethod
def init(cls, config):
model = cls(**config)
return model
@classmethod
def neox_init(cls, config):
model = cls(**config)
modules = [*model.layers[:-1], model.vocab_embed, model.ln_final, model.lm_head]
init = small_init_method(config["hidden_dim"])
for module in modules:
for param in module.parameters():
init(param)
last_layer = model.layers[-1]
last_layer_init = wang_init_method(config["n_layer"], config["hidden_dim"])
for param in last_layer.parameters():
last_layer_init(param)
return model
@classmethod
def simple_init(cls, config):
model = cls(**config)
state = model.state_dict()
for k in state:
state[k] = state[k] / math.sqrt(2 * config["n_layer"])
model.load_state_dict(state)
return model
def save(self, path):
try: os.mkdir(path)
except: pass
checkpoint = {}
for i, x in enumerate(self.state_dict().items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
def wang_init_method(n_layers, dim):
std = 2 / n_layers / math.sqrt(dim)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
# Stolen from NeoX. For the 20B run wang_init used on the output layer and small_init on rest of the layers.
def small_init_method(dim):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
std = math.sqrt(2 / (5 * dim))
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as ck
from einops import rearrange, repeat
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
import os
from pathlib import Path
import math
import lm_arch.gpt_arch as gpt_arch
#TODO: Might change with non einsum functions?
def get_logits(x, embedding):
return embedding(x)
def gelu_new(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None:
x = torch.empty(0)
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2) / dim)).to(x.dtype).to(x.device)
sinusoid_inp = torch.einsum('i , j -> i j', torch.arange(seq_len).to(x.device), inv_freq).float()
return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp)
def rotate_every_two(x):
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, '... d j -> ... (d j)')
def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2), sincos)
return (x * cos) + (rotate_every_two(x) * sin)
def _split_heads(tensor, num_heads, attn_head_size, rotary):
"""
Splits hidden_size dim into attn_head_size and num_heads
"""
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(*new_shape)
if rotary:
return tensor
if len(tensor.shape) == 5:
return tensor.permute(0, 1, 3, 2, 4) # (batch, blocks, head, block_length, head_features)
elif len(tensor.shape) == 4:
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
def _merge_heads(tensor, num_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden_size
"""
if len(tensor.shape) == 5:
tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
elif len(tensor.shape) == 4:
tensor = tensor.permute(0, 2, 1, 3).contiguous()
else:
raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device="cuda", dtype=torch.float16):
super(SelfAttention, self).__init__()
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
def forward(self, x):
query = self.q_proj(x)
key = self.k_proj(x)
value = self.v_proj(x)
query = _split_heads(query, self.n_head, self.head_dim, True)
key = _split_heads(key, self.n_head, self.head_dim, True)
value = _split_heads(value, self.n_head, self.head_dim, False)
offset = 0
if self.rotary_dim < self.head_dim:
k_rot = key[:, :, :, :self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
q_rot = query[:, :, :, :self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim:]
k_rot = apply_rotary_pos_emb(k_rot, (self.sin, self.cos), offset=offset).to(k_rot.dtype)
q_rot = apply_rotary_pos_emb(q_rot, (self.sin, self.cos), offset=offset).to(q_rot.dtype)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = _merge_heads(x, self.n_head, self.head_dim)
x = self.out_proj(x)
return x
class FeedForward(nn.Module):
def __init__(self, dim=768, hidden_dim=768*4, activation=nn.GELU(), device="cuda", dtype=torch.float16):
super(FeedForward, self).__init__()
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
if act_ck:
ck(self.activation, x)
else:
x = self.activation(x)
x = self.ff2(x)
return x
class GPTJLayer(nn.Module):
def __init__(self, attn=SelfAttention, ff=FeedForward, hidden_dim=768, n_head=4, eps=1e-6, activation=nn.GELU(), device="cuda", dtype=torch.float16):
super(GPTJLayer, self).__init__()
self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
def forward(self, x, hypernetwork=None, act_ck=False):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x)
else:
x = self.ln_preattn(x)
attn_out = self.attn(x)
ff_out = self.ff(x, act_ck)
x = residual + attn_out + ff_out
if hypernetwork:
hyper_out = hypernetwork(x)
x = x + hyper_out
return x
class GPTJModel(gpt_arch.GPTModel):
def __init__(self, hidden_dim=512, n_layer=12, n_head=4, vocab_dim=50400, eps=1e-4, activation=nn.GELU(), Layer=GPTJLayer, SelfAttention=SelfAttention, FeedForward=FeedForward, device="cuda", dtype=torch.float16):
super(GPTJModel, self).__init__(hidden_dim=hidden_dim, n_layer=n_layer, n_head=n_head, vocab_dim=vocab_dim, eps=eps, activation=activation, Layer=Layer, SelfAttention=SelfAttention, FeedForward=FeedForward, device=device, dtype=dtype)
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTJLayer
}
model = GPTJModel.load(config, path, state_dict)
return model
def init_6b():
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTJLayer
}
model = GPTJModel.init(config)
return model
def init_125m():
config = {
"n_layer": 12,
"n_head": 12,
"hidden_dim": 768,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTJLayer
}
model = GPTJModel.init(config)
return model
def init_1_3b():
config = {
"n_layer": 24,
"n_head": 16,
"hidden_dim": 2048,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTJLayer
}
model = GPTJModel(**config)
return model
\ No newline at end of file
import torch
try:
from collections.abc import MutableMapping
except ImportError:
from collections import MutableMapping
from pathlib import Path
import os
def no_init(loading_code):
def dummy(self):
return
modules = [torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm]
original = {}
for mod in modules:
original[mod] = mod.reset_parameters
mod.reset_parameters = dummy
result = loading_code()
for mod in modules:
mod.reset_parameters = original[mod]
return result
SPLIT_WEIGHTS_NAME = "m.pt"
class SplitCheckpoint(MutableMapping):
def __init__(self, name_or_path, device="cpu", subfolder=None):
self.device = device
localpath = Path(name_or_path)
if subfolder is not None:
localpath = localpath / subfolder
if os.path.isfile(localpath):
self.chkpt_dir = localpath.parent
self.remote = False
elif os.path.isfile(localpath / SPLIT_WEIGHTS_NAME):
self.chkpt_dir = localpath
self.checkpoint = torch.load(str(localpath / SPLIT_WEIGHTS_NAME))
self.remote = False
self.checkpoint = self._load(SPLIT_WEIGHTS_NAME, None)
def _load(self, name, shape, **kwparams):
path = str(self.chkpt_dir / name)
return torch.load(path, **kwparams)
def __len__(self):
return len(self.checkpoint)
def __getitem__(self, key):
name = self.checkpoint[key]
if type(name) is tuple:
return self._load(name[0].split('/')[-1], name[1], map_location=self.device)
else:
return self._load(name.split('/')[-1], None, map_location=self.device)
def __setitem__(self, key, value):
return
def __delitem__(self, key, value):
return
def keys(self):
return self.checkpoint.keys()
def __iter__(self):
for key in self.checkpoint:
yield (key, self.__getitem__(key))
def __copy__(self):
return SplitCheckpoint(self.chkpt_dir, device=self.device)
def copy(self):
return SplitCheckpoint(self.chkpt_dir, device=self.device)
\ No newline at end of file
......@@ -20,7 +20,7 @@ class BasedOptimizer:
"warmup_steps": 1,
"anneal_steps": 1,
"total_steps": None,
"weight_decay": 0,
"weight_decay": 0.01,
"tokens": None,
"epochs": None,
"beta1": 0.9,
......@@ -41,8 +41,12 @@ class BasedOptimizer:
if optimizer == "adamw":
self.optimizer = optim.AdamW(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
def step(self):
self.optimizer.step()
def step(self, scaler=None):
if scaler:
scaler.step(self.optimizer)
else:
self.optimizer.step()
self.curr_step = self.curr_step + 1
self.curr_lr = lr_schedule(self.curr_step, self.warmup_steps, self.anneal_steps, self.lr, self.end_lr)
......
......@@ -10,6 +10,7 @@ except ImportError:
import os
from pathlib import Path
import math
from lm_arch.gptj import GPTJModel
def no_init(loading_code):
def dummy(self):
......@@ -248,6 +249,7 @@ class GPTLayer(nn.Module):
class GPTModel(nn.Module):
def __init__(self, hidden_dim=512, n_layer=12, n_head=4, vocab_dim=50400, eps=1e-4, activation=nn.GELU(), Layer=GPTLayer, device="cuda", dtype=torch.float16):
super(GPTModel, self).__init__()
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
......@@ -258,15 +260,33 @@ class GPTModel(nn.Module):
#TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet.
#TODO: Do we want to pass a config object everywhere? I don't exactly like that but passing a lot of variables is a bit ugly too.
def forward(self, x, hypernetwork=None, act_ck=False):
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if ("ff2" in name or "out_proj" in name) and "weight" in name:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.n_layer)))
def get_embeds(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer in self.layers:
x = layer(x, hypernetwork, act_ck)
x = self.ln_final(x)
return x
def get_logits(self, x, hypernetwork=None, act_ck=False):
x = self.forward(x, hypernetwork=hypernetwork, act_ck=act_ck)
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
......@@ -289,15 +309,26 @@ class GPTModel(nn.Module):
model = cls(**config)
modules = [*model.layers[:-1], model.vocab_embed, model.ln_final, model.lm_head]
init = small_init_method(config["hidden_dim"])
for module in modules:
for param in module.parameters():
init(param)
last_layer = model.layers[-1]
last_layer_init = wang_init_method(config["n_layer"], config["hidden_dim"])
for param in last_layer.parameters():
last_layer_init(param)
for param in model.parameters():
init(param)
return model
@classmethod
def simple_init(cls, config):
model = cls(**config)
state = model.state_dict()
for k in state:
state[k] = state[k] / math.sqrt(2 * config["n_layer"])
model.load_state_dict(state)
return model
@classmethod
def gpt2_init(cls, config):
model = cls(**config)
for module in model.modules():
model._init_weights(module)
return model
def save(self, path):
......
......@@ -12,6 +12,7 @@ import sys
from tqdm import tqdm
import time
import wandb
from lm_arch.gpt2 import GPT2Model
model_config = {
"n_layer": 12,
......@@ -26,22 +27,23 @@ model_config = {
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2train",
"run_name": "owt2-125m-fp32",
"lr": 6e-4,
"end_lr": 6e-4,
"warmup_steps": 50,
"bs": 8,
"gas": 32,
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2fp16amp2",
"run_name": "owt2-125m-fp16AMP-1024ctx-120bs-1e-4lr",
"lr": 1e-4,
"end_lr": 1e-4,
"warmup_steps": 100,
"bs": 12,
"gas": 10,
"seed": 69,
"save_every": 50,
"save_every": 500,
"amp": True,
}
bs = train_config["bs"]
gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
model = GPTModel.neox_init(model_config).cuda().float()
model = GPT2Model.gpt2_init(model_config).cuda().float()
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
......@@ -61,26 +63,45 @@ for input_ids, labels in t:
labels = labels.cuda()
loss = 0
for x in range(train_config["gas"]):
logits = model.get_logits(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=None, act_ck=True)
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :]
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
scaler.scale(gas_loss).backward()
if train_config["amp"]:
with torch.cuda.amp.autocast():
logits = model(input_ids[x*bs:(x+1)*bs, :1024].cuda(), hypernetwork=None, act_ck=False)
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :1024].contiguous()
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
else:
logits = model(input_ids[x*bs:(x+1)*bs, :1024].cuda(), hypernetwork=None, act_ck=False)
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :1024].contiguous()
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
if train_config["amp"]:
scaler.scale(gas_loss).backward()
else:
gas_loss.backward()
loss += gas_loss.item()
loss = loss / gas
scaler.unscale_(opt.optimizer)
if train_config["amp"]:
scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
scaler.step(opt.optimizer)
scaler.update()
if train_config["amp"]:
opt.step(scaler=scaler)
else:
opt.step()
if train_config["amp"]:
scaler.update()
#opt.step()
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) / (bs*gas)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = step_per_sec * 2048
tokens_per_sec = step_per_sec * 1024
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr})
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr, "train/loss_scale": scaler.get_scale()})
curr_step += 1
if curr_step % train_config["save_every"] == 0:
model.save(train_config["save_path"] + f"/{curr_step}")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment