Commit 4596b61e authored by novelailab's avatar novelailab

hypertrain GRU and pass layer_ids

parent c99ffa47
......@@ -68,6 +68,34 @@ def discounted_cumsum(t, gamma):
def shift(x, amt, dim = -1):
return F.pad(x, (*((0, 0) * (-dim - 1)), amt, -amt), value = 0.)
class HyperNetworkGRU(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear1 = nn.Linear(embed_dim, embed_dim//8)
self.gru = nn.GRU(embed_dim//8, embed_dim // 8, num_layers=1, bidirectional=False, batch_first=True)
self.linear2 = nn.Linear(embed_dim // 8, embed_dim)
self.ln_1 = nn.LayerNorm(embed_dim // 8, eps=1e-5)
self.activation = gelu_new
for module in self.modules():
_init_weights(module)
for param in self.linear2.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
for param in self.gru.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
def forward(self, x):
x = x.float()
x = self.linear1(x)
x = self.gru(x)[0]
x = self.ln_1(x)
x = self.linear2(x)
x = ck(self.activation, x)
return x.bfloat16()
class HyperNetwork(nn.Module):
def __init__(self, config):
super().__init__()
......@@ -89,13 +117,37 @@ class HyperNetwork(nn.Module):
def forward(self, x):
x = x.float()
x = shift_tokens(x, self.num_shifts)
#x = shift_tokens(x, self.num_shifts)
x = self.linear(x)
x = ck(self.activation, x)
x = self.linear2(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
class HyperNetworkSingle(nn.Module):
def __init__(self, config):
super().__init__()
embed_dim = config["hidden_dim"]
self.linear = nn.Linear(embed_dim, embed_dim, bias=True)
self.activation = gelu_new
#self.linear.weight.data.normal_(mean=0.0, std=0.02)
for module in self.modules():
_init_weights(module)
for param in self.linear.parameters():
param.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * config["n_layer"])))
#state = self.state_dict()
#for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
#self.load_state_dict(state)
def forward(self, x):
x = x.float()
#x = shift_tokens(x, self.num_shifts)
x = self.linear(x)
x = x.mul(torch.sigmoid(x))
return x.bfloat16()
model_config = {
"n_layer": 12,
......@@ -123,9 +175,9 @@ train_config = {
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/fixedj",
"run_name": "bighyper-gpt-j-enwik9-6b-postln-bf16-1e-4",
"lr": 1e-4,
"end_lr": 1e-4,
"run_name": "gpt-j-enwik9-6b-postln-bf16-5e-4",
"lr": 5e-4,
"end_lr": 5e-4,
"warmup_steps": 50,
"bs": 1,
"gas": 16,
......@@ -149,7 +201,8 @@ for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True
hypernetwork = HyperNetwork(model_config).cuda().float()
#hypernetwork = HyperNetwork(model_config).cuda().float()
hypernetwork = nn.ModuleList([HyperNetwork(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)])
for param in hypernetwork.parameters():
param.requires_grad = True
......@@ -201,7 +254,7 @@ for input_ids, labels in t:
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex) / (bs*gas)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = step_per_sec * 1024
tokens_per_sec = step_per_sec * 2048
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log({"train/loss": loss, "train/tokens_per_sec": tokens_per_sec, "train/sec_per_step": sec_per_step, "train/step_per_sec": step_per_sec, "train/lr": opt.curr_lr, "train/loss_scale": scaler.get_scale()})
curr_step += 1
......
......@@ -225,10 +225,11 @@ class GPTLayer(nn.Module):
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, hypernetwork=None, act_ck=False):
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False):
residual = x
if act_ck:
x = ck(self.ln_preattn, x)
attn_out = ck(self.attn, x)
......@@ -237,14 +238,14 @@ class GPTLayer(nn.Module):
x = self.ln_preattn(x)
attn_out = self.attn(x)
if hypernetwork:
hyper_out = hypernetwork(x)
if hypernetwork and layer_id % 5 == 0:
hyper_out = hypernetwork[(layer_id // 5) - 1](x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork:
if hypernetwork and layer_id % 5 == 0:
x = x + hyper_out
return x
......@@ -284,8 +285,8 @@ class GPTModel(nn.Module):
def get_embeds(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer in self.layers:
x = layer(x, hypernetwork, act_ck)
for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id, hypernetwork, act_ck)
x = self.ln_final(x)
return x
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment