Commit 0bde514b authored by Wes Brown's avatar Wes Brown

Add defaults, project id, logs.

parent 2db358ff
......@@ -291,9 +291,9 @@ parser.add_argument("--logs", type=str, help="log directory location",
default="./logs")
parser.add_argument("--masked", type=bool, help="masked softmax fusion")
parser.add_argument("--sample_vanilla", type=bool, help="sample vanilla model")
parser.add_argument("--sample_tokens", type=int,
parser.add_argument("--sample_tokens", type=int, default=500,
help="number of tokens to sample")
parser.add_argument("--sample_num", type=int,
parser.add_argument("--sample_num", type=int, default=3,
help="number of samples per prompt")
parser.add_argument("--shuffle", type=bool, help="shuffle dataset contexts")
parser.add_argument("--epochs", type=int, help="number of epochs to train for")
......@@ -304,6 +304,7 @@ if args.output == '':
args.output = f'./{args.run_name}'
# we need 250 batch size to train the small GPT.
train_config = {
"project_id": args.project_id,
"data_path": args.dataset,
"save_path": args.output,
"lm_path": args.model,
......@@ -327,6 +328,7 @@ train_config = {
"num_tokens": args.sample_tokens,
"shuffle": args.shuffle,
"epochs": args.epochs,
"logs": args.logs,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
......@@ -374,7 +376,7 @@ train_loader = torch_data.DataLoader(train_dataset,
batch_size=bs * gas,
shuffle=train_config["shuffle"],
num_workers=0)
wandb.init(project="hypernetwork-tests",
wandb.init(project=train_config["project-id"],
name=train_config["run_name"],
config={**train_config, **model.config})
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment