Commit fb134d28 authored by novelailab's avatar novelailab

zero2 works

parent 9d27a5cc
......@@ -95,15 +95,18 @@ class SelfAttention(nn.Module):
sin, cos = fixed_pos_embedding(dim=self.rotary_dim, seq_len=max_positions)
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
self.fused_softmax = FusedScaleMaskSoftmax(
input_in_fp16=False,
input_in_bf16=True,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type="causal",
scaled_masked_softmax_fusion=True,
)
if self.config.masked_softmax_fusion:
self.fused_softmax = FusedScaleMaskSoftmax(
input_in_fp16=False,
input_in_bf16=True,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type="causal",
scaled_masked_softmax_fusion=True,
)
else:
self.fused_softmax = None
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
......@@ -242,7 +245,7 @@ class GPTJModel(base_lm.BaseModel):
'activation': gelu_new,
'SelfAttention': SelfAttention,
'FeedForward': FeedForward,
'masked_softmax_fusion': False,
'masked_softmax_fusion': True,
}
base_lm.BaseModel.__init__(self, user_config, **kwargs)
if self.config.masked_softmax_fusion:
......
......@@ -77,6 +77,10 @@ class BasedOptimizer:
eps=self.eps,
)
elif self.optimizer_name == "zero2":
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
self.optimizer = DistributedFusedAdam(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps, grad_sync_dtype=torch.float32)
elif self.optimizer_name == "adafactor":
try:
from transformers.optimization import Adafactor
......
......@@ -5,7 +5,7 @@ import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from torch.utils import data
from basedformer import optimizer, utils, lm_utils
from basedformer import optimizer, utils, lm_utils, dataset
import yaml
import sys
from tqdm import tqdm
......@@ -16,17 +16,11 @@ import os
from icecream import ic
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
#from torch.nn.parallel import DistributedDataParallel as DDP
from apex.parallel.distributed import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from dotmap import DotMap
import argparse
from torch.distributed.fsdp import (
FullyShardedDataParallel,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
default_auto_wrap_policy,
)
def setup(rank, world_size):
#os.environ['MASTER_ADDR'] = 'localhost'
......@@ -97,14 +91,19 @@ def fsdp_train(args, model, train_loader, opt):
norm = norm.matmul(norm.transpose(-1,-2))
contrastive_loss = torch.matmul(hs, hs.transpose(-2, -1)).div(norm).abs().mean()
gas_loss += contrastive_loss * args.contrastive_loss
if args["loss_scale"]:
scaler.scale(gas_loss).backward()
with opt.optimizer.no_sync():
scaler.scale(gas_loss).backward()
else:
gas_loss.backward()
with opt.optimizer.no_sync():
gas_loss.backward()
loss += gas_loss.item()
loss = loss / gas
opt.optimizer.grad_sync()
if args["loss_scale"]:
scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
......@@ -116,10 +115,10 @@ def fsdp_train(args, model, train_loader, opt):
if args["loss_scale"]:
scaler.update()
#opt.zero_grad()
model.zero_grad(set_to_none=True)
opt.zero_grad()
#model.zero_grad(set_to_none=True)
sec_per_step = (time.perf_counter() - timex)
flops = get_flops(args, model.module, sec_per_step)
flops = get_flops(args, model, sec_per_step)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas * world_size
batch_size = bs * gas * world_size
......@@ -153,15 +152,17 @@ def main(rank, global_rank, world_size, args):
setup(rank, world_size)
Path(args["save_path"]).mkdir(parents=True, exist_ok=True)
model = lm_utils.load_from_path("pretrained/gpt-j-base").float().to(rank)
fsdp_model = DDP(model, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
model = lm_utils.load_from_path("/home/xuser/nvme1/pretrained/gpt-j-base").half().to(rank)
#fsdp_model = DDP(model, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
#fsdp_model = DDP(model)
fsdp_model = model
utils.print_parameters(fsdp_model)
ic("model loaded")
opt = optimizer.BasedOptimizer(fsdp_model.parameters(), args, "zero1")
opt = optimizer.BasedOptimizer(fsdp_model.parameters(), args, "zero2")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = utils.ShardedDataset(2049, args["data_path"], world_size=world_size, rank=global_rank)
train_dataset = dataset.ShardedDataset(2049, args["data_path"], world_size=world_size, rank=global_rank)
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
if global_rank == 0:
wandb.init(project="basedformer-tests", name=args["run_name"], config={**args, **model.config})
......@@ -172,21 +173,21 @@ def main(rank, global_rank, world_size, args):
if __name__ == "__main__":
train_config = {
"data_path": "dataset/sigurd-1G.map",
"data_path": "/home/xuser/nvme1/dataset/sigurd-1G.map",
"save_path": "models/gptj-sigurd-1G-vanilla",
"do_save": True,
"do_save": False,
"run_name": "gptj-sigurd-1G-vanilla",
"lr": 6e-5,
"end_lr": 3e-5,
"warmup_steps": 100,
"anneal_steps": 7850,
"bs": 2,
"gas": 2,
"gas": 8,
"seed": 69,
"save_every": 500,
"amp": True,
"amp": False,
"loss_scale": True,
"cast_to": torch.float16,
"cast_to": torch.bfloat16,
"contrastive_loss": False,
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment