Commit 39518025 authored by novelailab's avatar novelailab

future pretrained_model class

parent 5ca754a6
......@@ -2,6 +2,7 @@ import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from dotmap import DotMap
import math
class BaseModel(nn.Module):
def __init__(self, user_config, **kwargs):
......@@ -27,7 +28,26 @@ class BaseModel(nn.Module):
config=config,
)
)
def init_weights(self):
n_layer = self.n_layer
for module in self.modules():
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if ("ff2" in name or "out_proj" in name) and "weight" in name:
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * n_layer)))
def configure_model(self):
full_config = {}
if not hasattr(self, 'default_config'):
......
import torch.nn as nn
import torch.nn.functional as F
from basedformer import utils
from dotmap import DotMap
from pathlib import Path
import torch
import json
class PretrainedModel(nn.Module):
def __init__(self, **kwargs):
nn.Module.__init__(self)
self.config = None
@classmethod
def no_init(cls, config):
model = utils.no_init(lambda: cls(config))
return model
@classmethod
def init(cls, config):
model = cls(config)
if hasattr(model, 'init_weights'):
model.init_weights()
else:
raise ValueError("No init_weights found, add one for init to function")
return model
def save(self, path, save_as=torch.float16):
original_dtype = model.dtype
model = self
if save_as:
model = model.to(save_as)
path = Path(path)
model_path = path / "model"
#make folder
model_path.mkdir(parents=True, exist_ok=True)
checkpoint = {}
for i, x in enumerate(model.state_dict().items()):
checkpoint[x[0]] = model_path / f"b{i}.pt"
torch.save(x[1], model_path / f"b{i}.pt")
torch.save(checkpoint, model_path / "m.pt")
#write model.config to config.json inside path
#with open(path / "config.json", "w") as f:
# json.dump(serialize_config(model.config), f)
if save_as:
model = model.to(original_dtype)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment