Commit 1fa53dab authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into cuda-device-id-selection

parents 29bfacd6 5aa95250
...@@ -28,3 +28,4 @@ notification.mp3 ...@@ -28,3 +28,4 @@ notification.mp3
/SwinIR /SwinIR
/textual_inversion /textual_inversion
.vscode .vscode
/extensions
...@@ -83,8 +83,17 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -83,8 +83,17 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- Estimated completion time in progress bar - Estimated completion time in progress bar
- API - API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- Aesthetic Gradients, a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
## Where are Aesthetic Gradients?!?!
Aesthetic Gradients are now an extension. You can install it using git:
```commandline
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients
```
After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart
the UI. The interface for Aesthetic Gradients should appear exactly the same as it was.
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
......
...@@ -17,14 +17,6 @@ var images_history_click_image = function(){ ...@@ -17,14 +17,6 @@ var images_history_click_image = function(){
images_history_set_image_info(this); images_history_set_image_info(this);
} }
var images_history_click_tab = function(){
var tabs_box = gradioApp().getElementById("images_history_tab");
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
tabs_box.classList.add(this.getAttribute("tabname"))
}
}
function images_history_disabled_del(){ function images_history_disabled_del(){
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){ gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
btn.setAttribute('disabled','disabled'); btn.setAttribute('disabled','disabled');
...@@ -43,7 +35,6 @@ function images_history_get_parent_by_tagname(item, tagname){ ...@@ -43,7 +35,6 @@ function images_history_get_parent_by_tagname(item, tagname){
var parent = item.parentElement; var parent = item.parentElement;
tagname = tagname.toUpperCase() tagname = tagname.toUpperCase()
while(parent.tagName != tagname){ while(parent.tagName != tagname){
console.log(parent.tagName, tagname)
parent = parent.parentElement; parent = parent.parentElement;
} }
return parent; return parent;
...@@ -88,15 +79,15 @@ function images_history_set_image_info(button){ ...@@ -88,15 +79,15 @@ function images_history_set_image_info(button){
} }
function images_history_get_current_img(tabname, image_path, files){ function images_history_get_current_img(tabname, img_index, files){
return [ return [
tabname,
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"), gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
image_path,
files files
]; ];
} }
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){ function images_history_delete(del_num, tabname, image_index){
image_index = parseInt(image_index); image_index = parseInt(image_index);
var tab = gradioApp().getElementById(tabname + '_images_history'); var tab = gradioApp().getElementById(tabname + '_images_history');
var set_btn = tab.querySelector(".images_history_set_index"); var set_btn = tab.querySelector(".images_history_set_index");
...@@ -107,6 +98,7 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i ...@@ -107,6 +98,7 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i
} }
}); });
var img_num = buttons.length / 2; var img_num = buttons.length / 2;
del_num = Math.min(img_num - image_index, del_num)
if (img_num <= del_num){ if (img_num <= del_num){
setTimeout(function(tabname){ setTimeout(function(tabname){
gradioApp().getElementById(tabname + '_images_history_renew_page').click(); gradioApp().getElementById(tabname + '_images_history_renew_page').click();
...@@ -114,30 +106,28 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i ...@@ -114,30 +106,28 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i
} else { } else {
var next_img var next_img
for (var i = 0; i < del_num; i++){ for (var i = 0; i < del_num; i++){
if (image_index + i < image_index + img_num){
buttons[image_index + i].style.display = 'none'; buttons[image_index + i].style.display = 'none';
buttons[image_index + img_num + 1].style.display = 'none'; buttons[image_index + i + img_num].style.display = 'none';
next_img = image_index + i + 1 next_img = image_index + i + 1
} }
}
var bnt; var bnt;
if (next_img >= img_num){ if (next_img >= img_num){
btn = buttons[image_index - del_num]; btn = buttons[image_index - 1];
} else { } else {
btn = buttons[next_img]; btn = buttons[next_img];
} }
setTimeout(function(btn){btn.click()}, 30, btn); setTimeout(function(btn){btn.click()}, 30, btn);
} }
images_history_disabled_del(); images_history_disabled_del();
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
} }
function images_history_turnpage(img_path, page_index, image_index, tabname){ function images_history_turnpage(tabname){
gradioApp().getElementById(tabname + '_images_history_del_button').setAttribute('disabled','disabled');
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item"); var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
buttons.forEach(function(elem) { buttons.forEach(function(elem) {
elem.style.display = 'block'; elem.style.display = 'block';
}) })
return [img_path, page_index, image_index, tabname];
} }
function images_history_enable_del_buttons(){ function images_history_enable_del_buttons(){
...@@ -147,40 +137,46 @@ function images_history_enable_del_buttons(){ ...@@ -147,40 +137,46 @@ function images_history_enable_del_buttons(){
} }
function images_history_init(){ function images_history_init(){
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page') var tabnames = gradioApp().getElementById("images_history_tabnames_list")
if (load_txt2img_button){ if (tabnames){
images_history_tab_list = tabnames.querySelector("textarea").value.split(",")
for (var i in images_history_tab_list ){ for (var i in images_history_tab_list ){
tab = images_history_tab_list[i]; var tab = images_history_tab_list[i];
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor"); gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index"); gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button"); gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery"); gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
gradioApp().getElementById(tab + "_images_history_start").setAttribute("style","padding:20px;font-size:25px");
} }
//preload
if (gradioApp().getElementById("images_history_preload").querySelector("input").checked ){
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div"); var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
tabs_box.setAttribute("id", "images_history_tab"); tabs_box.setAttribute("id", "images_history_tab");
var tab_btns = tabs_box.querySelectorAll("button"); var tab_btns = tabs_box.querySelectorAll("button");
for (var i in images_history_tab_list){ for (var i in images_history_tab_list){
var tabname = images_history_tab_list[i] var tabname = images_history_tab_list[i]
tab_btns[i].setAttribute("tabname", tabname); tab_btns[i].setAttribute("tabname", tabname);
tab_btns[i].addEventListener('click', function(){
// this refreshes history upon tab switch var tabs_box = gradioApp().getElementById("images_history_tab");
// until the history is known to work well, which is not the case now, we do not do this at startup if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
//tab_btns[i].addEventListener('click', images_history_click_tab); gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click();
tabs_box.classList.add(this.getAttribute("tabname"))
}
});
}
tab_btns[0].click()
} }
tabs_box.classList.add(images_history_tab_list[0]);
// same as above, at page load
//load_txt2img_button.click();
} else { } else {
setTimeout(images_history_init, 500); setTimeout(images_history_init, 500);
} }
} }
var images_history_tab_list = ["txt2img", "img2img", "extras"]; var images_history_tab_list = "";
setTimeout(images_history_init, 500); setTimeout(images_history_init, 500);
document.addEventListener("DOMContentLoaded", function() { document.addEventListener("DOMContentLoaded", function() {
var mutationObserver = new MutationObserver(function(m){ var mutationObserver = new MutationObserver(function(m){
if (images_history_tab_list != ""){
for (var i in images_history_tab_list ){ for (var i in images_history_tab_list ){
let tabname = images_history_tab_list[i] let tabname = images_history_tab_list[i]
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item'); var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
...@@ -188,19 +184,17 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -188,19 +184,17 @@ document.addEventListener("DOMContentLoaded", function() {
bnt.addEventListener('click', images_history_click_image, true); bnt.addEventListener('click', images_history_click_image, true);
}); });
// same as load_txt2img_button.click() above
/*
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
if (cls_btn){ if (cls_btn){
cls_btn.addEventListener('click', function(){ cls_btn.addEventListener('click', function(){
gradioApp().getElementById(tabname + '_images_history_renew_page').click(); gradioApp().getElementById(tabname + '_images_history_renew_page').click();
}, false); }, false);
}*/ }
} }
}
}); });
mutationObserver.observe( gradioApp(), { childList:true, subtree:true }); mutationObserver.observe(gradioApp(), { childList:true, subtree:true });
}); });
import copy
import itertools
import os
from pathlib import Path
import html
import gc
import gradio as gr
import torch
from PIL import Image
from torch import optim
from modules import shared
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
from tqdm.auto import tqdm, trange
from modules.shared import opts, device
def get_all_images_in_folder(folder):
return [os.path.join(folder, f) for f in os.listdir(folder) if
os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]
def check_is_valid_image_file(filename):
return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))
def batched(dataset, total, n=1):
for ndx in range(0, total, n):
yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]
def iter_to_batched(iterable, n=1):
it = iter(iterable)
while True:
chunk = tuple(itertools.islice(it, n))
if not chunk:
return
yield chunk
def create_ui():
import modules.ui
with gr.Group():
with gr.Accordion("Open for Clip Aesthetic!", open=False):
with gr.Row():
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
value=0.9)
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
with gr.Row():
aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
placeholder="Aesthetic learning rate", value="0.0001")
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
label="Aesthetic imgs embedding",
value="None")
modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")
with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
placeholder="This text is used to rotate the feature space of the imgs embs",
value="")
aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
value=0.1)
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
aesthetic_clip_model = None
def aesthetic_clip():
global aesthetic_clip_model
if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
aesthetic_clip_model.cpu()
return aesthetic_clip_model
def generate_imgs_embd(name, folder, batch_size):
model = aesthetic_clip().to(device)
processor = CLIPProcessor.from_pretrained(model.name_or_path)
with torch.no_grad():
embs = []
for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
desc=f"Generating embeddings for {name}"):
if shared.state.interrupted:
break
inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
outputs = model.get_image_features(**inputs).cpu()
embs.append(torch.clone(outputs))
inputs.to("cpu")
del inputs, outputs
embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
# The generated embedding will be located here
path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
torch.save(embs, path)
model.cpu()
del processor
del embs
gc.collect()
torch.cuda.empty_cache()
res = f"""
Done generating embedding for {name}!
Aesthetic embedding saved to {html.escape(path)}
"""
shared.update_aesthetic_embeddings()
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
value="None"), \
gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
label="Imgs embedding",
value="None"), res, ""
def slerp(low, high, val):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
class AestheticCLIP:
def __init__(self):
self.skip = False
self.aesthetic_steps = 0
self.aesthetic_weight = 0
self.aesthetic_lr = 0
self.slerp = False
self.aesthetic_text_negative = ""
self.aesthetic_slerp_angle = 0
self.aesthetic_imgs_text = ""
self.image_embs_name = None
self.image_embs = None
self.load_image_embs(None)
def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
aesthetic_slerp=True, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False):
self.aesthetic_imgs_text = aesthetic_imgs_text
self.aesthetic_slerp_angle = aesthetic_slerp_angle
self.aesthetic_text_negative = aesthetic_text_negative
self.slerp = aesthetic_slerp
self.aesthetic_lr = aesthetic_lr
self.aesthetic_weight = aesthetic_weight
self.aesthetic_steps = aesthetic_steps
self.load_image_embs(image_embs_name)
if self.image_embs_name is not None:
p.extra_generation_params.update({
"Aesthetic LR": aesthetic_lr,
"Aesthetic weight": aesthetic_weight,
"Aesthetic steps": aesthetic_steps,
"Aesthetic embedding": self.image_embs_name,
"Aesthetic slerp": aesthetic_slerp,
"Aesthetic text": aesthetic_imgs_text,
"Aesthetic text negative": aesthetic_text_negative,
"Aesthetic slerp angle": aesthetic_slerp_angle,
})
def set_skip(self, skip):
self.skip = skip
def load_image_embs(self, image_embs_name):
if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
image_embs_name = None
self.image_embs_name = None
if image_embs_name is not None and self.image_embs_name != image_embs_name:
self.image_embs_name = image_embs_name
self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
self.image_embs.requires_grad_(False)
def __call__(self, z, remade_batch_tokens):
if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
tokenizer = shared.sd_model.cond_stage_model.tokenizer
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [
[tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens).to(device)
model = copy.deepcopy(aesthetic_clip()).to(device)
model.requires_grad_(True)
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
text_embs_2 = model.get_text_features(
**tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
if self.aesthetic_text_negative:
text_embs_2 = self.image_embs - text_embs_2
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
else:
img_embs = self.image_embs
with torch.enable_grad():
# We optimize the model to maximize the similarity
optimizer = optim.Adam(
model.text_model.parameters(), lr=self.aesthetic_lr
)
for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
text_embs = model.get_text_features(input_ids=tokens)
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
sim = text_embs @ img_embs.T
loss = -sim
optimizer.zero_grad()
loss.mean().backward()
optimizer.step()
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
zn = model.text_model.final_layer_norm(zn)
else:
zn = zn.last_hidden_state
model.cpu()
del model
gc.collect()
torch.cuda.empty_cache()
zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
if self.slerp:
z = slerp(z, zn, self.aesthetic_weight)
else:
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
return z
import os import os
import shutil import shutil
import sys import time
import hashlib
import gradio
system_bak_path = "webui_log_and_bak"
custom_tab_name = "custom fold"
faverate_tab_name = "favorites"
tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name]
def is_valid_date(date):
try:
time.strptime(date, "%Y%m%d")
return True
except:
return False
def reduplicative_file_move(src, dst):
def same_name_file(basename, path):
name, ext = os.path.splitext(basename)
f_list = os.listdir(path)
max_num = 0
for f in f_list:
if len(f) <= len(basename):
continue
f_ext = f[-len(ext):] if len(ext) > 0 else ""
if f[:len(name)] == name and f_ext == ext:
if f[len(name)] == "(" and f[-len(ext)-1] == ")":
number = f[len(name)+1:-len(ext)-1]
if number.isdigit():
if int(number) > max_num:
max_num = int(number)
return f"{name}({max_num + 1}){ext}"
name = os.path.basename(src)
save_name = os.path.join(dst, name)
if not os.path.exists(save_name):
shutil.move(src, dst)
else:
name = same_name_file(name, dst)
shutil.move(src, os.path.join(dst, name))
def traverse_all_files(output_dir, image_list, curr_dir=None): def traverse_all_files(curr_path, image_list, all_type=False):
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
try: try:
f_list = os.listdir(curr_path) f_list = os.listdir(curr_path)
except: except:
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt": if all_type or (curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt" and curr_path[-4:] != ".csv"):
image_list.append(curr_dir) image_list.append(curr_path)
return image_list return image_list
for file in f_list: for file in f_list:
file = file if curr_dir is None else os.path.join(curr_dir, file) file = os.path.join(curr_path, file)
file_path = os.path.join(curr_path, file) if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"):
if file[-4:] == ".txt":
pass pass
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0: elif os.path.isfile(file) and file[-10:].rfind(".") > 0:
image_list.append(file) image_list.append(file)
else: else:
image_list = traverse_all_files(output_dir, image_list, file) image_list = traverse_all_files(file, image_list)
return image_list return image_list
def auto_sorting(dir_name):
def get_recent_images(dir_name, page_index, step, image_index, tabname): bak_path = os.path.join(dir_name, system_bak_path)
page_index = int(page_index) if not os.path.exists(bak_path):
image_list = [] os.mkdir(bak_path)
if not os.path.exists(dir_name): log_file = None
pass files_list = []
elif os.path.isdir(dir_name): f_list = os.listdir(dir_name)
image_list = traverse_all_files(dir_name, image_list) for file in f_list:
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file))) if file == system_bak_path:
else: continue
print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr) file_path = os.path.join(dir_name, file)
num = 48 if tabname != "extras" else 12 if not is_valid_date(file):
max_page_index = len(image_list) // num + 1 if file[-10:].rfind(".") > 0:
page_index = max_page_index if page_index == -1 else page_index + step files_list.append(file_path)
page_index = 1 if page_index < 1 else page_index
page_index = max_page_index if page_index > max_page_index else page_index
idx_frm = (page_index - 1) * num
image_list = image_list[idx_frm:idx_frm + num]
image_index = int(image_index)
if image_index < 0 or image_index > len(image_list) - 1:
current_file = None
hidden = None
else: else:
current_file = image_list[int(image_index)] files_list = traverse_all_files(file_path, files_list, all_type=True)
hidden = os.path.join(dir_name, current_file)
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
def first_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, 1, 0, image_index, tabname)
def end_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, -1, 0, image_index, tabname)
def prev_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
def next_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
for file in files_list:
date_str = time.strftime("%Y%m%d",time.localtime(os.path.getmtime(file)))
file_path = os.path.dirname(file)
hash_path = hashlib.md5(file_path.encode()).hexdigest()
path = os.path.join(dir_name, date_str, hash_path)
if not os.path.exists(path):
os.makedirs(path)
if log_file is None:
log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a")
log_file.write(f"{hash_path},{file_path}\n")
reduplicative_file_move(file, path)
def page_index_change(dir_name, page_index, image_index, tabname): date_list = []
return get_recent_images(dir_name, page_index, 0, image_index, tabname) f_list = os.listdir(dir_name)
for f in f_list:
if is_valid_date(f):
date_list.append(f)
elif f == system_bak_path:
continue
else:
try:
reduplicative_file_move(os.path.join(dir_name, f), bak_path)
except:
pass
def show_image_info(num, image_path, filenames): today = time.strftime("%Y%m%d",time.localtime(time.time()))
# print(f"select image {num}") if today not in date_list:
file = filenames[int(num)] date_list.append(today)
return file, num, os.path.join(image_path, file) return sorted(date_list, reverse=True)
def archive_images(dir_name, date_to):
filenames = []
batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num)
if batch_size <= 0:
batch_size = opts.images_history_num_per_page * 6
today = time.strftime("%Y%m%d",time.localtime(time.time()))
date_to = today if date_to is None or date_to == "" else date_to
date_to_bak = date_to
if False: #opts.images_history_reconstruct_directory:
date_list = auto_sorting(dir_name)
for date in date_list:
if date <= date_to:
path = os.path.join(dir_name, date)
if date == today and not os.path.exists(path):
continue
filenames = traverse_all_files(path, filenames)
if len(filenames) > batch_size:
break
filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file))
else:
filenames = traverse_all_files(dir_name, filenames)
total_num = len(filenames)
tmparray = [(os.path.getmtime(file), file) for file in filenames ]
date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
filenames = []
date_list = {date_to:None}
date = time.strftime("%Y%m%d",time.localtime(time.time()))
for t, f in tmparray:
date = time.strftime("%Y%m%d",time.localtime(t))
date_list[date] = None
if t <= date_stamp:
filenames.append((t, f ,date))
date_list = sorted(list(date_list.keys()), reverse=True)
sort_array = sorted(filenames, key=lambda x:-x[0])
if len(sort_array) > batch_size:
date = sort_array[batch_size][2]
filenames = [x[1] for x in sort_array]
else:
date = date_to if len(sort_array) == 0 else sort_array[-1][2]
filenames = [x[1] for x in sort_array]
filenames = [x[1] for x in sort_array if x[2]>= date]
num = len(filenames)
last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
date = date[:4] + "/" + date[4:6] + "/" + date[6:8]
date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8]
load_info = "<div style='color:#999' align='center'>"
load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
load_info += "</div>"
_, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
return (
date_to,
load_info,
filenames,
1,
image_list,
"",
"",
visible_num,
last_date_from,
gradio.update(visible=total_num > num)
)
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index): def delete_image(delete_num, name, filenames, image_index, visible_num):
if name == "": if name == "":
return filenames, delete_num return filenames, delete_num
else: else:
delete_num = int(delete_num) delete_num = int(delete_num)
visible_num = int(visible_num)
image_index = int(image_index)
index = list(filenames).index(name) index = list(filenames).index(name)
i = 0 i = 0
new_file_list = [] new_file_list = []
for name in filenames: for name in filenames:
if i >= index and i < index + delete_num: if i >= index and i < index + delete_num:
path = os.path.join(dir_name, name) if os.path.exists(name):
if os.path.exists(path): if visible_num == image_index:
print(f"Delete file {path}") new_file_list.append(name)
os.remove(path) i += 1
txt_file = os.path.splitext(path)[0] + ".txt" continue
print(f"Delete file {name}")
os.remove(name)
visible_num -= 1
txt_file = os.path.splitext(name)[0] + ".txt"
if os.path.exists(txt_file): if os.path.exists(txt_file):
os.remove(txt_file) os.remove(txt_file)
else: else:
print(f"Not exists file {path}") print(f"Not exists file {name}")
else: else:
new_file_list.append(name) new_file_list.append(name)
i += 1 i += 1
return new_file_list, 1 return new_file_list, 1, visible_num
def save_image(file_name):
if file_name is not None and os.path.exists(file_name):
shutil.copy(file_name, opts.outdir_save)
def get_recent_images(page_index, step, filenames):
page_index = int(page_index)
num_of_imgs_per_page = int(opts.images_history_num_per_page)
max_page_index = len(filenames) // num_of_imgs_per_page + 1
page_index = max_page_index if page_index == -1 else page_index + step
page_index = 1 if page_index < 1 else page_index
page_index = max_page_index if page_index > max_page_index else page_index
idx_frm = (page_index - 1) * num_of_imgs_per_page
image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page]
length = len(filenames)
visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page
visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
return page_index, image_list, "", "", visible_num
def loac_batch_click(date_to):
if date_to is None:
return time.strftime("%Y%m%d",time.localtime(time.time())), []
else:
return None, []
def forward_click(last_date_from, date_to_recorder):
if len(date_to_recorder) == 0:
return None, []
if last_date_from == date_to_recorder[-1]:
date_to_recorder = date_to_recorder[:-1]
if len(date_to_recorder) == 0:
return None, []
return date_to_recorder[-1], date_to_recorder[:-1]
def backward_click(last_date_from, date_to_recorder):
if last_date_from is None or last_date_from == "":
return time.strftime("%Y%m%d",time.localtime(time.time())), []
if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]:
date_to_recorder.append(last_date_from)
return last_date_from, date_to_recorder
def first_page_click(page_index, filenames):
return get_recent_images(1, 0, filenames)
def end_page_click(page_index, filenames):
return get_recent_images(-1, 0, filenames)
def prev_page_click(page_index, filenames):
return get_recent_images(page_index, -1, filenames)
def next_page_click(page_index, filenames):
return get_recent_images(page_index, 1, filenames)
def page_index_change(page_index, filenames):
return get_recent_images(page_index, 0, filenames)
def show_image_info(tabname_box, num, page_index, filenames):
file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
tm = "<div style='color:#999' align='right'>" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "</div>"
return file, tm, num, file
def enable_page_buttons():
return gradio.update(visible=True)
def change_dir(img_dir, date_to):
warning = None
try:
if os.path.exists(img_dir):
try:
f = os.listdir(img_dir)
except:
warning = f"'{img_dir} is not a directory"
else:
warning = "The directory is not exist"
except:
warning = "The format of the directory is incorrect"
if warning is None:
today = time.strftime("%Y%m%d",time.localtime(time.time()))
return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True)
else:
return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False)
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict): def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
if opts.outdir_samples != "": custom_dir = False
dir_name = opts.outdir_samples if tabname == "txt2img":
elif tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples dir_name = opts.outdir_txt2img_samples
elif tabname == "img2img": elif tabname == "img2img":
dir_name = opts.outdir_img2img_samples dir_name = opts.outdir_img2img_samples
elif tabname == "extras": elif tabname == "extras":
dir_name = opts.outdir_extras_samples dir_name = opts.outdir_extras_samples
elif tabname == faverate_tab_name:
dir_name = opts.outdir_save
else: else:
return custom_dir = True
dir_name = None
if not custom_dir:
d = dir_name.split("/")
dir_name = d[0]
for p in d[1:]:
dir_name = os.path.join(dir_name, p)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
with gr.Column() as page_panel:
with gr.Row():
with gr.Column(scale=1, visible=not custom_dir) as load_batch_box:
load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True)
with gr.Column(scale=4):
with gr.Row():
img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
with gr.Row(): with gr.Row():
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page") with gr.Column(visible=False, scale=1) as batch_panel:
with gr.Row():
forward = gr.Button('Prev batch')
backward = gr.Button('Next batch')
with gr.Column(scale=3):
load_info = gr.HTML(visible=not custom_dir)
with gr.Row(visible=False) as warning:
warning_box = gr.Textbox("Message", interactive=False)
with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
with gr.Column(scale=2):
with gr.Row(visible=True) as turn_page_buttons:
#date_to = gr.Dropdown(label="Date to")
first_page = gr.Button('First Page') first_page = gr.Button('First Page')
prev_page = gr.Button('Prev Page') prev_page = gr.Button('Prev Page')
page_index = gr.Number(value=1, label="Page Index") page_index = gr.Number(value=1, label="Page Index")
next_page = gr.Button('Next Page') next_page = gr.Button('Next Page')
end_page = gr.Button('End Page') end_page = gr.Button('End Page')
with gr.Row(elem_id=tabname + "_images_history"):
with gr.Row(): history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num)
with gr.Column(scale=2):
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
with gr.Row(): with gr.Row():
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next") delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button") delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
with gr.Column(): with gr.Column():
with gr.Row():
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
pnginfo_send_to_img2img = gr.Button('Send to img2img')
with gr.Row(): with gr.Row():
with gr.Column(): with gr.Column():
img_file_info = gr.Textbox(label="Generate Info", interactive=False) img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6)
img_file_name = gr.Textbox(label="File Name", interactive=False) gr.HTML("<hr>")
img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
img_file_time= gr.HTML()
with gr.Row(): with gr.Row():
# hiden items if tabname != faverate_tab_name:
save_btn = gr.Button('Collect')
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
pnginfo_send_to_img2img = gr.Button('Send to img2img')
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False) # hiden items
tabname_box = gr.Textbox(tabname, visible=False) with gr.Row(visible=False):
image_index = gr.Textbox(value=-1, visible=False) renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False) batch_date_to = gr.Textbox(label="Date to")
visible_img_num = gr.Number()
date_to_recorder = gr.State([])
last_date_from = gr.Textbox()
tabname_box = gr.Textbox(tabname)
image_index = gr.Textbox(value=-1)
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
filenames = gr.State() filenames = gr.State()
hidden = gr.Image(type="pil", visible=False) all_images_list = gr.State()
info1 = gr.Textbox(visible=False) hidden = gr.Image(type="pil")
info2 = gr.Textbox(visible=False) info1 = gr.Textbox()
info2 = gr.Textbox()
# turn pages
gallery_inputs = [img_path, page_index, image_index, tabname_box] img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info])
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
#change batch
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel]
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output)
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs) batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index]) load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder])
forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
#delete
delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None)
if tabname != faverate_tab_name:
save_btn.click(save_image, inputs=[img_file_name], outputs=None)
#turn page
gallery_inputs = [page_index, filenames]
gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num]
first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
# other funcitons # other funcitons
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden]) set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden])
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None) img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2]) hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img') switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img') switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict):
global opts;
opts = sys_opts
loads_files_num = int(opts.images_history_num_per_page)
num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
if cmp_ops.browse_all_images:
tabs_list.append(custom_tab_name)
with gr.Blocks(analytics_enabled=False) as images_history: with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs: with gr.Tabs() as tabs:
with gr.Tab("txt2img history"): for tab in tabs_list:
with gr.Blocks(analytics_enabled=False) as images_history_txt2img: with gr.Tab(tab):
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict) with gr.Blocks(analytics_enabled=False) :
with gr.Tab("img2img history"): show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
with gr.Blocks(analytics_enabled=False) as images_history_img2img: gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False)
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict) gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False)
with gr.Tab("extras history"):
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
return images_history return images_history
...@@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args): def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
is_inpaint = mode == 1 is_inpaint = mode == 1
is_batch = mode == 2 is_batch = mode == 2
...@@ -109,7 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -109,7 +109,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert, inpainting_mask_invert=inpainting_mask_invert,
) )
shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) p.scripts = modules.scripts.scripts_txt2img
p.script_args = args
if shared.cmd_opts.enable_console_prompts: if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
......
...@@ -104,6 +104,12 @@ class StableDiffusionProcessing(): ...@@ -104,6 +104,12 @@ class StableDiffusionProcessing():
self.seed_resize_from_h = 0 self.seed_resize_from_h = 0
self.seed_resize_from_w = 0 self.seed_resize_from_w = 0
self.scripts = None
self.script_args = None
self.all_prompts = None
self.all_seeds = None
self.all_subseeds = None
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
pass pass
...@@ -350,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -350,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.prompt_styles.apply_styles(p) shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list: if type(p.prompt) == list:
all_prompts = p.prompt p.all_prompts = p.prompt
else: else:
all_prompts = p.batch_size * p.n_iter * [p.prompt] p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
if type(seed) == list: if type(seed) == list:
all_seeds = seed p.all_seeds = seed
else: else:
all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))] p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
if type(subseed) == list: if type(subseed) == list:
all_subseeds = subseed p.all_subseeds = subseed
else: else:
all_subseeds = [int(subseed) + x for x in range(len(all_prompts))] p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0): def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch) return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings() model_hijack.embedding_db.load_textual_inversion_embeddings()
if p.scripts is not None:
p.scripts.run_alwayson_scripts(p)
infotexts = [] infotexts = []
output_images = [] output_images = []
with torch.no_grad(), p.sd_model.ema_scope(): with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast(): with devices.autocast():
p.init(all_prompts, all_seeds, all_subseeds) p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
if state.job_count == -1: if state.job_count == -1:
state.job_count = p.n_iter state.job_count = p.n_iter
...@@ -387,9 +396,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -387,9 +396,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if state.interrupted: if state.interrupted:
break break
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size] prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size] seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if (len(prompts) == 0): if (len(prompts) == 0):
break break
...@@ -490,10 +499,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -490,10 +499,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1 index_of_first_image = 1
if opts.grid_save: if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc() devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts) return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
......
callbacks_model_loaded = []
callbacks_ui_tabs = []
def clear_callbacks():
callbacks_model_loaded.clear()
callbacks_ui_tabs.clear()
def model_loaded_callback(sd_model):
for callback in callbacks_model_loaded:
callback(sd_model)
def ui_tabs_callback():
res = []
for callback in callbacks_ui_tabs:
res += callback() or []
return res
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
callbacks_model_loaded.append(callback)
def on_ui_tabs(callback):
"""register a function to be called when the UI is creating new tabs.
The function must either return a None, which means no new tabs to be added, or a list, where
each element is a tuple:
(gradio_component, title, elem_id)
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
callbacks_ui_tabs.append(callback)
import os import os
import sys import sys
import traceback import traceback
from collections import namedtuple
import modules.ui as ui import modules.ui as ui
import gradio as gr import gradio as gr
from modules.processing import StableDiffusionProcessing from modules.processing import StableDiffusionProcessing
from modules import shared from modules import shared, paths, script_callbacks
AlwaysVisible = object()
class Script: class Script:
filename = None filename = None
args_from = None args_from = None
args_to = None args_to = None
alwayson = False
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
"""
# The title of the script. This is what will be displayed in the dropdown menu.
def title(self): def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
raise NotImplementedError() raise NotImplementedError()
# How the script is displayed in the UI. See https://gradio.app/docs/#components
# for the different UI components you can use and how to create them.
# Most UI components can return a value, such as a boolean for a checkbox.
# The returned values are passed to the run method as parameters.
def ui(self, is_img2img): def ui(self, is_img2img):
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
The return value should be an array of all components that are used in processing.
Values of those returned componenbts will be passed to run() and process() functions.
"""
pass pass
# Determines when the script should be shown in the dropdown menu via the
# returned value. As an example:
# is_img2img is True if the current tab is img2img, and False if it is txt2img.
# Thus, return is_img2img to only show the script on the img2img tab.
def show(self, is_img2img): def show(self, is_img2img):
"""
is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
This function should return:
- False if the script should not be shown in UI at all
- True if the script should be shown in UI if it's scelected in the scripts drowpdown
- script.AlwaysVisible if the script should be shown in UI at all times
"""
return True return True
# This is where the additional processing is implemented. The parameters include def run(self, p, *args):
# self, the model object "p" (a StableDiffusionProcessing class, see """
# processing.py), and the parameters returned by the ui method. This function is called if the script has been selected in the script dropdown.
# Custom functions can be defined here, and additional libraries can be imported It must do all processing and return the Processed object with results, same as
# to be used in processing. The return value should be a Processed object, which is one returned by processing.process_images.
# what is returned by the process_images method.
def run(self, *args): Usually the processing is done by calling the processing.process_images function.
args contains all values returned by components from ui()
"""
raise NotImplementedError() raise NotImplementedError()
# The description method is currently unused. def process(self, p, *args):
# To add a description that appears when hovering over the title, amend the "titles" """
# dict in script.js to include the script title (returned by title) as a key, and This function is called before processing begins for AlwaysVisible scripts.
# your description as the value. scripts. You can modify the processing object (p) here, inject hooks, etc.
"""
pass
def describe(self): def describe(self):
"""unused"""
return "" return ""
current_basedir = paths.script_path
def basedir():
"""returns the base directory for the current script. For scripts in the main scripts directory,
this is the main directory (where webui.py resides), and for scripts in extensions directory
(ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
"""
return current_basedir
scripts_data = [] scripts_data = []
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
def load_scripts(basedir): def list_scripts(scriptdirname, extension):
if not os.path.exists(basedir): scripts_list = []
return
basedir = os.path.join(paths.script_path, scriptdirname)
if os.path.exists(basedir):
for filename in sorted(os.listdir(basedir)): for filename in sorted(os.listdir(basedir)):
path = os.path.join(basedir, filename) scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
extdir = os.path.join(paths.script_path, "extensions")
if os.path.exists(extdir):
for dirname in sorted(os.listdir(extdir)):
dirpath = os.path.join(extdir, dirname)
scriptdirpath = os.path.join(dirpath, scriptdirname)
if os.path.splitext(path)[1].lower() != '.py': if not os.path.isdir(scriptdirpath):
continue continue
if not os.path.isfile(path): for filename in sorted(os.listdir(scriptdirpath)):
scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
return scripts_list
def list_files_with_name(filename):
res = []
dirs = [paths.script_path]
extdir = os.path.join(paths.script_path, "extensions")
if os.path.exists(extdir):
dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
for dirpath in dirs:
if not os.path.isdir(dirpath):
continue continue
path = os.path.join(dirpath, filename)
if os.path.isfile(filename):
res.append(path)
return res
def load_scripts():
global current_basedir
scripts_data.clear()
script_callbacks.clear_callbacks()
scripts_list = list_scripts("scripts", ".py")
syspath = sys.path
for scriptfile in sorted(scripts_list):
try: try:
with open(path, "r", encoding="utf8") as file: if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir
with open(scriptfile.path, "r", encoding="utf8") as file:
text = file.read() text = file.read()
from types import ModuleType from types import ModuleType
compiled = compile(text, path, 'exec') compiled = compile(text, scriptfile.path, 'exec')
module = ModuleType(filename) module = ModuleType(scriptfile.filename)
exec(compiled, module.__dict__) exec(compiled, module.__dict__)
for key, script_class in module.__dict__.items(): for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script): if type(script_class) == type and issubclass(script_class, Script):
scripts_data.append((script_class, path)) scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
except Exception: except Exception:
print(f"Error loading script: {filename}", file=sys.stderr) print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
finally:
sys.path = syspath
current_basedir = paths.script_path
def wrap_call(func, filename, funcname, *args, default=None, **kwargs): def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try: try:
...@@ -96,56 +185,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs): ...@@ -96,56 +185,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptRunner: class ScriptRunner:
def __init__(self): def __init__(self):
self.scripts = [] self.scripts = []
self.selectable_scripts = []
self.alwayson_scripts = []
self.titles = [] self.titles = []
self.infotext_fields = []
def setup_ui(self, is_img2img): def setup_ui(self, is_img2img):
for script_class, path in scripts_data: for script_class, path, basedir in scripts_data:
script = script_class() script = script_class()
script.filename = path script.filename = path
if not script.show(is_img2img): visibility = script.show(is_img2img)
continue
if visibility == AlwaysVisible:
self.scripts.append(script) self.scripts.append(script)
self.alwayson_scripts.append(script)
script.alwayson = True
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts] elif visibility:
self.scripts.append(script)
self.selectable_scripts.append(script)
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
dropdown.save_to_config = True
inputs = [dropdown]
for script in self.scripts: inputs = [None]
inputs_alwayson = [True]
def create_script_ui(script, inputs, inputs_alwayson):
script.args_from = len(inputs) script.args_from = len(inputs)
script.args_to = len(inputs) script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", is_img2img) controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
if controls is None: if controls is None:
continue return
for control in controls: for control in controls:
control.custom_script_source = os.path.basename(script.filename) control.custom_script_source = os.path.basename(script.filename)
if not script.alwayson:
control.visible = False control.visible = False
if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields
inputs += controls inputs += controls
inputs_alwayson += [script.alwayson for _ in controls]
script.args_to = len(inputs) script.args_to = len(inputs)
for script in self.alwayson_scripts:
with gr.Group():
create_script_ui(script, inputs, inputs_alwayson)
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs[0] = dropdown
for script in self.selectable_scripts:
create_script_ui(script, inputs, inputs_alwayson)
def select_script(script_index): def select_script(script_index):
if 0 < script_index <= len(self.scripts): if 0 < script_index <= len(self.selectable_scripts):
script = self.scripts[script_index-1] script = self.selectable_scripts[script_index-1]
args_from = script.args_from args_from = script.args_from
args_to = script.args_to args_to = script.args_to
else: else:
args_from = 0 args_from = 0
args_to = 0 args_to = 0
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))] return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
def init_field(title): def init_field(title):
if title == 'None': if title == 'None':
return return
script_index = self.titles.index(title) script_index = self.titles.index(title)
script = self.scripts[script_index] script = self.selectable_scripts[script_index]
for i in range(script.args_from, script.args_to): for i in range(script.args_from, script.args_to):
inputs[i].visible = True inputs[i].visible = True
...@@ -164,7 +277,7 @@ class ScriptRunner: ...@@ -164,7 +277,7 @@ class ScriptRunner:
if script_index == 0: if script_index == 0:
return None return None
script = self.scripts[script_index-1] script = self.selectable_scripts[script_index-1]
if script is None: if script is None:
return None return None
...@@ -176,7 +289,16 @@ class ScriptRunner: ...@@ -176,7 +289,16 @@ class ScriptRunner:
return processed return processed
def reload_sources(self): def run_alwayson_scripts(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)): for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file: with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from args_from = script.args_from
...@@ -186,9 +308,12 @@ class ScriptRunner: ...@@ -186,9 +308,12 @@ class ScriptRunner:
from types import ModuleType from types import ModuleType
module = cache.get(filename, None)
if module is None:
compiled = compile(text, filename, 'exec') compiled = compile(text, filename, 'exec')
module = ModuleType(script.filename) module = ModuleType(script.filename)
exec(compiled, module.__dict__) exec(compiled, module.__dict__)
cache[filename] = module
for key, script_class in module.__dict__.items(): for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script): if type(script_class) == type and issubclass(script_class, Script):
...@@ -197,19 +322,22 @@ class ScriptRunner: ...@@ -197,19 +322,22 @@ class ScriptRunner:
self.scripts[si].args_from = args_from self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner() scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner() scripts_img2img = ScriptRunner()
def reload_script_body_only(): def reload_script_body_only():
scripts_txt2img.reload_sources() cache = {}
scripts_img2img.reload_sources() scripts_txt2img.reload_sources(cache)
scripts_img2img.reload_sources(cache)
def reload_scripts(basedir): def reload_scripts():
global scripts_txt2img, scripts_img2img global scripts_txt2img, scripts_img2img
scripts_data.clear() load_scripts()
load_scripts(basedir)
scripts_txt2img = ScriptRunner() scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner() scripts_img2img = ScriptRunner()
...@@ -332,7 +332,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -332,7 +332,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
multipliers.append([1.0] * 75) multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers) z1 = self.process_tokens(tokens, multipliers)
z1 = shared.aesthetic_clip(z1, remade_batch_tokens)
z = z1 if z is None else torch.cat((z, z1), axis=-2) z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens remade_batch_tokens = rem_tokens
......
...@@ -7,7 +7,7 @@ from omegaconf import OmegaConf ...@@ -7,7 +7,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices from modules import shared, modelloader, devices, script_callbacks
from modules.paths import models_path from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
...@@ -238,6 +238,9 @@ def load_model(checkpoint_info=None): ...@@ -238,6 +238,9 @@ def load_model(checkpoint_info=None):
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
sd_model.eval() sd_model.eval()
shared.sd_model = sd_model
script_callbacks.model_loaded_callback(sd_model)
print(f"Model loaded.") print(f"Model loaded.")
return sd_model return sd_model
...@@ -252,7 +255,7 @@ def reload_model_weights(sd_model, info=None): ...@@ -252,7 +255,7 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear() checkpoints_loaded.clear()
shared.sd_model = load_model(checkpoint_info) load_model(checkpoint_info)
return shared.sd_model return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
......
...@@ -31,7 +31,6 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th ...@@ -31,7 +31,6 @@ parser.add_argument("--no-half-vae", action='store_true', help="do not switch th
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)") parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
parser.add_argument("--aesthetic_embeddings-dir", type=str, default=os.path.join(models_path, 'aesthetic_embeddings'), help="aesthetic_embeddings directory(default: aesthetic_embeddings)")
parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory")
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
...@@ -81,6 +80,7 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl ...@@ -81,6 +80,7 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
restricted_opts = [ restricted_opts = [
...@@ -109,21 +109,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) ...@@ -109,21 +109,6 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None loaded_hypernetwork = None
os.makedirs(cmd_opts.aesthetic_embeddings_dir, exist_ok=True)
aesthetic_embeddings = {}
def update_aesthetic_embeddings():
global aesthetic_embeddings
aesthetic_embeddings = {f.replace(".pt", ""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings)
update_aesthetic_embeddings()
def reload_hypernetworks(): def reload_hypernetworks():
global hypernetworks global hypernetworks
...@@ -333,6 +318,14 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" ...@@ -333,6 +318,14 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
})) }))
options_templates.update(options_section(('images-history', "Images Browser"), {
#"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
"images_history_preload": OptionInfo(False, "Preload images at startup"),
"images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"),
"images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "),
"images_history_grid_num": OptionInfo(6, "Number of grids in each row"),
}))
class Options: class Options:
data = None data = None
...@@ -407,9 +400,6 @@ sd_model = None ...@@ -407,9 +400,6 @@ sd_model = None
clip_model = None clip_model = None
from modules.aesthetic_clip import AestheticCLIP
aesthetic_clip = AestheticCLIP()
progress_print_out = sys.stdout progress_print_out = sys.stdout
......
...@@ -7,7 +7,7 @@ import modules.processing as processing ...@@ -7,7 +7,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args): def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args):
p = StableDiffusionProcessingTxt2Img( p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
...@@ -36,7 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -36,7 +36,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None, firstphase_height=firstphase_height if enable_hr else None,
) )
shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) p.scripts = modules.scripts.scripts_txt2img
p.script_args = args
if cmd_opts.enable_console_prompts: if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
......
...@@ -23,10 +23,10 @@ import gradio as gr ...@@ -23,10 +23,10 @@ import gradio as gr
import gradio.utils import gradio.utils
import gradio.routes import gradio.routes
from modules import sd_hijack, sd_models, localization from modules import sd_hijack, sd_models, localization, script_callbacks
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts, aesthetic_embeddings from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru: if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags from modules.deepbooru import get_deepbooru_tags
...@@ -44,7 +44,6 @@ from modules.images import save_image ...@@ -44,7 +44,6 @@ from modules.images import save_image
import modules.textual_inversion.ui import modules.textual_inversion.ui
import modules.hypernetworks.ui import modules.hypernetworks.ui
import modules.aesthetic_clip as aesthetic_clip
import modules.images_history as img_his import modules.images_history as img_his
...@@ -662,8 +661,6 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -662,8 +661,6 @@ def create_ui(wrap_gradio_gpu_call):
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui()
with gr.Group(): with gr.Group():
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
...@@ -718,14 +715,6 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -718,14 +715,6 @@ def create_ui(wrap_gradio_gpu_call):
denoising_strength, denoising_strength,
firstphase_width, firstphase_width,
firstphase_height, firstphase_height,
aesthetic_lr,
aesthetic_weight,
aesthetic_steps,
aesthetic_imgs,
aesthetic_slerp,
aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
...@@ -804,14 +793,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -804,14 +793,7 @@ def create_ui(wrap_gradio_gpu_call):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"), (firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"), (firstphase_height, "First pass size-2"),
(aesthetic_lr, "Aesthetic LR"), *modules.scripts.scripts_txt2img.infotext_fields
(aesthetic_weight, "Aesthetic weight"),
(aesthetic_steps, "Aesthetic steps"),
(aesthetic_imgs, "Aesthetic embedding"),
(aesthetic_slerp, "Aesthetic slerp"),
(aesthetic_imgs_text, "Aesthetic text"),
(aesthetic_text_negative, "Aesthetic text negative"),
(aesthetic_slerp_angle, "Aesthetic slerp angle"),
] ]
txt2img_preview_params = [ txt2img_preview_params = [
...@@ -896,8 +878,6 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -896,8 +878,6 @@ def create_ui(wrap_gradio_gpu_call):
seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()
aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui()
with gr.Group(): with gr.Group():
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
...@@ -988,14 +968,6 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -988,14 +968,6 @@ def create_ui(wrap_gradio_gpu_call):
inpainting_mask_invert, inpainting_mask_invert,
img2img_batch_input_dir, img2img_batch_input_dir,
img2img_batch_output_dir, img2img_batch_output_dir,
aesthetic_lr_im,
aesthetic_weight_im,
aesthetic_steps_im,
aesthetic_imgs_im,
aesthetic_slerp_im,
aesthetic_imgs_text_im,
aesthetic_slerp_angle_im,
aesthetic_text_negative_im,
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
img2img_gallery, img2img_gallery,
...@@ -1087,14 +1059,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1087,14 +1059,7 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"), (seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(aesthetic_lr_im, "Aesthetic LR"), *modules.scripts.scripts_img2img.infotext_fields
(aesthetic_weight_im, "Aesthetic weight"),
(aesthetic_steps_im, "Aesthetic steps"),
(aesthetic_imgs_im, "Aesthetic embedding"),
(aesthetic_slerp_im, "Aesthetic slerp"),
(aesthetic_imgs_text_im, "Aesthetic text"),
(aesthetic_text_negative_im, "Aesthetic text negative"),
(aesthetic_slerp_angle_im, "Aesthetic slerp angle"),
] ]
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
...@@ -1217,12 +1182,12 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1217,12 +1182,12 @@ def create_ui(wrap_gradio_gpu_call):
) )
#images history #images history
images_history_switch_dict = { images_history_switch_dict = {
"fn":modules.generation_parameters_copypaste.connect_paste, "fn": modules.generation_parameters_copypaste.connect_paste,
"t2i":txt2img_paste_fields, "t2i": txt2img_paste_fields,
"i2i":img2img_paste_fields "i2i": img2img_paste_fields
} }
images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict) images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
with gr.Blocks() as modelmerger_interface: with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
...@@ -1264,18 +1229,6 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1264,18 +1229,6 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column(): with gr.Column():
create_embedding = gr.Button(value="Create embedding", variant='primary') create_embedding = gr.Button(value="Create embedding", variant='primary')
with gr.Tab(label="Create aesthetic images embedding"):
new_embedding_name_ae = gr.Textbox(label="Name")
process_src_ae = gr.Textbox(label='Source directory')
batch_ae = gr.Slider(minimum=1, maximum=1024, step=1, label="Batch size", value=256)
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
create_embedding_ae = gr.Button(value="Create images embedding", variant='primary')
with gr.Tab(label="Create hypernetwork"): with gr.Tab(label="Create hypernetwork"):
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
...@@ -1375,21 +1328,6 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1375,21 +1328,6 @@ def create_ui(wrap_gradio_gpu_call):
] ]
) )
create_embedding_ae.click(
fn=aesthetic_clip.generate_imgs_embd,
inputs=[
new_embedding_name_ae,
process_src_ae,
batch_ae
],
outputs=[
aesthetic_imgs,
aesthetic_imgs_im,
ti_output,
ti_outcome,
]
)
create_hypernetwork.click( create_hypernetwork.click(
fn=modules.hypernetworks.ui.create_hypernetwork, fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[ inputs=[
...@@ -1580,10 +1518,10 @@ Requested path was: {f} ...@@ -1580,10 +1518,10 @@ Requested path was: {f}
if not opts.same_type(value, opts.data_labels[key].default): if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson() return gr.update(visible=True), opts.dumpjson()
oldval = opts.data.get(key, None)
if cmd_opts.hide_ui_dir_config and key in restricted_opts: if cmd_opts.hide_ui_dir_config and key in restricted_opts:
return gr.update(value=oldval), opts.dumpjson() return gr.update(value=oldval), opts.dumpjson()
oldval = opts.data.get(key, None)
opts.data[key] = value opts.data[key] = value
if oldval != value: if oldval != value:
...@@ -1689,19 +1627,24 @@ Requested path was: {f} ...@@ -1689,19 +1627,24 @@ Requested path was: {f}
(img2img_interface, "img2img", "img2img"), (img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"), (extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"), (pnginfo_interface, "PNG Info", "pnginfo"),
(images_history, "History", "images_history"), (images_history, "Image Browser", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"), (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"), (train_interface, "Train", "ti"),
(settings_interface, "Settings", "settings"),
] ]
with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file: interfaces += script_callbacks.ui_tabs_callback()
css = file.read()
interfaces += [(settings_interface, "Settings", "settings")]
css = ""
for cssfile in modules.scripts.list_files_with_name("style.css"):
with open(cssfile, "r", encoding="utf8") as file:
css += file.read() + "\n"
if os.path.exists(os.path.join(script_path, "user.css")): if os.path.exists(os.path.join(script_path, "user.css")):
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file: with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
usercss = file.read() css += file.read() + "\n"
css += usercss
if not cmd_opts.no_progressbar_hiding: if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar css += css_hide_progressbar
...@@ -1924,9 +1867,9 @@ def load_javascript(raw_response): ...@@ -1924,9 +1867,9 @@ def load_javascript(raw_response):
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>' javascript = f'<script>{jsfile.read()}</script>'
jsdir = os.path.join(script_path, "javascript") scripts_list = modules.scripts.list_scripts("javascript", ".js")
for filename in sorted(os.listdir(jsdir)): for basedir, filename, path in scripts_list:
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>" javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
if cmd_opts.theme is not None: if cmd_opts.theme is not None:
...@@ -1944,6 +1887,5 @@ def load_javascript(raw_response): ...@@ -1944,6 +1887,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response gradio.routes.templates.TemplateResponse = template_response
reload_javascript = partial(load_javascript, reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
gradio.routes.templates.TemplateResponse)
reload_javascript() reload_javascript()
...@@ -477,7 +477,7 @@ input[type="range"]{ ...@@ -477,7 +477,7 @@ input[type="range"]{
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_aesthetic_embeddings{ #refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{
max-width: 2.5em; max-width: 2.5em;
min-width: 2.5em; min-width: 2.5em;
height: 2.4em; height: 2.4em;
......
...@@ -71,6 +71,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): ...@@ -71,6 +71,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
def initialize(): def initialize():
modelloader.cleanup_models() modelloader.cleanup_models()
modules.sd_models.setup_model() modules.sd_models.setup_model()
...@@ -79,9 +80,9 @@ def initialize(): ...@@ -79,9 +80,9 @@ def initialize():
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers() modelloader.load_upscalers()
modules.scripts.load_scripts(os.path.join(script_path, "scripts")) modules.scripts.load_scripts()
shared.sd_model = modules.sd_models.load_model() modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
...@@ -145,7 +146,7 @@ def webui(): ...@@ -145,7 +146,7 @@ def webui():
sd_samplers.set_samplers() sd_samplers.set_samplers()
print('Reloading Custom Scripts') print('Reloading Custom Scripts')
modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) modules.scripts.reload_scripts()
print('Reloading modules: modules.ui') print('Reloading modules: modules.ui')
importlib.reload(modules.ui) importlib.reload(modules.ui)
print('Refreshing Model List') print('Refreshing Model List')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment