Commit 7e65fc56 authored by novelailab's avatar novelailab

bighyper

parent b0155c91
from re import A
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from lm_train import optimizer, utils
from torch.utils import data
from main import *
import yaml
import sys
from tqdm import tqdm
import time
import wandb
from lm_arch.gpt2 import GPT2Model
import numpy as np
from transformers import AutoTokenizer
def _init_weights(module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
class HyperNetwork(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim//4, bias=True)
self.linear2 = nn.Linear(embed_dim//4, embed_dim, bias=True)
self.activation = gelu_new
#self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules():
_init_weights(module)
for param in self.linear2.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
#state = self.state_dict()
#for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
def forward(self, x):
x = self.linear(x.float())
x = self.activation(x)
x = self.linear2(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
model_config = {
"n_layer": 12,
"n_head": 12,
"hidden_dim": 768,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
model_config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/fixedj",
"run_name": "bighyper-gpt-j-enwik9-6b-postln-bf16-1e-4",
"lr": 1e-4,
"end_lr": 1e-4,
"warmup_steps": 50,
"bs": 1,
"gas": 16,
"seed": 69,
"save_every": 500,
"amp": False,
"loss_scale": False,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = load_gpt_j().cuda().bfloat16()
for param in model.parameters():
param.requires_grad = False
for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True
hypernetwork = HyperNetwork(model_config).cuda().float()
for param in hypernetwork.parameters():
param.requires_grad = True
opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
train_dataset = utils.FbDataset(2049, train_config["data_path"])
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0)
wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model_config})
t = tqdm(train_loader)
curr_step = 0
scaler = torch.cuda.amp.GradScaler()
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.cuda()
labels = labels.cuda()
loss = 0
for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=False)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :].contiguous()
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
if train_config["loss_scale"]:
scaler.scale(gas_loss).backward()
else:
gas_loss.backward()
loss += gas_loss.item()
loss = loss / gas
if train_config["loss_scale"]:
scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 1)
if train_config["loss_scale"]:
opt.step(scaler=scaler)
else:
opt.step()
if train_config["loss_scale"]:
scaler.update()
#opt.step()
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) / (bs*gas)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = step_per_sec * 1024
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr, "train/loss_scale": scaler.get_scale()})
curr_step += 1
if curr_step % train_config["save_every"] == 0:
#model.save(train_config["save_path"] + f"/{curr_step}")
print(f"Saved model at step {curr_step}")
......@@ -14,25 +14,6 @@ import time
import wandb
from lm_arch.gpt2 import GPT2Model
import numpy as np
from transformers import AutoTokenizer
class HyperNetwork(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
self.linear.weight.data.normal_(mean=0.0, std=0.02)
for param in self.linear.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
#state = self.state_dict()
#for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
def forward(self, hidden_states):
hidden_states = self.linear(hidden_states.float())
hidden_states = hidden_states.mul(torch.sigmoid(hidden_states))
return hidden_states.bfloat16()
model_config = {
......@@ -45,16 +26,6 @@ model_config = {
"Layer": GPTLayer
}
model_config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gelu_new,
"Layer": GPTLayer
}
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
......@@ -62,15 +33,15 @@ train_config = {
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/fixedj",
"run_name": "gpt-j-owt2-6b-preattn",
"lr": 5e-4,
"end_lr": 5e-4,
"lr": 1e-4,
"end_lr": 1e-4,
"warmup_steps": 50,
"bs": 1,
"gas": 16,
"bs": 12,
"gas": 10,
"seed": 69,
"save_every": 500,
"amp": False,
"loss_scale": False,
"amp": True,
"loss_scale": True,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
......@@ -78,20 +49,8 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = load_gpt_j().cuda().bfloat16()
for param in model.parameters():
param.requires_grad = False
for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True
hypernetwork = HyperNetwork(model_config).cuda().float()
for param in hypernetwork.parameters():
param.requires_grad = True
opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, "adamw")
model = GPTModel.gpt2_init(model_config).cuda().float()
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
......@@ -111,8 +70,7 @@ for input_ids, labels in t:
loss = 0
for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=False)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=None, act_ck=False)
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :].contiguous()
gas_labels = gas_labels.view(-1)
......@@ -128,14 +86,16 @@ for input_ids, labels in t:
loss = loss / gas
if train_config["loss_scale"]:
scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(hypernetwork.parameters(), 1)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
if train_config["loss_scale"]:
opt.step(scaler=scaler)
else:
opt.step()
if train_config["loss_scale"]:
scaler.update()
#opt.step()
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) / (bs*gas)
step_per_sec = (1. / sec_per_step)
......@@ -144,5 +104,5 @@ for input_ids, labels in t:
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr, "train/loss_scale": scaler.get_scale()})
curr_step += 1
if curr_step % train_config["save_every"] == 0:
#model.save(train_config["save_path"] + f"/{curr_step}")
model.save(train_config["save_path"] + f"/{curr_step}")
print(f"Saved model at step {curr_step}")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment