Commit 9f01112b authored by novelailab's avatar novelailab

fix attn and equal with HF output

parent 8382affa
from main import *
import time
from time import perf_counter, perf_counter_ns
import numpy as np
from tqdm import tqdm
from contextlib import contextmanager
import torch.nn.functional as F
from transformers import (
AutoModelForCausalLM,
GPTNeoForCausalLM,
AutoConfig,
)
#replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
precision = 'ns'
r_arr = np.empty([2, r]) # [0] = mean, [1] = std
if function:
func.__name__ = function.__name__
for i in tqdm(range(r)) if do_tqdm else range(r):
n_arr = np.empty(n)
for k in range(n):
start = perf_counter_ns()
func()
n_arr[k] = perf_counter_ns() - start
if not first:
# delete the first element from n_arr numpy array
n_arr = np.delete(n_arr, 0)
r_arr[0, i] = np.mean(n_arr)
r_arr[1, i] = np.std(n_arr)
best = r_arr[:, np.argmin(r_arr[0])] # [0] = mean, [1] = std
#check if best[0] bigger than 1ms in numpy
if best[0] < 1e3:
precision = 'ns'
elif best[0] >= 1e9:
print('b')
best[0] = best[0] * 1e-9
best[1] = best[1] * 1e-9
precision = 's'
elif best[0] >= 1e6:
best[0] = best[0] * 1e-6
best[1] = best[1] * 1e-6
precision = 'ms'
elif best[0] >= 1e3:
precision = 'μs'
best[0] = best[0] * 1e-3
best[1] = best[1] * 1e-3
if not quiet:
if precision == 'ns':
print(f"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
if precision == 'μs':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 'ms':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 's':
print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
with torch.no_grad():
based_model = load_gpt_j().cuda().half().eval()
print("Loaded based model")
hf_model = no_init(lambda: AutoModelForCausalLM.from_pretrained('/home/xuser/models/j6b_ckpt_14001')).cuda().half().eval()
print("Loaded hf model")
x = torch.randint(0, 50256, (1, 2048)).cuda().long()
assert torch.allclose(hf_model.transformer.wte(x), based_model.vocab_embed(x))
hidden = hf_model.transformer.wte(x)
for layer in range(28):
assert torch.allclose(hf_model.transformer.h[layer].ln_1(hidden), based_model.layers[layer].ln_preattn(hidden))
hidden = hf_model.transformer.h[layer].ln_1(hidden)
assert torch.allclose(hf_model.transformer.h[layer].mlp(hidden), based_model.layers[layer].ff(hidden))
hidden = hf_model.transformer.h[layer].mlp(hidden)
assert torch.allclose(hf_model.transformer.h[layer].attn(hidden)[0], based_model.layers[layer].attn(hidden))
hidden = hf_model.transformer.h[layer].attn(hidden)[0]
assert torch.allclose(hf_model.transformer.ln_f(hidden), based_model.ln_final(hidden))
hidden = hf_model.transformer.ln_f(hidden)
......@@ -151,17 +151,17 @@ class SelfAttention(nn.Module):
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e10, requires_grad=False)) #-1e10 is what mtj uses.
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.rotary_dim = self.head_dim
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
......@@ -176,9 +176,22 @@ class SelfAttention(nn.Module):
value = _split_heads(value, self.n_head, self.head_dim, False)
offset = 0
if self.rotary_dim < self.head_dim:
k_rot = key[:, :, :, :self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim:]
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
q_rot = query[:, :, :, :self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim:]
k_rot = apply_rotary_pos_emb(k_rot, (self.sin, self.cos), offset=offset).to(k_rot.dtype)
q_rot = apply_rotary_pos_emb(q_rot, (self.sin, self.cos), offset=offset).to(q_rot.dtype)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
......
......@@ -7,13 +7,13 @@ dry = False
config_obj = KubeConfig()
config_obj.set_name(name)
config_obj.set_gpu(gpu_name=GPU.RTX_A5000, amount=1)
config_obj.set_gpu(gpu_name=GPU.RTX_A6000, amount=1)
config_obj.set_ram(16)
config_obj.set_cpu(4)
config_obj.dry_run(dry)
config_obj.print_information()
#config_obj.create_deployment(overwrite=True)
#config_obj.create_service(overwrite=True)
config_obj.create_deployment(overwrite=True)
config_obj.create_service(overwrite=True)
remote = config_obj.get_pyfra_remote()
env1 = remote.env('noname', python_version=None)
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
#Based Optimizer
class BasedOptimizer:
def __init__(self, model, config, optimizer):
self.min_lr = config["min_lr"] if "min_lr" in config else 1e-06
self.warmup_end = config["lr"] if "lr" in config else 5e-06
self.warmup_init = config["warmup_init"] if "warmup_init" in config else 0
self.warmup_steps = config["warmup_steps"] if "warmup_steps" in config else 1
self.total_steps = config["total_steps"] if "total_steps" in config else None
self.weight_decay = config["weight_decay"] if "weight_decay" in config else 0
self.start_step = config["start_step"] if "start_step" in config else 0
self.curr_step = self.start_step
self.curr_lr = 0
optim_func = optim.AdamW
self.optimizers = optim_func(model.parameters(), lr=self.warmup_init, weight_decay=self.weight_decay, betas=config["betas"], eps=config["eps"])
def get_current_lr(self):
cosine_lr = self.min_lr + 0.5 * (self.warmup_end - self.min_lr) * (1 + math.cos(math.pi * min(1.0, max(0, self.curr_step - self.warmup_steps) / (self.total_steps - self.warmup_steps))))
target_lr = self.warmup_end if self.curr_step < self.warmup_steps else cosine_lr
return inter(self.warmup_init, target_lr, max(0, self.curr_step - self.start_step) / max(1, self.warmup_steps))
return min(self.end_lr * (self.curr_step / self.warmup_steps), self.end_lr)
def backward(self, loss):
self.optimizers[0].backward(loss, update_master_grads=False)
#loss.backward()
def step(self, scaler=None):
self.curr_lr = self.get_current_lr()
for optimizer in self.optimizers:
for paramx in optimizer.param_groups:
paramx['lr'] = self.curr_lr
optimizer.update_master_grads()
if scaler:
for optimizer in self.optimizers:
scaler.step(optimizer)
else:
optimizer.step()
self.curr_step += 1
def zero_grad(self):
for optimizer in self.optimizers:
optimizer.zero_grad()
def print_info(self):
print(f"min_lr: {str(self.min_lr)}")
print(f"warmup_end: {str(self.warmup_end)}")
print(f"warmup_init: {str(self.warmup_init)}")
print(f"warmup_steps: {str(self.warmup_steps)}")
print(f"start_step: {str(self.start_step)}")
print(f"total_steps: {str(self.total_steps)}")
print(f"weight_decay: {str(self.weight_decay)}")
print(f"step: {str(self.curr_step)}")
print(f"curr_lr: {str(self.get_current_lr())}")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment