Commit b07251f0 authored by novelailab's avatar novelailab

add todos and some stuff

parent d9df990b
......@@ -24,9 +24,12 @@ model_config = {
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/owt2train",
"run_name": "owt2-125m",
"lr": 6e-4,
"end_lr": 6e-4,
"warmup_steps": 20,
"warmup_steps": 50,
"bs": 16,
"gas": 16,
"seed": 69,
......@@ -35,9 +38,12 @@ bs = train_config["bs"]
gas = train_config["gas"]
model = GPTModel.neox_init(model_config).cuda().bfloat16()
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
train_dataset = utils.FbDataset(2049, "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_2049.map")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel.
train_dataset = utils.FbDataset(2049, train_config["data_path"])
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0)
wandb.init(project="basedformer-tests", name="sigurd_v5_2049")
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model_config})
t = tqdm(train_loader)
for input_ids, labels in t:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment