Commit 94c0ad6f authored by Wes Brown's avatar Wes Brown

Set default to `2049` and correctly calculate the tokens per second.

parent 53bcc538
......@@ -233,7 +233,7 @@ parser.add_argument('--output_path', type=str, help="Root path of all output",
parser.add_argument('--no_resume', type=bool, default=False,
help="Do not resume from last checkpoint")
parser.add_argument("--context_size", type=int, help="Dataset context sizes",
default=2048)
default=2049)
parser.add_argument("--project_id", type=str, help="Project ID for reporting",
default="hypernetwork-training")
parser.add_argument("--logs", type=str, help="log directory location",
......@@ -336,7 +336,6 @@ for input_ids, labels in t:
logits, _ = model(input_ids[x * bs:(x + 1) * bs, :].to(gpu),
hypernetwork=hypernetwork,
act_ck=True)
# print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x * bs:(x + 1) * bs, :].contiguous()
gas_labels = gas_labels.view(-1)
......@@ -364,7 +363,7 @@ for input_ids, labels in t:
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas
tokens_per_sec = (step_per_sec * train_config["context_size"]) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step,"
+ f"{tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment