Commit fccba472 authored by discus0434's avatar discus0434

add an option to avoid dying relu

parent dcb45dfe
......@@ -32,7 +32,6 @@ class HypernetworkModule(torch.nn.Module):
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
assert activation_func not in self.activation_dict.keys() + "linear", f"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
linears = []
for i in range(len(layer_structure) - 1):
......@@ -43,12 +42,13 @@ class HypernetworkModule(torch.nn.Module):
# Add an activation func
if activation_func == "linear" or activation_func is None:
pass
# If ReLU, Skip adding it to the first layer to avoid dying ReLU
elif activation_func == "relu" and i < 1:
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
raise RuntimeError(
"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
)
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
# Add dropout
if use_dropout:
......@@ -166,8 +166,8 @@ class Hypernetwork:
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
)
self.name = state_dict.get('name', self.name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment