Commit a28f0299 authored by FIRST_NAME LAST_NAME's avatar FIRST_NAME LAST_NAME

push

parent a2b7dffb
...@@ -17,6 +17,7 @@ class BaseModel(nn.Module): ...@@ -17,6 +17,7 @@ class BaseModel(nn.Module):
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype) self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True) self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
self.total_params = sum(p.numel() for p in self.parameters())
for i in range(config.n_layer): for i in range(config.n_layer):
config.layer_idx = i config.layer_idx = i
self.layers.append( self.layers.append(
......
...@@ -6,6 +6,7 @@ from dotmap import DotMap ...@@ -6,6 +6,7 @@ from dotmap import DotMap
import pickle import pickle
import os import os
from pathlib import Path from pathlib import Path
from torch.distributed.optim import ZeroRedundancyOptimizer
#Based Optimizer #Based Optimizer
def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr, cosine_warmup=False): def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr, cosine_warmup=False):
...@@ -61,6 +62,17 @@ class BasedOptimizer: ...@@ -61,6 +62,17 @@ class BasedOptimizer:
import bitsandbytes as bnb import bitsandbytes as bnb
self.optimizer = bnb.optim.Adam8bit(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps) self.optimizer = bnb.optim.Adam8bit(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif self.optimizer_name == "zero1":
import bitsandbytes as bnb
self.optimizer = ZeroRedundancyOptimizer(
self.parameters,
optimizer_class=bnb.optim.Adam8bit,
lr=0,
weight_decay=self.weight_decay,
betas=(self.beta1, self.beta2),
eps=self.eps,
)
elif self.optimizer_name == "adafactor": elif self.optimizer_name == "adafactor":
try: try:
from transformers.optimization import Adafactor from transformers.optimization import Adafactor
......
...@@ -30,6 +30,24 @@ class FbDataset(data.Dataset): ...@@ -30,6 +30,24 @@ class FbDataset(data.Dataset):
data = torch.tensor(self.npz[nth].astype(np.int64)) data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:]) return (data[:-1], data[1:])
class ShardedDataset(data.Dataset):
def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
#might want to pad later
self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
#shard
self.npz = self.npz[rank::world_size]
self.samples = self.npz.shape[0]
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
# Make loading models faster by not letting pytorch initialize the weights. # Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...)) # Usage: no_init(lambda: load_model(...))
......
...@@ -14,104 +14,79 @@ import wandb ...@@ -14,104 +14,79 @@ import wandb
import numpy as np import numpy as np
import os import os
from icecream import ic from icecream import ic
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import argparse
from torch.distributed.fsdp import (
FullyShardedDataParallel,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
default_auto_wrap_policy,
)
def setup(rank, world_size):
#os.environ['MASTER_ADDR'] = 'localhost'
#os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group(backend="nccl")
if dist.is_initialized():
print("Initialized process group")
else:
print("Failed to initialize process group")
def cleanup():
dist.destroy_process_group()
def get_rank():
if dist.is_initialized():
return dist.get_rank()
def get_world():
if dist.is_initialized():
return dist.get_world_size()
def get_flops(args, model, iter_time_s):
ff = model.total_params * 6
attn = 2048 * args.hidden_size * args.n_layers * 60
flops = (
args.bs * args.gas
* 2048
* (ff + attn)
/ (iter_time_s)
)
return flops
def fsdp_train(args, model, train_loader, opt):
bs = args["bs"]
gas = args["gas"]
rank = get_rank()
world_size = get_world()
model.train()
ddp_loss = torch.zeros(1).cuda()
if rank == 0:
t = tqdm(train_loader)
else:
t = train_loader
# we need 250 batch size to train the small GPT. scaler = torch.cuda.amp.GradScaler()
train_config = { for input_ids, labels in t:
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/gptj-finetune",
"do_save": False,
"run_name": "gptj-finetune",
"lr": 2e-5,
"end_lr": 6e-5,
"warmup_steps": 100,
"anneal_steps": 10000,
"bs": 1,
"gas": 4,
"seed": 69,
"save_every": 500,
"amp": True,
"loss_scale": True,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = lm_utils.load_from_path("pretrained/gptj-6b").cuda().float()
model.train()
ic("model loaded")
cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(cp_list) > 0 else None
print(last_cp)
trainable_params = [
model.vocab_embed,
model.lm_head,
model.layers[10],
]
for param in model.parameters():
param.requires_grad = False
for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True
for module in trainable_params:
for p in module.parameters():
module.requires_grad = True
'''
if last_cp:
print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
model.load(model_config, last_cp / "lm", strict=True)
opt = optimizer.BasedOptimizer.load(model.parameters(), last_cp / "opt")
else:
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
'''
opt = optimizer.BasedOptimizer(model.layers[10].parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = utils.FbDataset(2049, train_config["data_path"])
if last_cp:
train_dataset.skip = opt.curr_step * bs * gas
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model.config})
if last_cp:
curr_step = opt.curr_step
else:
curr_step = 0
t = tqdm(train_loader, initial=curr_step)
scaler = torch.cuda.amp.GradScaler()
for input_ids, labels in t:
timex = time.perf_counter() timex = time.perf_counter()
input_ids = input_ids.cuda() input_ids = input_ids.to(rank)
labels = labels.cuda() labels = labels.to(rank)
loss = 0 loss = 0
for x in range(train_config["gas"]): for x in range(args["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16): with torch.cuda.amp.autocast(enabled=args["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :2048].cuda(), act_ck=True) logits = model(input_ids[x*bs:(x+1)*bs, :2048].to(rank), act_ck=True)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
#roll down the sequence
logits = logits.view(-1, logits.shape[-1]) logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :2048].contiguous() gas_labels = labels[x*bs:(x+1)*bs, :2048].contiguous()
gas_labels = gas_labels.view(-1) gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels) gas_loss = F.cross_entropy(logits, gas_labels)
if train_config["loss_scale"]: if args["loss_scale"]:
scaler.scale(gas_loss).backward() scaler.scale(gas_loss).backward()
else: else:
gas_loss.backward() gas_loss.backward()
...@@ -119,39 +94,85 @@ for input_ids, labels in t: ...@@ -119,39 +94,85 @@ for input_ids, labels in t:
loss += gas_loss.item() loss += gas_loss.item()
loss = loss / gas loss = loss / gas
if train_config["loss_scale"]: if args["loss_scale"]:
scaler.unscale_(opt.optimizer) scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1) torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
if train_config["loss_scale"]: if args["loss_scale"]:
opt.step(scaler=scaler) opt.step(scaler=scaler)
else: else:
opt.step() opt.step()
if train_config["loss_scale"]: if args["loss_scale"]:
scaler.update() scaler.update()
opt.zero_grad() #opt.zero_grad()
model.zero_grad(set_to_none=True)
sec_per_step = (time.perf_counter() - timex) sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step) step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas tokens_per_sec = (step_per_sec * 2048) * bs * gas * world_size
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}") batch_size = bs * gas * world_size
wandb.log( ddp_loss[0] = loss
{ dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
"train/loss": loss, if rank == 0:
wandb.log({
"train_loss": ddp_loss[0] / world_size,
"train/tokens_per_sec": tokens_per_sec, "train/tokens_per_sec": tokens_per_sec,
"train/sec_per_step": sec_per_step, "train/sec_per_step": sec_per_step,
"train/step_per_sec": step_per_sec, "train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr, "train/lr": opt.curr_lr,
"train/batch_size": batch_size,
"train/loss_scale": scaler.get_scale() "train/loss_scale": scaler.get_scale()
}, })
step=curr_step)
# we need 250 batch size to train the small GPT.
if train_config["do_save"]: def main(rank, world_size, args):
if curr_step % train_config["save_every"] == 0: bs = args["bs"]
save_folder = Path(train_config["save_path"]) / f"step_{curr_step}" gas = args["gas"]
save_folder.mkdir(parents=True, exist_ok=True) torch.manual_seed(train_config["seed"])
model.save(save_folder / "lm") setup(rank, world_size)
opt.save(save_folder / "opt") Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
print(f"Saved model at step {curr_step}")
model = lm_utils.load_from_path("pretrained/gpt-j-base").float().to(rank)
curr_step += 1 #fsdp_model = FullyShardedDataParallel(
\ No newline at end of file #model,
#fsdp_auto_wrap_policy=default_auto_wrap_policy,
#cpu_offload=CPUOffload(offload_params=True),
#)
fsdp_model = DDP(model, device_ids=[rank], gradient_as_bucket_view=True)
utils.print_parameters(fsdp_model)
ic("model loaded")
opt = optimizer.BasedOptimizer(fsdp_model.parameters(), train_config, "zero1")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = utils.ShardedDataset(2049, train_config["data_path"], world_size=world_size, rank=rank)
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
if rank == 0:
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model.config})
fsdp_train(args, fsdp_model, train_loader, opt)
dist.barrier()
cleanup()
if __name__ == "__main__":
train_config = {
"data_path": "dataset/sigurd-1G.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/gptj-sigurd-1G-vanilla",
"do_save": False,
"run_name": "gptj-sigurd-1G-vanilla",
"lr": 6e-5,
"end_lr": 2e-5,
"warmup_steps": 100,
"anneal_steps": 10000,
"bs": 2,
"gas": 1,
"seed": 69,
"save_every": 500,
"amp": True,
"loss_scale": True,
}
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(rank)
main(rank, world_size, train_config)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment