Commit a28f0299 authored by FIRST_NAME LAST_NAME's avatar FIRST_NAME LAST_NAME

push

parent a2b7dffb
......@@ -17,6 +17,7 @@ class BaseModel(nn.Module):
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
self.total_params = sum(p.numel() for p in self.parameters())
for i in range(config.n_layer):
config.layer_idx = i
self.layers.append(
......
......@@ -6,6 +6,7 @@ from dotmap import DotMap
import pickle
import os
from pathlib import Path
from torch.distributed.optim import ZeroRedundancyOptimizer
#Based Optimizer
def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr, cosine_warmup=False):
......@@ -61,6 +62,17 @@ class BasedOptimizer:
import bitsandbytes as bnb
self.optimizer = bnb.optim.Adam8bit(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif self.optimizer_name == "zero1":
import bitsandbytes as bnb
self.optimizer = ZeroRedundancyOptimizer(
self.parameters,
optimizer_class=bnb.optim.Adam8bit,
lr=0,
weight_decay=self.weight_decay,
betas=(self.beta1, self.beta2),
eps=self.eps,
)
elif self.optimizer_name == "adafactor":
try:
from transformers.optimization import Adafactor
......
......@@ -30,6 +30,24 @@ class FbDataset(data.Dataset):
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
class ShardedDataset(data.Dataset):
def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
#might want to pad later
self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
#shard
self.npz = self.npz[rank::world_size]
self.samples = self.npz.shape[0]
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))
......
......@@ -14,104 +14,79 @@ import wandb
import numpy as np
import os
from icecream import ic
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import argparse
from torch.distributed.fsdp import (
FullyShardedDataParallel,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
default_auto_wrap_policy,
)
def setup(rank, world_size):
#os.environ['MASTER_ADDR'] = 'localhost'
#os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group(backend="nccl")
if dist.is_initialized():
print("Initialized process group")
else:
print("Failed to initialize process group")
def cleanup():
dist.destroy_process_group()
def get_rank():
if dist.is_initialized():
return dist.get_rank()
def get_world():
if dist.is_initialized():
return dist.get_world_size()
def get_flops(args, model, iter_time_s):
ff = model.total_params * 6
attn = 2048 * args.hidden_size * args.n_layers * 60
flops = (
args.bs * args.gas
* 2048
* (ff + attn)
/ (iter_time_s)
)
return flops
def fsdp_train(args, model, train_loader, opt):
bs = args["bs"]
gas = args["gas"]
rank = get_rank()
world_size = get_world()
model.train()
ddp_loss = torch.zeros(1).cuda()
if rank == 0:
t = tqdm(train_loader)
else:
t = train_loader
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/gptj-finetune",
"do_save": False,
"run_name": "gptj-finetune",
"lr": 2e-5,
"end_lr": 6e-5,
"warmup_steps": 100,
"anneal_steps": 10000,
"bs": 1,
"gas": 4,
"seed": 69,
"save_every": 500,
"amp": True,
"loss_scale": True,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = lm_utils.load_from_path("pretrained/gptj-6b").cuda().float()
model.train()
ic("model loaded")
cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(cp_list) > 0 else None
print(last_cp)
trainable_params = [
model.vocab_embed,
model.lm_head,
model.layers[10],
]
for param in model.parameters():
param.requires_grad = False
for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True
for module in trainable_params:
for p in module.parameters():
module.requires_grad = True
'''
if last_cp:
print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
model.load(model_config, last_cp / "lm", strict=True)
opt = optimizer.BasedOptimizer.load(model.parameters(), last_cp / "opt")
else:
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
'''
opt = optimizer.BasedOptimizer(model.layers[10].parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = utils.FbDataset(2049, train_config["data_path"])
if last_cp:
train_dataset.skip = opt.curr_step * bs * gas
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model.config})
if last_cp:
curr_step = opt.curr_step
else:
curr_step = 0
t = tqdm(train_loader, initial=curr_step)
scaler = torch.cuda.amp.GradScaler()
for input_ids, labels in t:
scaler = torch.cuda.amp.GradScaler()
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.cuda()
labels = labels.cuda()
input_ids = input_ids.to(rank)
labels = labels.to(rank)
loss = 0
for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :2048].cuda(), act_ck=True)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
#roll down the sequence
for x in range(args["gas"]):
with torch.cuda.amp.autocast(enabled=args["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :2048].to(rank), act_ck=True)
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :2048].contiguous()
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
if train_config["loss_scale"]:
if args["loss_scale"]:
scaler.scale(gas_loss).backward()
else:
gas_loss.backward()
......@@ -119,39 +94,85 @@ for input_ids, labels in t:
loss += gas_loss.item()
loss = loss / gas
if train_config["loss_scale"]:
if args["loss_scale"]:
scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
if train_config["loss_scale"]:
if args["loss_scale"]:
opt.step(scaler=scaler)
else:
opt.step()
if train_config["loss_scale"]:
if args["loss_scale"]:
scaler.update()
opt.zero_grad()
#opt.zero_grad()
model.zero_grad(set_to_none=True)
sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log(
{
"train/loss": loss,
tokens_per_sec = (step_per_sec * 2048) * bs * gas * world_size
batch_size = bs * gas * world_size
ddp_loss[0] = loss
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
wandb.log({
"train_loss": ddp_loss[0] / world_size,
"train/tokens_per_sec": tokens_per_sec,
"train/sec_per_step": sec_per_step,
"train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr,
"train/batch_size": batch_size,
"train/loss_scale": scaler.get_scale()
},
step=curr_step)
if train_config["do_save"]:
if curr_step % train_config["save_every"] == 0:
save_folder = Path(train_config["save_path"]) / f"step_{curr_step}"
save_folder.mkdir(parents=True, exist_ok=True)
model.save(save_folder / "lm")
opt.save(save_folder / "opt")
print(f"Saved model at step {curr_step}")
curr_step += 1
\ No newline at end of file
})
# we need 250 batch size to train the small GPT.
def main(rank, world_size, args):
bs = args["bs"]
gas = args["gas"]
torch.manual_seed(train_config["seed"])
setup(rank, world_size)
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
model = lm_utils.load_from_path("pretrained/gpt-j-base").float().to(rank)
#fsdp_model = FullyShardedDataParallel(
#model,
#fsdp_auto_wrap_policy=default_auto_wrap_policy,
#cpu_offload=CPUOffload(offload_params=True),
#)
fsdp_model = DDP(model, device_ids=[rank], gradient_as_bucket_view=True)
utils.print_parameters(fsdp_model)
ic("model loaded")
opt = optimizer.BasedOptimizer(fsdp_model.parameters(), train_config, "zero1")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = utils.ShardedDataset(2049, train_config["data_path"], world_size=world_size, rank=rank)
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
if rank == 0:
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model.config})
fsdp_train(args, fsdp_model, train_loader, opt)
dist.barrier()
cleanup()
if __name__ == "__main__":
train_config = {
"data_path": "dataset/sigurd-1G.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/gptj-sigurd-1G-vanilla",
"do_save": False,
"run_name": "gptj-sigurd-1G-vanilla",
"lr": 6e-5,
"end_lr": 2e-5,
"warmup_steps": 100,
"anneal_steps": 10000,
"bs": 2,
"gas": 1,
"seed": 69,
"save_every": 500,
"amp": True,
"loss_scale": True,
}
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(rank)
main(rank, world_size, train_config)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment