Commit 6a02841f authored by discus0434's avatar discus0434 Committed by GitHub

Merge pull request #2 from aria1th/patch-6

generalized some functions and option for ignoring first layer
parents f8733ad0 f89829ec
...@@ -21,21 +21,27 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler ...@@ -21,21 +21,27 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
activation_dict = {"relu": torch.nn.ReLU, "leakyrelu": torch.nn.LeakyReLU, "elu": torch.nn.ELU,
"swish": torch.nn.Hardswish}
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = [] linears = []
for i in range(len(layer_structure) - 1): for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
if activation_func == "relu": # if skip_first_layer because first parameters potentially contain negative values
linears.append(torch.nn.ReLU()) # if i < 1: continue
if activation_func == "leakyrelu": if activation_func in HypernetworkModule.activation_dict:
linears.append(torch.nn.LeakyReLU()) linears.append(HypernetworkModule.activation_dict[activation_func]())
else:
print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout:
# linears.append(torch.nn.Dropout(p=0.3))
if add_layer_norm: if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
...@@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module): ...@@ -46,7 +52,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict) self.load_state_dict(state_dict)
else: else:
for layer in self.linear: for layer in self.linear:
if not "ReLU" in layer.__str__(): if isinstance(layer, torch.nn.Linear):
layer.weight.data.normal_(mean=0.0, std=0.01) layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_() layer.bias.data.zero_()
...@@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module): ...@@ -74,7 +80,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self): def trainables(self):
layer_structure = [] layer_structure = []
for layer in self.linear: for layer in self.linear:
if not "ReLU" in layer.__str__(): if isinstance(layer, torch.nn.Linear):
layer_structure += [layer.weight, layer.bias] layer_structure += [layer.weight, layer.bias]
return layer_structure return layer_structure
...@@ -298,6 +304,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -298,6 +304,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment