Commit f89829ec authored by aria1th's avatar aria1th

Revert "fix bugs and optimizations"

This reverts commit 108be155.
parent 108be155
...@@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module): ...@@ -36,14 +36,14 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# if skip_first_layer because first parameters potentially contain negative values # if skip_first_layer because first parameters potentially contain negative values
# if i < 1: continue # if i < 1: continue
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
if activation_func in HypernetworkModule.activation_dict: if activation_func in HypernetworkModule.activation_dict:
linears.append(HypernetworkModule.activation_dict[activation_func]()) linears.append(HypernetworkModule.activation_dict[activation_func]())
else: else:
print("Invalid key {} encountered as activation function!".format(activation_func)) print("Invalid key {} encountered as activation function!".format(activation_func))
# if use_dropout: # if use_dropout:
# linears.append(torch.nn.Dropout(p=0.3)) # linears.append(torch.nn.Dropout(p=0.3))
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
self.linear = torch.nn.Sequential(*linears) self.linear = torch.nn.Sequential(*linears)
...@@ -115,24 +115,11 @@ class Hypernetwork: ...@@ -115,24 +115,11 @@ class Hypernetwork:
for k, layers in self.layers.items(): for k, layers in self.layers.items():
for layer in layers: for layer in layers:
layer.train()
res += layer.trainables() res += layer.trainables()
return res return res
def eval(self):
for k, layers in self.layers.items():
for layer in layers:
layer.eval()
for items in self.weights():
items.requires_grad = False
def train(self):
for k, layers in self.layers.items():
for layer in layers:
layer.train()
for items in self.weights():
items.requires_grad = True
def save(self, filename): def save(self, filename):
state_dict = {} state_dict = {}
...@@ -303,6 +290,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -303,6 +290,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.sd_model.first_stage_model.to(devices.cpu) shared.sd_model.first_stage_model.to(devices.cpu)
hypernetwork = shared.loaded_hypernetwork hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
losses = torch.zeros((32,)) losses = torch.zeros((32,))
last_saved_file = "<none>" last_saved_file = "<none>"
...@@ -313,10 +304,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -313,10 +304,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW(hypernetwork.weights(), lr=scheduler.learn_rate) # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
hypernetwork.train()
for i, entries in pbar: for i, entries in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
...@@ -337,9 +328,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -337,9 +328,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
losses[hypernetwork.step % losses.shape[0]] = loss.item() losses[hypernetwork.step % losses.shape[0]] = loss.item()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad()
loss.backward() loss.backward()
del loss
optimizer.step() optimizer.step()
mean_loss = losses.mean() mean_loss = losses.mean()
if torch.isnan(mean_loss): if torch.isnan(mean_loss):
...@@ -356,10 +346,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -356,10 +346,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
}) })
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
torch.cuda.empty_cache()
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
with torch.no_grad():
hypernetwork.eval() optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
...@@ -396,8 +385,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -396,8 +385,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
image.save(last_saved_image) image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}" last_saved_image += f", prompt: {preview_text}"
hypernetwork.train()
shared.state.job_no = hypernetwork.step shared.state.job_no = hypernetwork.step
shared.state.textinfo = f""" shared.state.textinfo = f"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment