Commit 8073ccfc authored by Wes Brown's avatar Wes Brown

Add argument handling, a closured `hypernetwork_saver`, and save the final result.

parent 704947b4
......@@ -8,6 +8,7 @@ from basedformer.utils import *
from transformers import AutoTokenizer
from basedformer import sampling
from termcolor import colored
import argparse
gpu = "cuda"
amp = torch.cuda.amp
......@@ -15,8 +16,18 @@ if gpu != "cuda":
amp = torch.amp
scaler = torch.cuda.amp.GradScaler()
prompts = ["<|endoftext|>"]
prompts = ["<|endoftext|>",
"The year was",
"I grabbed my",
"She lifted the",
"He was known as the",
"The tavern was full again, so I ended up sharing a table with three very different creatures: a",
"I had been hiking in the wilderness when suddenly a",
"She spread her",
"The mercurial and beautiful woman laughed",
"[ Author:",
"[ Tags:",
"***"]
def _init_weights(module):
if isinstance(module, nn.Linear):
......@@ -158,39 +169,91 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0):
return data
def report_wandb(data):
columns = ["Step", "Prompt", "Generated Text", "Vanilla Model"]
wandb.log({"Generations": wandb.Table(data=data, columns=columns)})
def report_console(data):
for gen in data[3]:
for gen in data[2]:
print(colored("======================================================",
"red"))
print(colored(gen, "green"))
print(colored("======================================================",
"red"))
def make_hypernet_saver(train_config, hypernetwork):
def hypernet_saver(id: str):
save_folder = Path(train_config["save_path"]) / id
save_folder.mkdir(parents=True, exist_ok=True)
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt")
return hypernet_saver
parser = argparse.ArgumentParser(description='Hypernetwork Finetuner')
parser.add_argument('--run_name', type=str, help='the run name to use',
required=True)
parser.add_argument('--model', type=str, help='the model to train against',
required=True)
parser.add_argument('--dataset', type=str, help='pre-tokenized dataset to use',
required=True)
parser.add_argument("--output", type=str, help='output path',
default='')
parser.add_argument('--optimizer', type=str, help='the optimizer to use',
default='adamw')
parser.add_argument('--lr', type=float, help='learning rate', default=2e-4)
parser.add_argument('--end_lr', type=float, help='end learning rate',
default=2e-4)
parser.add_argument('--warmup', type=int, help='warmup steps')
parser.add_argument('--bs', type=int, help='batch size', default=4)
parser.add_argument('--gas', type=int, help='gas', default=1)
parser.add_argument('--seed', type=int, help="Random seed value",
default=42)
parser.add_argument("--save_steps", type=int,
help='# of steps between checkpoint saves',
default=300)
parser.add_argument("--amp", type=bool, help='enable amp', default=False)
parser.add_argument('--loss_scale', type=bool, help='whether to scale loss',
default=False)
parser.add_argument("--eval_every", type=int,
help='evaluate hypernetwork every x steps',
default=100)
parser.add_argument('--output_path', type=str, help="Root path of all output",
default="./")
parser.add_argument('--no_resume', type=bool, default=False,
help="Do not resume from last checkpoint")
parser.add_argument("--context_size", type=int, help="Dataset context sizes",
default=2048)
parser.add_argument("--project_id", type=str, help="Project ID for reporting",
default="hypernetwork-training")
parser.add_argument("--logs", type=str, help="log directory location",
default="./logs")
parser.add_argument("--masked", type=bool, help="masked softmax fusion")
parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False)
args = parser.parse_args()
if args.output == '':
args.output = f'./{args.run_name}'
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "dataset/cassandra.map",
"save_path": "models/sigurdv4-cassandra-hypernet2",
"lm_path": "pretrained/sigurdv4",
"optimizer": "adamw",
"masked_softmax_fusion": False,
"do_save": True,
"run_name": "sigurdv4-cassandra-6b-postln-bf16-2e-4-4bsz-every5layer",
"lr": 2e-4,
"end_lr": 2e-4,
"warmup_steps": 50,
"bs": 4,
"gas": 1,
"seed": 69,
"save_every": 300,
"amp": False,
"loss_scale": False,
"eval_every": 100,
"data_path": args.dataset,
"save_path": args.model,
"lm_path": args.model,
"optimizer": args.optimizer,
"masked_softmax_fusion": args.masked,
"do_save": args.save_steps != 0,
"run_name": args.run_name,
"lr": args.lr,
"end_lr": args.end_lr,
"warmup_steps": args.warmup,
"bs": args.bs,
"gas": args.gas,
"seed": args.seed,
"save_every": args.save_steps0,
"amp": args.amp,
"loss_scale": args.loss_scale,
"eval_every": args.eval_every,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
......@@ -209,6 +272,7 @@ for name, p in model.named_parameters():
hypernetwork = HyperNetworkSingle(model.config).to(gpu).float()
for param in hypernetwork.parameters():
param.requires_grad = True
hypernetwork_saver = make_hypernet_saver(train_config, hypernetwork)
cp_list = sorted(os.listdir(train_config["save_path"]),
key=lambda x: int(x.split("_")[-1]))
......@@ -216,7 +280,7 @@ last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(
cp_list) > 0 else None
print(last_cp)
if last_cp:
if last_cp and not args.no_resume:
print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(),
......@@ -303,12 +367,10 @@ for input_ids, labels in t:
},
step=curr_step)
if train_config["do_save"] and curr_step % train_config[
"save_every"] == 0 and curr_step != 0:
save_folder = Path(train_config["save_path"]) / f"step_{curr_step}"
save_folder.mkdir(parents=True, exist_ok=True)
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt")
if train_config["do_save"] and \
curr_step % train_config["save_every"] == 0 and \
curr_step != 0:
hypernetwork_saver(f"step_{curr_step}")
print(f"\nSaved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0 and curr_step != 0:
......@@ -320,3 +382,5 @@ for input_ids, labels in t:
report_wandb(sample_data)
curr_step += 1
hypernetwork_saver("final")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment