Commit 6a4fa73a authored by discus0434's avatar discus0434

small fix

parent 97749b7c
......@@ -51,10 +51,9 @@ class HypernetworkModule(torch.nn.Module):
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout
if use_dropout:
p = 0.5 if 0 <= i <= len(layer_structure) - 3 else 0.2
linears.append(torch.nn.Dropout(p=p))
# Add dropout expect last layer
if use_dropout and i < len(layer_structure) - 3:
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment