Commit 03a1e288 authored by AUTOMATIC's avatar AUTOMATIC

turns out LayerNorm also has weight and bias and needs to be pre-multiplied...

turns out LayerNorm also has weight and bias and needs to be pre-multiplied and trained for hypernets
parent e4877722
......@@ -52,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
if type(layer) == torch.nn.Linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
......@@ -80,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
if type(layer) == torch.nn.Linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer_structure += [layer.weight, layer.bias]
return layer_structure
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment