Commit af144ebd authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into ckpt-cache

parents e21f01f6 6a4e8467
...@@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) {
window.document.addEventListener('dragover', e => { window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap ) { if ( !imgWrap && target.placeholder != "Prompt") {
return; return;
} }
e.stopPropagation(); e.stopPropagation();
...@@ -53,6 +53,9 @@ window.document.addEventListener('dragover', e => { ...@@ -53,6 +53,9 @@ window.document.addEventListener('dragover', e => {
window.document.addEventListener('drop', e => { window.document.addEventListener('drop', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
if (target.placeholder === "Prompt") {
return;
}
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap ) { if ( !imgWrap ) {
return; return;
......
...@@ -85,7 +85,10 @@ titles = { ...@@ -85,7 +85,10 @@ titles = {
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
"Filename join string": "This string will be used to join split words into a single line if the option above is enabled.", "Filename join string": "This string will be used to join split words into a single line if the option above is enabled.",
"Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply." "Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.",
"Weighted sum": "Result = A * (1 - M) + B * M",
"Add difference": "Result = A + (B - C) * M",
} }
......
window.onload = (function(){
window.addEventListener('drop', e => {
const target = e.composedPath()[0];
const idx = selected_gallery_index();
if (target.placeholder != "Prompt") return;
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
e.stopPropagation();
e.preventDefault();
const imgParent = gradioApp().getElementById(prompt_target);
const files = e.dataTransfer.files;
const fileInput = imgParent.querySelector('input[type="file"]');
if ( fileInput ) {
fileInput.files = files;
fileInput.dispatchEvent(new Event('change'));
}
});
});
var images_history_click_image = function(){
if (!this.classList.contains("transform")){
var gallery = images_history_get_parent_by_class(this, "images_history_cantainor");
var buttons = gallery.querySelectorAll(".gallery-item");
var i = 0;
var hidden_list = [];
buttons.forEach(function(e){
if (e.style.display == "none"){
hidden_list.push(i);
}
i += 1;
})
if (hidden_list.length > 0){
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
}
}
images_history_set_image_info(this);
}
var images_history_click_tab = function(){
var tabs_box = gradioApp().getElementById("images_history_tab");
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
tabs_box.classList.add(this.getAttribute("tabname"))
}
}
function images_history_disabled_del(){
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
btn.setAttribute('disabled','disabled');
});
}
function images_history_get_parent_by_class(item, class_name){
var parent = item.parentElement;
while(!parent.classList.contains(class_name)){
parent = parent.parentElement;
}
return parent;
}
function images_history_get_parent_by_tagname(item, tagname){
var parent = item.parentElement;
tagname = tagname.toUpperCase()
while(parent.tagName != tagname){
console.log(parent.tagName, tagname)
parent = parent.parentElement;
}
return parent;
}
function images_history_hide_buttons(hidden_list, gallery){
var buttons = gallery.querySelectorAll(".gallery-item");
var num = 0;
buttons.forEach(function(e){
if (e.style.display == "none"){
num += 1;
}
});
if (num == hidden_list.length){
setTimeout(images_history_hide_buttons, 10, hidden_list, gallery);
}
for( i in hidden_list){
buttons[hidden_list[i]].style.display = "none";
}
}
function images_history_set_image_info(button){
var buttons = images_history_get_parent_by_tagname(button, "DIV").querySelectorAll(".gallery-item");
var index = -1;
var i = 0;
buttons.forEach(function(e){
if(e == button){
index = i;
}
if(e.style.display != "none"){
i += 1;
}
});
var gallery = images_history_get_parent_by_class(button, "images_history_cantainor");
var set_btn = gallery.querySelector(".images_history_set_index");
var curr_idx = set_btn.getAttribute("img_index", index);
if (curr_idx != index) {
set_btn.setAttribute("img_index", index);
images_history_disabled_del();
}
set_btn.click();
}
function images_history_get_current_img(tabname, image_path, files){
return [
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
image_path,
files
];
}
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){
image_index = parseInt(image_index);
var tab = gradioApp().getElementById(tabname + '_images_history');
var set_btn = tab.querySelector(".images_history_set_index");
var buttons = [];
tab.querySelectorAll(".gallery-item").forEach(function(e){
if (e.style.display != 'none'){
buttons.push(e);
}
});
var img_num = buttons.length / 2;
if (img_num <= del_num){
setTimeout(function(tabname){
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
}, 30, tabname);
} else {
var next_img
for (var i = 0; i < del_num; i++){
if (image_index + i < image_index + img_num){
buttons[image_index + i].style.display = 'none';
buttons[image_index + img_num + 1].style.display = 'none';
next_img = image_index + i + 1
}
}
var bnt;
if (next_img >= img_num){
btn = buttons[image_index - del_num];
} else {
btn = buttons[next_img];
}
setTimeout(function(btn){btn.click()}, 30, btn);
}
images_history_disabled_del();
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
}
function images_history_turnpage(img_path, page_index, image_index, tabname){
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
buttons.forEach(function(elem) {
elem.style.display = 'block';
})
return [img_path, page_index, image_index, tabname];
}
function images_history_enable_del_buttons(){
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
btn.removeAttribute('disabled');
})
}
function images_history_init(){
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page')
if (load_txt2img_button){
for (var i in images_history_tab_list ){
tab = images_history_tab_list[i];
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
}
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
tabs_box.setAttribute("id", "images_history_tab");
var tab_btns = tabs_box.querySelectorAll("button");
for (var i in images_history_tab_list){
var tabname = images_history_tab_list[i]
tab_btns[i].setAttribute("tabname", tabname);
// this refreshes history upon tab switch
// until the history is known to work well, which is not the case now, we do not do this at startup
//tab_btns[i].addEventListener('click', images_history_click_tab);
}
tabs_box.classList.add(images_history_tab_list[0]);
// same as above, at page load
//load_txt2img_button.click();
} else {
setTimeout(images_history_init, 500);
}
}
var images_history_tab_list = ["txt2img", "img2img", "extras"];
setTimeout(images_history_init, 500);
document.addEventListener("DOMContentLoaded", function() {
var mutationObserver = new MutationObserver(function(m){
for (var i in images_history_tab_list ){
let tabname = images_history_tab_list[i]
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
buttons.forEach(function(bnt){
bnt.addEventListener('click', images_history_click_image, true);
});
// same as load_txt2img_button.click() above
/*
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
if (cls_btn){
cls_btn.addEventListener('click', function(){
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
}, false);
}*/
}
});
mutationObserver.observe( gradioApp(), { childList:true, subtree:true });
});
// code related to showing and updating progressbar shown as the image is being made // code related to showing and updating progressbar shown as the image is being made
global_progressbars = {} global_progressbars = {}
galleries = {}
galleryObservers = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
var progressbar = gradioApp().getElementById(id_progressbar) var progressbar = gradioApp().getElementById(id_progressbar)
...@@ -31,13 +33,24 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip ...@@ -31,13 +33,24 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
preview.style.width = gallery.clientWidth + "px" preview.style.width = gallery.clientWidth + "px"
preview.style.height = gallery.clientHeight + "px" preview.style.height = gallery.clientHeight + "px"
//only watch gallery if there is a generation process going on
check_gallery(id_gallery);
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
if(!progressDiv){ if(!progressDiv){
if (skip) { if (skip) {
skip.style.display = "none" skip.style.display = "none"
} }
interrupt.style.display = "none" interrupt.style.display = "none"
//disconnect observer once generation finished, so user can close selected image if they want
if (galleryObservers[id_gallery]) {
galleryObservers[id_gallery].disconnect();
galleries[id_gallery] = null;
}
} }
} }
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500) window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
...@@ -46,6 +59,28 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip ...@@ -46,6 +59,28 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
} }
} }
function check_gallery(id_gallery){
let gallery = gradioApp().getElementById(id_gallery)
// if gallery has no change, no need to setting up observer again.
if (gallery && galleries[id_gallery] !== gallery){
galleries[id_gallery] = gallery;
if(galleryObservers[id_gallery]){
galleryObservers[id_gallery].disconnect();
}
let prevSelectedIndex = selected_gallery_index();
galleryObservers[id_gallery] = new MutationObserver(function (){
let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item')
let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2')
if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) {
//automatically re-open previously selected index (if exists)
galleryButtons[prevSelectedIndex].click();
showGalleryImage();
}
})
galleryObservers[id_gallery].observe( gallery, { childList:true, subtree:false })
}
}
onUiUpdate(function(){ onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
......
...@@ -187,12 +187,10 @@ onUiUpdate(function(){ ...@@ -187,12 +187,10 @@ onUiUpdate(function(){
if (!txt2img_textarea) { if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
} }
if (!img2img_textarea) { if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
} }
}) })
...@@ -220,14 +218,6 @@ function update_token_counter(button_id) { ...@@ -220,14 +218,6 @@ function update_token_counter(button_id) {
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
} }
function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) {
event.preventDefault();
gradioApp().getElementById(generate_button_id).click();
return;
}
}
function restart_reload(){ function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},2000) setTimeout(function(){location.reload()},2000)
......
...@@ -9,6 +9,7 @@ import platform ...@@ -9,6 +9,7 @@ import platform
dir_repos = "repositories" dir_repos = "repositories"
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
def extract_arg(args, name): def extract_arg(args, name):
...@@ -57,7 +58,8 @@ def run_python(code, desc=None, errdesc=None): ...@@ -57,7 +58,8 @@ def run_python(code, desc=None, errdesc=None):
def run_pip(args, desc=None): def run_pip(args, desc=None):
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") index_url_line = f' --index-url {index_url}' if index_url != '' else ''
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
def check_run_python(code): def check_run_python(code):
...@@ -76,7 +78,7 @@ def git_clone(url, dir, name, commithash=None): ...@@ -76,7 +78,7 @@ def git_clone(url, dir, name, commithash=None):
return return
run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}") run(f'"{git}" -C {dir} fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commint for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}") run(f'"{git}" -C {dir} checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
return return
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}") run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
......
...@@ -102,7 +102,7 @@ def get_deepbooru_tags_model(): ...@@ -102,7 +102,7 @@ def get_deepbooru_tags_model():
tags = dd.project.load_tags_from_project(model_path) tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project( model = dd.project.load_model_from_project(
model_path, compile_model=True model_path, compile_model=False
) )
return model, tags return model, tags
......
...@@ -34,7 +34,7 @@ def enable_tf32(): ...@@ -34,7 +34,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
device = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
dtype = torch.float16 dtype = torch.float16
dtype_vae = torch.float16 dtype_vae = torch.float16
......
...@@ -159,48 +159,52 @@ def run_pnginfo(image): ...@@ -159,48 +159,52 @@ def run_pnginfo(image):
return '', geninfo, info return '', geninfo, info
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name): def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) def weighted_sum(theta0, theta1, theta2, alpha):
def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
# Smoothstep (https://en.wikipedia.org/wiki/Smoothstep) def add_difference(theta0, theta1, theta2, alpha):
def sigmoid(theta0, theta1, alpha): return theta0 + (theta1 - theta2) * alpha
alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha)
# Inverse Smoothstep (https://en.wikipedia.org/wiki/Smoothstep)
def inv_sigmoid(theta0, theta1, alpha):
import math
alpha = 0.5 - math.sin(math.asin(1.0 - 2.0 * alpha) / 3.0)
return theta0 + ((theta1 - theta0) * alpha)
primary_model_info = sd_models.checkpoints_list[primary_model_name] primary_model_info = sd_models.checkpoints_list[primary_model_name]
secondary_model_info = sd_models.checkpoints_list[secondary_model_name] secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
print(f"Loading {primary_model_info.filename}...") print(f"Loading {primary_model_info.filename}...")
primary_model = torch.load(primary_model_info.filename, map_location='cpu') primary_model = torch.load(primary_model_info.filename, map_location='cpu')
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
print(f"Loading {secondary_model_info.filename}...") print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model) theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
if teritary_model_info is not None:
print(f"Loading {teritary_model_info.filename}...")
teritary_model = torch.load(teritary_model_info.filename, map_location='cpu')
theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model)
else:
theta_2 = None
theta_funcs = { theta_funcs = {
"Weighted Sum": weighted_sum, "Weighted sum": weighted_sum,
"Sigmoid": sigmoid, "Add difference": add_difference,
"Inverse Sigmoid": inv_sigmoid,
} }
theta_func = theta_funcs[interp_method] theta_func = theta_funcs[interp_method]
print(f"Merging...") print(f"Merging...")
for key in tqdm.tqdm(theta_0.keys()): for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint t2 = (theta_2 or {}).get(key)
if t2 is None:
t2 = torch.zeros_like(theta_0[key])
theta_0[key] = theta_func(theta_0[key], theta_1[key], t2, multiplier)
if save_as_half: if save_as_half:
theta_0[key] = theta_0[key].half() theta_0[key] = theta_0[key].half()
# I believe this part should be discarded, but I'll leave it for now until I am sure
for key in theta_1.keys(): for key in theta_1.keys():
if 'model' in key and key not in theta_0: if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key] theta_0[key] = theta_1[key]
...@@ -209,7 +213,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int ...@@ -209,7 +213,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt') filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(ckpt_dir, filename) output_modelname = os.path.join(ckpt_dir, filename)
...@@ -219,4 +223,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int ...@@ -219,4 +223,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
sd_models.list_models() sd_models.list_models()
print(f"Checkpoint saved.") print(f"Checkpoint saved.")
return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(3)] return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
import sys import sys
import traceback import traceback
import tqdm import tqdm
import csv
import torch import torch
...@@ -14,6 +15,7 @@ import torch ...@@ -14,6 +15,7 @@ import torch
from torch import einsum from torch import einsum
from einops import rearrange, repeat from einops import rearrange, repeat
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
...@@ -180,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): ...@@ -180,7 +182,21 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
return self.to_out(out) return self.to_out(out)
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt): def stack_conds(conds):
if len(conds) == 1:
return torch.stack(conds)
# same as in reconstruct_multicond_batch
token_count = max([x.shape[0] for x in conds])
for i in range(len(conds)):
if conds[i].shape[0] != token_count:
last_vector = conds[i][-1:]
last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
conds[i] = torch.vstack([conds[i], last_vector_repeated])
return torch.stack(conds)
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected' assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
...@@ -209,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -209,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
...@@ -233,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -233,7 +249,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entry in pbar: for i, entries in pbar:
hypernetwork.step = i + ititial_step hypernetwork.step = i + ititial_step
scheduler.apply(optimizer, hypernetwork.step) scheduler.apply(optimizer, hypernetwork.step)
...@@ -244,11 +260,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -244,11 +260,12 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
cond = entry.cond.to(devices.device) c = stack_conds([entry.cond for entry in entries]).to(devices.device)
x = entry.latent.to(devices.device) # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0] x = torch.stack([entry.latent for entry in entries]).to(devices.device)
loss = shared.sd_model(x, c)[0]
del x del x
del cond del c
losses[hypernetwork.step % losses.shape[0]] = loss.item() losses[hypernetwork.step % losses.shape[0]] = loss.item()
...@@ -262,23 +279,39 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -262,23 +279,39 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt') last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file) hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png') last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad() optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
prompt=preview_text,
steps=20,
do_not_save_grid=True, do_not_save_grid=True,
do_not_save_samples=True, do_not_save_samples=True,
) )
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
preview_text = p.prompt
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] if len(processed.images)>0 else None image = processed.images[0] if len(processed.images)>0 else None
...@@ -297,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, ...@@ -297,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/> Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entry.cond_text)}<br/> Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
......
import datetime import datetime
import io
import math import math
import os import os
from collections import namedtuple from collections import namedtuple
...@@ -23,6 +24,10 @@ def image_grid(imgs, batch_size=1, rows=None): ...@@ -23,6 +24,10 @@ def image_grid(imgs, batch_size=1, rows=None):
rows = opts.n_rows rows = opts.n_rows
elif opts.n_rows == 0: elif opts.n_rows == 0:
rows = batch_size rows = batch_size
elif opts.grid_prevent_empty_spots:
rows = math.floor(math.sqrt(len(imgs)))
while len(imgs) % rows != 0:
rows -= 1
else: else:
rows = math.sqrt(len(imgs)) rows = math.sqrt(len(imgs))
rows = round(rows) rows = round(rows)
...@@ -463,3 +468,22 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -463,3 +468,22 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
txt_fullfn = None txt_fullfn = None
return fullfn, txt_fullfn return fullfn, txt_fullfn
def image_data(data):
try:
image = Image.open(io.BytesIO(data))
textinfo = image.text["parameters"]
return textinfo, None
except Exception:
pass
try:
text = data.decode('utf8')
assert len(text) < 10000
return text, None
except Exception:
pass
return '', None
import os
import shutil
def traverse_all_files(output_dir, image_list, curr_dir=None):
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
try:
f_list = os.listdir(curr_path)
except:
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
image_list.append(curr_dir)
return image_list
for file in f_list:
file = file if curr_dir is None else os.path.join(curr_dir, file)
file_path = os.path.join(curr_path, file)
if file[-4:] == ".txt":
pass
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
image_list.append(file)
else:
image_list = traverse_all_files(output_dir, image_list, file)
return image_list
def get_recent_images(dir_name, page_index, step, image_index, tabname):
page_index = int(page_index)
f_list = os.listdir(dir_name)
image_list = []
image_list = traverse_all_files(dir_name, image_list)
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
num = 48 if tabname != "extras" else 12
max_page_index = len(image_list) // num + 1
page_index = max_page_index if page_index == -1 else page_index + step
page_index = 1 if page_index < 1 else page_index
page_index = max_page_index if page_index > max_page_index else page_index
idx_frm = (page_index - 1) * num
image_list = image_list[idx_frm:idx_frm + num]
image_index = int(image_index)
if image_index < 0 or image_index > len(image_list) - 1:
current_file = None
hidden = None
else:
current_file = image_list[int(image_index)]
hidden = os.path.join(dir_name, current_file)
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
def first_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, 1, 0, image_index, tabname)
def end_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, -1, 0, image_index, tabname)
def prev_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
def next_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
def page_index_change(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
def show_image_info(num, image_path, filenames):
# print(f"select image {num}")
file = filenames[int(num)]
return file, num, os.path.join(image_path, file)
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
if name == "":
return filenames, delete_num
else:
delete_num = int(delete_num)
index = list(filenames).index(name)
i = 0
new_file_list = []
for name in filenames:
if i >= index and i < index + delete_num:
path = os.path.join(dir_name, name)
if os.path.exists(path):
print(f"Delete file {path}")
os.remove(path)
txt_file = os.path.splitext(path)[0] + ".txt"
if os.path.exists(txt_file):
os.remove(txt_file)
else:
print(f"Not exists file {path}")
else:
new_file_list.append(name)
i += 1
return new_file_list, 1
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
if opts.outdir_samples != "":
dir_name = opts.outdir_samples
elif tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples
elif tabname == "img2img":
dir_name = opts.outdir_img2img_samples
elif tabname == "extras":
dir_name = opts.outdir_extras_samples
d = dir_name.split("/")
dir_name = "/" if dir_name.startswith("/") else d[0]
for p in d[1:]:
dir_name = os.path.join(dir_name, p)
with gr.Row():
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
first_page = gr.Button('First Page')
prev_page = gr.Button('Prev Page')
page_index = gr.Number(value=1, label="Page Index")
next_page = gr.Button('Next Page')
end_page = gr.Button('End Page')
with gr.Row(elem_id=tabname + "_images_history"):
with gr.Row():
with gr.Column(scale=2):
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
with gr.Row():
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
with gr.Column():
with gr.Row():
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
pnginfo_send_to_img2img = gr.Button('Send to img2img')
with gr.Row():
with gr.Column():
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
img_file_name = gr.Textbox(label="File Name", interactive=False)
with gr.Row():
# hiden items
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
tabname_box = gr.Textbox(tabname, visible=False)
image_index = gr.Textbox(value=-1, visible=False)
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
filenames = gr.State()
hidden = gr.Image(type="pil", visible=False)
info1 = gr.Textbox(visible=False)
info2 = gr.Textbox(visible=False)
# turn pages
gallery_inputs = [img_path, page_index, image_index, tabname_box]
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
# other funcitons
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs:
with gr.Tab("txt2img history"):
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
with gr.Tab("img2img history"):
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
with gr.Tab("extras history"):
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
return images_history
...@@ -55,7 +55,7 @@ class InterrogateModels: ...@@ -55,7 +55,7 @@ class InterrogateModels:
model, preprocess = clip.load(clip_model_name) model, preprocess = clip.load(clip_model_name)
model.eval() model.eval()
model = model.to(shared.device) model = model.to(devices.device_interrogate)
return model, preprocess return model, preprocess
...@@ -65,14 +65,14 @@ class InterrogateModels: ...@@ -65,14 +65,14 @@ class InterrogateModels:
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
self.blip_model = self.blip_model.half() self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(shared.device) self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None: if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model() self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
self.clip_model = self.clip_model.half() self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(shared.device) self.clip_model = self.clip_model.to(devices.device_interrogate)
self.dtype = next(self.clip_model.parameters()).dtype self.dtype = next(self.clip_model.parameters()).dtype
...@@ -99,11 +99,11 @@ class InterrogateModels: ...@@ -99,11 +99,11 @@ class InterrogateModels:
text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)] text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]
top_count = min(top_count, len(text_array)) top_count = min(top_count, len(text_array))
text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device) text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(devices.device_interrogate)
text_features = self.clip_model.encode_text(text_tokens).type(self.dtype) text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
text_features /= text_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = torch.zeros((1, len(text_array))).to(shared.device) similarity = torch.zeros((1, len(text_array))).to(devices.device_interrogate)
for i in range(image_features.shape[0]): for i in range(image_features.shape[0]):
similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1) similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
similarity /= image_features.shape[0] similarity /= image_features.shape[0]
...@@ -116,7 +116,7 @@ class InterrogateModels: ...@@ -116,7 +116,7 @@ class InterrogateModels:
transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC), transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) ])(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
with torch.no_grad(): with torch.no_grad():
caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length) caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)
...@@ -140,7 +140,7 @@ class InterrogateModels: ...@@ -140,7 +140,7 @@ class InterrogateModels:
res = caption res = caption
clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device) clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)
precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
with torch.no_grad(), precision_scope("cuda"): with torch.no_grad(), precision_scope("cuda"):
......
...@@ -140,7 +140,7 @@ class Processed: ...@@ -140,7 +140,7 @@ class Processed:
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or [self.prompt]
...@@ -501,16 +501,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -501,16 +501,15 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None sampler = None
firstphase_width = 0
firstphase_height = 0
firstphase_width_truncated = 0
firstphase_height_truncated = 0
def __init__(self, enable_hr=False, scale_latent=True, denoising_strength=0.75, **kwargs): def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.enable_hr = enable_hr self.enable_hr = enable_hr
self.scale_latent = scale_latent
self.denoising_strength = denoising_strength self.denoising_strength = denoising_strength
self.firstphase_width = firstphase_width
self.firstphase_height = firstphase_height
self.truncate_x = 0
self.truncate_y = 0
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr: if self.enable_hr:
...@@ -519,14 +518,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -519,14 +518,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else: else:
state.job_count = state.job_count * 2 state.job_count = state.job_count * 2
if self.firstphase_width == 0 or self.firstphase_height == 0:
desired_pixel_count = 512 * 512 desired_pixel_count = 512 * 512
actual_pixel_count = self.width * self.height actual_pixel_count = self.width * self.height
scale = math.sqrt(desired_pixel_count / actual_pixel_count) scale = math.sqrt(desired_pixel_count / actual_pixel_count)
self.firstphase_width = math.ceil(scale * self.width / 64) * 64 self.firstphase_width = math.ceil(scale * self.width / 64) * 64
self.firstphase_height = math.ceil(scale * self.height / 64) * 64 self.firstphase_height = math.ceil(scale * self.height / 64) * 64
self.firstphase_width_truncated = int(scale * self.width) firstphase_width_truncated = int(scale * self.width)
self.firstphase_height_truncated = int(scale * self.height) firstphase_height_truncated = int(scale * self.height)
else:
self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}"
width_ratio = self.width / self.firstphase_width
height_ratio = self.height / self.firstphase_height
if width_ratio > height_ratio:
firstphase_width_truncated = self.firstphase_width
firstphase_height_truncated = self.firstphase_width * self.height / self.width
else:
firstphase_width_truncated = self.firstphase_height * self.width / self.height
firstphase_height_truncated = self.firstphase_height
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
...@@ -539,14 +555,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -539,14 +555,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
truncate_x = (self.firstphase_width - self.firstphase_width_truncated) // opt_f samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
truncate_y = (self.firstphase_height - self.firstphase_height_truncated) // opt_f
samples = samples[:, :, truncate_y//2:samples.shape[2]-truncate_y//2, truncate_x//2:samples.shape[3]-truncate_x//2]
if self.scale_latent:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
else:
decoded_samples = decode_first_stage(self.sd_model, samples) decoded_samples = decode_first_stage(self.sd_model, samples)
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None": if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None":
......
...@@ -96,11 +96,18 @@ def load(filename, *args, **kwargs): ...@@ -96,11 +96,18 @@ def load(filename, *args, **kwargs):
if not shared.cmd_opts.disable_safe_unpickle: if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename) check_pt(filename)
except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr)
return None
except Exception: except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr) print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr) print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr)
return None return None
return unsafe_torch_load(filename, *args, **kwargs) return unsafe_torch_load(filename, *args, **kwargs)
......
...@@ -136,7 +136,7 @@ def load_model_weights(model, checkpoint_info): ...@@ -136,7 +136,7 @@ def load_model_weights(model, checkpoint_info):
if checkpoint_info not in checkpoints_loaded: if checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
pl_sd = torch.load(checkpoint_file, map_location="cpu") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
...@@ -159,9 +159,8 @@ def load_model_weights(model, checkpoint_info): ...@@ -159,9 +159,8 @@ def load_model_weights(model, checkpoint_info):
if os.path.exists(vae_file): if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}") print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location="cpu") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
......
...@@ -34,6 +34,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_ ...@@ -34,6 +34,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
...@@ -54,7 +55,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en ...@@ -54,7 +55,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--use-cpu", nargs='+',choices=['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer'], help="use CPU as torch device for specified modules", default=[]) parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
...@@ -76,10 +77,11 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl ...@@ -76,10 +77,11 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
devices.device, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_bsrgan, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if x in cmd_opts.use_cpu else devices.get_optimal_device() for x in ['SD', 'GFPGAN', 'BSRGAN', 'ESRGAN', 'SCUNet', 'CodeFormer']) (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'bsrgan', 'esrgan', 'scunet', 'codeformer'])
device = devices.device device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu"
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
...@@ -175,6 +177,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" ...@@ -175,6 +177,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
"grid_format": OptionInfo('png', 'File format for grids'), "grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
...@@ -233,7 +236,8 @@ options_templates.update(options_section(('training', "Training"), { ...@@ -233,7 +236,8 @@ options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), "unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
......
...@@ -24,11 +24,12 @@ class DatasetEntry: ...@@ -24,11 +24,12 @@ class DatasetEntry:
class PersonalizedBase(Dataset): class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token self.placeholder_token = placeholder_token
self.batch_size = batch_size
self.width = width self.width = width
self.height = height self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.flip = transforms.RandomHorizontalFlip(p=flip_p)
...@@ -78,13 +79,14 @@ class PersonalizedBase(Dataset): ...@@ -78,13 +79,14 @@ class PersonalizedBase(Dataset):
if include_cond: if include_cond:
entry.cond_text = self.create_text(filename_text) entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu) entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry) self.dataset.append(entry)
self.length = len(self.dataset) * repeats assert len(self.dataset) > 1, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(self.length) % len(self.dataset) self.initial_indexes = np.arange(len(self.dataset))
self.indexes = None self.indexes = None
self.shuffle() self.shuffle()
...@@ -101,13 +103,19 @@ class PersonalizedBase(Dataset): ...@@ -101,13 +103,19 @@ class PersonalizedBase(Dataset):
return self.length return self.length
def __getitem__(self, i): def __getitem__(self, i):
if i % len(self.dataset) == 0: res = []
for j in range(self.batch_size):
position = i * self.batch_size + j
if position % len(self.indexes) == 0:
self.shuffle() self.shuffle()
index = self.indexes[i % len(self.indexes)] index = self.indexes[position % len(self.indexes)]
entry = self.dataset[index] entry = self.dataset[index]
if entry.cond is None: if entry.cond is None:
entry.cond_text = self.create_text(entry.filename_text) entry.cond_text = self.create_text(entry.filename_text)
return entry res.append(entry)
return res
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import tqdm import tqdm
import html import html
import datetime import datetime
import csv
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
...@@ -172,7 +173,33 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): ...@@ -172,7 +173,33 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt): def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
if step % shared.opts.training_write_csv_every != 0:
return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
if write_csv_header:
csv_writer.writeheader()
epoch = step // epoch_len
epoch_step = step - epoch * epoch_len
csv_writer.writerow({
"step": step + 1,
"epoch": epoch + 1,
"epoch_step": epoch_step + 1,
**values,
})
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected' assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..." shared.state.textinfo = "Initializing textual inversion training..."
...@@ -204,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -204,7 +231,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
hijack = sd_hijack.model_hijack hijack = sd_hijack.model_hijack
...@@ -224,7 +251,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -224,7 +251,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step) pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entry in pbar: for i, entries in pbar:
embedding.step = i + ititial_step embedding.step = i + ititial_step
scheduler.apply(optimizer, embedding.step) scheduler.apply(optimizer, embedding.step)
...@@ -235,10 +262,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -235,10 +262,9 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
break break
with torch.autocast("cuda"): with torch.autocast("cuda"):
c = cond_model([entry.cond_text]) c = cond_model([entry.cond_text for entry in entries])
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
x = entry.latent.to(devices.device) loss = shared.sd_model(x, c)[0]
loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x del x
losses[embedding.step % losses.shape[0]] = loss.item() losses[embedding.step % losses.shape[0]] = loss.item()
...@@ -256,21 +282,37 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -256,21 +282,37 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file) embedding.save(last_saved_file)
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}",
"learn_rate": scheduler.learn_rate
})
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
prompt=preview_text,
steps=20,
height=training_height,
width=training_width,
do_not_save_grid=True, do_not_save_grid=True,
do_not_save_samples=True, do_not_save_samples=True,
) )
if preview_from_txt2img:
p.prompt = preview_prompt
p.negative_prompt = preview_negative_prompt
p.steps = preview_steps
p.sampler_index = preview_sampler_index
p.cfg_scale = preview_cfg_scale
p.seed = preview_seed
p.width = preview_width
p.height = preview_height
else:
p.prompt = entries[0].cond_text
p.steps = 20
p.width = training_width
p.height = training_height
preview_text = p.prompt
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] image = processed.images[0]
...@@ -305,7 +347,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini ...@@ -305,7 +347,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p> <p>
Loss: {losses.mean():.7f}<br/> Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/> Step: {embedding.step}<br/>
Last prompt: {html.escape(entry.cond_text)}<br/> Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/> Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/> Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
......
...@@ -6,7 +6,7 @@ import modules.processing as processing ...@@ -6,7 +6,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, scale_latent: bool, denoising_strength: float, *args): def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
...@@ -30,8 +30,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -30,8 +30,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
restore_faces=restore_faces, restore_faces=restore_faces,
tiling=tiling, tiling=tiling,
enable_hr=enable_hr, enable_hr=enable_hr,
scale_latent=scale_latent if enable_hr else None,
denoising_strength=denoising_strength if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None,
firstphase_width=firstphase_width if enable_hr else None,
firstphase_height=firstphase_height if enable_hr else None,
) )
if cmd_opts.enable_console_prompts: if cmd_opts.enable_console_prompts:
......
This diff is collapsed.
...@@ -50,9 +50,9 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -50,9 +50,9 @@ document.addEventListener("DOMContentLoaded", function() {
document.addEventListener('keydown', function(e) { document.addEventListener('keydown', function(e) {
var handled = false; var handled = false;
if (e.key !== undefined) { if (e.key !== undefined) {
if((e.key == "Enter" && (e.metaKey || e.ctrlKey))) handled = true; if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
} else if (e.keyCode !== undefined) { } else if (e.keyCode !== undefined) {
if((e.keyCode == 13 && (e.metaKey || e.ctrlKey))) handled = true; if((e.keyCode == 13 && (e.metaKey || e.ctrlKey || e.altKey))) handled = true;
} }
if (handled) { if (handled) {
button = get_uiCurrentTabContent().querySelector('button[id$=_generate]'); button = get_uiCurrentTabContent().querySelector('button[id$=_generate]');
......
...@@ -12,7 +12,7 @@ import gradio as gr ...@@ -12,7 +12,7 @@ import gradio as gr
from modules import images from modules import images
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.sd_samplers import modules.sd_samplers
...@@ -176,7 +176,7 @@ axis_options = [ ...@@ -176,7 +176,7 @@ axis_options = [
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
] ]
...@@ -338,7 +338,7 @@ class Script(scripts.Script): ...@@ -338,7 +338,7 @@ class Script(scripts.Script):
ys = process_axis(y_opt, y_values) ys = process_axis(y_opt, y_values)
def fix_axis_seeds(axis_opt, axis_list): def fix_axis_seeds(axis_opt, axis_list):
if axis_opt.label == 'Seed': if axis_opt.label in ['Seed','Var. seed']:
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
else: else:
return axis_list return axis_list
...@@ -354,6 +354,9 @@ class Script(scripts.Script): ...@@ -354,6 +354,9 @@ class Script(scripts.Script):
else: else:
total_steps = p.steps * len(xs) * len(ys) total_steps = p.steps * len(xs) * len(ys)
if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
total_steps *= 2
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
shared.total_tqdm.updateTotal(total_steps * p.n_iter) shared.total_tqdm.updateTotal(total_steps * p.n_iter)
......
...@@ -167,14 +167,6 @@ button{ ...@@ -167,14 +167,6 @@ button{
align-self: stretch !important; align-self: stretch !important;
} }
#prompt, #negative_prompt{
border: none !important;
}
#prompt textarea, #negative_prompt textarea{
border: none !important;
}
#img2maskimg .h-60{ #img2maskimg .h-60{
height: 30rem; height: 30rem;
} }
......
...@@ -82,8 +82,8 @@ then ...@@ -82,8 +82,8 @@ then
clone_dir="${PWD##*/}" clone_dir="${PWD##*/}"
fi fi
# Check prequisites # Check prerequisites
for preq in git python3 for preq in "${GIT}" "${python_cmd}"
do do
if ! hash "${preq}" &>/dev/null if ! hash "${preq}" &>/dev/null
then then
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment