Commit 5ff36559 authored by novelailab's avatar novelailab

might have broken stuff, will clean sshit later

parent 2a8307f2
...@@ -5,6 +5,8 @@ from time import perf_counter, perf_counter_ns ...@@ -5,6 +5,8 @@ from time import perf_counter, perf_counter_ns
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
from contextlib import contextmanager from contextlib import contextmanager
from lm_arch.hypernet import *
import sys
#replicating timeit magic function of ipython #replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True): def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
precision = 'ns' precision = 'ns'
...@@ -64,13 +66,27 @@ def test_thing(graph, input): ...@@ -64,13 +66,27 @@ def test_thing(graph, input):
graph.replay() graph.replay()
torch.cuda.synchronize() torch.cuda.synchronize()
model_config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
with torch.no_grad(): with torch.no_grad():
model = init_6b().cuda().half() model = init_6b().cuda().bfloat16()
shape = (1, 1) shape = (1, 2048)
hypernet = HyperNetworkSingle(model_config).cuda()
x = torch.zeros(shape).cuda().long() x = torch.zeros(shape).cuda().long()
print(shape) print(shape)
print("PyTorch Eager") print("PyTorch Eager")
timeit(r=1, n=100, func=lambda: model(x), do_tqdm=False, first=False) timeit(r=1, n=100, func=lambda: model(x, hypernetwork=None), do_tqdm=False, first=False)
print("PyTorch Eager + Hypernet")
timeit(r=1, n=100, func=lambda: model(x, hypernetwork=hypernet), do_tqdm=False, first=False)
sys.exit(0)
print("PyTorch CUDAGraph+JIT+NVFuser") print("PyTorch CUDAGraph+JIT+NVFuser")
with torch.jit.fuser("fuser2"): with torch.jit.fuser("fuser2"):
module = torch.jit.trace(model, torch.zeros(shape).long().cuda()) module = torch.jit.trace(model, torch.zeros(shape).long().cuda())
......
...@@ -146,7 +146,7 @@ class HyperNetworkSingle(nn.Module): ...@@ -146,7 +146,7 @@ class HyperNetworkSingle(nn.Module):
#x = shift_tokens(x, self.num_shifts) #x = shift_tokens(x, self.num_shifts)
x = self.linear(x) x = self.linear(x)
x = x.mul(torch.sigmoid(x)) x = x.mul(torch.sigmoid(x))
return x.bfloat16() return x.half()
model_config = { model_config = {
...@@ -174,15 +174,16 @@ train_config = { ...@@ -174,15 +174,16 @@ train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map", "data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map", #"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map", #"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/fixedj", "save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-bs16",
"run_name": "gpt-j-enwik9-6b-postln-bf16-5e-4", "do_save": False,
"lr": 5e-4, "run_name": "gpt-j-enwik9-6b-postln-bf16-2e-4-4bsz-every5layer",
"end_lr": 5e-4, "lr": 2e-4,
"end_lr": 2e-4,
"warmup_steps": 50, "warmup_steps": 50,
"bs": 1, "bs": 1,
"gas": 16, "gas": 4,
"seed": 69, "seed": 69,
"save_every": 500, "save_every": 100,
"amp": False, "amp": False,
"loss_scale": False, "loss_scale": False,
} }
...@@ -193,7 +194,7 @@ gas = train_config["gas"] ...@@ -193,7 +194,7 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True) Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float() #model = GPTModel.gpt2_init(model_config).cuda().float()
model = load_gpt_j().cuda().bfloat16() model = load_gpt_j().cuda().half()
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
...@@ -202,7 +203,8 @@ for name, p in model.named_parameters(): ...@@ -202,7 +203,8 @@ for name, p in model.named_parameters():
p.requires_grad = True p.requires_grad = True
hypernetwork = HyperNetworkSingle(model_config).cuda().float() hypernetwork = HyperNetworkSingle(model_config).cuda().float()
#hypernetwork = nn.ModuleList([HyperNetwork(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)]) #hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)])
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(2)])
for param in hypernetwork.parameters(): for param in hypernetwork.parameters():
param.requires_grad = True param.requires_grad = True
...@@ -252,12 +254,14 @@ for input_ids, labels in t: ...@@ -252,12 +254,14 @@ for input_ids, labels in t:
scaler.update() scaler.update()
#opt.step() #opt.step()
opt.zero_grad() opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) / (bs*gas) sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step) step_per_sec = (1. / sec_per_step)
tokens_per_sec = step_per_sec * 2048 tokens_per_sec = (step_per_sec * 2048) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}") t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr, "train/loss_scale": scaler.get_scale()}) wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr, "train/loss_scale": scaler.get_scale()})
curr_step += 1 curr_step += 1
if curr_step % train_config["save_every"] == 0: if train_config["do_save"]:
#model.save(train_config["save_path"] + f"/{curr_step}") if curr_step % train_config["save_every"] == 0 or curr_step == 1:
print(f"Saved model at step {curr_step}") torch.save(hypernetwork.state_dict(), train_config["save_path"] + f"/{curr_step}.hyper")
#model.save(train_config["save_path"] + f"/{curr_step}")
print(f"Saved model at step {curr_step}")
import torch.nn as nn
import torch
import torch.nn.functional as F
import math
import torch.utils.checkpoint as ck
def gelu_new(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
def _init_weights(module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class HyperNetworkGRU(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear1 = nn.Linear(embed_dim, embed_dim//8)
self.gru = nn.GRU(embed_dim//8, embed_dim // 8, num_layers=1, bidirectional=False, batch_first=True)
self.linear2 = nn.Linear(embed_dim // 8, embed_dim)
self.ln_1 = nn.LayerNorm(embed_dim // 8, eps=1e-5)
self.activation = gelu_new
for module in self.modules():
_init_weights(module)
for param in self.linear2.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
for param in self.gru.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
def forward(self, x):
x = x.float()
x = self.linear1(x)
x = self.gru(x)[0]
x = self.ln_1(x)
x = self.linear2(x)
x = ck(self.activation, x)
return x.bfloat16()
class HyperNetwork(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True)
self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True)
self.activation = gelu_new
self.num_shifts = ceil(log2(2048)) - 1
#self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules():
_init_weights(module)
for param in self.linear2.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
#state = self.state_dict()
#for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
def forward(self, x):
x = x.float()
#x = shift_tokens(x, self.num_shifts)
x = self.linear(x)
x = ck(self.activation, x)
x = self.linear2(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
class HyperNetworkSingle(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
self.activation = gelu_new
#self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules():
_init_weights(module)
for param in self.linear.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
#state = self.state_dict()
#for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
def forward(self, x):
x = x.float()
#x = shift_tokens(x, self.num_shifts)
x = self.linear(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
\ No newline at end of file
...@@ -43,6 +43,13 @@ class BasedOptimizer: ...@@ -43,6 +43,13 @@ class BasedOptimizer:
elif optimizer == "adamw8bit": elif optimizer == "adamw8bit":
import bitsandbytes as bnb import bitsandbytes as bnb
self.optimizer = bnb.optim.Adam8bit(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps) self.optimizer = bnb.optim.Adam8bit(parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif optimizer == "adafactor":
try:
from transformers.optimization import Adafactor
except ImportError:
raise ImportError("Please install transformers for Adafactor")
self.optimizer = Adafactor(params=parameters)
def step(self, scaler=None): def step(self, scaler=None):
if scaler: if scaler:
......
...@@ -227,7 +227,7 @@ class GPTLayer(nn.Module): ...@@ -227,7 +227,7 @@ class GPTLayer(nn.Module):
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype) self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False): def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5):
residual = x residual = x
if act_ck: if act_ck:
...@@ -238,17 +238,29 @@ class GPTLayer(nn.Module): ...@@ -238,17 +238,29 @@ class GPTLayer(nn.Module):
x = self.ln_preattn(x) x = self.ln_preattn(x)
attn_out = self.attn(x) attn_out = self.attn(x)
if diff_hypernets and hypernetwork: if hypernetwork:
if layer_id % 1 == 0: if diff_hypernets:
hyper_out = hypernetwork[(layer_id // 5) - 1](x) if interleaving_layers and layer_id % every_n == 0:
else: if self.tick:
hyper_out = hypernetwork(x) hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck) ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here. #order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match. #x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and not diff_hypernets or layer_id % 5 == 0: if hypernetwork and layer_id % every_n == 0:
#if hypernetwork and layer_id % 5 == 0:
x = x + hyper_out x = x + hyper_out
return x return x
......
...@@ -8,13 +8,13 @@ bash = False ...@@ -8,13 +8,13 @@ bash = False
config_obj = KubeConfig() config_obj = KubeConfig()
config_obj.set_name(name) config_obj.set_name(name)
config_obj.set_gpu(gpu_name=GPU.A100_PCIE_40GB, amount=1) config_obj.set_gpu(gpu_name=GPU.A40, amount=1)
config_obj.set_ram(16) config_obj.set_ram(16)
config_obj.set_cpu(4) config_obj.set_cpu(4)
config_obj.dry_run(dry) config_obj.dry_run(dry)
config_obj.print_information() config_obj.print_information()
config_obj.create_deployment(overwrite=True) #config_obj.create_deployment(overwrite=True)
config_obj.create_service(overwrite=True) #config_obj.create_service(overwrite=True)
remote = config_obj.get_pyfra_remote() remote = config_obj.get_pyfra_remote()
env1 = remote.env('noname', python_version=None) env1 = remote.env('noname', python_version=None)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment