Commit 326fe7d4 authored by AUTOMATIC's avatar AUTOMATIC

Merge remote-tracking branch 'Melanpan/master'

parents 989a552d 8636b50a
......@@ -5,6 +5,7 @@ import os
import sys
import traceback
import tqdm
import csv
import torch
......@@ -262,6 +263,20 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0:
write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True
with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout:
csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"])
if write_csv_header:
csv_writer.writeheader()
csv_writer.writerow({"step": hypernetwork.step,
"loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
......
......@@ -6,6 +6,7 @@ import torch
import tqdm
import html
import datetime
import csv
from PIL import Image, PngImagePlugin
......@@ -256,6 +257,21 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0:
write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True
with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout:
csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss", "learn_rate"])
if write_csv_header:
csv_writer.writeheader()
csv_writer.writerow({"epoch": epoch_num + 1,
"epoch_step": epoch_step - 1,
"loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate})
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
......
......@@ -1172,6 +1172,7 @@ def create_ui(wrap_gradio_gpu_call):
training_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=500, precision=0)
write_csv_every = gr.Number(label='Save an csv containing the loss to log directory every N steps, 0 to disable', value=500, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
......@@ -1250,6 +1251,7 @@ def create_ui(wrap_gradio_gpu_call):
steps,
create_image_every,
save_embedding_every,
write_csv_every,
template_file,
save_image_with_stored_embedding,
preview_from_txt2img,
......@@ -1272,6 +1274,7 @@ def create_ui(wrap_gradio_gpu_call):
steps,
create_image_every,
save_embedding_every,
write_csv_every,
template_file,
preview_from_txt2img,
*txt2img_preview_params,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment