Commit dcb45dfe authored by discus0434's avatar discus0434

Merge branch 'master' of upstream

parents 0e8ca8e7 50b55044
......@@ -45,6 +45,8 @@ body:
attributes:
label: Commit where the problem happens
description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
validations:
required: true
- type: dropdown
id: platforms
attributes:
......
blank_issues_enabled: false
contact_links:
- name: WebUI Community Support
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
about: Please ask and answer questions here.
......@@ -27,4 +27,5 @@ __pycache__
notification.mp3
/SwinIR
/textual_inversion
.vscode
\ No newline at end of file
.vscode
/extensions
......@@ -11,6 +11,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- One click install and run script (but you still must install python and git)
- Outpainting
- Inpainting
- Color Sketch
- Prompt Matrix
- Stable Diffusion Upscale
- Attention, specify parts of text that the model should pay more attention to
......@@ -23,6 +24,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- have as many embeddings as you want and use any names you like for them
- use multiple embeddings with different numbers of vectors per token
- works with half precision floating point numbers
- train embeddings on 8GB (also reports of 6GB working)
- Extras tab with:
- GFPGAN, neural network that fixes faces
- CodeFormer, face restoration tool as an alternative to GFPGAN
......@@ -37,14 +39,14 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- Interrupt processing at any time
- 4GB video card support (also reports of 2GB working)
- Correct seeds for batches
- Prompt length validation
- get length of prompt in tokens as you type
- get a warning after generation if some text was truncated
- Live prompt token length validation
- Generation parameters
- parameters you used to generate images are saved with that image
- in PNG chunks for PNG, in EXIF for JPEG
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
- can be disabled in settings
- drag and drop an image/text-parameters to promptbox
- Read Generation Parameters Button, loads parameters in promptbox to UI
- Settings page
- Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements
......@@ -59,10 +61,10 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- CLIP interrogator, a button that tries to guess prompt from an image
- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
- Batch Processing, process a group of files using img2img
- Img2img Alternative
- Img2img Alternative, reverse Euler method of cross attention control
- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
- Reloading checkpoints on the fly
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
- Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
- separate prompts using uppercase `AND`
......@@ -70,14 +72,35 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
- History tab: view, direct and delete images conveniently within the UI
- Generate forever option
- Training tab
- hypernetworks and embeddings options
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
- Clip skip
- Use Hypernetworks
- Use VAEs
- Estimated completion time in progress bar
- API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
## Where are Aesthetic Gradients?!?!
Aesthetic Gradients are now an extension. You can install it using git:
```commandline
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients
```
After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart
the UI. The interface for Aesthetic Gradients should appear exactly the same as it was.
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
Alternatively, use Google Colab:
Alternatively, use online services (like Google Colab):
- [Colab, maintained by Akaibu](https://colab.research.google.com/drive/1kw3egmSn-KgWsikYvOMjJkVDsPLjEMzl)
- [Colab, original by me, outdated](https://colab.research.google.com/drive/1Iy-xW9t1-OQWhb0hNxueGij8phCyluOh).
- [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
......
......@@ -3,12 +3,12 @@ let currentWidth = null;
let currentHeight = null;
let arFrameTimeout = setTimeout(function(){},0);
function dimensionChange(e,dimname){
function dimensionChange(e, is_width, is_height){
if(dimname == 'Width'){
if(is_width){
currentWidth = e.target.value*1.0
}
if(dimname == 'Height'){
if(is_height){
currentHeight = e.target.value*1.0
}
......@@ -18,22 +18,13 @@ function dimensionChange(e,dimname){
return;
}
var img2imgMode = gradioApp().querySelector('#mode_img2img.tabs > div > button.rounded-t-lg.border-gray-200')
if(img2imgMode){
img2imgMode=img2imgMode.innerText
}else{
return;
}
var redrawImage = gradioApp().querySelector('div[data-testid=image] img');
var inpaintImage = gradioApp().querySelector('#img2maskimg div[data-testid=image] img')
var targetElement = null;
if(img2imgMode=='img2img' && redrawImage){
targetElement = redrawImage;
}else if(img2imgMode=='Inpaint' && inpaintImage){
targetElement = inpaintImage;
var tabIndex = get_tab_index('mode_img2img')
if(tabIndex == 0){
targetElement = gradioApp().querySelector('div[data-testid=image] img');
} else if(tabIndex == 1){
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
}
if(targetElement){
......@@ -98,22 +89,20 @@ onUiUpdate(function(){
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
if(inImg2img){
let inputs = gradioApp().querySelectorAll('input');
inputs.forEach(function(e){
let parentLabel = e.parentElement.querySelector('label')
if(parentLabel && parentLabel.innerText){
if(!e.classList.contains('scrollwatch')){
if(parentLabel.innerText == 'Width' || parentLabel.innerText == 'Height'){
e.addEventListener('input', function(e){dimensionChange(e,parentLabel.innerText)} )
e.classList.add('scrollwatch')
}
if(parentLabel.innerText == 'Width'){
currentWidth = e.value*1.0
}
if(parentLabel.innerText == 'Height'){
currentHeight = e.value*1.0
}
}
}
inputs.forEach(function(e){
var is_width = e.parentElement.id == "img2img_width"
var is_height = e.parentElement.id == "img2img_height"
if((is_width || is_height) && !e.classList.contains('scrollwatch')){
e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
e.classList.add('scrollwatch')
}
if(is_width){
currentWidth = e.value*1.0
}
if(is_height){
currentHeight = e.value*1.0
}
})
}
});
......@@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) {
window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) {
if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
return;
}
e.stopPropagation();
......
......@@ -17,14 +17,6 @@ var images_history_click_image = function(){
images_history_set_image_info(this);
}
var images_history_click_tab = function(){
var tabs_box = gradioApp().getElementById("images_history_tab");
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
tabs_box.classList.add(this.getAttribute("tabname"))
}
}
function images_history_disabled_del(){
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
btn.setAttribute('disabled','disabled');
......@@ -43,7 +35,6 @@ function images_history_get_parent_by_tagname(item, tagname){
var parent = item.parentElement;
tagname = tagname.toUpperCase()
while(parent.tagName != tagname){
console.log(parent.tagName, tagname)
parent = parent.parentElement;
}
return parent;
......@@ -88,15 +79,15 @@ function images_history_set_image_info(button){
}
function images_history_get_current_img(tabname, image_path, files){
function images_history_get_current_img(tabname, img_index, files){
return [
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
image_path,
tabname,
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
files
];
}
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){
function images_history_delete(del_num, tabname, image_index){
image_index = parseInt(image_index);
var tab = gradioApp().getElementById(tabname + '_images_history');
var set_btn = tab.querySelector(".images_history_set_index");
......@@ -107,6 +98,7 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i
}
});
var img_num = buttons.length / 2;
del_num = Math.min(img_num - image_index, del_num)
if (img_num <= del_num){
setTimeout(function(tabname){
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
......@@ -114,30 +106,28 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i
} else {
var next_img
for (var i = 0; i < del_num; i++){
if (image_index + i < image_index + img_num){
buttons[image_index + i].style.display = 'none';
buttons[image_index + img_num + 1].style.display = 'none';
next_img = image_index + i + 1
}
buttons[image_index + i].style.display = 'none';
buttons[image_index + i + img_num].style.display = 'none';
next_img = image_index + i + 1
}
var bnt;
if (next_img >= img_num){
btn = buttons[image_index - del_num];
btn = buttons[image_index - 1];
} else {
btn = buttons[next_img];
}
setTimeout(function(btn){btn.click()}, 30, btn);
}
images_history_disabled_del();
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
}
function images_history_turnpage(img_path, page_index, image_index, tabname){
function images_history_turnpage(tabname){
gradioApp().getElementById(tabname + '_images_history_del_button').setAttribute('disabled','disabled');
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
buttons.forEach(function(elem) {
elem.style.display = 'block';
})
return [img_path, page_index, image_index, tabname];
})
}
function images_history_enable_del_buttons(){
......@@ -147,60 +137,64 @@ function images_history_enable_del_buttons(){
}
function images_history_init(){
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page')
if (load_txt2img_button){
var tabnames = gradioApp().getElementById("images_history_tabnames_list")
if (tabnames){
images_history_tab_list = tabnames.querySelector("textarea").value.split(",")
for (var i in images_history_tab_list ){
tab = images_history_tab_list[i];
var tab = images_history_tab_list[i];
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
gradioApp().getElementById(tab + "_images_history_start").setAttribute("style","padding:20px;font-size:25px");
}
//preload
if (gradioApp().getElementById("images_history_preload").querySelector("input").checked ){
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
tabs_box.setAttribute("id", "images_history_tab");
var tab_btns = tabs_box.querySelectorAll("button");
for (var i in images_history_tab_list){
var tabname = images_history_tab_list[i]
tab_btns[i].setAttribute("tabname", tabname);
tab_btns[i].addEventListener('click', function(){
var tabs_box = gradioApp().getElementById("images_history_tab");
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click();
tabs_box.classList.add(this.getAttribute("tabname"))
}
});
}
tab_btns[0].click()
}
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
tabs_box.setAttribute("id", "images_history_tab");
var tab_btns = tabs_box.querySelectorAll("button");
for (var i in images_history_tab_list){
var tabname = images_history_tab_list[i]
tab_btns[i].setAttribute("tabname", tabname);
// this refreshes history upon tab switch
// until the history is known to work well, which is not the case now, we do not do this at startup
//tab_btns[i].addEventListener('click', images_history_click_tab);
}
tabs_box.classList.add(images_history_tab_list[0]);
// same as above, at page load
//load_txt2img_button.click();
} else {
setTimeout(images_history_init, 500);
}
}
var images_history_tab_list = ["txt2img", "img2img", "extras"];
var images_history_tab_list = "";
setTimeout(images_history_init, 500);
document.addEventListener("DOMContentLoaded", function() {
var mutationObserver = new MutationObserver(function(m){
for (var i in images_history_tab_list ){
let tabname = images_history_tab_list[i]
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
buttons.forEach(function(bnt){
bnt.addEventListener('click', images_history_click_image, true);
});
// same as load_txt2img_button.click() above
/*
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
if (cls_btn){
cls_btn.addEventListener('click', function(){
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
}, false);
}*/
}
if (images_history_tab_list != ""){
for (var i in images_history_tab_list ){
let tabname = images_history_tab_list[i]
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
buttons.forEach(function(bnt){
bnt.addEventListener('click', images_history_click_image, true);
});
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
if (cls_btn){
cls_btn.addEventListener('click', function(){
gradioApp().getElementById(tabname + '_images_history_renew_page').click();
}, false);
}
}
}
});
mutationObserver.observe( gradioApp(), { childList:true, subtree:true });
mutationObserver.observe(gradioApp(), { childList:true, subtree:true });
});
import sys, os, shlex
import contextlib
import torch
from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
......@@ -9,10 +8,22 @@ has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
def extract_device_id(args, name):
for x in range(len(args)):
if name in args[x]: return args[x+1]
return None
def get_optimal_device():
if torch.cuda.is_available():
return torch.device("cuda")
from modules import shared
device_id = shared.cmd_opts.device_id
if device_id is not None:
cuda_device = f"cuda:{device_id}"
return torch.device(cuda_device)
else:
return torch.device("cuda")
if has_mps:
return torch.device("mps")
......@@ -34,7 +45,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
......
......@@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '':
return outputs, "Please select an input directory.", ''
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
for img in image_list:
image = Image.open(img)
try:
image = Image.open(img)
except Exception:
continue
imageArr.append(image)
imageNameArr.append(img)
else:
......@@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
forced_filename=image_name if opts.use_original_name_batch else None)
if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
basename = ''
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo:
image.info = existing_pnginfo
......
......@@ -4,13 +4,22 @@ import gradio as gr
from modules.shared import script_path
from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update())
def quote(text):
if ',' not in str(text):
return text
text = str(text)
text = text.replace('\\', '\\\\')
text = text.replace('"', '\\"')
return f'"{text}"'
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
......@@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
else:
try:
valtype = type(output.value)
val = valtype(v)
if valtype == bool and v == "False":
val = False
else:
val = valtype(v)
res.append(gr.update(value=val))
except Exception:
res.append(gr.update())
......
......@@ -41,12 +41,12 @@ class HypernetworkModule(torch.nn.Module):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
# Add an activation func
if activation_func == "linear":
if activation_func == "linear" or activation_func is None:
pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else:
raise NotImplementedError(
raise RuntimeError(
"Valid activation funcs: 'linear', 'relu', 'leakyrelu', 'elu', 'swish'"
)
......@@ -65,7 +65,7 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
if isinstance(layer, torch.nn.Linear):
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
......@@ -93,7 +93,7 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
if isinstance(layer, torch.nn.Linear):
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer_structure += [layer.weight, layer.bias]
return layer_structure
......@@ -272,6 +272,9 @@ def stack_conds(conds):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
......@@ -314,6 +317,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
......@@ -353,7 +357,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
pbar.set_description(f"loss: {mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
# Before saving, change name to match current checkpoint.
hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
......@@ -362,7 +368,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
......@@ -398,7 +405,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if image is not None:
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
......@@ -408,7 +415,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
Loss: {mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
......@@ -417,6 +424,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
hypernetwork.sd_checkpoint = checkpoint.hash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
# Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
hypernetwork.name = hypernetwork_name
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(filename)
return hypernetwork, filename
......
......@@ -9,9 +9,13 @@ from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
......
import os
import shutil
import sys
import time
import hashlib
import gradio
system_bak_path = "webui_log_and_bak"
custom_tab_name = "custom fold"
faverate_tab_name = "favorites"
tabs_list = ["txt2img", "img2img", "extras", faverate_tab_name]
def is_valid_date(date):
try:
time.strptime(date, "%Y%m%d")
return True
except:
return False
def reduplicative_file_move(src, dst):
def same_name_file(basename, path):
name, ext = os.path.splitext(basename)
f_list = os.listdir(path)
max_num = 0
for f in f_list:
if len(f) <= len(basename):
continue
f_ext = f[-len(ext):] if len(ext) > 0 else ""
if f[:len(name)] == name and f_ext == ext:
if f[len(name)] == "(" and f[-len(ext)-1] == ")":
number = f[len(name)+1:-len(ext)-1]
if number.isdigit():
if int(number) > max_num:
max_num = int(number)
return f"{name}({max_num + 1}){ext}"
name = os.path.basename(src)
save_name = os.path.join(dst, name)
if not os.path.exists(save_name):
shutil.move(src, dst)
else:
name = same_name_file(name, dst)
shutil.move(src, os.path.join(dst, name))
def traverse_all_files(output_dir, image_list, curr_dir=None):
curr_path = output_dir if curr_dir is None else os.path.join(output_dir, curr_dir)
def traverse_all_files(curr_path, image_list, all_type=False):
try:
f_list = os.listdir(curr_path)
except:
if curr_dir[-10:].rfind(".") > 0 and curr_dir[-4:] != ".txt":
image_list.append(curr_dir)
if all_type or (curr_path[-10:].rfind(".") > 0 and curr_path[-4:] != ".txt" and curr_path[-4:] != ".csv"):
image_list.append(curr_path)
return image_list
for file in f_list:
file = file if curr_dir is None else os.path.join(curr_dir, file)
file_path = os.path.join(curr_path, file)
if file[-4:] == ".txt":
file = os.path.join(curr_path, file)
if (not all_type) and (file[-4:] == ".txt" or file[-4:] == ".csv"):
pass
elif os.path.isfile(file_path) and file[-10:].rfind(".") > 0:
elif os.path.isfile(file) and file[-10:].rfind(".") > 0:
image_list.append(file)
else:
image_list = traverse_all_files(output_dir, image_list, file)
image_list = traverse_all_files(file, image_list)
return image_list
def auto_sorting(dir_name):
bak_path = os.path.join(dir_name, system_bak_path)
if not os.path.exists(bak_path):
os.mkdir(bak_path)
log_file = None
files_list = []
f_list = os.listdir(dir_name)
for file in f_list:
if file == system_bak_path:
continue
file_path = os.path.join(dir_name, file)
if not is_valid_date(file):
if file[-10:].rfind(".") > 0:
files_list.append(file_path)
else:
files_list = traverse_all_files(file_path, files_list, all_type=True)
def get_recent_images(dir_name, page_index, step, image_index, tabname):
page_index = int(page_index)
image_list = []
if not os.path.exists(dir_name):
pass
elif os.path.isdir(dir_name):
image_list = traverse_all_files(dir_name, image_list)
image_list = sorted(image_list, key=lambda file: -os.path.getctime(os.path.join(dir_name, file)))
else:
print(f'ERROR: "{dir_name}" is not a directory. Check the path in the settings.', file=sys.stderr)
num = 48 if tabname != "extras" else 12
max_page_index = len(image_list) // num + 1
page_index = max_page_index if page_index == -1 else page_index + step
page_index = 1 if page_index < 1 else page_index
page_index = max_page_index if page_index > max_page_index else page_index
idx_frm = (page_index - 1) * num
image_list = image_list[idx_frm:idx_frm + num]
image_index = int(image_index)
if image_index < 0 or image_index > len(image_list) - 1:
current_file = None
hidden = None
else:
current_file = image_list[int(image_index)]
hidden = os.path.join(dir_name, current_file)
return [os.path.join(dir_name, file) for file in image_list], page_index, image_list, current_file, hidden, ""
def first_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, 1, 0, image_index, tabname)
def end_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, -1, 0, image_index, tabname)
def prev_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, page_index, -1, image_index, tabname)
def next_page_click(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, page_index, 1, image_index, tabname)
def page_index_change(dir_name, page_index, image_index, tabname):
return get_recent_images(dir_name, page_index, 0, image_index, tabname)
def show_image_info(num, image_path, filenames):
# print(f"select image {num}")
file = filenames[int(num)]
return file, num, os.path.join(image_path, file)
for file in files_list:
date_str = time.strftime("%Y%m%d",time.localtime(os.path.getmtime(file)))
file_path = os.path.dirname(file)
hash_path = hashlib.md5(file_path.encode()).hexdigest()
path = os.path.join(dir_name, date_str, hash_path)
if not os.path.exists(path):
os.makedirs(path)
if log_file is None:
log_file = open(os.path.join(bak_path,"path_mapping.csv"),"a")
log_file.write(f"{hash_path},{file_path}\n")
reduplicative_file_move(file, path)
date_list = []
f_list = os.listdir(dir_name)
for f in f_list:
if is_valid_date(f):
date_list.append(f)
elif f == system_bak_path:
continue
else:
try:
reduplicative_file_move(os.path.join(dir_name, f), bak_path)
except:
pass
today = time.strftime("%Y%m%d",time.localtime(time.time()))
if today not in date_list:
date_list.append(today)
return sorted(date_list, reverse=True)
def archive_images(dir_name, date_to):
filenames = []
batch_size =int(opts.images_history_num_per_page * opts.images_history_pages_num)
if batch_size <= 0:
batch_size = opts.images_history_num_per_page * 6
today = time.strftime("%Y%m%d",time.localtime(time.time()))
date_to = today if date_to is None or date_to == "" else date_to
date_to_bak = date_to
if False: #opts.images_history_reconstruct_directory:
date_list = auto_sorting(dir_name)
for date in date_list:
if date <= date_to:
path = os.path.join(dir_name, date)
if date == today and not os.path.exists(path):
continue
filenames = traverse_all_files(path, filenames)
if len(filenames) > batch_size:
break
filenames = sorted(filenames, key=lambda file: -os.path.getmtime(file))
else:
filenames = traverse_all_files(dir_name, filenames)
total_num = len(filenames)
tmparray = [(os.path.getmtime(file), file) for file in filenames ]
date_stamp = time.mktime(time.strptime(date_to, "%Y%m%d")) + 86400
filenames = []
date_list = {date_to:None}
date = time.strftime("%Y%m%d",time.localtime(time.time()))
for t, f in tmparray:
date = time.strftime("%Y%m%d",time.localtime(t))
date_list[date] = None
if t <= date_stamp:
filenames.append((t, f ,date))
date_list = sorted(list(date_list.keys()), reverse=True)
sort_array = sorted(filenames, key=lambda x:-x[0])
if len(sort_array) > batch_size:
date = sort_array[batch_size][2]
filenames = [x[1] for x in sort_array]
else:
date = date_to if len(sort_array) == 0 else sort_array[-1][2]
filenames = [x[1] for x in sort_array]
filenames = [x[1] for x in sort_array if x[2]>= date]
num = len(filenames)
last_date_from = date_to_bak if num == 0 else time.strftime("%Y%m%d", time.localtime(time.mktime(time.strptime(date, "%Y%m%d")) - 1000))
date = date[:4] + "/" + date[4:6] + "/" + date[6:8]
date_to_bak = date_to_bak[:4] + "/" + date_to_bak[4:6] + "/" + date_to_bak[6:8]
load_info = "<div style='color:#999' align='center'>"
load_info += f"{total_num} images in this directory. Loaded {num} images during {date} - {date_to_bak}, divided into {int((num + 1) // opts.images_history_num_per_page + 1)} pages"
load_info += "</div>"
_, image_list, _, _, visible_num = get_recent_images(1, 0, filenames)
return (
date_to,
load_info,
filenames,
1,
image_list,
"",
"",
visible_num,
last_date_from,
gradio.update(visible=total_num > num)
)
def delete_image(delete_num, tabname, dir_name, name, page_index, filenames, image_index):
def delete_image(delete_num, name, filenames, image_index, visible_num):
if name == "":
return filenames, delete_num
else:
delete_num = int(delete_num)
visible_num = int(visible_num)
image_index = int(image_index)
index = list(filenames).index(name)
i = 0
new_file_list = []
for name in filenames:
if i >= index and i < index + delete_num:
path = os.path.join(dir_name, name)
if os.path.exists(path):
print(f"Delete file {path}")
os.remove(path)
txt_file = os.path.splitext(path)[0] + ".txt"
if os.path.exists(name):
if visible_num == image_index:
new_file_list.append(name)
i += 1
continue
print(f"Delete file {name}")
os.remove(name)
visible_num -= 1
txt_file = os.path.splitext(name)[0] + ".txt"
if os.path.exists(txt_file):
os.remove(txt_file)
else:
print(f"Not exists file {path}")
print(f"Not exists file {name}")
else:
new_file_list.append(name)
i += 1
return new_file_list, 1
return new_file_list, 1, visible_num
def save_image(file_name):
if file_name is not None and os.path.exists(file_name):
shutil.copy(file_name, opts.outdir_save)
def get_recent_images(page_index, step, filenames):
page_index = int(page_index)
num_of_imgs_per_page = int(opts.images_history_num_per_page)
max_page_index = len(filenames) // num_of_imgs_per_page + 1
page_index = max_page_index if page_index == -1 else page_index + step
page_index = 1 if page_index < 1 else page_index
page_index = max_page_index if page_index > max_page_index else page_index
idx_frm = (page_index - 1) * num_of_imgs_per_page
image_list = filenames[idx_frm:idx_frm + num_of_imgs_per_page]
length = len(filenames)
visible_num = num_of_imgs_per_page if idx_frm + num_of_imgs_per_page <= length else length % num_of_imgs_per_page
visible_num = num_of_imgs_per_page if visible_num == 0 else visible_num
return page_index, image_list, "", "", visible_num
def loac_batch_click(date_to):
if date_to is None:
return time.strftime("%Y%m%d",time.localtime(time.time())), []
else:
return None, []
def forward_click(last_date_from, date_to_recorder):
if len(date_to_recorder) == 0:
return None, []
if last_date_from == date_to_recorder[-1]:
date_to_recorder = date_to_recorder[:-1]
if len(date_to_recorder) == 0:
return None, []
return date_to_recorder[-1], date_to_recorder[:-1]
def backward_click(last_date_from, date_to_recorder):
if last_date_from is None or last_date_from == "":
return time.strftime("%Y%m%d",time.localtime(time.time())), []
if len(date_to_recorder) == 0 or last_date_from != date_to_recorder[-1]:
date_to_recorder.append(last_date_from)
return last_date_from, date_to_recorder
def first_page_click(page_index, filenames):
return get_recent_images(1, 0, filenames)
def end_page_click(page_index, filenames):
return get_recent_images(-1, 0, filenames)
def prev_page_click(page_index, filenames):
return get_recent_images(page_index, -1, filenames)
def next_page_click(page_index, filenames):
return get_recent_images(page_index, 1, filenames)
def page_index_change(page_index, filenames):
return get_recent_images(page_index, 0, filenames)
def show_image_info(tabname_box, num, page_index, filenames):
file = filenames[int(num) + int((page_index - 1) * int(opts.images_history_num_per_page))]
tm = "<div style='color:#999' align='right'>" + time.strftime("%Y-%m-%d %H:%M:%S",time.localtime(os.path.getmtime(file))) + "</div>"
return file, tm, num, file
def enable_page_buttons():
return gradio.update(visible=True)
def change_dir(img_dir, date_to):
warning = None
try:
if os.path.exists(img_dir):
try:
f = os.listdir(img_dir)
except:
warning = f"'{img_dir} is not a directory"
else:
warning = "The directory is not exist"
except:
warning = "The format of the directory is incorrect"
if warning is None:
today = time.strftime("%Y%m%d",time.localtime(time.time()))
return gradio.update(visible=False), gradio.update(visible=True), None, None if date_to != today else today, gradio.update(visible=True), gradio.update(visible=True)
else:
return gradio.update(visible=True), gradio.update(visible=False), warning, date_to, gradio.update(visible=False), gradio.update(visible=False)
def show_images_history(gr, opts, tabname, run_pnginfo, switch_dict):
if opts.outdir_samples != "":
dir_name = opts.outdir_samples
elif tabname == "txt2img":
custom_dir = False
if tabname == "txt2img":
dir_name = opts.outdir_txt2img_samples
elif tabname == "img2img":
dir_name = opts.outdir_img2img_samples
elif tabname == "extras":
dir_name = opts.outdir_extras_samples
elif tabname == faverate_tab_name:
dir_name = opts.outdir_save
else:
return
with gr.Row():
renew_page = gr.Button('Renew Page', elem_id=tabname + "_images_history_renew_page")
first_page = gr.Button('First Page')
prev_page = gr.Button('Prev Page')
page_index = gr.Number(value=1, label="Page Index")
next_page = gr.Button('Next Page')
end_page = gr.Button('End Page')
with gr.Row(elem_id=tabname + "_images_history"):
with gr.Row():
with gr.Column(scale=2):
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=6)
with gr.Row():
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
with gr.Column():
with gr.Row():
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
pnginfo_send_to_img2img = gr.Button('Send to img2img')
with gr.Row():
with gr.Column():
img_file_info = gr.Textbox(label="Generate Info", interactive=False)
img_file_name = gr.Textbox(label="File Name", interactive=False)
with gr.Row():
custom_dir = True
dir_name = None
if not custom_dir:
d = dir_name.split("/")
dir_name = d[0]
for p in d[1:]:
dir_name = os.path.join(dir_name, p)
if not os.path.exists(dir_name):
os.makedirs(dir_name)
with gr.Column() as page_panel:
with gr.Row():
with gr.Column(scale=1, visible=not custom_dir) as load_batch_box:
load_batch = gr.Button('Load', elem_id=tabname + "_images_history_start", full_width=True)
with gr.Column(scale=4):
with gr.Row():
img_path = gr.Textbox(dir_name, label="Images directory", placeholder="Input images directory", interactive=custom_dir)
with gr.Row():
with gr.Column(visible=False, scale=1) as batch_panel:
with gr.Row():
forward = gr.Button('Prev batch')
backward = gr.Button('Next batch')
with gr.Column(scale=3):
load_info = gr.HTML(visible=not custom_dir)
with gr.Row(visible=False) as warning:
warning_box = gr.Textbox("Message", interactive=False)
with gr.Row(visible=not custom_dir, elem_id=tabname + "_images_history") as main_panel:
with gr.Column(scale=2):
with gr.Row(visible=True) as turn_page_buttons:
#date_to = gr.Dropdown(label="Date to")
first_page = gr.Button('First Page')
prev_page = gr.Button('Prev Page')
page_index = gr.Number(value=1, label="Page Index")
next_page = gr.Button('Next Page')
end_page = gr.Button('End Page')
history_gallery = gr.Gallery(show_label=False, elem_id=tabname + "_images_history_gallery").style(grid=opts.images_history_grid_num)
with gr.Row():
delete_num = gr.Number(value=1, interactive=True, label="number of images to delete consecutively next")
delete = gr.Button('Delete', elem_id=tabname + "_images_history_del_button")
with gr.Column():
with gr.Row():
with gr.Column():
img_file_info = gr.Textbox(label="Generate Info", interactive=False, lines=6)
gr.HTML("<hr>")
img_file_name = gr.Textbox(value="", label="File Name", interactive=False)
img_file_time= gr.HTML()
with gr.Row():
if tabname != faverate_tab_name:
save_btn = gr.Button('Collect')
pnginfo_send_to_txt2img = gr.Button('Send to txt2img')
pnginfo_send_to_img2img = gr.Button('Send to img2img')
# hiden items
with gr.Row(visible=False):
renew_page = gr.Button('Refresh page', elem_id=tabname + "_images_history_renew_page")
batch_date_to = gr.Textbox(label="Date to")
visible_img_num = gr.Number()
date_to_recorder = gr.State([])
last_date_from = gr.Textbox()
tabname_box = gr.Textbox(tabname)
image_index = gr.Textbox(value=-1)
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index")
filenames = gr.State()
all_images_list = gr.State()
hidden = gr.Image(type="pil")
info1 = gr.Textbox()
info2 = gr.Textbox()
img_path = gr.Textbox(dir_name.rstrip("/"), visible=False)
tabname_box = gr.Textbox(tabname, visible=False)
image_index = gr.Textbox(value=-1, visible=False)
set_index = gr.Button('set_index', elem_id=tabname + "_images_history_set_index", visible=False)
filenames = gr.State()
hidden = gr.Image(type="pil", visible=False)
info1 = gr.Textbox(visible=False)
info2 = gr.Textbox(visible=False)
# turn pages
gallery_inputs = [img_path, page_index, image_index, tabname_box]
gallery_outputs = [history_gallery, page_index, filenames, img_file_name, hidden, img_file_name]
first_page.click(first_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
next_page.click(next_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
prev_page.click(prev_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
end_page.click(end_page_click, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
page_index.submit(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
renew_page.click(page_index_change, _js="images_history_turnpage", inputs=gallery_inputs, outputs=gallery_outputs)
# page_index.change(page_index_change, inputs=[tabname_box, img_path, page_index], outputs=[history_gallery, page_index])
img_path.submit(change_dir, inputs=[img_path, batch_date_to], outputs=[warning, main_panel, warning_box, batch_date_to, load_batch_box, load_info])
#change batch
change_date_output = [batch_date_to, load_info, filenames, page_index, history_gallery, img_file_name, img_file_time, visible_img_num, last_date_from, batch_panel]
batch_date_to.change(archive_images, inputs=[img_path, batch_date_to], outputs=change_date_output)
batch_date_to.change(enable_page_buttons, inputs=None, outputs=[turn_page_buttons])
batch_date_to.change(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
load_batch.click(loac_batch_click, inputs=[batch_date_to], outputs=[batch_date_to, date_to_recorder])
forward.click(forward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
backward.click(backward_click, inputs=[last_date_from, date_to_recorder], outputs=[batch_date_to, date_to_recorder])
#delete
delete.click(delete_image, inputs=[delete_num, img_file_name, filenames, image_index, visible_img_num], outputs=[filenames, delete_num, visible_img_num])
delete.click(fn=None, _js="images_history_delete", inputs=[delete_num, tabname_box, image_index], outputs=None)
if tabname != faverate_tab_name:
save_btn.click(save_image, inputs=[img_file_name], outputs=None)
#turn page
gallery_inputs = [page_index, filenames]
gallery_outputs = [page_index, history_gallery, img_file_name, img_file_time, visible_img_num]
first_page.click(first_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
next_page.click(next_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
prev_page.click(prev_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
end_page.click(end_page_click, inputs=gallery_inputs, outputs=gallery_outputs)
page_index.submit(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
renew_page.click(page_index_change, inputs=gallery_inputs, outputs=gallery_outputs)
first_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
next_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
prev_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
end_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
page_index.submit(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
renew_page.click(fn=None, inputs=[tabname_box], outputs=None, _js="images_history_turnpage")
# other funcitons
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, img_path, filenames], outputs=[img_file_name, image_index, hidden])
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
delete.click(delete_image, _js="images_history_delete", inputs=[delete_num, tabname_box, img_path, img_file_name, page_index, filenames, image_index], outputs=[filenames, delete_num])
set_index.click(show_image_info, _js="images_history_get_current_img", inputs=[tabname_box, image_index, page_index, filenames], outputs=[img_file_name, img_file_time, image_index, hidden])
img_file_name.change(fn=None, _js="images_history_enable_del_buttons", inputs=None, outputs=None)
hidden.change(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
# pnginfo.click(fn=run_pnginfo, inputs=[hidden], outputs=[info1, img_file_info, info2])
switch_dict["fn"](pnginfo_send_to_txt2img, switch_dict["t2i"], img_file_info, 'switch_to_txt2img')
switch_dict["fn"](pnginfo_send_to_img2img, switch_dict["i2i"], img_file_info, 'switch_to_img2img_img2img')
def create_history_tabs(gr, opts, run_pnginfo, switch_dict):
def create_history_tabs(gr, sys_opts, cmp_ops, run_pnginfo, switch_dict):
global opts;
opts = sys_opts
loads_files_num = int(opts.images_history_num_per_page)
num_of_imgs_per_page = int(opts.images_history_num_per_page * opts.images_history_pages_num)
if cmp_ops.browse_all_images:
tabs_list.append(custom_tab_name)
with gr.Blocks(analytics_enabled=False) as images_history:
with gr.Tabs() as tabs:
with gr.Tab("txt2img history"):
with gr.Blocks(analytics_enabled=False) as images_history_txt2img:
show_images_history(gr, opts, "txt2img", run_pnginfo, switch_dict)
with gr.Tab("img2img history"):
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
show_images_history(gr, opts, "img2img", run_pnginfo, switch_dict)
with gr.Tab("extras history"):
with gr.Blocks(analytics_enabled=False) as images_history_img2img:
show_images_history(gr, opts, "extras", run_pnginfo, switch_dict)
for tab in tabs_list:
with gr.Tab(tab):
with gr.Blocks(analytics_enabled=False) :
show_images_history(gr, opts, tab, run_pnginfo, switch_dict)
gradio.Checkbox(opts.images_history_preload, elem_id="images_history_preload", visible=False)
gradio.Textbox(",".join(tabs_list), elem_id="images_history_tabnames_list", visible=False)
return images_history
......@@ -109,6 +109,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert,
)
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
......
......@@ -28,9 +28,11 @@ class InterrogateModels:
clip_preprocess = None
categories = None
dtype = None
running_on_cpu = None
def __init__(self, content_dir):
self.categories = []
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir):
for filename in os.listdir(content_dir):
......@@ -53,7 +55,11 @@ class InterrogateModels:
def load_clip_model(self):
import clip
model, preprocess = clip.load(clip_model_name)
if self.running_on_cpu:
model, preprocess = clip.load(clip_model_name, device="cpu")
else:
model, preprocess = clip.load(clip_model_name)
model.eval()
model = model.to(devices.device_interrogate)
......@@ -62,14 +68,14 @@ class InterrogateModels:
def load(self):
if self.blip_model is None:
self.blip_model = self.load_blip_model()
if not shared.cmd_opts.no_half:
if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half:
if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(devices.device_interrogate)
......
import torch
from modules.devices import get_optimal_device
from modules import devices
module_in_gpu = None
cpu = torch.device("cpu")
device = gpu = get_optimal_device()
def send_everything_to_cpu():
......@@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram):
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module.to(gpu)
module.to(devices.device)
module_in_gpu = module
# see below for register_forward_pre_hook;
......@@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
sd_model.to(device)
sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models
......@@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
sd_model.model.to(devices.device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
......
......@@ -12,7 +12,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
......@@ -104,6 +104,12 @@ class StableDiffusionProcessing():
self.seed_resize_from_h = 0
self.seed_resize_from_w = 0
self.scripts = None
self.script_args = None
self.all_prompts = None
self.all_seeds = None
self.all_subseeds = None
def init(self, all_prompts, all_seeds, all_subseeds):
pass
......@@ -304,7 +310,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.filename.split('\\')[-1].split('.')[0]),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
......@@ -318,7 +324,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""
......@@ -350,32 +356,35 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.prompt_styles.apply_styles(p)
if type(p.prompt) == list:
all_prompts = p.prompt
p.all_prompts = p.prompt
else:
all_prompts = p.batch_size * p.n_iter * [p.prompt]
p.all_prompts = p.batch_size * p.n_iter * [p.prompt]
if type(seed) == list:
all_seeds = seed
p.all_seeds = seed
else:
all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(all_prompts))]
p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))]
if type(subseed) == list:
all_subseeds = subseed
p.all_subseeds = subseed
else:
all_subseeds = [int(subseed) + x for x in range(len(all_prompts))]
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
if p.scripts is not None:
p.scripts.run_alwayson_scripts(p)
infotexts = []
output_images = []
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(all_prompts, all_seeds, all_subseeds)
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
if state.job_count == -1:
state.job_count = p.n_iter
......@@ -387,9 +396,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if (len(prompts) == 0):
break
......@@ -490,10 +499,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", all_seeds[0], all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
devices.torch_gc()
return Processed(p, output_images, all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=all_subseeds[0], all_prompts=all_prompts, all_seeds=all_seeds, all_subseeds=all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
......@@ -540,17 +549,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def create_dummy_mask(self, x, width=None, height=None):
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
height = height or self.height
width = width or self.width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
else:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
return image_conditioning
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
......@@ -587,7 +616,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
return samples
......@@ -613,6 +642,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpainting_mask_invert = inpainting_mask_invert
self.mask = None
self.nmask = None
self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
......@@ -714,10 +744,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
if self.image_mask is not None:
conditioning_mask = np.array(self.image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
conditioning_mask = conditioning_mask.to(image.device)
conditioning_image = image * (1.0 - conditioning_mask)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
else:
self.image_conditioning = torch.zeros(
self.init_latent.shape[0], 5, 1, 1,
dtype=self.init_latent.dtype,
device=self.init_latent.device
)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
......
callbacks_model_loaded = []
callbacks_ui_tabs = []
def clear_callbacks():
callbacks_model_loaded.clear()
callbacks_ui_tabs.clear()
def model_loaded_callback(sd_model):
for callback in callbacks_model_loaded:
callback(sd_model)
def ui_tabs_callback():
res = []
for callback in callbacks_ui_tabs:
res += callback() or []
return res
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
callbacks_model_loaded.append(callback)
def on_ui_tabs(callback):
"""register a function to be called when the UI is creating new tabs.
The function must either return a None, which means no new tabs to be added, or a list, where
each element is a tuple:
(gradio_component, title, elem_id)
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
callbacks_ui_tabs.append(callback)
import os
import sys
import traceback
from collections import namedtuple
import modules.ui as ui
import gradio as gr
from modules.processing import StableDiffusionProcessing
from modules import shared
from modules import shared, paths, script_callbacks
AlwaysVisible = object()
class Script:
filename = None
args_from = None
args_to = None
alwayson = False
infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
"""
# The title of the script. This is what will be displayed in the dropdown menu.
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
raise NotImplementedError()
# How the script is displayed in the UI. See https://gradio.app/docs/#components
# for the different UI components you can use and how to create them.
# Most UI components can return a value, such as a boolean for a checkbox.
# The returned values are passed to the run method as parameters.
def ui(self, is_img2img):
"""this function should create gradio UI elements. See https://gradio.app/docs/#components
The return value should be an array of all components that are used in processing.
Values of those returned componenbts will be passed to run() and process() functions.
"""
pass
# Determines when the script should be shown in the dropdown menu via the
# returned value. As an example:
# is_img2img is True if the current tab is img2img, and False if it is txt2img.
# Thus, return is_img2img to only show the script on the img2img tab.
def show(self, is_img2img):
"""
is_img2img is True if this function is called for the img2img interface, and Fasle otherwise
This function should return:
- False if the script should not be shown in UI at all
- True if the script should be shown in UI if it's scelected in the scripts drowpdown
- script.AlwaysVisible if the script should be shown in UI at all times
"""
return True
# This is where the additional processing is implemented. The parameters include
# self, the model object "p" (a StableDiffusionProcessing class, see
# processing.py), and the parameters returned by the ui method.
# Custom functions can be defined here, and additional libraries can be imported
# to be used in processing. The return value should be a Processed object, which is
# what is returned by the process_images method.
def run(self, *args):
def run(self, p, *args):
"""
This function is called if the script has been selected in the script dropdown.
It must do all processing and return the Processed object with results, same as
one returned by processing.process_images.
Usually the processing is done by calling the processing.process_images function.
args contains all values returned by components from ui()
"""
raise NotImplementedError()
# The description method is currently unused.
# To add a description that appears when hovering over the title, amend the "titles"
# dict in script.js to include the script title (returned by title) as a key, and
# your description as the value.
def process(self, p, *args):
"""
This function is called before processing begins for AlwaysVisible scripts.
scripts. You can modify the processing object (p) here, inject hooks, etc.
"""
pass
def describe(self):
"""unused"""
return ""
current_basedir = paths.script_path
def basedir():
"""returns the base directory for the current script. For scripts in the main scripts directory,
this is the main directory (where webui.py resides), and for scripts in extensions directory
(ie extensions/aesthetic/script/aesthetic.py), this is extension's directory (extensions/aesthetic)
"""
return current_basedir
scripts_data = []
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
def load_scripts(basedir):
if not os.path.exists(basedir):
return
def list_scripts(scriptdirname, extension):
scripts_list = []
for filename in sorted(os.listdir(basedir)):
path = os.path.join(basedir, filename)
basedir = os.path.join(paths.script_path, scriptdirname)
if os.path.exists(basedir):
for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
if os.path.splitext(path)[1].lower() != '.py':
continue
extdir = os.path.join(paths.script_path, "extensions")
if os.path.exists(extdir):
for dirname in sorted(os.listdir(extdir)):
dirpath = os.path.join(extdir, dirname)
scriptdirpath = os.path.join(dirpath, scriptdirname)
if not os.path.isdir(scriptdirpath):
continue
for filename in sorted(os.listdir(scriptdirpath)):
scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
return scripts_list
def list_files_with_name(filename):
res = []
dirs = [paths.script_path]
extdir = os.path.join(paths.script_path, "extensions")
if os.path.exists(extdir):
dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
if not os.path.isfile(path):
for dirpath in dirs:
if not os.path.isdir(dirpath):
continue
path = os.path.join(dirpath, filename)
if os.path.isfile(filename):
res.append(path)
return res
def load_scripts():
global current_basedir
scripts_data.clear()
script_callbacks.clear_callbacks()
scripts_list = list_scripts("scripts", ".py")
syspath = sys.path
for scriptfile in sorted(scripts_list):
try:
with open(path, "r", encoding="utf8") as file:
if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir
with open(scriptfile.path, "r", encoding="utf8") as file:
text = file.read()
from types import ModuleType
compiled = compile(text, path, 'exec')
module = ModuleType(filename)
compiled = compile(text, scriptfile.path, 'exec')
module = ModuleType(scriptfile.filename)
exec(compiled, module.__dict__)
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
scripts_data.append((script_class, path))
scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
except Exception:
print(f"Error loading script: {filename}", file=sys.stderr)
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
finally:
sys.path = syspath
current_basedir = paths.script_path
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
......@@ -96,56 +185,80 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
class ScriptRunner:
def __init__(self):
self.scripts = []
self.selectable_scripts = []
self.alwayson_scripts = []
self.titles = []
self.infotext_fields = []
def setup_ui(self, is_img2img):
for script_class, path in scripts_data:
for script_class, path, basedir in scripts_data:
script = script_class()
script.filename = path
if not script.show(is_img2img):
continue
visibility = script.show(is_img2img)
self.scripts.append(script)
if visibility == AlwaysVisible:
self.scripts.append(script)
self.alwayson_scripts.append(script)
script.alwayson = True
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
elif visibility:
self.scripts.append(script)
self.selectable_scripts.append(script)
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs = [dropdown]
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
inputs = [None]
inputs_alwayson = [True]
for script in self.scripts:
def create_script_ui(script, inputs, inputs_alwayson):
script.args_from = len(inputs)
script.args_to = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
if controls is None:
continue
return
for control in controls:
control.custom_script_source = os.path.basename(script.filename)
control.visible = False
if not script.alwayson:
control.visible = False
if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields
inputs += controls
inputs_alwayson += [script.alwayson for _ in controls]
script.args_to = len(inputs)
for script in self.alwayson_scripts:
with gr.Group():
create_script_ui(script, inputs, inputs_alwayson)
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs[0] = dropdown
for script in self.selectable_scripts:
create_script_ui(script, inputs, inputs_alwayson)
def select_script(script_index):
if 0 < script_index <= len(self.scripts):
script = self.scripts[script_index-1]
if 0 < script_index <= len(self.selectable_scripts):
script = self.selectable_scripts[script_index-1]
args_from = script.args_from
args_to = script.args_to
else:
args_from = 0
args_to = 0
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)]
def init_field(title):
if title == 'None':
return
script_index = self.titles.index(title)
script = self.scripts[script_index]
script = self.selectable_scripts[script_index]
for i in range(script.args_from, script.args_to):
inputs[i].visible = True
......@@ -164,7 +277,7 @@ class ScriptRunner:
if script_index == 0:
return None
script = self.scripts[script_index-1]
script = self.selectable_scripts[script_index-1]
if script is None:
return None
......@@ -176,7 +289,16 @@ class ScriptRunner:
return processed
def reload_sources(self):
def run_alwayson_scripts(self, p):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
with open(script.filename, "r", encoding="utf8") as file:
args_from = script.args_from
......@@ -186,9 +308,12 @@ class ScriptRunner:
from types import ModuleType
compiled = compile(text, filename, 'exec')
module = ModuleType(script.filename)
exec(compiled, module.__dict__)
module = cache.get(filename, None)
if module is None:
compiled = compile(text, filename, 'exec')
module = ModuleType(script.filename)
exec(compiled, module.__dict__)
cache[filename] = module
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
......@@ -197,19 +322,22 @@ class ScriptRunner:
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
def reload_script_body_only():
scripts_txt2img.reload_sources()
scripts_img2img.reload_sources()
cache = {}
scripts_txt2img.reload_sources(cache)
scripts_img2img.reload_sources(cache)
def reload_scripts(basedir):
def reload_scripts():
global scripts_txt2img, scripts_img2img
scripts_data.clear()
load_scripts(basedir)
load_scripts()
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
......@@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
def apply_optimizations():
undo_optimizations()
......@@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
......@@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
......@@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
......@@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text):
use_old = opts.use_old_emphasis_implementation
if use_old:
......@@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old:
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None
i = 0
while max(map(len, remade_batch_tokens)) != 0:
......@@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if fix[0] == i:
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
......@@ -333,19 +333,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers
i += 1
return z
def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
......@@ -385,8 +384,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor)
......
import torch
from einops import repeat
from omegaconf import ListConfig
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
# =================================================================================================
# Monkey patch DDIMSampler methods from RunwayML repo directly.
# Adapted from:
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
# =================================================================================================
@torch.no_grad()
def sample_ddim(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
# =================================================================================================
# Monkey patch PLMSSampler methods.
# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
# Adapted from:
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
# =================================================================================================
@torch.no_grad()
def sample_plms(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
# =================================================================================================
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
# Adapted from:
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
# =================================================================================================
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
def should_hijack_inpainting(checkpoint_info):
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
def do_inpainting_hijack():
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
\ No newline at end of file
......@@ -7,8 +7,9 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices
from modules import shared, modelloader, devices, script_callbacks
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
......@@ -20,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict()
try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
from transformers import logging, CLIPModel
logging.set_verbosity_error()
except Exception:
......@@ -154,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
......@@ -185,7 +189,7 @@ def load_model_weights(model, checkpoint_info):
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
......@@ -203,14 +207,26 @@ def load_model_weights(model, checkpoint_info):
model.sd_checkpoint_info = checkpoint_info
def load_model():
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = select_checkpoint()
checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.use_ema = False
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.unet_config.params.in_channels = 9
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
......@@ -222,6 +238,9 @@ def load_model():
sd_hijack.model_hijack.hijack(sd_model)
sd_model.eval()
shared.sd_model = sd_model
script_callbacks.model_loaded_callback(sd_model)
print(f"Model loaded.")
return sd_model
......@@ -234,9 +253,9 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
shared.sd_model = load_model()
load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
......
......@@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler:
self.config = None
self.last_latent = None
self.conditioning_key = sd_model.model.conditioning_key
def number_of_needed_noises(self, p):
return 0
......@@ -136,6 +138,12 @@ class VanillaStableDiffusionSampler:
if self.stop_at is not None and self.step > self.stop_at:
raise InterruptedException
# Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None
if isinstance(cond, dict):
image_conditioning = cond["c_concat"][0]
cond = cond["c_crossattn"][0]
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
......@@ -157,6 +165,12 @@ class VanillaStableDiffusionSampler:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None:
......@@ -182,7 +196,7 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
self.initialize(p)
......@@ -196,20 +210,33 @@ class VanillaStableDiffusionSampler:
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
self.last_latent = x
self.step = 0
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p)
self.init_latent = None
self.last_latent = x
self.step = 0
steps = steps or p.steps
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
# existing code fails with certain step counts, like 9
try:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
......@@ -228,7 +255,7 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale):
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
......@@ -239,28 +266,29 @@ class CFGDenoiser(torch.nn.Module):
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
......@@ -306,6 +334,8 @@ class KDiffusionSampler:
self.config = None
self.last_latent = None
self.conditioning_key = sd_model.model.conditioning_key
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
......@@ -361,7 +391,7 @@ class KDiffusionSampler:
return extra_params_kwargs
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
if p.sampler_noise_scheduler_override:
......@@ -388,12 +418,18 @@ class KDiffusionSampler:
extra_params_kwargs['sigmas'] = sigma_sched
self.model_wrap_cfg.init_latent = x
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
if p.sampler_noise_scheduler_override:
......@@ -414,7 +450,13 @@ class KDiffusionSampler:
else:
extra_params_kwargs['sigmas'] = sigmas
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
......@@ -3,6 +3,7 @@ import datetime
import json
import os
import sys
from collections import OrderedDict
import gradio as gr
import tqdm
......@@ -78,6 +79,8 @@ parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencode
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False)
cmd_opts = parser.parse_args()
restricted_opts = [
......@@ -249,7 +252,7 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
......@@ -315,6 +318,14 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
}))
options_templates.update(options_section(('images-history', "Images Browser"), {
#"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
"images_history_preload": OptionInfo(False, "Preload images at startup"),
"images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"),
"images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "),
"images_history_grid_num": OptionInfo(6, "Number of grids in each row"),
}))
class Options:
data = None
......@@ -387,6 +398,8 @@ sd_upscalers = []
sd_model = None
clip_model = None
progress_print_out = sys.stdout
......
......@@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
self.dataset.append(entry)
assert len(self.dataset) > 1, "No images have been found in the dataset."
assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset))
......@@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
self.shuffle()
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
def create_text(self, filename_text):
text = random.choice(self.lines)
......
......@@ -5,6 +5,7 @@ import zlib
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
from fonts.ttf import Roboto
import torch
from modules.shared import opts
class EmbeddingEncoder(json.JSONEncoder):
......@@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
from math import cos
image = srcimage.copy()
fontsize = 32
if textfont is None:
try:
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
......@@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image)
fontsize = 32
font = ImageFont.truetype(textfont, fontsize)
padding = 10
......
import os
from PIL import Image, ImageOps
import math
import platform
import sys
import tqdm
......@@ -11,7 +12,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
try:
if process_caption:
shared.interrogator.load()
......@@ -21,7 +22,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
finally:
......@@ -33,11 +34,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
width = process_width
height = process_height
src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst)
split_threshold = max(0.0, min(1.0, split_threshold))
overlap_ratio = max(0.0, min(0.9, overlap_ratio))
assert src != dst, 'same directory specified as source and destination'
......@@ -48,7 +51,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
def save_pic_with_caption(image, index):
def save_pic_with_caption(image, index, existing_caption=None):
caption = ""
if process_caption:
......@@ -66,17 +69,49 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))
if preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
elif preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
elif preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption
caption = caption.strip()
if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
subindex[0] += 1
def save_pic(image, index):
save_pic_with_caption(image, index)
def save_pic(image, index, existing_caption=None):
save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index)
save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
def split_pic(image, inverse_xy):
if inverse_xy:
from_w, from_h = image.height, image.width
to_w, to_h = height, width
else:
from_w, from_h = image.width, image.height
to_w, to_h = width, height
h = from_h * to_w // from_w
if inverse_xy:
image = image.resize((h, to_w))
else:
image = image.resize((to_w, h))
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
y_step = (h - to_h) / (split_count - 1)
for i in range(split_count):
y = int(y_step * i)
if inverse_xy:
splitted = image.crop((y, 0, y + to_h, to_w))
else:
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
......@@ -86,31 +121,27 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
except Exception:
continue
existing_caption = None
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
if os.path.exists(existing_caption_filename):
with open(existing_caption_filename, 'r', encoding="utf8") as file:
existing_caption = file.read()
if shared.state.interrupted:
break
ratio = img.height / img.width
is_tall = ratio > 1.35
is_wide = ratio < 1 / 1.35
if process_split and is_tall:
img = img.resize((width, height * img.height // img.width))
top = img.crop((0, 0, width, height))
save_pic(top, index)
bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height))
left = img.crop((0, 0, width, height))
save_pic(left, index)
if img.height > img.width:
ratio = (img.width * height) / (img.height * width)
inverse_xy = False
else:
ratio = (img.height * width) / (img.width * height)
inverse_xy = True
right = img.crop((img.width - width, 0, img.width, height))
save_pic(right, index)
if process_split and ratio < 1.0 and ratio <= split_threshold:
for splitted in split_pic(img, inverse_xy):
save_pic(splitted, index, existing_caption=existing_caption)
else:
img = images.resize_image(1, img, width, height)
save_pic(img, index)
save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob()
......@@ -153,7 +153,7 @@ class EmbeddingDatabase:
return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'):
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
......@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
......@@ -275,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward()
optimizer.step()
epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
......
......@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
def create_embedding(name, initialization_text, nvpt, overwrite_old):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
......
import modules.scripts
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
import modules.shared as shared
import modules.processing as processing
......@@ -35,6 +36,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None,
)
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
......@@ -53,4 +57,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info)
......@@ -22,9 +22,14 @@ import piexif
import torch
from PIL import Image, PngImagePlugin
from modules import localization, sd_hijack, sd_models
import gradio as gr
import gradio.utils
import gradio.routes
from modules import sd_hijack, sd_models, localization, script_callbacks
from modules.paths import script_path
from modules.shared import cmd_opts, opts, restricted_opts
from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags
......@@ -43,6 +48,11 @@ from modules import prompt_parser
from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
import modules.textual_inversion.ui
import modules.hypernetworks.ui
import modules.images_history as img_his
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI
mimetypes.init()
......@@ -593,27 +603,29 @@ def apply_setting(key, value):
return value
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
setattr(refresh_component, k, v)
for k, v in args.items():
setattr(refresh_component, k, v)
return gr.update(**(args or {}))
return gr.update(**(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
refresh_button.click(
fn=refresh,
inputs=[],
outputs=[refresh_component]
)
return refresh_button
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id)
refresh_button.click(
fn = refresh,
inputs = [],
outputs = [refresh_component]
)
return refresh_button
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
......@@ -711,6 +723,7 @@ def create_ui(wrap_gradio_gpu_call):
firstphase_width,
firstphase_height,
] + custom_inputs,
outputs=[
txt2img_gallery,
generation_info,
......@@ -787,6 +800,7 @@ def create_ui(wrap_gradio_gpu_call):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
(firstphase_width, "First pass size-1"),
(firstphase_height, "First pass size-2"),
*modules.scripts.scripts_txt2img.infotext_fields
]
txt2img_preview_params = [
......@@ -854,8 +868,8 @@ def create_ui(wrap_gradio_gpu_call):
sampler_index = gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index")
with gr.Group():
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512, elem_id="img2img_width")
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512, elem_id="img2img_height")
with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
......@@ -1052,6 +1066,7 @@ def create_ui(wrap_gradio_gpu_call):
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
(denoising_strength, "Denoising strength"),
*modules.scripts.scripts_img2img.infotext_fields
]
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
......@@ -1174,12 +1189,12 @@ def create_ui(wrap_gradio_gpu_call):
)
#images history
images_history_switch_dict = {
"fn":modules.generation_parameters_copypaste.connect_paste,
"t2i":txt2img_paste_fields,
"i2i":img2img_paste_fields
"fn": modules.generation_parameters_copypaste.connect_paste,
"t2i": txt2img_paste_fields,
"i2i": img2img_paste_fields
}
images_history = img_his.create_history_tabs(gr, opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
images_history = img_his.create_history_tabs(gr, opts, cmd_opts, wrap_gradio_call(modules.extras.run_pnginfo), images_history_switch_dict)
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
......@@ -1212,6 +1227,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
overwrite_old_embedding = gr.Checkbox(value=False, label="Overwrite Old Embedding")
with gr.Row():
with gr.Column(scale=3):
......@@ -1227,6 +1243,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row():
with gr.Column(scale=3):
......@@ -1240,13 +1258,18 @@ def create_ui(wrap_gradio_gpu_call):
process_dst = gr.Textbox(label='Destination directory')
process_width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
process_height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
preprocess_txt_action = gr.Dropdown(label='Existing Caption txt Action', value="ignore", choices=["ignore", "copy", "prepend", "append"])
with gr.Row():
process_flip = gr.Checkbox(label='Create flipped copies')
process_split = gr.Checkbox(label='Split oversized images into two')
process_split = gr.Checkbox(label='Split oversized images')
process_caption = gr.Checkbox(label='Use BLIP for caption')
process_caption_deepbooru = gr.Checkbox(label='Use deepbooru for caption', visible=True if cmd_opts.deepdanbooru else False)
with gr.Row(visible=False) as process_split_extra_row:
process_split_threshold = gr.Slider(label='Split image threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05)
process_overlap_ratio = gr.Slider(label='Split image overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05)
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
......@@ -1254,15 +1277,24 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Column():
run_preprocess = gr.Button(value="Preprocess", variant='primary')
process_split.change(
fn=lambda show: gr_show(show),
inputs=[process_split],
outputs=[process_split_extra_row],
)
with gr.Tab(label="Train"):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 1:1 ratio images</p>")
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding or Hypernetwork; you must specify a directory with a set of 1:1 ratio images <a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\" style=\"font-weight:bold;\">[wiki]</a></p>")
with gr.Row():
train_embedding_name = gr.Dropdown(label='Embedding', elem_id="train_embedding", choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
create_refresh_button(train_embedding_name, sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings, lambda: {"choices": sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())}, "refresh_train_embedding_name")
with gr.Row():
train_hypernetwork_name = gr.Dropdown(label='Hypernetwork', elem_id="train_hypernetwork", choices=[x for x in shared.hypernetworks.keys()])
create_refresh_button(train_hypernetwork_name, shared.reload_hypernetworks, lambda: {"choices": sorted([x for x in shared.hypernetworks.keys()])}, "refresh_train_hypernetwork_name")
learn_rate = gr.Textbox(label='Learning rate', placeholder="Learning rate", value="0.005")
with gr.Row():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
......@@ -1296,6 +1328,7 @@ def create_ui(wrap_gradio_gpu_call):
new_embedding_name,
initialization_text,
nvpt,
overwrite_old_embedding,
],
outputs=[
train_embedding_name,
......@@ -1309,6 +1342,7 @@ def create_ui(wrap_gradio_gpu_call):
inputs=[
new_hypernetwork_name,
new_hypernetwork_sizes,
overwrite_old_hypernetwork,
new_hypernetwork_layer_structure,
new_hypernetwork_activation_func,
new_hypernetwork_add_layer_norm,
......@@ -1329,10 +1363,13 @@ def create_ui(wrap_gradio_gpu_call):
process_dst,
process_width,
process_height,
preprocess_txt_action,
process_flip,
process_split,
process_caption,
process_caption_deepbooru
process_caption_deepbooru,
process_split_threshold,
process_overlap_ratio,
],
outputs=[
ti_output,
......@@ -1345,7 +1382,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
learn_rate,
embedding_learn_rate,
batch_size,
dataset_directory,
log_directory,
......@@ -1370,7 +1407,7 @@ def create_ui(wrap_gradio_gpu_call):
_js="start_training_textual_inversion",
inputs=[
train_hypernetwork_name,
learn_rate,
hypernetwork_learn_rate,
batch_size,
dataset_directory,
log_directory,
......@@ -1491,10 +1528,10 @@ Requested path was: {f}
if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson()
oldval = opts.data.get(key, None)
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
return gr.update(value=oldval), opts.dumpjson()
oldval = opts.data.get(key, None)
opts.data[key] = value
if oldval != value:
......@@ -1600,19 +1637,24 @@ Requested path was: {f}
(img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(images_history, "History", "images_history"),
(images_history, "Image Browser", "images_history"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(train_interface, "Train", "ti"),
(settings_interface, "Settings", "settings"),
]
with open(os.path.join(script_path, "style.css"), "r", encoding="utf8") as file:
css = file.read()
interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")]
css = ""
for cssfile in modules.scripts.list_files_with_name("style.css"):
with open(cssfile, "r", encoding="utf8") as file:
css += file.read() + "\n"
if os.path.exists(os.path.join(script_path, "user.css")):
with open(os.path.join(script_path, "user.css"), "r", encoding="utf8") as file:
usercss = file.read()
css += usercss
css += file.read() + "\n"
if not cmd_opts.no_progressbar_hiding:
css += css_hide_progressbar
......@@ -1835,9 +1877,9 @@ def load_javascript(raw_response):
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>'
jsdir = os.path.join(script_path, "javascript")
for filename in sorted(os.listdir(jsdir)):
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
scripts_list = modules.scripts.list_scripts("javascript", ".js")
for basedir, filename, path in scripts_list:
with open(path, "r", encoding="utf8") as jsfile:
javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
if cmd_opts.theme is not None:
......@@ -1855,6 +1897,5 @@ def load_javascript(raw_response):
gradio.routes.templates.TemplateResponse = template_response
reload_javascript = partial(load_javascript,
gradio.routes.templates.TemplateResponse)
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
reload_javascript()
......@@ -172,54 +172,54 @@ class Script(scripts.Script):
if down > 0:
down = target_h - init_img.height - up
init_image = p.init_images[0]
state.job_count = (1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)
def expand(init, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
is_horiz = is_left or is_right
is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0
res_w = init.width + pixels_horiz
res_h = init.height + pixels_vert
process_res_w = math.ceil(res_w / 64) * 64
process_res_h = math.ceil(res_h / 64) * 64
img = Image.new("RGB", (process_res_w, process_res_h))
img.paste(init, (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
mask = Image.new("RGB", (process_res_w, process_res_h), "white")
draw = ImageDraw.Draw(mask)
draw.rectangle((
expand_pixels + mask_blur if is_left else 0,
expand_pixels + mask_blur if is_top else 0,
mask.width - expand_pixels - mask_blur if is_right else res_w,
mask.height - expand_pixels - mask_blur if is_bottom else res_h,
), fill="black")
np_image = (np.asarray(img) / 255.0).astype(np.float64)
np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
target_width = min(process_width, init.width + pixels_horiz) if is_horiz else img.width
target_height = min(process_height, init.height + pixels_vert) if is_vert else img.height
crop_region = (
0 if is_left else out.width - target_width,
0 if is_top else out.height - target_height,
target_width if is_left else out.width,
target_height if is_top else out.height,
)
image_to_process = out.crop(crop_region)
mask = mask.crop(crop_region)
p.width = target_width if is_horiz else img.width
p.height = target_height if is_vert else img.height
p.init_images = [image_to_process]
p.image_mask = mask
images_to_process = []
output_images = []
for n in range(count):
res_w = init[n].width + pixels_horiz
res_h = init[n].height + pixels_vert
process_res_w = math.ceil(res_w / 64) * 64
process_res_h = math.ceil(res_h / 64) * 64
img = Image.new("RGB", (process_res_w, process_res_h))
img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
mask = Image.new("RGB", (process_res_w, process_res_h), "white")
draw = ImageDraw.Draw(mask)
draw.rectangle((
expand_pixels + mask_blur if is_left else 0,
expand_pixels + mask_blur if is_top else 0,
mask.width - expand_pixels - mask_blur if is_right else res_w,
mask.height - expand_pixels - mask_blur if is_bottom else res_h,
), fill="black")
np_image = (np.asarray(img) / 255.0).astype(np.float64)
np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
p.width = target_width if is_horiz else img.width
p.height = target_height if is_vert else img.height
crop_region = (
0 if is_left else output_images[n].width - target_width,
0 if is_top else output_images[n].height - target_height,
target_width if is_left else output_images[n].width,
target_height if is_top else output_images[n].height,
)
mask = mask.crop(crop_region)
p.image_mask = mask
image_to_process = output_images[n].crop(crop_region)
images_to_process.append(image_to_process)
p.init_images = images_to_process
latent_mask = Image.new("RGB", (p.width, p.height), "white")
draw = ImageDraw.Draw(latent_mask)
......@@ -232,31 +232,52 @@ class Script(scripts.Script):
p.latent_mask = latent_mask
proc = process_images(p)
proc_img = proc.images[0]
if initial_seed_and_info[0] is None:
initial_seed_and_info[0] = proc.seed
initial_seed_and_info[1] = proc.info
out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height))
out = out.crop((0, 0, res_w, res_h))
return out
for n in range(count):
output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
img = init_image
return output_images
if left > 0:
img = expand(img, left, is_left=True)
if right > 0:
img = expand(img, right, is_right=True)
if up > 0:
img = expand(img, up, is_top=True)
if down > 0:
img = expand(img, down, is_bottom=True)
batch_count = p.n_iter
batch_size = p.batch_size
p.n_iter = 1
state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
all_processed_images = []
for i in range(batch_count):
imgs = [init_img] * batch_size
state.job = f"Batch {i + 1} out of {batch_count}"
if left > 0:
imgs = expand(imgs, batch_size, left, is_left=True)
if right > 0:
imgs = expand(imgs, batch_size, right, is_right=True)
if up > 0:
imgs = expand(imgs, batch_size, up, is_top=True)
if down > 0:
imgs = expand(imgs, batch_size, down, is_bottom=True)
res = Processed(p, [img], initial_seed_and_info[0], initial_seed_and_info[1])
all_processed_images += imgs
all_images = all_processed_images
combined_grid_image = images.image_grid(all_processed_images)
unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
if opts.return_grid and not unwanted_grid_because_of_img_count:
all_images = [combined_grid_image] + all_processed_images
res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
if opts.samples_save:
images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
for img in all_processed_images:
images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
return res
if opts.grid_save and not unwanted_grid_because_of_img_count:
images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
return res
......@@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs):
if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info)
p.sd_model = shared.sd_model
def confirm_checkpoints(p, xs):
......
......@@ -71,6 +71,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
def initialize():
modelloader.cleanup_models()
modules.sd_models.setup_model()
......@@ -79,9 +80,9 @@ def initialize():
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers()
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
modules.scripts.load_scripts()
shared.sd_model = modules.sd_models.load_model()
modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
......@@ -118,7 +119,8 @@ def api_only():
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def webui(launch_api=False):
def webui():
launch_api = cmd_opts.api
initialize()
while 1:
......@@ -144,7 +146,7 @@ def webui(launch_api=False):
sd_samplers.set_samplers()
print('Reloading Custom Scripts')
modules.scripts.reload_scripts(os.path.join(script_path, "scripts"))
modules.scripts.reload_scripts()
print('Reloading modules: modules.ui')
importlib.reload(modules.ui)
print('Refreshing Model List')
......@@ -158,4 +160,4 @@ if __name__ == "__main__":
if cmd_opts.nowebui:
api_only()
else:
webui(cmd_opts.api)
webui()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment