Commit e4877722 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #3197 from AUTOMATIC1111/training-help-text

Training UI Changes
parents 2273e752 0c5522ea
......@@ -396,7 +396,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
Loss: {mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
......
......@@ -10,9 +10,10 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
......
......@@ -11,7 +11,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
try:
if process_caption:
shared.interrogator.load()
......@@ -21,7 +21,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru)
finally:
......@@ -33,7 +33,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
......@@ -48,7 +48,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
def save_pic_with_caption(image, index):
def save_pic_with_caption(image, index, existing_caption=None):
caption = ""
if process_caption:
......@@ -66,17 +66,26 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))
if preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
elif preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
elif preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption
caption = caption.strip()
if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
subindex[0] += 1
def save_pic(image, index):
save_pic_with_caption(image, index)
def save_pic(image, index, existing_caption=None):
save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index)
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
......@@ -86,6 +95,13 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
except Exception:
continue
existing_caption = None
try:
existing_caption = open(os.path.splitext(filename)[0] + '.txt', 'r').read()
except Exception as e:
print(e)
if shared.state.interrupted:
break
......@@ -97,20 +113,20 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
img = img.resize((width, height * img.height // img.width))
top = img.crop((0, 0, width, height))
save_pic(top, index)
save_pic(top, index, existing_caption=existing_caption)
bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
save_pic(bot, index, existing_caption=existing_caption)
elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height))
left = img.crop((0, 0, width, height))
save_pic(left, index)
save_pic(left, index, existing_caption=existing_caption)
right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
save_pic(right, index, existing_caption=existing_caption)
else:
img = images.resize_image(1, img, width, height)
save_pic(img, index)
save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob()
......@@ -153,7 +153,7 @@ class EmbeddingDatabase:
return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'):
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
......@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
......
......@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
def create_embedding(name, initialization_text, nvpt, overwrite_old):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
......
......@@ -1211,6 +1211,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
with gr.Row():
with gr.Column(scale=3):
......@@ -1224,6 +1225,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row():
......@@ -1238,6 +1240,7 @@ def create_ui(wrap_gradio_gpu_call):
process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
......@@ -1253,14 +1256,17 @@ def create_ui(wrap_gradio_gpu_call):
run_preprocess = gr.Button(value="Preprocess", variant='primary')
with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
with gr.Row():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
......@@ -1294,6 +1300,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name,
initialization_text,
nvpt,
overwrite_old_embedding,
],
outputs=[
train_embedding_name,
......@@ -1307,6 +1314,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
overwrite_old_hypernetwork,
new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func,
......@@ -1326,6 +1334,7 @@ def create_ui(wrap_gradio_gpu_call):
process_dst,
process_width,
process_height,
preprocess_txt_action,
process_flip,
process_split,
process_caption,
......@@ -1342,7 +1351,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
learn_rate,
embedding_learn_rate,
batch_size,
dataset_directory,
log_directory,
......@@ -1367,7 +1376,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
learn_rate,
hypernetwork_learn_rate,
batch_size,
dataset_directory,
log_directory,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment