Commit 2a8307f2 authored by novelailab's avatar novelailab

reproduced golden

parent 4596b61e
...@@ -201,8 +201,8 @@ for name, p in model.named_parameters(): ...@@ -201,8 +201,8 @@ for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name): if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True p.requires_grad = True
#hypernetwork = HyperNetwork(model_config).cuda().float() hypernetwork = HyperNetworkSingle(model_config).cuda().float()
hypernetwork = nn.ModuleList([HyperNetwork(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)]) #hypernetwork = nn.ModuleList([HyperNetwork(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)])
for param in hypernetwork.parameters(): for param in hypernetwork.parameters():
param.requires_grad = True param.requires_grad = True
......
...@@ -227,7 +227,7 @@ class GPTLayer(nn.Module): ...@@ -227,7 +227,7 @@ class GPTLayer(nn.Module):
self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype) self.attn = attn(hidden_dim=hidden_dim, n_head=n_head, device=device, dtype=dtype)
self.tick = True self.tick = True
def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False): def forward(self, x, layer_id=None, hypernetwork=None, act_ck=False, diff_hypernets=False):
residual = x residual = x
if act_ck: if act_ck:
...@@ -238,14 +238,17 @@ class GPTLayer(nn.Module): ...@@ -238,14 +238,17 @@ class GPTLayer(nn.Module):
x = self.ln_preattn(x) x = self.ln_preattn(x)
attn_out = self.attn(x) attn_out = self.attn(x)
if hypernetwork and layer_id % 5 == 0: if diff_hypernets and hypernetwork:
hyper_out = hypernetwork[(layer_id // 5) - 1](x) if layer_id % 1 == 0:
hyper_out = hypernetwork[(layer_id // 5) - 1](x)
else:
hyper_out = hypernetwork(x)
ff_out = self.ff(x, act_ck) ff_out = self.ff(x, act_ck)
#order of addition matters, i had no idea... fixed a bug here. #order of addition matters, i had no idea... fixed a bug here.
x = attn_out + ff_out + residual x = attn_out + ff_out + residual
#x = residual + attn_out + ff_out -> doesn't match. #x = residual + attn_out + ff_out -> doesn't match.
if hypernetwork and layer_id % 5 == 0: if hypernetwork and not diff_hypernets or layer_id % 5 == 0:
x = x + hyper_out x = x + hyper_out
return x return x
...@@ -286,7 +289,7 @@ class GPTModel(nn.Module): ...@@ -286,7 +289,7 @@ class GPTModel(nn.Module):
def get_embeds(self, x, hypernetwork=None, act_ck=False): def get_embeds(self, x, hypernetwork=None, act_ck=False):
x = self.vocab_embed(x) x = self.vocab_embed(x)
for layer_id, layer in enumerate(self.layers): for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id, hypernetwork, act_ck) x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x) x = self.ln_final(x)
return x return x
......
...@@ -4,6 +4,7 @@ import sys ...@@ -4,6 +4,7 @@ import sys
name = 'pyfra-basedformer' name = 'pyfra-basedformer'
dry = False dry = False
bash = False
config_obj = KubeConfig() config_obj = KubeConfig()
config_obj.set_name(name) config_obj.set_name(name)
...@@ -27,5 +28,9 @@ env1.sh('pip3 install einops==0.4.1 pyyaml wandb') ...@@ -27,5 +28,9 @@ env1.sh('pip3 install einops==0.4.1 pyyaml wandb')
env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4') env1.sh('wandb login 21a9442d42a35e15ce421f2b702ec58508b9adc4')
env1.sh('pip3 install dotmap') env1.sh('pip3 install dotmap')
with always_rerun(): with always_rerun():
print(f"Running {sys.argv[1]}") if bash:
path.sh(f'python3 {sys.argv[1]}') path.sh("bash")
\ No newline at end of file else:
print(f"Running {sys.argv[1]}")
path.sh(f'python3 {sys.argv[1]}')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment