Commit 2db358ff authored by Wes Brown's avatar Wes Brown

Prompt updates, sampling configuration, and report in `tokens` space.

parent bd34e4d4
...@@ -18,18 +18,18 @@ if gpu != "cuda": ...@@ -18,18 +18,18 @@ if gpu != "cuda":
scaler = torch.cuda.amp.GradScaler() scaler = torch.cuda.amp.GradScaler()
prompts = ["<|endoftext|>", prompts = ["<|endoftext|>",
"The year was", " The year was",
"I grabbed my", " I grabbed my",
"She lifted the", " He was known as the",
"He was known as the", " The tavern was full again, so I ended up sharing a table with three very different creatures: a",
"The tavern was full again, so I ended up sharing a table with three very different creatures: a", " She spread her",
"I had been hiking in the wilderness when suddenly a", " The mercurial and beautiful",
"She spread her",
"The mercurial and beautiful",
"<|endoftext|>[ Author:", "<|endoftext|>[ Author:",
"<|endoftext|>[ Genre:",
"***", "***",
"----"] "----",
"> You look around.\n",
"John:",
"Jane:"]
def _init_weights(module): def _init_weights(module):
...@@ -215,13 +215,19 @@ def make_eval_function(hypernetwork: HyperNetworkSingle, config: dict) -> \ ...@@ -215,13 +215,19 @@ def make_eval_function(hypernetwork: HyperNetworkSingle, config: dict) -> \
sample_data = {'rows': []} sample_data = {'rows': []}
gen_vanilla = config.get('generate_vanilla', False) gen_vanilla = config.get('generate_vanilla', False)
run_name = config.get('run_name', '') run_name = config.get('run_name', '')
tokens_step = config.get('context_size', 2049) * \
config.get('bs', 1) * \
config.get('gas', 1)
num_samples = config.get('num_samples', 3)
num_tokens = config.get('num_tokens', 500)
def eval_function(curr_step: int) -> None: def eval_function(curr_step: int) -> None:
curr_tokens_step = tokens_step * (curr_step + 1)
print() print()
print_colored_bars('yellow') print_colored_bars('yellow')
print(f"Step: {curr_step}") print(f"Step: {curr_step} @ {curr_tokens_step} tokens processed")
for prompt in prompts: for prompt in prompts:
sampled = sample(prompt, 500, 3, sampled = sample(prompt, num_tokens, num_samples,
run_name=run_name, run_name=run_name,
hypernetwork=hypernetwork, hypernetwork=hypernetwork,
step=curr_step, step=curr_step,
...@@ -285,6 +291,10 @@ parser.add_argument("--logs", type=str, help="log directory location", ...@@ -285,6 +291,10 @@ parser.add_argument("--logs", type=str, help="log directory location",
default="./logs") default="./logs")
parser.add_argument("--masked", type=bool, help="masked softmax fusion") parser.add_argument("--masked", type=bool, help="masked softmax fusion")
parser.add_argument("--sample_vanilla", type=bool, help="sample vanilla model") parser.add_argument("--sample_vanilla", type=bool, help="sample vanilla model")
parser.add_argument("--sample_tokens", type=int,
help="number of tokens to sample")
parser.add_argument("--sample_num", type=int,
help="number of samples per prompt")
parser.add_argument("--shuffle", type=bool, help="shuffle dataset contexts") parser.add_argument("--shuffle", type=bool, help="shuffle dataset contexts")
parser.add_argument("--epochs", type=int, help="number of epochs to train for") parser.add_argument("--epochs", type=int, help="number of epochs to train for")
parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False, parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False,
...@@ -313,6 +323,8 @@ train_config = { ...@@ -313,6 +323,8 @@ train_config = {
"eval_every": args.eval_every, "eval_every": args.eval_every,
"context_size": args.context_size, "context_size": args.context_size,
"sample_vanilla": args.sample_vanilla, "sample_vanilla": args.sample_vanilla,
"num_samples": args.sample_num,
"num_tokens": args.sample_tokens,
"shuffle": args.shuffle, "shuffle": args.shuffle,
"epochs": args.epochs, "epochs": args.epochs,
} }
...@@ -373,6 +385,9 @@ else: ...@@ -373,6 +385,9 @@ else:
epoch_steps = len(train_loader) epoch_steps = len(train_loader)
total_steps = epoch_steps * train_config['epochs'] total_steps = epoch_steps * train_config['epochs']
tokens_per_step = train_config['context_size'] * \
train_config['bs'] * \
train_config['gas']
with tqdm(total=total_steps, initial=curr_step) as t: with tqdm(total=total_steps, initial=curr_step) as t:
for epoch in range(train_config['epochs']): for epoch in range(train_config['epochs']):
...@@ -415,8 +430,8 @@ with tqdm(total=total_steps, initial=curr_step) as t: ...@@ -415,8 +430,8 @@ with tqdm(total=total_steps, initial=curr_step) as t:
opt.zero_grad() opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step) step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * train_config["context_size"]) * \ tokens_per_sec = step_per_sec * tokens_per_step
bs * gas curr_tokens = tokens_per_step * curr_step
t.set_description(f"{step_per_sec:.2f} steps/s, " t.set_description(f"{step_per_sec:.2f} steps/s, "
f"{sec_per_step:.2f}s/step, " f"{sec_per_step:.2f}s/step, "
f"{tokens_per_sec:.2f}tokens/s, " f"{tokens_per_sec:.2f}tokens/s, "
...@@ -433,6 +448,18 @@ with tqdm(total=total_steps, initial=curr_step) as t: ...@@ -433,6 +448,18 @@ with tqdm(total=total_steps, initial=curr_step) as t:
}, },
step=curr_step) step=curr_step)
wandb.log(
{
"train_tokens/epoch": float(curr_step) / float(epoch_steps),
"train_tokens/loss": loss,
"train_tokens/tokens_per_sec": tokens_per_sec,
"train_tokens/sec_per_step": sec_per_step,
"train_tokens/step_per_sec": step_per_sec,
"train_tokens/lr": opt.curr_lr,
"train_tokens/loss_scale": scaler.get_scale()
},
step=curr_tokens)
if train_config["do_save"] and \ if train_config["do_save"] and \
curr_step % train_config["save_every"] == 0 and \ curr_step % train_config["save_every"] == 0 and \
curr_step != 0: curr_step != 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment