Commit 41b51369 authored by novelailab's avatar novelailab

not too bad

parent 6cbab785
...@@ -11,6 +11,7 @@ except ImportError: ...@@ -11,6 +11,7 @@ except ImportError:
import os import os
from pathlib import Path from pathlib import Path
import math import math
from basedformer import lm_base
def fixed_pos_embedding(dim=None, seq_len=None, x=None): def fixed_pos_embedding(dim=None, seq_len=None, x=None):
if x is None: if x is None:
...@@ -223,4 +224,21 @@ class GPTJModel(nn.Module): ...@@ -223,4 +224,21 @@ class GPTJModel(nn.Module):
for layer_id, layer in enumerate(self.layers): for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck) x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x) x = self.ln_final(x)
return x return x
\ No newline at end of file
class GPTJBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPTJModel
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = GPTJBaseLM.load(config, path, state_dict)
return model
...@@ -5,6 +5,28 @@ from torch import nn ...@@ -5,6 +5,28 @@ from torch import nn
from basedformer import gptj from basedformer import gptj
import os import os
import json import json
from dataclasses import dataclass
'''
BaseLM config dataclass:
model_config = {
"model_class":
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
}
'''
@dataclass
class BaseLMConfig():
model_class: type
n_layer: int
n_head: int
hidden_dim: int
vocab_dim: int
eps: float
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense. #Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class BaseLM(nn.Module): class BaseLM(nn.Module):
...@@ -12,6 +34,7 @@ class BaseLM(nn.Module): ...@@ -12,6 +34,7 @@ class BaseLM(nn.Module):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.lm = lm self.lm = lm
self.model_class = None
def init_weights(self): def init_weights(self):
for module in self.lm.modules(): for module in self.lm.modules():
...@@ -33,8 +56,8 @@ class BaseLM(nn.Module): ...@@ -33,8 +56,8 @@ class BaseLM(nn.Module):
@classmethod @classmethod
def init(cls, config): def init(cls, config):
lm = config["model_class"](**config) model = cls(config)
model = cls(config, lm) model.lm = model.model_class(**config)
model.init_weights() model.init_weights()
#make this modular later #make this modular later
...@@ -42,8 +65,8 @@ class BaseLM(nn.Module): ...@@ -42,8 +65,8 @@ class BaseLM(nn.Module):
@classmethod @classmethod
def no_init(cls, config): def no_init(cls, config):
lm = utils.no_init(lambda: config.model_class(**config)) model = cls(config)
model = cls(config, lm) model.lm = utils.no_init(lambda: model.model_class(**config))
return model return model
@classmethod @classmethod
...@@ -53,8 +76,8 @@ class BaseLM(nn.Module): ...@@ -53,8 +76,8 @@ class BaseLM(nn.Module):
if path: if path:
state_dict = utils.SplitCheckpoint(path, device="cuda") state_dict = utils.SplitCheckpoint(path, device="cuda")
lm = config["model_class"](**config) model = cls(config)
model = cls(config, lm) model.lm = model.model_class(**config)
model.lm.load_state_dict(state_dict, strict=strict) model.lm.load_state_dict(state_dict, strict=strict)
return model return model
...@@ -70,16 +93,3 @@ class BaseLM(nn.Module): ...@@ -70,16 +93,3 @@ class BaseLM(nn.Module):
checkpoint[x[0]] = f"{path}/b{i}.pt" checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt") torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt") torch.save(checkpoint, f"{path}/m.pt")
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"model_class": gptj.GPTJModel,
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5
}
model = BaseLM.load(config, path, state_dict)
return model
...@@ -146,6 +146,7 @@ class HyperNetworkSingle(nn.Module): ...@@ -146,6 +146,7 @@ class HyperNetworkSingle(nn.Module):
return x.bfloat16() return x.bfloat16()
model_config = { model_config = {
"model_class":
"n_layer": 28, "n_layer": 28,
"n_head": 16, "n_head": 16,
"hidden_dim": 4096, "hidden_dim": 4096,
...@@ -178,7 +179,7 @@ gas = train_config["gas"] ...@@ -178,7 +179,7 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True) Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float() #model = GPTModel.gpt2_init(model_config).cuda().float()
model = load_gpt_j().cuda().bfloat16() model = lm_base.().cuda().bfloat16()
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
...@@ -196,7 +197,7 @@ opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, "adamw") ...@@ -196,7 +197,7 @@ opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function. # TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
train_dataset = utils.FbDataset(2049, train_config["data_path"]) train_dataset = FbDataset(2049, train_config["data_path"])
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0) train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0)
wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model_config}) wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model_config})
......
from basedformer import lm_base from basedformer import gptj
from basedformer.utils import * from basedformer.utils import *
import time import time
...@@ -67,7 +67,7 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True ...@@ -67,7 +67,7 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with torch.no_grad(): with torch.no_grad():
based_model = lm_base.load_gpt_j().cuda().half().eval() based_model = gptj.load_gpt_j().cuda().half().eval()
based_model = based_model.lm based_model = based_model.lm
print("Loaded based model") print("Loaded based model")
hf_model = no_init(lambda: AutoModelForCausalLM.from_pretrained('/home/xuser/models/j6b_ckpt_14001')).cuda().half().eval() hf_model = no_init(lambda: AutoModelForCausalLM.from_pretrained('/home/xuser/models/j6b_ckpt_14001')).cuda().half().eval()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment