Commit 53bcc538 authored by Wes Brown's avatar Wes Brown

Fix shuffling and use provided context size on CLI.

parent 6c1a2d67
......@@ -263,6 +263,7 @@ train_config = {
"amp": args.amp,
"loss_scale": args.loss_scale,
"eval_every": args.eval_every,
"context_size": args.context_size,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
......@@ -303,13 +304,14 @@ else:
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden
# states from the get_logits function.
print(opt.curr_step)
train_dataset = dataset.ShardedDataset(2049, train_config["data_path"])
train_dataset = dataset.ShardedDataset(train_config["context_size"],
train_config["data_path"])
if last_cp:
train_dataset.skip = opt.curr_step * bs * gas
train_loader = data.DataLoader(train_dataset,
batch_size=bs * gas,
shuffle=False,
shuffle=True,
num_workers=0)
wandb.init(project="hypernetwork-tests",
name=train_config["run_name"],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment