Commit 24459438 authored by novelailab's avatar novelailab

fairseq works, start neo impl

parent 92b7b187
......@@ -11,39 +11,20 @@ except ImportError:
import os
from pathlib import Path
import math
from basedformer import lm_base
from basedformer.models import base_lm
from typing import Optional, Any
def shift_tokens(x, amt, eps = 1e-5):
n, device = x.shape[1], x.device
cumsum = x.cumsum(dim = 1)
*x, x_pass = x.chunk(amt + 1, dim = -1)
*x_cumsum, _ = cumsum.chunk(amt + 1, dim = -1)
amts = 2 ** torch.arange(amt)
amts = amts.tolist()
shifts = []
denom = torch.arange(n, device = device)
for x_chunk, x_cumsum_chunk, amt in zip(x, x_cumsum, amts):
shifted_chunk = shift(x_cumsum_chunk, amt, dim = -2) - shift(x_cumsum_chunk, 2 * amt, dim = -2)
shifted_denom = shift(denom, amt, dim = -1) - shift(denom, 2 * amt, dim = -1)
shifted_denom = rearrange(shifted_denom, 'n -> () n ()')
normed_shifted_x = shifted_chunk / (shifted_denom + eps)
shifts.append(normed_shifted_x)
return torch.cat((*shifts, x_pass), dim = -1)
def shift(x, amt, dim = -1):
return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
attention_mask=None, scale_attn=None, fp32_attn=True):
if fp32_attn:
attn_weights = torch.matmul(query.float(), key.transpose(-1, -2).float())
else:
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
attn_weights = attn_weights / scale_attn.to(attn_weights.dtype)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
......@@ -57,19 +38,34 @@ def _attn(query, key, value, causal_mask, masked_bias,
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype):
def __init__(self, config, attention_type):
nn.Module.__init__(self)
self.config = config
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = hidden_dim // n_head
if attention_type == "local":
self.register_buffer(
"bias",
bias ^ torch.tril(bias, -config.window_size),
)
else:
self.register_buffer(
"bias",
bias,
)
self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = hidden_dim
self.n_head = n_head
self.hidden_dim = config.hidden_dim
self.n_head = config.n_head
device = config.device
dtype = config.dtype
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
attn_bias = True #fairseq has attn_bias
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
......@@ -93,7 +89,7 @@ class SelfAttention(nn.Module):
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn, self.config.fp32_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
......@@ -101,14 +97,14 @@ class SelfAttention(nn.Module):
if cache:
return x, (key, value)
else:
return x
return x, None
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype):
def __init__(self, config):
nn.Module.__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation
self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim * 4, device=config.device, dtype=config.dtype)
self.ff2 = nn.Linear(config.hidden_dim * 4, config.hidden_dim, device=config.device, dtype=config.dtype)
self.activation = config.activation
def forward(self, x, act_ck=False):
x = self.ff1(x)
......@@ -120,70 +116,104 @@ class FeedForward(nn.Module):
return x
class GPTNeoLayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
def __init__(self, attn, ff, config, layer_idx):
nn.Module.__init__(self)
self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ln_postattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False):
self.hidden_dim = config.hidden_dim
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ln_postattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ff = ff(config)
if layer_idx % 2 == 0:
attn_type = "global"
else:
attn_type = "local"
self.attn = attn(config, attn_type)
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, cache=False, kv=None):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x)
attn_out, kv = ck(self.attn, x, kv=kv, cache=cache)
else:
x = self.ln_preattn(x)
attn_out = self.attn(x)
attn_out, kv = self.attn(x, kv=kv, cache=cache)
residual = residual + attn_out
x = residual + attn_out
residual = x
x = self.ln_postattn(x)
ff_out = self.ff(x, act_ck)
x = residual + ff_out
return x
return x, kv
class GPTNeoModel(base_lm.BaseModel):
def __init__(self, user_config, **kwargs):
self.default_config = {
'n_layer': 6,
'n_head': 8,
'n_tokens': 2049,
'hidden_dim': 512,
'vocab_dim': 50400,
'fp32_attn': True, #fairseq models are trained with fp32 attn
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
'Layer': GPTNeoLayer,
'activation': gelu_new,
'SelfAttention': SelfAttention,
'FeedForward': FeedForward,
'window_size': 256,
}
def __init__(self, user_config, **kwargs):
nn.Module.__init__(self)
#configuration
self.user_config = user_config
self.config = self.configure_model()
config = self.config
#modeling
self.n_layer = config.n_layer
self.hidden_dim = config.hidden_dim
self.vocab_embed = nn.Embedding(config.vocab_dim, self.hidden_dim, device=config.device, dtype=config.dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
for _ in range(config.n_layer):
self.layers.append(
config.Layer(
attn=config.SelfAttention,
ff=config.FeedForward,
config=config,
)
)
# returns sinusoidal embeddings of shape: (1, n_tokens, 768)
self.register_buffer("embed_scale", torch.sqrt(torch.tensor(self.config.hidden_dim, requires_grad=False)))
self.pos_embed = nn.Embedding(self.config.n_tokens, self.config.hidden_dim)
self.lm_head = nn.Linear(self.config.hidden_dim, self.config.vocab_dim, bias=False)
#bias=False for fairseq models
def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
if kv is None:
kv = [None] * self.n_layer
past_length = 0
else:
past_length = kv[0][0].size(-2) #get sequence dim of key
kv_new = []
position_ids = torch.arange(past_length, x[-1] + past_length, dtype=torch.long, device=x.device)
position_ids = position_ids.unsqueeze(0).view(-1, x[-1])
class GPTNeoModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTNeoLayer, device="cuda", dtype=torch.float16, **kwargs):
nn.Module.__init__(self)
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
x = x + self.pos_embed(position_ids)
for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x)
return x
x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
kv_new.append(kvi)
class GPTNeoBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPTNeoModel
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = GPTNeoBaseLM.load(config, path, state_dict)
return model
x = self.ln_final(x)
if cache:
return x, kv_new
else:
return x, None
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment