Commit 939f1652 authored by DepFA's avatar DepFA Committed by AUTOMATIC1111

only save 1 image per embedding

parent 9e846083
...@@ -242,6 +242,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -242,6 +242,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
last_saved_file = "<none>" last_saved_file = "<none>"
last_saved_image = "<none>" last_saved_image = "<none>"
embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0 ititial_step = embedding.step or 0
if ititial_step > steps: if ititial_step > steps:
...@@ -281,6 +282,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -281,6 +282,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file) embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), { write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}", "loss": f"{losses.mean():.7f}",
...@@ -318,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -318,7 +320,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.current_image = image shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file): if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
...@@ -342,6 +344,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -342,6 +344,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
captioned_image = insert_image_data_embed(captioned_image, data) captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
embedding_yet_to_be_embedded = False
image.save(last_saved_image) image.save(last_saved_image)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment