Commit 3ce2bfdf authored by Muhammad Rizqi Nur's avatar Muhammad Rizqi Nur

Add cleanup after training

parent ab27c111
...@@ -398,6 +398,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -398,6 +398,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
forced_filename = "<none>" forced_filename = "<none>"
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
try:
for i, entries in pbar: for i, entries in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
if len(loss_dict) > 0: if len(loss_dict) > 0:
...@@ -510,6 +512,13 @@ Last saved hypernetwork: {html.escape(last_saved_file)}<br/> ...@@ -510,6 +512,13 @@ Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
finally:
if weights:
for weight in weights:
weight.requires_grad = False
if unload:
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
report_statistics(loss_dict) report_statistics(loss_dict)
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
......
...@@ -283,6 +283,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -283,6 +283,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
embedding_yet_to_be_embedded = False embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
try:
for i, entries in pbar: for i, entries in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
...@@ -396,6 +398,9 @@ Last saved embedding: {html.escape(last_saved_file)}<br/> ...@@ -396,6 +398,9 @@ Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
finally:
if embedding and embedding.vec is not None:
embedding.vec.requires_grad = False
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment