Commit 835709ca authored by novelailab's avatar novelailab

happy with BaseModel abstraction

parent c8d491e1
...@@ -128,7 +128,7 @@ dmypy.json ...@@ -128,7 +128,7 @@ dmypy.json
# Pyre type checker # Pyre type checker
.pyre/ .pyre/
models /models
gptjconvert gptjconvert
j6b_vanilla j6b_vanilla
wandb wandb
......
from . import gptj from .models import gptj
MODEL_MAP = { MODEL_MAP = {
"gptj": (gptj.GPTJModel, gptj.GPTJConfig), "gptj": gptj.GPTJModel,
} }
def get_model(model_name: str): def get_model(model_name: str):
......
from basedformer import utils from basedformer import utils
import basedformer from basedformer import models
import math import math
import torch import torch
from torch import nn from torch import nn
from basedformer import gptj
import os import os
import json import json
from dataclasses import dataclass from dataclasses import dataclass
...@@ -48,11 +47,9 @@ def save(model, path): ...@@ -48,11 +47,9 @@ def save(model, path):
def load_from_path(config_folder=None, strict=False): def load_from_path(config_folder=None, strict=False):
config_folder = Path(config_folder) config_folder = Path(config_folder)
config = _load_config_file(config_folder / "config.json") config = _load_config_file(config_folder / "config.json")
model_class = basedformer.get_model(config["model_class"])[0] model_class = models.get_model(config["model_class"])
config_class = basedformer.get_model(config["model_class"])[1]
model_path = config["model_path"] model_path = config["model_path"]
model_config = config["model_config"] model_config = config["model_config"]
model_config = config_class(**model_config)
print(model_config) print(model_config)
if model_path == ".": if model_path == ".":
......
from . import gptj
MODEL_MAP = {
"gptj": gptj.GPTJModel,
}
def get_model(model_name: str):
return MODEL_MAP[model_name]
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from dotmap import DotMap
class BaseModel(nn.Module):
def __init__(self, user_config, **kwargs):
nn.Module.__init__(self)
#configuration
self.user_config = user_config
self.config = self.configure_model()
config = self.config
#modeling
self.n_layer = config.n_layer
self.hidden_dim = config.hidden_dim
self.vocab_embed = nn.Embedding(config.vocab_dim, self.hidden_dim, device=config.device, dtype=config.dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True)
for _ in range(config.n_layer):
self.layers.append(
config.Layer(
attn=config.SelfAttention,
ff=config.FeedForward,
config=config,
)
)
def configure_model(self):
full_config = {}
if not hasattr(self, 'default_config'):
raise ValueError("No default config found, add one for the model to function")
#apply defaults
for k, v in self.default_config.items():
full_config[k] = v
#apply user defined config if provided
for k, v in self.user_config.items():
full_config[k] = v
full_config = DotMap(full_config)
return full_config
def forward(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
x = self.lm_head(x)
if target:
logits = x.view(-1, logits.shape[-1])
labels = target.view(-1)
loss = F.cross_entropy(logits, labels)
#clean this mess later
if cache:
if target:
return loss, x.float(), kv
else:
return x.float(), kv
else:
if target:
return loss, x.float()
else:
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
if kv is None:
kv = [None] * self.n_layer
kv_new = []
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
kv_new.append(kvi)
x = self.ln_final(x)
if cache:
return x, kv_new
else:
return x, None
\ No newline at end of file
...@@ -261,14 +261,3 @@ class GPTJBaseLM(lm_base.BaseLM): ...@@ -261,14 +261,3 @@ class GPTJBaseLM(lm_base.BaseLM):
nn.Module.__init__(self) nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm) lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPTJModel self.model_class=GPTJModel
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = GPTJBaseLM.load(config, path, state_dict)
return model
...@@ -14,7 +14,10 @@ import os ...@@ -14,7 +14,10 @@ import os
from pathlib import Path from pathlib import Path
import math import math
from basedformer import lm_utils from basedformer import lm_utils
from basedformer.models import base_lm
from dataclasses import dataclass from dataclasses import dataclass
#import dotmap
from dotmap import DotMap
def fixed_pos_embedding(dim=None, seq_len=None, x=None): def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None: if x is None:
...@@ -189,83 +192,20 @@ class GPTJLayer(nn.Module): ...@@ -189,83 +192,20 @@ class GPTJLayer(nn.Module):
return x, kv return x, kv
class GPTJModel(nn.Module): class GPTJModel(base_lm.BaseModel):
def __init__(self, config, **kwargs): def __init__(self, user_config, **kwargs):
nn.Module.__init__(self) self.default_config = {
self.config = config 'n_layer': 6,
self.n_layer = config.n_layer 'n_head': 8,
self.hidden_dim = config.hidden_dim 'n_tokens': 2048,
self.vocab_embed = nn.Embedding(config.vocab_dim, self.hidden_dim, device=config.device, dtype=config.dtype) 'hidden_dim': 512,
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype) 'vocab_dim': 50400,
self.layers = nn.ModuleList([]) 'eps': 1e-5,
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_dim, bias=True) 'device': torch.device('cuda'),
for _ in range(config.n_layer): 'dtype': torch.float16,
self.layers.append( 'Layer': GPTJLayer,
config.Layer( 'activation': gelu_new,
attn=SelfAttention, 'SelfAttention': SelfAttention,
ff=FeedForward, 'FeedForward': FeedForward,
config=config, }
) base_lm.BaseModel.__init__(self, user_config, **kwargs)
)
def forward(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
x = self.lm_head(x)
if target:
logits = x.view(-1, logits.shape[-1])
labels = target.view(-1)
loss = F.cross_entropy(logits, labels)
#clean this mess later
if cache:
if target:
return loss, x.float(), kv
else:
return x.float(), kv
else:
if target:
return loss, x.float()
else:
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
if kv is None:
kv = [None] * self.n_layer
kv_new = []
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x, kvi = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv[layer_id], cache=cache)
kv_new.append(kvi)
x = self.ln_final(x)
if cache:
return x, kv_new
else:
return x, None
@dataclass
class GPTJConfig:
n_layer: int = 6
n_head: int = 8
n_tokens: int = 2048
hidden_dim: int = 512
vocab_dim: int = 50400
eps: float = 1e-5
device: torch.device = torch.device('cuda')
dtype: torch.dtype = torch.float16
Layer: nn.Module = GPTJLayer
activation: Callable = gelu_new
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
config = GPTJConfig(**config)
model = lm_utils._load_dict_model(GPTJModel, config, path)
return model
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment