Commit 7e2ee30f authored by novelailab's avatar novelailab

add fusedad, fix bug in hypertrain, add setup pkgs

parent 94ad5ad7
......@@ -62,6 +62,10 @@ class BasedOptimizer:
import bitsandbytes as bnb
self.optimizer = bnb.optim.Adam8bit(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif self.optimizer_name == "adamwfused":
import apex
self.optimizer = apex.optimizers.FusedAdam(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif self.optimizer_name == "zero1":
import bitsandbytes as bnb
self.optimizer = ZeroRedundancyOptimizer(
......
......@@ -190,11 +190,12 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None):
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-fairseq-6b-2048-enwik9-again",
"data_path": "/home/xuser/nvme1/dataset/enwik9-gpt2-2049.map",
"save_path": "/home/xuser/models/enwik9-sigurdv4-hypernet2",
"lm_path": "/home/xuser/nvme1/pretrained/sigurdv4",
"optimizer": "adamwfused",
"do_save": True,
"run_name": "fairseq-6b-enwik9-6b-postln-bf16-2e-4-4bsz-every5layer",
"run_name": "gptj-6b-enwik9-6b-postln-bf16-2e-4-4bsz-every5layer",
"lr": 2e-4,
"end_lr": 2e-4,
"warmup_steps": 50,
......@@ -237,7 +238,7 @@ if last_cp:
opt = optimizer.BasedOptimizer.load(hypernetwork.parameters(), last_cp / "opt")
else:
opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, "adamw")
opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, train_config["optimizer"])
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
......@@ -265,7 +266,7 @@ for input_ids, labels in t:
loss = 0
for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=True)
logits, _ = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=True)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :].contiguous()
......
......@@ -10,5 +10,7 @@ setuptools.setup(
python_requires='>=3.7',
package_data={'basedformer': ['*.json']},
install_requires=['dotmap',
'numpy']
'numpy',
'wandb',
'transformers']
)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment