Commit 9240d7ea authored by Wes Brown's avatar Wes Brown

Fix wandb reporting, do first zero-eval *before* evaluating the first step.

parent 60893ad7
......@@ -95,7 +95,7 @@ spec:
- name: hypertrainer_image
value: 'docker.io/gooseai/basedformer'
- name: hypertrainer_tag
value: '3b75904'
value: '60893ad'
templates:
- name: main
......
......@@ -379,6 +379,7 @@ train_loader = torch_data.DataLoader(train_dataset,
wandb.init(project=train_config["project_id"],
name=train_config["run_name"],
config={**train_config, **model.config})
print("wandb initialized")
if last_cp:
curr_step = opt.curr_step
......@@ -391,6 +392,8 @@ tokens_per_step = train_config['context_size'] * \
train_config['bs'] * \
train_config['gas']
eval_fn(curr_step)
with tqdm(total=total_steps, initial=curr_step) as t:
for epoch in range(train_config['epochs']):
for input_ids, labels in train_loader:
......@@ -433,11 +436,12 @@ with tqdm(total=total_steps, initial=curr_step) as t:
sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = step_per_sec * tokens_per_step
curr_tokens = tokens_per_step * curr_step
curr_tokens = tokens_per_step * (curr_step + 1)
t.set_description(f"{step_per_sec:.2f} steps/s, "
f"{sec_per_step:.2f}s/step, "
f"{tokens_per_sec:.2f}tokens/s, "
f"loss={loss:.4f}")
f"loss={loss:.4f}, "
f"{curr_tokens} tokens processed")
wandb.log(
{
"train/epoch": float(curr_step) / float(epoch_steps),
......@@ -446,29 +450,18 @@ with tqdm(total=total_steps, initial=curr_step) as t:
"train/sec_per_step": sec_per_step,
"train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr,
"train/loss_scale": scaler.get_scale()
"train/loss_scale": scaler.get_scale(),
"train/tokens": curr_tokens,
},
step=curr_step)
wandb.log(
{
"train_tokens/epoch": float(curr_step) / float(epoch_steps),
"train_tokens/loss": loss,
"train_tokens/tokens_per_sec": tokens_per_sec,
"train_tokens/sec_per_step": sec_per_step,
"train_tokens/step_per_sec": step_per_sec,
"train_tokens/lr": opt.curr_lr,
"train_tokens/loss_scale": scaler.get_scale()
},
step=curr_tokens)
if train_config["do_save"] and \
curr_step % train_config["save_every"] == 0 and \
curr_step != 0:
hypernetwork_saver(f"step_{curr_step}")
print(f"\nSaved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0:
if curr_step % train_config["eval_every"] == 0 and curr_step != 0:
eval_fn(curr_step)
curr_step += 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment