Commit fb25b47c authored by novelailab's avatar novelailab

set seed, everything works

parent 4f87dce5
......@@ -13,6 +13,7 @@ from tqdm import tqdm
import time
import wandb
from lm_arch.gpt2 import GPT2Model
import numpy as np
model_config = {
"n_layer": 12,
......@@ -38,12 +39,13 @@ train_config = {
"save_every": 500,
"amp": True,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
model = GPT2Model.gpt2_init(model_config).cuda().float()
model = GPTModel.gpt2_init(model_config).cuda().float()
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
......@@ -65,6 +67,8 @@ for input_ids, labels in t:
for x in range(train_config["gas"]):
if train_config["amp"]:
with torch.cuda.amp.autocast():
#with torch.jit.fuser("fuser2"):
# module = torch.jit.trace(model, torch.randint(0, 50256, (12, 1024)).long().cuda())
logits = model(input_ids[x*bs:(x+1)*bs, :1024].cuda(), hypernetwork=None, act_ck=False)
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :1024].contiguous()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment