Commit eebb1fa8 authored by Wes Brown's avatar Wes Brown

Add `--shuffle` argument, default to `False`.

parent ea32d948
......@@ -284,8 +284,9 @@ parser.add_argument("--logs", type=str, help="log directory location",
default="./logs")
parser.add_argument("--masked", type=bool, help="masked softmax fusion")
parser.add_argument("--sample_vanilla", type=bool, help="sample vanilla model")
parser.add_argument("--shuffle", type=bool, help="shuffle dataset contexts")
parser.set_defaults(loss_scale=False, amp=False, no_resume=False, masked=False,
sample_vanilla=False)
sample_vanilla=False, shuffle=False)
args = parser.parse_args()
if args.output == '':
args.output = f'./{args.run_name}'
......@@ -310,6 +311,7 @@ train_config = {
"eval_every": args.eval_every,
"context_size": args.context_size,
"sample_vanilla": args.sample_vanilla,
"shuffle": args.shuffle,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
......@@ -355,7 +357,7 @@ if last_cp:
train_loader = torch_data.DataLoader(train_dataset,
batch_size=bs * gas,
shuffle=True,
shuffle=train_config["shuffle"],
num_workers=0)
wandb.init(project="hypernetwork-tests",
name=train_config["run_name"],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment