Commit db8c9ae8 authored by novelailab's avatar novelailab

fp16 AMP maybe works

parent 347ef912
...@@ -25,6 +25,7 @@ env1.sh('pip install tqdm') ...@@ -25,6 +25,7 @@ env1.sh('pip install tqdm')
env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo') env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo')
env1.sh('pip3 install einops==0.4.1 pyyaml wandb') env1.sh('pip3 install einops==0.4.1 pyyaml wandb')
env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4') env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4')
env1.sh('pip3 install dotmap')
with always_rerun(): with always_rerun():
print(f"Running {sys.argv[1]}") print(f"Running {sys.argv[1]}")
path.sh(f'python3 {sys.argv[1]}') path.sh(f'python3 {sys.argv[1]}')
\ No newline at end of file
...@@ -3,6 +3,7 @@ import torch.nn as nn ...@@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.cuda.amp as amp import torch.cuda.amp as amp
import torch.optim as optim import torch.optim as optim
from pathlib import Path
from lm_train import optimizer, utils from lm_train import optimizer, utils
from torch.utils import data from torch.utils import data
from main import * from main import *
...@@ -26,18 +27,21 @@ model_config = { ...@@ -26,18 +27,21 @@ model_config = {
train_config = { train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map", "data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2train", "save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2train",
"run_name": "owt2-125m", "run_name": "owt2-125m-fp32",
"lr": 6e-4, "lr": 6e-4,
"end_lr": 6e-4, "end_lr": 6e-4,
"warmup_steps": 50, "warmup_steps": 50,
"bs": 16, "bs": 8,
"gas": 16, "gas": 32,
"seed": 69, "seed": 69,
"save_every": 50, "save_every": 50,
} }
bs = train_config["bs"] bs = train_config["bs"]
gas = train_config["gas"] gas = train_config["gas"]
model = GPTModel.neox_init(model_config).cuda().bfloat16()
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
model = GPTModel.neox_init(model_config).cuda().float()
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw") opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function. # TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
...@@ -48,6 +52,9 @@ wandb.init(project="basedformer-tests", name=train_config["run_name"], config={* ...@@ -48,6 +52,9 @@ wandb.init(project="basedformer-tests", name=train_config["run_name"], config={*
t = tqdm(train_loader) t = tqdm(train_loader)
curr_step = 0 curr_step = 0
scaler = torch.cuda.amp.GradScaler()
for input_ids, labels in t: for input_ids, labels in t:
timex = time.perf_counter() timex = time.perf_counter()
input_ids = input_ids.cuda() input_ids = input_ids.cuda()
...@@ -59,11 +66,15 @@ for input_ids, labels in t: ...@@ -59,11 +66,15 @@ for input_ids, labels in t:
gas_labels = labels[x*bs:(x+1)*bs, :] gas_labels = labels[x*bs:(x+1)*bs, :]
gas_labels = gas_labels.view(-1) gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels) gas_loss = F.cross_entropy(logits, gas_labels)
gas_loss.backward() scaler.scale(gas_loss).backward()
loss += gas_loss.item() loss += gas_loss.item()
loss = loss / gas loss = loss / gas
opt.step() scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
scaler.step(opt.optimizer)
scaler.update()
#opt.step()
opt.zero_grad() opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) / (bs*gas) sec_per_step = (time.perf_counter() - timex) / (bs*gas)
step_per_sec = (1. / sec_per_step) step_per_sec = (1. / sec_per_step)
...@@ -72,5 +83,5 @@ for input_ids, labels in t: ...@@ -72,5 +83,5 @@ for input_ids, labels in t:
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr}) wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr})
curr_step += 1 curr_step += 1
if curr_step % train_config["save_every"] == 0: if curr_step % train_config["save_every"] == 0:
model.save(train_config["save_path"]) model.save(train_config["save_path"] + f"/{curr_step}")
print(f"Saved model at step {curr_step}") print(f"Saved model at step {curr_step}")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment