Commit 41b51369 authored by novelailab's avatar novelailab

not too bad

parent 6cbab785
......@@ -11,6 +11,7 @@ except ImportError:
import os
from pathlib import Path
import math
from basedformer import lm_base
def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None:
......@@ -224,3 +225,20 @@ class GPTJModel(nn.Module):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x)
return x
class GPTJBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPTJModel
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = GPTJBaseLM.load(config, path, state_dict)
return model
......@@ -5,6 +5,28 @@ from torch import nn
from basedformer import gptj
import os
import json
from dataclasses import dataclass
'''
BaseLM config dataclass:
model_config = {
"model_class":
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
}
'''
@dataclass
class BaseLMConfig():
model_class: type
n_layer: int
n_head: int
hidden_dim: int
vocab_dim: int
eps: float
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class BaseLM(nn.Module):
......@@ -12,6 +34,7 @@ class BaseLM(nn.Module):
nn.Module.__init__(self)
self.config = config
self.lm = lm
self.model_class = None
def init_weights(self):
for module in self.lm.modules():
......@@ -33,8 +56,8 @@ class BaseLM(nn.Module):
@classmethod
def init(cls, config):
lm = config["model_class"](**config)
model = cls(config, lm)
model = cls(config)
model.lm = model.model_class(**config)
model.init_weights()
#make this modular later
......@@ -42,8 +65,8 @@ class BaseLM(nn.Module):
@classmethod
def no_init(cls, config):
lm = utils.no_init(lambda: config.model_class(**config))
model = cls(config, lm)
model = cls(config)
model.lm = utils.no_init(lambda: model.model_class(**config))
return model
@classmethod
......@@ -53,8 +76,8 @@ class BaseLM(nn.Module):
if path:
state_dict = utils.SplitCheckpoint(path, device="cuda")
lm = config["model_class"](**config)
model = cls(config, lm)
model = cls(config)
model.lm = model.model_class(**config)
model.lm.load_state_dict(state_dict, strict=strict)
return model
......@@ -70,16 +93,3 @@ class BaseLM(nn.Module):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"model_class": gptj.GPTJModel,
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = BaseLM.load(config, path, state_dict)
return model
......@@ -146,6 +146,7 @@ class HyperNetworkSingle(nn.Module):
return x.bfloat16()
model_config = {
"model_class":
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
......@@ -178,7 +179,7 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = load_gpt_j().cuda().bfloat16()
model = lm_base.().cuda().bfloat16()
for param in model.parameters():
param.requires_grad = False
......@@ -196,7 +197,7 @@ opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
train_dataset = utils.FbDataset(2049, train_config["data_path"])
train_dataset = FbDataset(2049, train_config["data_path"])
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0)
wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model_config})
......
from basedformer import lm_base
from basedformer import gptj
from basedformer.utils import *
import time
......@@ -67,7 +67,7 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with torch.no_grad():
based_model = lm_base.load_gpt_j().cuda().half().eval()
based_model = gptj.load_gpt_j().cuda().half().eval()
based_model = based_model.lm
print("Loaded based model")
hf_model = no_init(lambda: AutoModelForCausalLM.from_pretrained('/home/xuser/models/j6b_ckpt_14001')).cuda().half().eval()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment