Commit db8c9ae8 authored by novelailab's avatar novelailab

fp16 AMP maybe works

parent 347ef912
......@@ -25,6 +25,7 @@ env1.sh('pip install tqdm')
env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo')
env1.sh('pip3 install einops==0.4.1 pyyaml wandb')
env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4')
env1.sh('pip3 install dotmap')
with always_rerun():
print(f"Running {sys.argv[1]}")
path.sh(f'python3 {sys.argv[1]}')
\ No newline at end of file
......@@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from lm_train import optimizer, utils
from torch.utils import data
from main import *
......@@ -26,18 +27,21 @@ model_config = {
train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2train",
"run_name": "owt2-125m",
"run_name": "owt2-125m-fp32",
"lr": 6e-4,
"end_lr": 6e-4,
"warmup_steps": 50,
"bs": 16,
"gas": 16,
"bs": 8,
"gas": 32,
"seed": 69,
"save_every": 50,
}
bs = train_config["bs"]
gas = train_config["gas"]
model = GPTModel.neox_init(model_config).cuda().bfloat16()
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
model = GPTModel.neox_init(model_config).cuda().float()
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
......@@ -48,6 +52,9 @@ wandb.init(project="basedformer-tests", name=train_config["run_name"], config={*
t = tqdm(train_loader)
curr_step = 0
scaler = torch.cuda.amp.GradScaler()
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.cuda()
......@@ -59,11 +66,15 @@ for input_ids, labels in t:
gas_labels = labels[x*bs:(x+1)*bs, :]
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
gas_loss.backward()
scaler.scale(gas_loss).backward()
loss += gas_loss.item()
loss = loss / gas
opt.step()
scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
scaler.step(opt.optimizer)
scaler.update()
#opt.step()
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) / (bs*gas)
step_per_sec = (1. / sec_per_step)
......@@ -72,5 +83,5 @@ for input_ids, labels in t:
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr})
curr_step += 1
if curr_step % train_config["save_every"] == 0:
model.save(train_config["save_path"])
model.save(train_config["save_path"] + f"/{curr_step}")
print(f"Saved model at step {curr_step}")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment