Commit aa6444e9 authored by novelailab's avatar novelailab

gpt neo working

parent 88f0a90e
This source diff could not be displayed because it is too large. You can view the blob instead.
from . import gptj
from . import gpt2
from . import fairseq
from . import gptneo
MODEL_MAP = {
"gptj": gptj.GPTJModel,
"gpt2": gpt2.GPT2Model,
"gpt-fairseq": fairseq.GPTFairModel
"gpt-fairseq": fairseq.GPTFairModel,
"gpt-neo": gptneo.GPTNeoModel
}
def get_model(model_name: str):
......
......@@ -18,12 +18,12 @@ class BaseModel(nn.Module):
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
for i in range(config.n_layer):
config.layer_idx = i
self.layers.append(
config.Layer(
attn=config.SelfAttention,
ff=config.FeedForward,
config=config,
layer_idx=i,
)
)
......
......@@ -13,6 +13,7 @@ from pathlib import Path
import math
from basedformer.models import base_lm
from typing import Optional, Any
from icecream import ic
def _attn(query, key, value, causal_mask, masked_bias,
......@@ -24,7 +25,8 @@ def _attn(query, key, value, causal_mask, masked_bias,
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn.to(attn_weights.dtype)
if scale_attn:
attn_weights = attn_weights / scale_attn.to(attn_weights.dtype)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
......@@ -38,14 +40,15 @@ def _attn(query, key, value, causal_mask, masked_bias,
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config, attention_type):
def __init__(self, config, attn_type):
ic(attn_type)
nn.Module.__init__(self)
self.config = config
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
if attention_type == "local":
if attn_type == "local":
self.register_buffer(
"bias",
bias ^ torch.tril(bias, -config.window_size),
......@@ -63,13 +66,13 @@ class SelfAttention(nn.Module):
device = config.device
dtype = config.dtype
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.scale_attn = None
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = True #fairseq has attn_bias
attn_bias = False #fairseq has attn_bias
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=True, device=device, dtype=dtype)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
......@@ -116,13 +119,13 @@ class FeedForward(nn.Module):
return x
class GPTNeoLayer(nn.Module):
def __init__(self, attn, ff, config, layer_idx):
def __init__(self, attn, ff, config):
nn.Module.__init__(self)
self.hidden_dim = config.hidden_dim
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ln_postattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ff = ff(config)
if layer_idx % 2 == 0:
if config.layer_idx % 2 == 0:
attn_type = "global"
else:
attn_type = "local"
......@@ -154,8 +157,8 @@ class GPTNeoModel(base_lm.BaseModel):
'n_head': 8,
'n_tokens': 2049,
'hidden_dim': 512,
'vocab_dim': 50400,
'fp32_attn': True, #fairseq models are trained with fp32 attn
'vocab_dim': 50257,
'fp32_attn': False,
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
......@@ -165,29 +168,7 @@ class GPTNeoModel(base_lm.BaseModel):
'FeedForward': FeedForward,
'window_size': 256,
}
def __init__(self, user_config, **kwargs):
nn.Module.__init__(self)
#configuration
self.user_config = user_config
self.config = self.configure_model()
config = self.config
#modeling
self.n_layer = config.n_layer
self.hidden_dim = config.hidden_dim
self.vocab_embed = nn.Embedding(config.vocab_dim, self.hidden_dim, device=config.device, dtype=config.dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
for _ in range(config.n_layer):
self.layers.append(
config.Layer(
attn=config.SelfAttention,
ff=config.FeedForward,
config=config,
)
)
# returns sinusoidal embeddings of shape: (1, n_tokens, 768)
self.register_buffer("embed_scale", torch.sqrt(torch.tensor(self.config.hidden_dim, requires_grad=False)))
base_lm.BaseModel.__init__(self, user_config, **kwargs)
self.pos_embed = nn.Embedding(self.config.n_tokens, self.config.hidden_dim)
self.lm_head = nn.Linear(self.config.hidden_dim, self.config.vocab_dim, bias=False)
#bias=False for fairseq models
......@@ -202,8 +183,10 @@ class GPTNeoModel(base_lm.BaseModel):
kv_new = []
position_ids = torch.arange(past_length, x[-1] + past_length, dtype=torch.long, device=x.device)
position_ids = position_ids.unsqueeze(0).view(-1, x[-1])
position_ids = torch.arange(past_length,
x.shape[-1] + past_length,
dtype=torch.long, device=x.device)
position_ids = position_ids.unsqueeze(0).view(-1, x.shape[-1])
x = self.vocab_embed(x)
x = x + self.pos_embed(position_ids)
......
......@@ -40,9 +40,9 @@ if False:
#path.sh("pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113")
with always_rerun():
if True:
env1.sh('pip3 uninstall transformers')
env1.sh('pip3 install transformers')
path.sh("python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/fairseq_125m --device 0 --tasks lambada --no_cache")
#env1.sh('pip3 uninstall transformers')
#env1.sh('pip3 install transformers')
path.sh("python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/gpt-neo-125m-ported --device 0 --tasks lambada --no_cache")
#path.sh("python3 ../lm-evaluation-harness/main.py --batch_size 8")
else:
......
import torch
import transformers
import sys
from icecream import ic
import os
from pathlib import Path
"""
Original:
ln_f.weight
ln_f.bias
wte.weight
wpe.weight
h.0.ln_1.weight
h.0.ln_1.bias
h.0.attn.bias
h.0.attn.c_attn.weight
h.0.attn.c_attn.bias
h.0.attn.c_proj.weight
h.0.attn.c_proj.bias
h.0.ln_2.weight
h.0.ln_2.bias
h.0.mlp.c_fc.weight
h.0.mlp.c_fc.bias
h.0.mlp.c_proj.weight
h.0.mlp.c_proj.bias
attn has biases unlike GPT-J. QKV Matrices are also merged instead of separate. what is the order though? probably just QKV.
"""
x = torch.load("pretrained/gpt-neo-125m/pytorch_model.bin")
state_dict = x
ic(x.keys())
new_state_dict = {}
module_map = {
"ln_1": "ln_preattn",
"ln_2": "ln_postattn",
"mlp.c_proj": "ff.ff2",
"mlp.c_fc": "ff.ff1",
"attn.attention.q_proj": "attn.q_proj",
"attn.attention.k_proj": "attn.k_proj",
"attn.attention.v_proj": "attn.v_proj",
"attn.attention.out_proj": "attn.out_proj",
"wte": "vocab_embed",
"wpe": "pos_embed",
'ln_f': 'ln_final',
}
print(type(state_dict))
for key in state_dict.keys():
dotlist = key.split('.')
if len(dotlist) > 3:
layer = dotlist[2]
for x in module_map:
if x in key:
new_state_dict[f"layers.{layer}.{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
print(f"{key} -> layers.{layer}.{module_map[x]}.{dotlist[-1]}")
else:
for x in module_map:
if x in key:
new_state_dict[f"{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
print(f"{key} -> {module_map[x]}.{dotlist[-1]}")
new_state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
for k, v in new_state_dict.items():
print(f"{k} -> {v.shape}")
def save(state_dict, path):
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
checkpoint = {}
for i, x in enumerate(state_dict.items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
save(new_state_dict, "pretrained/gpt-neo-125m-ported/lm")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment