Commit 840dd7f4 authored by novelailab's avatar novelailab

cleanup

parent 1c9d3a31
This diff is collapsed.
from torch import nn
import torch
import os
import math
from lm_arch.utils import *
# Can access and change every module from here, as both Layer class and ff and attn classes are passed from GPTModel.
class GPTModel(nn.Module):
def __init__(self, hidden_dim=512, n_layer=12, n_head=4, vocab_dim=50400, eps=1e-4, activation=nn.GELU(), Layer=None, SelfAttention=None, FeedForward=None, device="cuda", dtype=torch.float16):
super(GPTModel, self).__init__()
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
#TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet.
#TODO: Do we want to pass a config object everywhere? I don't exactly like that but passing a lot of variables is a bit ugly too.
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer in self.layers:
x = layer(x, hypernetwork, act_ck)
x = self.ln_final(x)
return x
def get_logits(self, x, hypernetwork=None, act_ck=False):
x = self.forward(x, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.lm_head(x)
return x.float()
@classmethod
def load(cls, config, path=None, state_dict=None):
if path:
state_dict = SplitCheckpoint(path, device="cuda")
model = no_init(lambda: cls(**config))
model.load_state_dict(state_dict, strict=False)
return model
@classmethod
def init(cls, config):
model = cls(**config)
return model
@classmethod
def neox_init(cls, config):
model = cls(**config)
modules = [*model.layers[:-1], model.vocab_embed, model.ln_final, model.lm_head]
init = small_init_method(config["hidden_dim"])
for module in modules:
for param in module.parameters():
init(param)
last_layer = model.layers[-1]
last_layer_init = wang_init_method(config["n_layer"], config["hidden_dim"])
for param in last_layer.parameters():
last_layer_init(param)
return model
@classmethod
def simple_init(cls, config):
model = cls(**config)
state = model.state_dict()
for k in state:
state[k] = state[k] / math.sqrt(2 * config["n_layer"])
model.load_state_dict(state)
return model
def save(self, path):
try: os.mkdir(path)
except: pass
checkpoint = {}
for i, x in enumerate(self.state_dict().items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
def wang_init_method(n_layers, dim):
std = 2 / n_layers / math.sqrt(dim)
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
# Stolen from NeoX. For the 20B run wang_init used on the output layer and small_init on rest of the layers.
def small_init_method(dim):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2010), using a normal distribution."""
std = math.sqrt(2 / (5 * dim))
def init_(tensor):
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
\ No newline at end of file
......@@ -79,7 +79,7 @@ def _attn(query, key, value, causal_mask, masked_bias,
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype):
super(SelfAttention, self).__init__()
super().__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
......@@ -143,7 +143,7 @@ class SelfAttention(nn.Module):
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype):
super(FeedForward, self).__init__()
super().__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation
......@@ -159,14 +159,16 @@ class FeedForward(nn.Module):
class GPTJLayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
super(GPTJLayer, self).__init__()
super().__init__(self)
self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, hypernetwork=None, act_ck=False):
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False, interleaving_layers=False, every_n=5):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x)
......@@ -175,17 +177,35 @@ class GPTJLayer(nn.Module):
x = self.ln_preattn(x)
attn_out = self.attn(x)
ff_out = self.ff(x, act_ck)
x = residual + attn_out + ff_out
if hypernetwork:
hyper_out = hypernetwork(x)
if diff_hypernets:
if interleaving_layers and layer_id % every_n == 0:
if self.tick:
hyper_out = hypernetwork[0](x)
self.tick = False
else:
hyper_out = hypernetwork[1](x)
self.tick = True
elif layer_id % every_n == 0:
hyper_out = hypernetwork[(layer_id // every_n) - 1](x)
else:
if layer_id % every_n == 0:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % every_n == 0:
x = x + hyper_out
return x
class GPTModel(nn.Module):
class GPTJModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation, Layer, device, dtype):
super(GPTModel, self).__init__()
super().__init__(self)
self.n_layer = n_layer
self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
......@@ -194,25 +214,6 @@ class GPTModel(nn.Module):
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
#TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet.
#TODO: Do we want to pass a config object everywhere? I don't exactly like that but passing a lot of variables is a bit ugly too.
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if ("ff2" in name or "out_proj" in name) and "weight" in name:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * self.n_layer)))
def forward(self, x, hypernetwork=None, act_ck=False):
x = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck)
......
......@@ -4,6 +4,7 @@ import torch
from torch import nn
import os
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class BaseLM(nn.Module):
def __init__(self, config=None, lm=None):
self.config = config
......@@ -57,6 +58,8 @@ class BaseLM(nn.Module):
def save(self, path):
if self.lm is None:
print("No LM object to save. Please first init a model.")
return
try: os.mkdir(path)
except: pass
checkpoint = {}
......
......@@ -6,7 +6,6 @@ except ImportError:
from pathlib import Path
import os
def no_init(loading_code):
def dummy(self):
return
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment