Commit aa6444e9 authored by novelailab's avatar novelailab

gpt neo working

parent 88f0a90e
This diff is collapsed.
from . import gptj
from . import gpt2
from . import fairseq
from . import gptneo
MODEL_MAP = {
"gptj": gptj.GPTJModel,
"gpt2": gpt2.GPT2Model,
"gpt-fairseq": fairseq.GPTFairModel
"gpt-fairseq": fairseq.GPTFairModel,
"gpt-neo": gptneo.GPTNeoModel
}
def get_model(model_name: str):
......
......@@ -18,12 +18,12 @@ class BaseModel(nn.Module):
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
for i in range(config.n_layer):
config.layer_idx = i
self.layers.append(
config.Layer(
attn=config.SelfAttention,
ff=config.FeedForward,
config=config,
layer_idx=i,
)
)
......
......@@ -13,6 +13,7 @@ from pathlib import Path
import math
from basedformer.models import base_lm
from typing import Optional, Any
from icecream import ic
def _attn(query, key, value, causal_mask, masked_bias,
......@@ -24,7 +25,8 @@ def _attn(query, key, value, causal_mask, masked_bias,
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn.to(attn_weights.dtype)
if scale_attn:
attn_weights = attn_weights / scale_attn.to(attn_weights.dtype)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
......@@ -38,14 +40,15 @@ def _attn(query, key, value, causal_mask, masked_bias,
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config, attention_type):
def __init__(self, config, attn_type):
ic(attn_type)
nn.Module.__init__(self)
self.config = config
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
if attention_type == "local":
if attn_type == "local":
self.register_buffer(
"bias",
bias ^ torch.tril(bias, -config.window_size),
......@@ -63,13 +66,13 @@ class SelfAttention(nn.Module):
device = config.device
dtype = config.dtype
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.scale_attn = None
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = True #fairseq has attn_bias
attn_bias = False #fairseq has attn_bias
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=True, device=device, dtype=dtype)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
......@@ -116,13 +119,13 @@ class FeedForward(nn.Module):
return x
class GPTNeoLayer(nn.Module):
def __init__(self, attn, ff, config, layer_idx):
def __init__(self, attn, ff, config):
nn.Module.__init__(self)
self.hidden_dim = config.hidden_dim
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ln_postattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ff = ff(config)
if layer_idx % 2 == 0:
if config.layer_idx % 2 == 0:
attn_type = "global"
else:
attn_type = "local"
......@@ -154,8 +157,8 @@ class GPTNeoModel(base_lm.BaseModel):
'n_head': 8,
'n_tokens': 2049,
'hidden_dim': 512,
'vocab_dim': 50400,
'fp32_attn': True, #fairseq models are trained with fp32 attn
'vocab_dim': 50257,
'fp32_attn': False,
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
......@@ -165,29 +168,7 @@ class GPTNeoModel(base_lm.BaseModel):
'FeedForward': FeedForward,
'window_size': 256,
}
def __init__(self, user_config, **kwargs):
nn.Module.__init__(self)
#configuration
self.user_config = user_config
self.config = self.configure_model()
config = self.config
#modeling
self.n_layer = config.n_layer
self.hidden_dim = config.hidden_dim
self.vocab_embed = nn.Embedding(config.vocab_dim, self.hidden_dim, device=config.device, dtype=config.dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
for _ in range(config.n_layer):
self.layers.append(
config.Layer(
attn=config.SelfAttention,
ff=config.FeedForward,
config=config,
)
)
# returns sinusoidal embeddings of shape: (1, n_tokens, 768)
self.register_buffer("embed_scale", torch.sqrt(torch.tensor(self.config.hidden_dim, requires_grad=False)))
base_lm.BaseModel.__init__(self, user_config, **kwargs)
self.pos_embed = nn.Embedding(self.config.n_tokens, self.config.hidden_dim)
self.lm_head = nn.Linear(self.config.hidden_dim, self.config.vocab_dim, bias=False)
#bias=False for fairseq models
......@@ -202,8 +183,10 @@ class GPTNeoModel(base_lm.BaseModel):
kv_new = []
position_ids = torch.arange(past_length, x[-1] + past_length, dtype=torch.long, device=x.device)
position_ids = position_ids.unsqueeze(0).view(-1, x[-1])
position_ids = torch.arange(past_length,
x.shape[-1] + past_length,
dtype=torch.long, device=x.device)
position_ids = position_ids.unsqueeze(0).view(-1, x.shape[-1])
x = self.vocab_embed(x)
x = x + self.pos_embed(position_ids)
......
......@@ -40,9 +40,9 @@ if False:
#path.sh("pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113")
with always_rerun():
if True:
env1.sh('pip3 uninstall transformers')
env1.sh('pip3 install transformers')
path.sh("python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/fairseq_125m --device 0 --tasks lambada --no_cache")
#env1.sh('pip3 uninstall transformers')
#env1.sh('pip3 install transformers')
path.sh("python3 ../lm-evaluation-harness/main.py --model basedformer --batch_size 8 --model_args pretrained=/home/xuser/diffusionstorage/workspace/kuru/basedformer/pretrained/gpt-neo-125m-ported --device 0 --tasks lambada --no_cache")
#path.sh("python3 ../lm-evaluation-harness/main.py --batch_size 8")
else:
......
import torch
import transformers
import sys
from icecream import ic
import os
from pathlib import Path
"""
Original:
ln_f.weight
ln_f.bias
wte.weight
wpe.weight
h.0.ln_1.weight
h.0.ln_1.bias
h.0.attn.bias
h.0.attn.c_attn.weight
h.0.attn.c_attn.bias
h.0.attn.c_proj.weight
h.0.attn.c_proj.bias
h.0.ln_2.weight
h.0.ln_2.bias
h.0.mlp.c_fc.weight
h.0.mlp.c_fc.bias
h.0.mlp.c_proj.weight
h.0.mlp.c_proj.bias
attn has biases unlike GPT-J. QKV Matrices are also merged instead of separate. what is the order though? probably just QKV.
"""
x = torch.load("pretrained/gpt-neo-125m/pytorch_model.bin")
state_dict = x
ic(x.keys())
new_state_dict = {}
module_map = {
"ln_1": "ln_preattn",
"ln_2": "ln_postattn",
"mlp.c_proj": "ff.ff2",
"mlp.c_fc": "ff.ff1",
"attn.attention.q_proj": "attn.q_proj",
"attn.attention.k_proj": "attn.k_proj",
"attn.attention.v_proj": "attn.v_proj",
"attn.attention.out_proj": "attn.out_proj",
"wte": "vocab_embed",
"wpe": "pos_embed",
'ln_f': 'ln_final',
}
print(type(state_dict))
for key in state_dict.keys():
dotlist = key.split('.')
if len(dotlist) > 3:
layer = dotlist[2]
for x in module_map:
if x in key:
new_state_dict[f"layers.{layer}.{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
print(f"{key} -> layers.{layer}.{module_map[x]}.{dotlist[-1]}")
else:
for x in module_map:
if x in key:
new_state_dict[f"{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
print(f"{key} -> {module_map[x]}.{dotlist[-1]}")
new_state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
for k, v in new_state_dict.items():
print(f"{k} -> {v.shape}")
def save(state_dict, path):
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
checkpoint = {}
for i, x in enumerate(state_dict.items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
save(new_state_dict, "pretrained/gpt-neo-125m-ported/lm")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment