Commit 2a8307f2 authored by novelailab's avatar novelailab

reproduced golden

parent 4596b61e
......@@ -201,8 +201,8 @@ for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True
#hypernetwork = HyperNetwork(model_config).cuda().float()
hypernetwork = nn.ModuleList([HyperNetwork(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)])
hypernetwork = HyperNetworkSingle(model_config).cuda().float()
#hypernetwork = nn.ModuleList([HyperNetwork(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)])
for param in hypernetwork.parameters():
param.requires_grad = True
......
......@@ -227,7 +227,7 @@ class GPTLayer(nn.Module):
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False):
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False):
residual = x
if act_ck:
......@@ -238,14 +238,17 @@ class GPTLayer(nn.Module):
x = self.ln_preattn(x)
attn_out = self.attn(x)
if hypernetwork and layer_id % 5 == 0:
hyper_out = hypernetwork[(layer_id // 5) - 1](x)
if diff_hypernets and hypernetwork:
if layer_id % 1 == 0:
hyper_out = hypernetwork[(layer_id // 5) - 1](x)
else:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % 5 == 0:
if hypernetwork and not diff_hypernets or layer_id % 5 == 0:
x = x + hyper_out
return x
......@@ -286,7 +289,7 @@ class GPTModel(nn.Module):
def get_embeds(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id, hypernetwork, act_ck)
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x)
return x
......
......@@ -4,6 +4,7 @@ import sys
name = 'pyfra-basedformer'
dry = False
bash = False
config_obj = KubeConfig()
config_obj.set_name(name)
......@@ -27,5 +28,9 @@ env1.sh('pip3 install einops==0.4.1 pyyaml wandb')
env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4')
env1.sh('pip3 install dotmap')
with always_rerun():
print(f"Running {sys.argv[1]}")
path.sh(f'python3 {sys.argv[1]}')
\ No newline at end of file
if bash:
path.sh("bash")
else:
print(f"Running {sys.argv[1]}")
path.sh(f'python3 {sys.argv[1]}')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment