Commit c58dfef8 authored by FIRST_NAME LAST_NAME's avatar FIRST_NAME LAST_NAME

push

parent a28f0299
......@@ -2,7 +2,7 @@ from basedformer import utils
from basedformer import models
import math
import torch
from torch import nn
from torch import nn, distributed
import os
import json
from dataclasses import dataclass
......@@ -36,14 +36,41 @@ def no_init(config):
model = utils.no_init(lambda: model_class(config))
return model
def save(model, path):
try: os.mkdir(path)
except: pass
def serialize_config(config):
serialized_dict = {
"model_class": "gptj",
"model_path": ".",
'model_config': {
'n_layer': config.n_layer,
'n_head': config.n_head,
'n_tokens': config.n_tokens,
'hidden_dim': config.hidden_dim,
'vocab_dim': config.vocab_dim,
'eps': config.eps,
}
}
return serialized_dict
def save(model, path, save_fp16=True):
if distributed.is_initialized() and distributed.get_rank() != 0:
return
if save_fp16:
model = model.half()
path = Path(path)
lm_path = path / "lm"
#make folder
lm_path.mkdir(parents=True, exist_ok=True)
checkpoint = {}
for i, x in enumerate(model.state_dict().items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
checkpoint[x[0]] = lm_path / f"b{i}.pt"
torch.save(x[1], lm_path / f"b{i}.pt")
torch.save(checkpoint, lm_path / "m.pt")
#write model.config to config.json inside path
with open(path / "config.json", "w") as f:
json.dump(serialize_config(model.config), f)
def load_from_path(config_folder=None, strict=False):
config_folder = Path(config_folder)
......@@ -51,13 +78,12 @@ def load_from_path(config_folder=None, strict=False):
model_class = models.get_model(config["model_class"])
model_path = config["model_path"]
model_config = config["model_config"]
print(model_config)
if model_path == ".":
# model_path is the config_folder directory.
model_path = config_folder
model_path = Path(model_path) / "lm"
model_path = str(Path(model_path) / "lm")
model = _load_dict_model(model_class, model_config, model_path, strict=strict)
return model
......
......@@ -44,7 +44,7 @@ class BaseModel(nn.Module):
full_config = DotMap(full_config)
return full_config
def forward(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
def forward_with_hidden_states(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
x = self.lm_head(x)
if target:
......@@ -64,6 +64,26 @@ class BaseModel(nn.Module):
else:
return x.float()
def forward(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
hidden_states, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
x = self.lm_head(hidden_states)
if target:
logits = x.view(-1, logits.shape[-1])
labels = target.view(-1)
loss = F.cross_entropy(logits, labels)
#clean this mess later
if cache:
if target:
return loss, x.float(), kv
else:
return x.float(), kv
else:
if target:
return loss, x.float()
else:
return x.float(), hidden_states
def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
if kv is None:
kv = [None] * self.n_layer
......
......@@ -98,7 +98,7 @@ class SplitCheckpoint(MutableMapping):
def __len__(self):
return len(self.checkpoint)
def __getitem__(self, key):
name = self.checkpoint[key]
name = str(self.checkpoint[key])
if type(name) is tuple:
return self._load(name[0].split('/')[-1], name[1], map_location=self.device)
else:
......
import os
import torch
from dotmap import DotMap
from finetune import main
if __name__ == "__main__":
train_config = {
"data_path": "dataset/sigurd-1G.map",
"save_path": "models/gptj-sigurd-1G-contrastive-0.3weight",
"do_save": True,
"run_name": "gptj-sigurd-1G-contrastive0.3weight",
"lr": 6e-5,
"end_lr": 3e-5,
"warmup_steps": 100,
"anneal_steps": 7850,
"bs": 2,
"gas": 2,
"seed": 69,
"save_every": 500,
"amp": True,
"loss_scale": True,
"cast_to": torch.float16,
"contrastive_loss": 0.3,
}
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
torch.cuda.set_device(rank)
main(rank, global_rank, world_size, DotMap(train_config))
\ No newline at end of file
from basedformer import lm_utils as lmu
import torch
model = lmu.load_from_path("models/gptj-sigurd-1G-vanilla/final")
lmu.save(model, "models/gptj-sigurd-1G-vanilla/final_fp16")
\ No newline at end of file
NCCL_DEBUG=INFO torchrun \
--nnodes=2 \
--nproc_per_node=8 \
--rdzv_endpoint=10.0.155.233:29300 \
finetune.py
\ No newline at end of file
......@@ -18,6 +18,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from dotmap import DotMap
import argparse
from torch.distributed.fsdp import (
FullyShardedDataParallel,
......@@ -51,19 +52,20 @@ def get_world():
def get_flops(args, model, iter_time_s):
ff = model.total_params * 6
attn = 2048 * args.hidden_size * args.n_layers * 60
attn = 2048 * model.config.hidden_dim * model.config.n_layer * 60
flops = (
args.bs * args.gas
* 2048
* (ff + attn)
/ (iter_time_s)
)
return flops
return flops / 1e12
def fsdp_train(args, model, train_loader, opt):
bs = args["bs"]
gas = args["gas"]
rank = get_rank()
global_rank = get_rank()
rank = int(os.environ["LOCAL_RANK"])
world_size = get_world()
model.train()
ddp_loss = torch.zeros(1).cuda()
......@@ -73,19 +75,28 @@ def fsdp_train(args, model, train_loader, opt):
t = train_loader
scaler = torch.cuda.amp.GradScaler()
counter = 0
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.to(rank)
labels = labels.to(rank)
loss = 0
for x in range(args["gas"]):
with torch.cuda.amp.autocast(enabled=args["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :2048].to(rank), act_ck=True)
with torch.cuda.amp.autocast(enabled=args["amp"], dtype=args["cast_to"]):
logits, hidden_states = model(input_ids[x*bs:(x+1)*bs, :2048].to(rank), act_ck=True)
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :2048].contiguous()
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
if args.contrastive_loss:
#print("contrastive enabled")
with torch.no_grad():
max = hidden_states.abs().amax().detach()
hs = hidden_states.div(max)
norm = hs.norm(dim=-1, keepdim=True)
norm = norm.matmul(norm.transpose(-1,-2))
contrastive_loss = torch.matmul(hs, hs.transpose(-2, -1)).div(norm).abs().mean()
gas_loss += contrastive_loss * args.contrastive_loss
if args["loss_scale"]:
scaler.scale(gas_loss).backward()
else:
......@@ -108,12 +119,13 @@ def fsdp_train(args, model, train_loader, opt):
#opt.zero_grad()
model.zero_grad(set_to_none=True)
sec_per_step = (time.perf_counter() - timex)
flops = get_flops(args, model.module, sec_per_step)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas * world_size
batch_size = bs * gas * world_size
ddp_loss[0] = loss
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
if global_rank == 0:
wandb.log({
"train_loss": ddp_loss[0] / world_size,
......@@ -122,57 +134,64 @@ def fsdp_train(args, model, train_loader, opt):
"train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr,
"train/batch_size": batch_size,
"train/loss_scale": scaler.get_scale()
"train/loss_scale": scaler.get_scale(),
"train/flops": flops,
})
if counter != 0 and counter % args["save_every"] == 0:
if global_rank == 0:
lm_utils.save(model.module, Path(args["save_path"]) / f"step_{str(counter)}")
dist.barrier()
counter += 1
# we need 250 batch size to train the small GPT.
def main(rank, world_size, args):
def main(rank, global_rank, world_size, args):
bs = args["bs"]
gas = args["gas"]
torch.manual_seed(train_config["seed"])
torch.manual_seed(args["seed"])
setup(rank, world_size)
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
Path(args["save_path"]).mkdir(parents=True, exist_ok=True)
model = lm_utils.load_from_path("pretrained/gpt-j-base").float().to(rank)
#fsdp_model = FullyShardedDataParallel(
#model,
#fsdp_auto_wrap_policy=default_auto_wrap_policy,
#cpu_offload=CPUOffload(offload_params=True),
#)
fsdp_model = DDP(model, device_ids=[rank], gradient_as_bucket_view=True)
fsdp_model = DDP(model, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
utils.print_parameters(fsdp_model)
ic("model loaded")
opt = optimizer.BasedOptimizer(fsdp_model.parameters(), train_config, "zero1")
opt = optimizer.BasedOptimizer(fsdp_model.parameters(), args, "zero1")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = utils.ShardedDataset(2049, train_config["data_path"], world_size=world_size, rank=rank)
train_dataset = utils.ShardedDataset(2049, args["data_path"], world_size=world_size, rank=global_rank)
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
if rank == 0:
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model.config})
if global_rank == 0:
wandb.init(project="basedformer-tests", name=args["run_name"], config={**args, **model.config})
fsdp_train(args, fsdp_model, train_loader, opt)
lm_utils.save(fsdp_model.module, Path(args["save_path"]) / "final")
dist.barrier()
cleanup()
if __name__ == "__main__":
train_config = {
"data_path": "dataset/sigurd-1G.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/gptj-sigurd-1G-vanilla",
"do_save": False,
"save_path": "models/gptj-sigurd-1G-vanilla",
"do_save": True,
"run_name": "gptj-sigurd-1G-vanilla",
"lr": 6e-5,
"end_lr": 2e-5,
"end_lr": 3e-5,
"warmup_steps": 100,
"anneal_steps": 10000,
"anneal_steps": 7850,
"bs": 2,
"gas": 1,
"gas": 2,
"seed": 69,
"save_every": 500,
"amp": True,
"loss_scale": True,
"cast_to": torch.float16,
"contrastive_loss": False,
}
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
torch.cuda.set_device(rank)
main(rank, world_size, train_config)
\ No newline at end of file
main(rank, global_rank, world_size, DotMap(train_config))
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from torch.utils import data
from basedformer import optimizer, utils, lm_utils
import yaml
import sys
from tqdm import tqdm
import time
import wandb
import numpy as np
import os
from icecream import ic
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from dotmap import DotMap
import argparse
from torch.distributed.fsdp import (
FullyShardedDataParallel,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
default_auto_wrap_policy,
)
def setup(rank, world_size):
#os.environ['MASTER_ADDR'] = 'localhost'
#os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist.init_process_group(backend="nccl")
if dist.is_initialized():
print("Initialized process group")
else:
print("Failed to initialize process group")
def cleanup():
dist.destroy_process_group()
def get_rank():
if dist.is_initialized():
return dist.get_rank()
def get_world():
if dist.is_initialized():
return dist.get_world_size()
def get_flops(args, model, iter_time_s):
ff = model.total_params * 6
attn = 2048 * model.config.hidden_dim * model.config.n_layer * 60
flops = (
args.bs * args.gas
* 2048
* (ff + attn)
/ (iter_time_s)
)
return flops / 1e12
def fsdp_train(args, model, train_loader, opt):
bs = args["bs"]
gas = args["gas"]
rank = get_rank()
world_size = get_world()
model.train()
ddp_loss = torch.zeros(1).cuda()
if rank == 0:
t = tqdm(train_loader)
else:
t = train_loader
scaler = torch.cuda.amp.GradScaler()
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.to(rank)
labels = labels.to(rank)
loss = 0
for x in range(args["gas"]):
with torch.cuda.amp.autocast(enabled=args["amp"], dtype=args["cast_to"]):
logits = model(input_ids[x*bs:(x+1)*bs, :2048].to(rank), act_ck=True)
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :2048].contiguous()
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
if args["loss_scale"]:
scaler.scale(gas_loss).backward()
else:
gas_loss.backward()
loss += gas_loss.item()
loss = loss / gas
if args["loss_scale"]:
scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
if args["loss_scale"]:
opt.step(scaler=scaler)
else:
opt.step()
if args["loss_scale"]:
scaler.update()
#opt.zero_grad()
model.zero_grad(set_to_none=True)
sec_per_step = (time.perf_counter() - timex)
flops = get_flops(args, model.module, sec_per_step)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas * world_size
batch_size = bs * gas * world_size
ddp_loss[0] = loss
dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
if rank == 0:
wandb.log({
"train_loss": ddp_loss[0] / world_size,
"train/tokens_per_sec": tokens_per_sec,
"train/sec_per_step": sec_per_step,
"train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr,
"train/batch_size": batch_size,
"train/loss_scale": scaler.get_scale(),
"train/flops": flops,
})
# we need 250 batch size to train the small GPT.
def main(rank, global_rank, world_size, args):
bs = args["bs"]
gas = args["gas"]
torch.manual_seed(train_config["seed"])
setup(rank, world_size)
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
model = lm_utils.load_from_path("pretrained/gpt-j-base").float().to(rank)
fsdp_model = DDP(model, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
utils.print_parameters(fsdp_model)
ic("model loaded")
opt = optimizer.BasedOptimizer(fsdp_model.parameters(), train_config, "zero1")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = utils.ShardedDataset(2049, train_config["data_path"], world_size=world_size, rank=global_rank)
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
if rank == 0:
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model.config})
fsdp_train(args, fsdp_model, train_loader, opt)
dist.barrier()
if rank == 0:
fsdp_model.module.
cleanup()
if __name__ == "__main__":
train_config = {
"data_path": "dataset/sigurd-1G.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/gptj-sigurd-1G-vanilla",
"do_save": False,
"run_name": "gptj-sigurd-1G-vanilla",
"lr": 6e-5,
"end_lr": 3e-5,
"warmup_steps": 100,
"anneal_steps": 7861,
"bs": 2,
"gas": 2,
"seed": 69,
"save_every": 500,
"amp": True,
"loss_scale": False,
"cast_to": torch.bfloat16,
}
world_size = int(os.environ["WORLD_SIZE"])
rank = int(os.environ["LOCAL_RANK"])
global_rank = int(os.environ["RANK"])
torch.cuda.set_device(rank)
main(rank, global_rank, world_size, DotMap(train_config))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment