Commit a28f0299 authored by FIRST_NAME LAST_NAME's avatar FIRST_NAME LAST_NAME

push

parent a2b7dffb
......@@ -17,6 +17,7 @@ class BaseModel(nn.Module):
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
self.total_params = sum(p.numel() for p in self.parameters())
for i in range(config.n_layer):
config.layer_idx = i
self.layers.append(
......
......@@ -6,6 +6,7 @@ from dotmap import DotMap
import pickle
import os
from pathlib import Path
from torch.distributed.optim import ZeroRedundancyOptimizer
#Based Optimizer
def lr_schedule(step, warmup_steps, anneal_steps, lr, end_lr, cosine_warmup=False):
......@@ -61,6 +62,17 @@ class BasedOptimizer:
import bitsandbytes as bnb
self.optimizer = bnb.optim.Adam8bit(self.parameters, lr=0, weight_decay=self.weight_decay, betas=(self.beta1, self.beta2), eps=self.eps)
elif self.optimizer_name == "zero1":
import bitsandbytes as bnb
self.optimizer = ZeroRedundancyOptimizer(
self.parameters,
optimizer_class=bnb.optim.Adam8bit,
lr=0,
weight_decay=self.weight_decay,
betas=(self.beta1, self.beta2),
eps=self.eps,
)
elif self.optimizer_name == "adafactor":
try:
from transformers.optimization import Adafactor
......
......@@ -30,6 +30,24 @@ class FbDataset(data.Dataset):
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
class ShardedDataset(data.Dataset):
def __init__(self, block_size, map_file, world_size=1, rank=0, skip=0):
self.npz = np.memmap(map_file, mode="r", dtype="uint16").reshape((-1, block_size))
#might want to pad later
self.npz = self.npz[:self.npz.shape[0] - (self.npz.shape[0] % world_size)]
#shard
self.npz = self.npz[rank::world_size]
self.samples = self.npz.shape[0]
self.skip = skip
def __len__(self):
return self.samples
def __getitem__(self, _id):
nth = _id + self.skip
data = torch.tensor(self.npz[nth].astype(np.int64))
return (data[:-1], data[1:])
# Make loading models faster by not letting pytorch initialize the weights.
# Usage: no_init(lambda: load_model(...))
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment