Commit 41f39980 authored by novelailab's avatar novelailab

make lm_base functions functional

parent 8d44445e
......@@ -7,88 +7,54 @@ import os
import json
from dataclasses import dataclass
from pathlib import Path
'''
BaseLM config dataclass:
model_config = {
"model_class":
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
}
'''
@dataclass
class BaseLMConfig():
model_class: type
n_layer: int
n_head: int
hidden_dim: int
vocab_dim: int
eps: float
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class BaseLM(nn.Module):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
self.config = config
self.lm = lm
self.model_class = None
def init_weights(self):
for module in self.lm.modules():
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
def init_weights(model, n_layer):
for module in model.modules():
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if ("ff2" in name or "out_proj" in name) and "weight" in name:
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.config["n_layer"])))
@classmethod
def init(cls, config):
model = cls(config)
model.lm = model.model_class(**config)
model.init_weights()
#make this modular later
return model
@classmethod
def no_init(cls, config):
model = cls(config)
model.lm = utils.no_init(lambda: model.model_class(**config))
return model
@classmethod
def load(cls, config, path=None, state_dict=None, strict=False):
# I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions -- actually nope.
if path:
state_dict = utils.SplitCheckpoint(path, device="cuda")
model = cls(config)
model.lm = utils.no_init(lambda: model.model_class(**config))
model.lm.load_state_dict(state_dict, strict=strict)
return model
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if ("ff2" in name or "out_proj" in name) and "weight" in name:
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * n_layer)))
@classmethod
def init(model_class, config):
model = model_class(config)
model.init_weights()
return model
@classmethod
def no_init(model_class, config):
model = utils.no_init(lambda: model_class(config))
return model
@classmethod
def load(config, model_class, path=None, state_dict=None, strict=False):
# I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions -- actually nope.
if path:
state_dict = utils.SplitCheckpoint(path, device="cuda")
model= utils.no_init(lambda: model_class(**config))
model.load_state_dict(state_dict, strict=strict)
return model
def save(model, path):
try: os.mkdir(path)
except: pass
checkpoint = {}
for i, x in enumerate(model.state_dict().items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
def save(self, path):
if self.lm is None:
print("No LM object to save. Please first init a model.")
return
try: os.mkdir(path)
except: pass
checkpoint = {}
for i, x in enumerate(self.lm.state_dict().items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment