Commit 6a9ea40d authored by yfszzx's avatar yfszzx

Move browser and Inspiration into extension

parents 67b78f0e 1ef32c8b
...@@ -45,6 +45,8 @@ body: ...@@ -45,6 +45,8 @@ body:
attributes: attributes:
label: Commit where the problem happens label: Commit where the problem happens
description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI) description: Which commit are you running ? (copy the **Commit hash** shown in the cmd/terminal when you launch the UI)
validations:
required: true
- type: dropdown - type: dropdown
id: platforms id: platforms
attributes: attributes:
......
blank_issues_enabled: false
contact_links:
- name: WebUI Community Support
url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions
about: Please ask and answer questions here.
...@@ -28,4 +28,6 @@ notification.mp3 ...@@ -28,4 +28,6 @@ notification.mp3
/SwinIR /SwinIR
/textual_inversion /textual_inversion
.vscode .vscode
/inspiration /extensions
\ No newline at end of file /inspiration
...@@ -11,6 +11,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -11,6 +11,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- One click install and run script (but you still must install python and git) - One click install and run script (but you still must install python and git)
- Outpainting - Outpainting
- Inpainting - Inpainting
- Color Sketch
- Prompt Matrix - Prompt Matrix
- Stable Diffusion Upscale - Stable Diffusion Upscale
- Attention, specify parts of text that the model should pay more attention to - Attention, specify parts of text that the model should pay more attention to
...@@ -23,6 +24,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -23,6 +24,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- have as many embeddings as you want and use any names you like for them - have as many embeddings as you want and use any names you like for them
- use multiple embeddings with different numbers of vectors per token - use multiple embeddings with different numbers of vectors per token
- works with half precision floating point numbers - works with half precision floating point numbers
- train embeddings on 8GB (also reports of 6GB working)
- Extras tab with: - Extras tab with:
- GFPGAN, neural network that fixes faces - GFPGAN, neural network that fixes faces
- CodeFormer, face restoration tool as an alternative to GFPGAN - CodeFormer, face restoration tool as an alternative to GFPGAN
...@@ -37,14 +39,14 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -37,14 +39,14 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- Interrupt processing at any time - Interrupt processing at any time
- 4GB video card support (also reports of 2GB working) - 4GB video card support (also reports of 2GB working)
- Correct seeds for batches - Correct seeds for batches
- Prompt length validation - Live prompt token length validation
- get length of prompt in tokens as you type
- get a warning after generation if some text was truncated
- Generation parameters - Generation parameters
- parameters you used to generate images are saved with that image - parameters you used to generate images are saved with that image
- in PNG chunks for PNG, in EXIF for JPEG - in PNG chunks for PNG, in EXIF for JPEG
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI - can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
- can be disabled in settings - can be disabled in settings
- drag and drop an image/text-parameters to promptbox
- Read Generation Parameters Button, loads parameters in promptbox to UI
- Settings page - Settings page
- Running arbitrary python code from UI (must run with --allow-code to enable) - Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements - Mouseover hints for most UI elements
...@@ -59,10 +61,10 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -59,10 +61,10 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- CLIP interrogator, a button that tries to guess prompt from an image - CLIP interrogator, a button that tries to guess prompt from an image
- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway - Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
- Batch Processing, process a group of files using img2img - Batch Processing, process a group of files using img2img
- Img2img Alternative - Img2img Alternative, reverse Euler method of cross attention control
- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions - Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
- Reloading checkpoints on the fly - Reloading checkpoints on the fly
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one - Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community - [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once - [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
- separate prompts using uppercase `AND` - separate prompts using uppercase `AND`
...@@ -70,14 +72,35 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -70,14 +72,35 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args) - DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args) - [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
- History tab: view, direct and delete images conveniently within the UI
- Generate forever option
- Training tab
- hypernetworks and embeddings options
- Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime)
- Clip skip
- Use Hypernetworks
- Use VAEs
- Estimated completion time in progress bar
- API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
## Where are Aesthetic Gradients?!?!
Aesthetic Gradients are now an extension. You can install it using git:
```commandline
git clone https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients extensions/aesthetic-gradients
```
After running this command, make sure that you have `aesthetic-gradients` dir in webui's `extensions` directory and restart
the UI. The interface for Aesthetic Gradients should appear exactly the same as it was.
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
Alternatively, use Google Colab: Alternatively, use online services (like Google Colab):
- [Colab, maintained by Akaibu](https://colab.research.google.com/drive/1kw3egmSn-KgWsikYvOMjJkVDsPLjEMzl) - [List of Online Services](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Online-Services)
- [Colab, original by me, outdated](https://colab.research.google.com/drive/1Iy-xW9t1-OQWhb0hNxueGij8phCyluOh).
### Automatic Installation on Windows ### Automatic Installation on Windows
1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH" 1. Install [Python 3.10.6](https://www.python.org/downloads/windows/), checking "Add Python to PATH"
......
...@@ -3,12 +3,12 @@ let currentWidth = null; ...@@ -3,12 +3,12 @@ let currentWidth = null;
let currentHeight = null; let currentHeight = null;
let arFrameTimeout = setTimeout(function(){},0); let arFrameTimeout = setTimeout(function(){},0);
function dimensionChange(e,dimname){ function dimensionChange(e, is_width, is_height){
if(dimname == 'Width'){ if(is_width){
currentWidth = e.target.value*1.0 currentWidth = e.target.value*1.0
} }
if(dimname == 'Height'){ if(is_height){
currentHeight = e.target.value*1.0 currentHeight = e.target.value*1.0
} }
...@@ -18,22 +18,13 @@ function dimensionChange(e,dimname){ ...@@ -18,22 +18,13 @@ function dimensionChange(e,dimname){
return; return;
} }
var img2imgMode = gradioApp().querySelector('#mode_img2img.tabs > div > button.rounded-t-lg.border-gray-200')
if(img2imgMode){
img2imgMode=img2imgMode.innerText
}else{
return;
}
var redrawImage = gradioApp().querySelector('div[data-testid=image] img');
var inpaintImage = gradioApp().querySelector('#img2maskimg div[data-testid=image] img')
var targetElement = null; var targetElement = null;
if(img2imgMode=='img2img' && redrawImage){ var tabIndex = get_tab_index('mode_img2img')
targetElement = redrawImage; if(tabIndex == 0){
}else if(img2imgMode=='Inpaint' && inpaintImage){ targetElement = gradioApp().querySelector('div[data-testid=image] img');
targetElement = inpaintImage; } else if(tabIndex == 1){
targetElement = gradioApp().querySelector('#img2maskimg div[data-testid=image] img');
} }
if(targetElement){ if(targetElement){
...@@ -98,22 +89,20 @@ onUiUpdate(function(){ ...@@ -98,22 +89,20 @@ onUiUpdate(function(){
var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200")) var inImg2img = Boolean(gradioApp().querySelector("button.rounded-t-lg.border-gray-200"))
if(inImg2img){ if(inImg2img){
let inputs = gradioApp().querySelectorAll('input'); let inputs = gradioApp().querySelectorAll('input');
inputs.forEach(function(e){ inputs.forEach(function(e){
let parentLabel = e.parentElement.querySelector('label') var is_width = e.parentElement.id == "img2img_width"
if(parentLabel && parentLabel.innerText){ var is_height = e.parentElement.id == "img2img_height"
if(!e.classList.contains('scrollwatch')){
if(parentLabel.innerText == 'Width' || parentLabel.innerText == 'Height'){ if((is_width || is_height) && !e.classList.contains('scrollwatch')){
e.addEventListener('input', function(e){dimensionChange(e,parentLabel.innerText)} ) e.addEventListener('input', function(e){dimensionChange(e, is_width, is_height)} )
e.classList.add('scrollwatch') e.classList.add('scrollwatch')
} }
if(parentLabel.innerText == 'Width'){ if(is_width){
currentWidth = e.value*1.0 currentWidth = e.value*1.0
} }
if(parentLabel.innerText == 'Height'){ if(is_height){
currentHeight = e.value*1.0 currentHeight = e.value*1.0
} }
}
}
}) })
} }
}); });
...@@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) { ...@@ -43,7 +43,7 @@ function dropReplaceImage( imgWrap, files ) {
window.document.addEventListener('dragover', e => { window.document.addEventListener('dragover', e => {
const target = e.composedPath()[0]; const target = e.composedPath()[0];
const imgWrap = target.closest('[data-testid="image"]'); const imgWrap = target.closest('[data-testid="image"]');
if ( !imgWrap && target.placeholder.indexOf("Prompt") == -1) { if ( !imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) {
return; return;
} }
e.stopPropagation(); e.stopPropagation();
......
...@@ -17,14 +17,6 @@ var images_history_click_image = function(){ ...@@ -17,14 +17,6 @@ var images_history_click_image = function(){
images_history_set_image_info(this); images_history_set_image_info(this);
} }
var images_history_click_tab = function(){
var tabs_box = gradioApp().getElementById("images_history_tab");
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_renew_page").click();
tabs_box.classList.add(this.getAttribute("tabname"))
}
}
function images_history_disabled_del(){ function images_history_disabled_del(){
gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){ gradioApp().querySelectorAll(".images_history_del_button").forEach(function(btn){
btn.setAttribute('disabled','disabled'); btn.setAttribute('disabled','disabled');
...@@ -43,7 +35,6 @@ function images_history_get_parent_by_tagname(item, tagname){ ...@@ -43,7 +35,6 @@ function images_history_get_parent_by_tagname(item, tagname){
var parent = item.parentElement; var parent = item.parentElement;
tagname = tagname.toUpperCase() tagname = tagname.toUpperCase()
while(parent.tagName != tagname){ while(parent.tagName != tagname){
console.log(parent.tagName, tagname)
parent = parent.parentElement; parent = parent.parentElement;
} }
return parent; return parent;
...@@ -88,15 +79,15 @@ function images_history_set_image_info(button){ ...@@ -88,15 +79,15 @@ function images_history_set_image_info(button){
} }
function images_history_get_current_img(tabname, image_path, files){ function images_history_get_current_img(tabname, img_index, files){
return [ return [
gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"), tabname,
image_path, gradioApp().getElementById(tabname + '_images_history_set_index').getAttribute("img_index"),
files files
]; ];
} }
function images_history_delete(del_num, tabname, img_path, img_file_name, page_index, filenames, image_index){ function images_history_delete(del_num, tabname, image_index){
image_index = parseInt(image_index); image_index = parseInt(image_index);
var tab = gradioApp().getElementById(tabname + '_images_history'); var tab = gradioApp().getElementById(tabname + '_images_history');
var set_btn = tab.querySelector(".images_history_set_index"); var set_btn = tab.querySelector(".images_history_set_index");
...@@ -107,6 +98,7 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i ...@@ -107,6 +98,7 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i
} }
}); });
var img_num = buttons.length / 2; var img_num = buttons.length / 2;
del_num = Math.min(img_num - image_index, del_num)
if (img_num <= del_num){ if (img_num <= del_num){
setTimeout(function(tabname){ setTimeout(function(tabname){
gradioApp().getElementById(tabname + '_images_history_renew_page').click(); gradioApp().getElementById(tabname + '_images_history_renew_page').click();
...@@ -114,30 +106,28 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i ...@@ -114,30 +106,28 @@ function images_history_delete(del_num, tabname, img_path, img_file_name, page_i
} else { } else {
var next_img var next_img
for (var i = 0; i < del_num; i++){ for (var i = 0; i < del_num; i++){
if (image_index + i < image_index + img_num){ buttons[image_index + i].style.display = 'none';
buttons[image_index + i].style.display = 'none'; buttons[image_index + i + img_num].style.display = 'none';
buttons[image_index + img_num + 1].style.display = 'none'; next_img = image_index + i + 1
next_img = image_index + i + 1
}
} }
var bnt; var bnt;
if (next_img >= img_num){ if (next_img >= img_num){
btn = buttons[image_index - del_num]; btn = buttons[image_index - 1];
} else { } else {
btn = buttons[next_img]; btn = buttons[next_img];
} }
setTimeout(function(btn){btn.click()}, 30, btn); setTimeout(function(btn){btn.click()}, 30, btn);
} }
images_history_disabled_del(); images_history_disabled_del();
return [del_num, tabname, img_path, img_file_name, page_index, filenames, image_index];
} }
function images_history_turnpage(img_path, page_index, image_index, tabname){ function images_history_turnpage(tabname){
gradioApp().getElementById(tabname + '_images_history_del_button').setAttribute('disabled','disabled');
var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item"); var buttons = gradioApp().getElementById(tabname + '_images_history').querySelectorAll(".gallery-item");
buttons.forEach(function(elem) { buttons.forEach(function(elem) {
elem.style.display = 'block'; elem.style.display = 'block';
}) })
return [img_path, page_index, image_index, tabname];
} }
function images_history_enable_del_buttons(){ function images_history_enable_del_buttons(){
...@@ -147,60 +137,64 @@ function images_history_enable_del_buttons(){ ...@@ -147,60 +137,64 @@ function images_history_enable_del_buttons(){
} }
function images_history_init(){ function images_history_init(){
var load_txt2img_button = gradioApp().getElementById('txt2img_images_history_renew_page') var tabnames = gradioApp().getElementById("images_history_tabnames_list")
if (load_txt2img_button){ if (tabnames){
images_history_tab_list = tabnames.querySelector("textarea").value.split(",")
for (var i in images_history_tab_list ){ for (var i in images_history_tab_list ){
tab = images_history_tab_list[i]; var tab = images_history_tab_list[i];
gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor"); gradioApp().getElementById(tab + '_images_history').classList.add("images_history_cantainor");
gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index"); gradioApp().getElementById(tab + '_images_history_set_index').classList.add("images_history_set_index");
gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button"); gradioApp().getElementById(tab + '_images_history_del_button').classList.add("images_history_del_button");
gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery"); gradioApp().getElementById(tab + '_images_history_gallery').classList.add("images_history_gallery");
gradioApp().getElementById(tab + "_images_history_start").setAttribute("style","padding:20px;font-size:25px");
}
//preload
if (gradioApp().getElementById("images_history_preload").querySelector("input").checked ){
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
tabs_box.setAttribute("id", "images_history_tab");
var tab_btns = tabs_box.querySelectorAll("button");
for (var i in images_history_tab_list){
var tabname = images_history_tab_list[i]
tab_btns[i].setAttribute("tabname", tabname);
tab_btns[i].addEventListener('click', function(){
var tabs_box = gradioApp().getElementById("images_history_tab");
if (!tabs_box.classList.contains(this.getAttribute("tabname"))) {
gradioApp().getElementById(this.getAttribute("tabname") + "_images_history_start").click();
tabs_box.classList.add(this.getAttribute("tabname"))
}
});
}
tab_btns[0].click()
} }
var tabs_box = gradioApp().getElementById("tab_images_history").querySelector("div").querySelector("div").querySelector("div");
tabs_box.setAttribute("id", "images_history_tab");
var tab_btns = tabs_box.querySelectorAll("button");
for (var i in images_history_tab_list){
var tabname = images_history_tab_list[i]
tab_btns[i].setAttribute("tabname", tabname);
// this refreshes history upon tab switch
// until the history is known to work well, which is not the case now, we do not do this at startup
//tab_btns[i].addEventListener('click', images_history_click_tab);
}
tabs_box.classList.add(images_history_tab_list[0]);
// same as above, at page load
//load_txt2img_button.click();
} else { } else {
setTimeout(images_history_init, 500); setTimeout(images_history_init, 500);
} }
} }
var images_history_tab_list = ["txt2img", "img2img", "extras"]; var images_history_tab_list = "";
setTimeout(images_history_init, 500); setTimeout(images_history_init, 500);
document.addEventListener("DOMContentLoaded", function() { document.addEventListener("DOMContentLoaded", function() {
var mutationObserver = new MutationObserver(function(m){ var mutationObserver = new MutationObserver(function(m){
for (var i in images_history_tab_list ){ if (images_history_tab_list != ""){
let tabname = images_history_tab_list[i] for (var i in images_history_tab_list ){
var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item'); let tabname = images_history_tab_list[i]
buttons.forEach(function(bnt){ var buttons = gradioApp().querySelectorAll('#' + tabname + '_images_history .gallery-item');
bnt.addEventListener('click', images_history_click_image, true); buttons.forEach(function(bnt){
}); bnt.addEventListener('click', images_history_click_image, true);
});
// same as load_txt2img_button.click() above
/* var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg");
var cls_btn = gradioApp().getElementById(tabname + '_images_history_gallery').querySelector("svg"); if (cls_btn){
if (cls_btn){ cls_btn.addEventListener('click', function(){
cls_btn.addEventListener('click', function(){ gradioApp().getElementById(tabname + '_images_history_renew_page').click();
gradioApp().getElementById(tabname + '_images_history_renew_page').click(); }, false);
}, false); }
}*/
}
} }
}); });
mutationObserver.observe( gradioApp(), { childList:true, subtree:true }); mutationObserver.observe(gradioApp(), { childList:true, subtree:true });
}); });
This diff is collapsed.
This diff is collapsed.
from modules.api.processing import StableDiffusionProcessingAPI from modules.api.processing import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo from modules.extras import run_pnginfo
import modules.shared as shared import modules.shared as shared
...@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, Json ...@@ -10,6 +10,7 @@ from pydantic import BaseModel, Field, Json
import json import json
import io import io
import base64 import base64
from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
...@@ -18,6 +19,11 @@ class TextToImageResponse(BaseModel): ...@@ -18,6 +19,11 @@ class TextToImageResponse(BaseModel):
parameters: Json parameters: Json
info: Json info: Json
class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class Api: class Api:
def __init__(self, app, queue_lock): def __init__(self, app, queue_lock):
...@@ -25,8 +31,17 @@ class Api: ...@@ -25,8 +31,17 @@ class Api:
self.app = app self.app = app
self.queue_lock = queue_lock self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix
if "," in base64_string:
base64_string = base64_string.split(",")[1]
imgdata = base64.b64decode(base64_string)
# convert base64 to PIL image
return Image.open(io.BytesIO(imgdata))
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None: if sampler_index is None:
...@@ -54,8 +69,49 @@ class Api: ...@@ -54,8 +69,49 @@ class Api:
def img2imgapi(self): def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
raise NotImplementedError sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
mask = img2imgreq.mask
if mask:
mask = self.__base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
"do_not_save_grid": True,
"mask": mask
}
)
p = StableDiffusionProcessingImg2Img(**vars(populate))
imgs = []
for img in init_images:
img = self.__base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs
# Override object param
with self.queue_lock:
processed = process_images(p)
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=json.dumps(processed.info))
def extrasapi(self): def extrasapi(self):
raise NotImplementedError raise NotImplementedError
......
from array import array
from inflection import underscore from inflection import underscore
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from modules.processing import StableDiffusionProcessingTxt2Img from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
import inspect import inspect
...@@ -92,8 +93,14 @@ class PydanticModelGenerator: ...@@ -92,8 +93,14 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_mutation = True DynamicModel.__config__.allow_mutation = True
return DynamicModel return DynamicModel
StableDiffusionProcessingAPI = PydanticModelGenerator( StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img", "StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img, StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}] [{"key": "sampler_index", "type": str, "default": "Euler"}]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}]
).generate_model() ).generate_model()
\ No newline at end of file
...@@ -50,11 +50,12 @@ def create_deepbooru_process(threshold, deepbooru_opts): ...@@ -50,11 +50,12 @@ def create_deepbooru_process(threshold, deepbooru_opts):
the tags. the tags.
""" """
from modules import shared # prevents circular reference from modules import shared # prevents circular reference
shared.deepbooru_process_manager = multiprocessing.Manager() context = multiprocessing.get_context("spawn")
shared.deepbooru_process_manager = context.Manager()
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1 shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts)) shared.deepbooru_process = context.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start() shared.deepbooru_process.start()
......
import sys, os, shlex
import contextlib import contextlib
import torch import torch
from modules import errors from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
...@@ -9,10 +8,22 @@ has_mps = getattr(torch, 'has_mps', False) ...@@ -9,10 +8,22 @@ has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu") cpu = torch.device("cpu")
def extract_device_id(args, name):
for x in range(len(args)):
if name in args[x]: return args[x+1]
return None
def get_optimal_device(): def get_optimal_device():
if torch.cuda.is_available(): if torch.cuda.is_available():
return torch.device("cuda") from modules import shared
device_id = shared.cmd_opts.device_id
if device_id is not None:
cuda_device = f"cuda:{device_id}"
return torch.device(cuda_device)
else:
return torch.device("cuda")
if has_mps: if has_mps:
return torch.device("mps") return torch.device("mps")
...@@ -34,7 +45,7 @@ def enable_tf32(): ...@@ -34,7 +45,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16 dtype = torch.float16
dtype_vae = torch.float16 dtype_vae = torch.float16
......
...@@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '': if input_dir == '':
return outputs, "Please select an input directory.", '' return outputs, "Please select an input directory.", ''
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
for img in image_list: for img in image_list:
image = Image.open(img) try:
image = Image.open(img)
except Exception:
continue
imageArr.append(image) imageArr.append(image)
imageNameArr.append(img) imageNameArr.append(img)
else: else:
...@@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
while len(cached_images) > 2: while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))] del cached_images[next(iter(cached_images.keys()))]
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True, if opts.use_original_name_batch and image_name != None:
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, basename = os.path.splitext(os.path.basename(image_name))[0]
forced_filename=image_name if opts.use_original_name_batch else None) else:
basename = ''
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo: if opts.enable_pnginfo:
image.info = existing_pnginfo image.info = existing_pnginfo
......
...@@ -4,13 +4,22 @@ import gradio as gr ...@@ -4,13 +4,22 @@ import gradio as gr
from modules.shared import script_path from modules.shared import script_path
from modules import shared from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update()) type_of_gr_update = type(gr.update())
def quote(text):
if ',' not in str(text):
return text
text = str(text)
text = text.replace('\\', '\\\\')
text = text.replace('"', '\\"')
return f'"{text}"'
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI: """parses generation parameters string, the one you see in text field under the picture in UI:
``` ```
...@@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None): ...@@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
else: else:
try: try:
valtype = type(output.value) valtype = type(output.value)
val = valtype(v)
if valtype == bool and v == "False":
val = False
else:
val = valtype(v)
res.append(gr.update(value=val)) res.append(gr.update(value=val))
except Exception: except Exception:
res.append(gr.update()) res.append(gr.update())
......
This diff is collapsed.
...@@ -3,16 +3,19 @@ import os ...@@ -3,16 +3,19 @@ import os
import re import re
import gradio as gr import gradio as gr
import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess import modules.textual_inversion.preprocess
from modules import sd_hijack, shared, devices import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists" if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
if type(layer_structure) == str: if type(layer_structure) == str:
layer_structure = [float(x.strip()) for x in layer_structure.split(",")] layer_structure = [float(x.strip()) for x in layer_structure.split(",")]
...@@ -21,7 +24,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm ...@@ -21,7 +24,9 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
name=name, name=name,
enable_sizes=[int(x) for x in enable_sizes], enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure, layer_structure=layer_structure,
activation_func=activation_func,
add_layer_norm=add_layer_norm, add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
) )
hypernet.save(fn) hypernet.save(fn)
......
This diff is collapsed.
...@@ -109,6 +109,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -109,6 +109,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert, inpainting_mask_invert=inpainting_mask_invert,
) )
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args
if shared.cmd_opts.enable_console_prompts: if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
......
...@@ -28,9 +28,11 @@ class InterrogateModels: ...@@ -28,9 +28,11 @@ class InterrogateModels:
clip_preprocess = None clip_preprocess = None
categories = None categories = None
dtype = None dtype = None
running_on_cpu = None
def __init__(self, content_dir): def __init__(self, content_dir):
self.categories = [] self.categories = []
self.running_on_cpu = devices.device_interrogate == torch.device("cpu")
if os.path.exists(content_dir): if os.path.exists(content_dir):
for filename in os.listdir(content_dir): for filename in os.listdir(content_dir):
...@@ -53,7 +55,11 @@ class InterrogateModels: ...@@ -53,7 +55,11 @@ class InterrogateModels:
def load_clip_model(self): def load_clip_model(self):
import clip import clip
model, preprocess = clip.load(clip_model_name) if self.running_on_cpu:
model, preprocess = clip.load(clip_model_name, device="cpu")
else:
model, preprocess = clip.load(clip_model_name)
model.eval() model.eval()
model = model.to(devices.device_interrogate) model = model.to(devices.device_interrogate)
...@@ -62,14 +68,14 @@ class InterrogateModels: ...@@ -62,14 +68,14 @@ class InterrogateModels:
def load(self): def load(self):
if self.blip_model is None: if self.blip_model is None:
self.blip_model = self.load_blip_model() self.blip_model = self.load_blip_model()
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.blip_model = self.blip_model.half() self.blip_model = self.blip_model.half()
self.blip_model = self.blip_model.to(devices.device_interrogate) self.blip_model = self.blip_model.to(devices.device_interrogate)
if self.clip_model is None: if self.clip_model is None:
self.clip_model, self.clip_preprocess = self.load_clip_model() self.clip_model, self.clip_preprocess = self.load_clip_model()
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half and not self.running_on_cpu:
self.clip_model = self.clip_model.half() self.clip_model = self.clip_model.half()
self.clip_model = self.clip_model.to(devices.device_interrogate) self.clip_model = self.clip_model.to(devices.device_interrogate)
......
import torch import torch
from modules.devices import get_optimal_device from modules import devices
module_in_gpu = None module_in_gpu = None
cpu = torch.device("cpu") cpu = torch.device("cpu")
device = gpu = get_optimal_device()
def send_everything_to_cpu(): def send_everything_to_cpu():
...@@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -33,7 +32,7 @@ def setup_for_low_vram(sd_model, use_medvram):
if module_in_gpu is not None: if module_in_gpu is not None:
module_in_gpu.to(cpu) module_in_gpu.to(cpu)
module.to(gpu) module.to(devices.device)
module_in_gpu = module module_in_gpu = module
# see below for register_forward_pre_hook; # see below for register_forward_pre_hook;
...@@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -51,7 +50,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# send the model to GPU. Then put modules back. the modules will be in CPU. # send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
sd_model.to(device) sd_model.to(devices.device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models # register hooks for those the first two models
...@@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -70,7 +69,7 @@ def setup_for_low_vram(sd_model, use_medvram):
# so that only one of them is in GPU at a time # so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device) sd_model.model.to(devices.device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model # install hooks for bits of third model
......
This diff is collapsed.
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
def clear_callbacks():
callbacks_model_loaded.clear()
callbacks_ui_tabs.clear()
def model_loaded_callback(sd_model):
for callback in callbacks_model_loaded:
callback(sd_model)
def ui_tabs_callback():
res = []
for callback in callbacks_ui_tabs:
res += callback() or []
return res
def ui_settings_callback():
for callback in callbacks_ui_settings:
callback()
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
callbacks_model_loaded.append(callback)
def on_ui_tabs(callback):
"""register a function to be called when the UI is creating new tabs.
The function must either return a None, which means no new tabs to be added, or a list, where
each element is a tuple:
(gradio_component, title, elem_id)
gradio_component is a gradio component to be used for contents of the tab (usually gr.Blocks)
title is tab text displayed to user in the UI
elem_id is HTML id for the tab
"""
callbacks_ui_tabs.append(callback)
def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """
callbacks_ui_settings.append(callback)
This diff is collapsed.
...@@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward ...@@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
def apply_optimizations(): def apply_optimizations():
undo_optimizations() undo_optimizations()
...@@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -167,11 +168,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens = remade_tokens[:last_comma] remade_tokens = remade_tokens[:last_comma]
length = len(remade_tokens) length = len(remade_tokens)
rem = int(math.ceil(length / 75)) * 75 - length rem = int(math.ceil(length / 75)) * 75 - length
remade_tokens += [id_end] * rem + reloc_tokens remade_tokens += [id_end] * rem + reloc_tokens
multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults multipliers = multipliers[:last_comma] + [1.0] * rem + reloc_mults
if embedding is None: if embedding is None:
remade_tokens.append(token) remade_tokens.append(token)
multipliers.append(weight) multipliers.append(weight)
...@@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text): def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
...@@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -280,7 +280,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
token_count = len(remade_tokens) token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers) cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
...@@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -290,7 +290,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
hijack_fixes.append(fixes) hijack_fixes.append(fixes)
batch_multipliers.append(multipliers) batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text): def forward(self, text):
use_old = opts.use_old_emphasis_implementation use_old = opts.use_old_emphasis_implementation
if use_old: if use_old:
...@@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -302,11 +302,11 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if len(used_custom_terms) > 0: if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
if use_old: if use_old:
self.hijack.fixes = hijack_fixes self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers) return self.process_tokens(remade_batch_tokens, batch_multipliers)
z = None z = None
i = 0 i = 0
while max(map(len, remade_batch_tokens)) != 0: while max(map(len, remade_batch_tokens)) != 0:
...@@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -320,7 +320,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if fix[0] == i: if fix[0] == i:
fixes.append(fix[1]) fixes.append(fix[1])
self.hijack.fixes.append(fixes) self.hijack.fixes.append(fixes)
tokens = [] tokens = []
multipliers = [] multipliers = []
for j in range(len(remade_batch_tokens)): for j in range(len(remade_batch_tokens)):
...@@ -333,19 +333,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -333,19 +333,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers) z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2) z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers batch_multipliers = rem_multipliers
i += 1 i += 1
return z return z
def process_tokens(self, remade_batch_tokens, batch_multipliers): def process_tokens(self, remade_batch_tokens, batch_multipliers):
if not opts.use_old_emphasis_implementation: if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens] remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers] batch_multipliers = [[1.0] + x[:75] + [1.0] for x in batch_multipliers]
tokens = torch.asarray(remade_batch_tokens).to(device) tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
...@@ -385,8 +384,8 @@ class EmbeddingsWithFixes(torch.nn.Module): ...@@ -385,8 +384,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
for fixes, tensor in zip(batch_fixes, inputs_embeds): for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes: for offset, embedding in fixes:
emb = embedding.vec emb = embedding.vec
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0]) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
vecs.append(tensor) vecs.append(tensor)
......
This diff is collapsed.
...@@ -7,8 +7,9 @@ from omegaconf import OmegaConf ...@@ -7,8 +7,9 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices from modules import shared, modelloader, devices, script_callbacks
from modules.paths import models_path from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion" model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir)) model_path = os.path.abspath(os.path.join(models_path, model_dir))
...@@ -20,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict() ...@@ -20,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict()
try: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging from transformers import logging, CLIPModel
logging.set_verbosity_error() logging.set_verbosity_error()
except Exception: except Exception:
...@@ -154,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd): ...@@ -154,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd return pl_sd
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
def load_model_weights(model, checkpoint_info): def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash sd_model_hash = checkpoint_info.hash
...@@ -185,7 +189,7 @@ def load_model_weights(model, checkpoint_info): ...@@ -185,7 +189,7 @@ def load_model_weights(model, checkpoint_info):
if os.path.exists(vae_file): if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}") print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
...@@ -203,14 +207,26 @@ def load_model_weights(model, checkpoint_info): ...@@ -203,14 +207,26 @@ def load_model_weights(model, checkpoint_info):
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
def load_model(): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}") print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.use_ema = False
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.unet_config.params.in_channels = 9
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
...@@ -222,6 +238,9 @@ def load_model(): ...@@ -222,6 +238,9 @@ def load_model():
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
sd_model.eval() sd_model.eval()
shared.sd_model = sd_model
script_callbacks.model_loaded_callback(sd_model)
print(f"Model loaded.") print(f"Model loaded.")
return sd_model return sd_model
...@@ -234,9 +253,9 @@ def reload_model_weights(sd_model, info=None): ...@@ -234,9 +253,9 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config: if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear() checkpoints_loaded.clear()
shared.sd_model = load_model() load_model(checkpoint_info)
return shared.sd_model return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
...@@ -249,6 +268,7 @@ def reload_model_weights(sd_model, info=None): ...@@ -249,6 +268,7 @@ def reload_model_weights(sd_model, info=None):
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device) sd_model.to(devices.device)
......
This diff is collapsed.
...@@ -3,6 +3,7 @@ import datetime ...@@ -3,6 +3,7 @@ import datetime
import json import json
import os import os
import sys import sys
from collections import OrderedDict
import gradio as gr import gradio as gr
import tqdm import tqdm
...@@ -63,6 +64,7 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po ...@@ -63,6 +64,7 @@ parser.add_argument("--port", type=int, help="launch gradio with given server po
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json')) parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False) parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
parser.add_argument("--freeze-settings", action='store_true', help="disable editing settings", default=False)
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json')) parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option") parser.add_argument("--gradio-debug", action='store_true', help="launch gradio with --debug option")
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
...@@ -79,6 +81,8 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl ...@@ -79,6 +81,8 @@ parser.add_argument("--disable-safe-unpickle", action='store_true', help="disabl
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui") parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--browse-all-images", action='store_true', help="Allow browsing all images by Image Browser", default=False)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
restricted_opts = [ restricted_opts = [
...@@ -163,13 +167,13 @@ def realesrgan_models_names(): ...@@ -163,13 +167,13 @@ def realesrgan_models_names():
class OptionInfo: class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None): def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None):
self.default = default self.default = default
self.label = label self.label = label
self.component = component self.component = component
self.component_args = component_args self.component_args = component_args
self.onchange = onchange self.onchange = onchange
self.section = None self.section = section
self.refresh = refresh self.refresh = refresh
...@@ -250,7 +254,7 @@ options_templates.update(options_section(('system', "System"), { ...@@ -250,7 +254,7 @@ options_templates.update(options_section(('system', "System"), {
})) }))
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
...@@ -292,6 +296,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), ...@@ -292,6 +296,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
...@@ -323,6 +328,15 @@ options_templates.update(options_section(('inspiration', "Inspiration"), { ...@@ -323,6 +328,15 @@ options_templates.update(options_section(('inspiration', "Inspiration"), {
"inspiration_cols_num": OptionInfo(8, "Columns of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}), "inspiration_cols_num": OptionInfo(8, "Columns of inspiration interface frame", gr.Slider, {"minimum": 4, "maximum": 16, "step": 1}),
})) }))
options_templates.update(options_section(('images-history', "Images Browser"), {
#"images_history_reconstruct_directory": OptionInfo(False, "Reconstruct output directory structure.This can greatly improve the speed of loading , but will change the original output directory structure"),
"images_history_preload": OptionInfo(False, "Preload images at startup"),
"images_history_num_per_page": OptionInfo(36, "Number of pictures displayed on each page"),
"images_history_pages_num": OptionInfo(6, "Minimum number of pages per load "),
"images_history_grid_num": OptionInfo(6, "Number of grids in each row"),
}))
class Options: class Options:
data = None data = None
data_labels = options_templates data_labels = options_templates
...@@ -385,6 +399,20 @@ class Options: ...@@ -385,6 +399,20 @@ class Options:
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d) return json.dumps(d)
def add_option(self, key, info):
self.data_labels[key] = info
def reorder(self):
"""reorder settings so that all items related to section always go together"""
section_ids = {}
settings_items = self.data_labels.items()
for k, item in settings_items:
if item.section not in section_ids:
section_ids[item.section] = len(section_ids)
self.data_labels = {k: v for k, v in sorted(settings_items, key=lambda x: section_ids[x[1].section])}
opts = Options() opts = Options()
if os.path.exists(config_filename): if os.path.exists(config_filename):
...@@ -394,6 +422,8 @@ sd_upscalers = [] ...@@ -394,6 +422,8 @@ sd_upscalers = []
sd_model = None sd_model = None
clip_model = None
progress_print_out = sys.stdout progress_print_out = sys.stdout
......
...@@ -83,7 +83,7 @@ class PersonalizedBase(Dataset): ...@@ -83,7 +83,7 @@ class PersonalizedBase(Dataset):
self.dataset.append(entry) self.dataset.append(entry)
assert len(self.dataset) > 1, "No images have been found in the dataset." assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset)) self.initial_indexes = np.arange(len(self.dataset))
...@@ -91,7 +91,7 @@ class PersonalizedBase(Dataset): ...@@ -91,7 +91,7 @@ class PersonalizedBase(Dataset):
self.shuffle() self.shuffle()
def shuffle(self): def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])] self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
def create_text(self, filename_text): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
......
...@@ -5,6 +5,7 @@ import zlib ...@@ -5,6 +5,7 @@ import zlib
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
from fonts.ttf import Roboto from fonts.ttf import Roboto
import torch import torch
from modules.shared import opts
class EmbeddingEncoder(json.JSONEncoder): class EmbeddingEncoder(json.JSONEncoder):
...@@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t ...@@ -133,7 +134,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
from math import cos from math import cos
image = srcimage.copy() image = srcimage.copy()
fontsize = 32
if textfont is None: if textfont is None:
try: try:
textfont = ImageFont.truetype(opts.font or Roboto, fontsize) textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
...@@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t ...@@ -150,7 +151,7 @@ def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, t
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size)) image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
fontsize = 32
font = ImageFont.truetype(textfont, fontsize) font = ImageFont.truetype(textfont, fontsize)
padding = 10 padding = 10
......
import os import os
from PIL import Image, ImageOps from PIL import Image, ImageOps
import math
import platform import platform
import sys import sys
import tqdm import tqdm
...@@ -11,7 +12,7 @@ if cmd_opts.deepdanbooru: ...@@ -11,7 +12,7 @@ if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
try: try:
if process_caption: if process_caption:
shared.interrogator.load() shared.interrogator.load()
...@@ -21,7 +22,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ ...@@ -21,7 +22,7 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts) deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru, split_threshold, overlap_ratio)
finally: finally:
...@@ -33,11 +34,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ ...@@ -33,11 +34,13 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False): def preprocess_work(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2):
width = process_width width = process_width
height = process_height height = process_height
src = os.path.abspath(process_src) src = os.path.abspath(process_src)
dst = os.path.abspath(process_dst) dst = os.path.abspath(process_dst)
split_threshold = max(0.0, min(1.0, split_threshold))
overlap_ratio = max(0.0, min(0.9, overlap_ratio))
assert src != dst, 'same directory specified as source and destination' assert src != dst, 'same directory specified as source and destination'
...@@ -48,7 +51,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro ...@@ -48,7 +51,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
shared.state.textinfo = "Preprocessing..." shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files) shared.state.job_count = len(files)
def save_pic_with_caption(image, index): def save_pic_with_caption(image, index, existing_caption=None):
caption = "" caption = ""
if process_caption: if process_caption:
...@@ -66,17 +69,49 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro ...@@ -66,17 +69,49 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
basename = f"{index:05}-{subindex[0]}-{filename_part}" basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png")) image.save(os.path.join(dst, f"{basename}.png"))
if preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
elif preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
elif preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption
caption = caption.strip()
if len(caption) > 0: if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file: with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption) file.write(caption)
subindex[0] += 1 subindex[0] += 1
def save_pic(image, index): def save_pic(image, index, existing_caption=None):
save_pic_with_caption(image, index) save_pic_with_caption(image, index, existing_caption=existing_caption)
if process_flip: if process_flip:
save_pic_with_caption(ImageOps.mirror(image), index) save_pic_with_caption(ImageOps.mirror(image), index, existing_caption=existing_caption)
def split_pic(image, inverse_xy):
if inverse_xy:
from_w, from_h = image.height, image.width
to_w, to_h = height, width
else:
from_w, from_h = image.width, image.height
to_w, to_h = width, height
h = from_h * to_w // from_w
if inverse_xy:
image = image.resize((h, to_w))
else:
image = image.resize((to_w, h))
split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio)))
y_step = (h - to_h) / (split_count - 1)
for i in range(split_count):
y = int(y_step * i)
if inverse_xy:
splitted = image.crop((y, 0, y + to_h, to_w))
else:
splitted = image.crop((0, y, to_w, y + to_h))
yield splitted
for index, imagefile in enumerate(tqdm.tqdm(files)): for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0] subindex = [0]
...@@ -86,31 +121,27 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro ...@@ -86,31 +121,27 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pro
except Exception: except Exception:
continue continue
existing_caption = None
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
if os.path.exists(existing_caption_filename):
with open(existing_caption_filename, 'r', encoding="utf8") as file:
existing_caption = file.read()
if shared.state.interrupted: if shared.state.interrupted:
break break
ratio = img.height / img.width if img.height > img.width:
is_tall = ratio > 1.35 ratio = (img.width * height) / (img.height * width)
is_wide = ratio < 1 / 1.35 inverse_xy = False
else:
if process_split and is_tall: ratio = (img.height * width) / (img.width * height)
img = img.resize((width, height * img.height // img.width)) inverse_xy = True
top = img.crop((0, 0, width, height))
save_pic(top, index)
bot = img.crop((0, img.height - height, width, img.height))
save_pic(bot, index)
elif process_split and is_wide:
img = img.resize((width * img.width // img.height, height))
left = img.crop((0, 0, width, height))
save_pic(left, index)
right = img.crop((img.width - width, 0, img.width, height)) if process_split and ratio < 1.0 and ratio <= split_threshold:
save_pic(right, index) for splitted in split_pic(img, inverse_xy):
save_pic(splitted, index, existing_caption=existing_caption)
else: else:
img = images.resize_image(1, img, width, height) img = images.resize_image(1, img, width, height)
save_pic(img, index) save_pic(img, index, existing_caption=existing_caption)
shared.state.nextjob() shared.state.nextjob()
...@@ -153,7 +153,7 @@ class EmbeddingDatabase: ...@@ -153,7 +153,7 @@ class EmbeddingDatabase:
return None, None return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'): def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
...@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): ...@@ -165,7 +165,8 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists" if not overwrite_old:
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name) embedding = Embedding(vec, name)
embedding.step = 0 embedding.step = 0
...@@ -275,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -275,6 +276,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward() loss.backward()
optimizer.step() optimizer.step()
epoch_num = embedding.step // len(ds) epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1 epoch_step = embedding.step - (epoch_num * len(ds)) + 1
......
...@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess ...@@ -7,8 +7,8 @@ import modules.textual_inversion.preprocess
from modules import sd_hijack, shared from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt): def create_embedding(name, initialization_text, nvpt, overwrite_old):
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text) filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
......
import modules.scripts import modules.scripts
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
import modules.processing as processing import modules.processing as processing
...@@ -35,6 +36,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -35,6 +36,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None, firstphase_height=firstphase_height if enable_hr else None,
) )
p.scripts = modules.scripts.scripts_txt2img
p.script_args = args
if cmd_opts.enable_console_prompts: if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
...@@ -53,4 +57,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -53,4 +57,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed.images = [] processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info)
This diff is collapsed.
...@@ -34,6 +34,9 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps): ...@@ -34,6 +34,9 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
sigma_in = torch.cat([sigmas[i] * s_in] * 2) sigma_in = torch.cat([sigmas[i] * s_in] * 2)
cond_in = torch.cat([uncond, cond]) cond_in = torch.cat([uncond, cond])
image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
t = dnw.sigma_to_t(sigma_in) t = dnw.sigma_to_t(sigma_in)
...@@ -78,6 +81,9 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps): ...@@ -78,6 +81,9 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2) sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
cond_in = torch.cat([uncond, cond]) cond_in = torch.cat([uncond, cond])
image_conditioning = torch.cat([p.image_conditioning] * 2)
cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)] c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
if i == 1: if i == 1:
...@@ -194,7 +200,7 @@ class Script(scripts.Script): ...@@ -194,7 +200,7 @@ class Script(scripts.Script):
p.seed = p.seed + 1 p.seed = p.seed + 1
return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning) return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
p.sample = sample_extra p.sample = sample_extra
......
...@@ -172,54 +172,54 @@ class Script(scripts.Script): ...@@ -172,54 +172,54 @@ class Script(scripts.Script):
if down > 0: if down > 0:
down = target_h - init_img.height - up down = target_h - init_img.height - up
init_image = p.init_images[0] def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
state.job_count = (1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)
def expand(init, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
is_horiz = is_left or is_right is_horiz = is_left or is_right
is_vert = is_top or is_bottom is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0 pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0 pixels_vert = expand_pixels if is_vert else 0
res_w = init.width + pixels_horiz images_to_process = []
res_h = init.height + pixels_vert output_images = []
process_res_w = math.ceil(res_w / 64) * 64 for n in range(count):
process_res_h = math.ceil(res_h / 64) * 64 res_w = init[n].width + pixels_horiz
res_h = init[n].height + pixels_vert
img = Image.new("RGB", (process_res_w, process_res_h)) process_res_w = math.ceil(res_w / 64) * 64
img.paste(init, (pixels_horiz if is_left else 0, pixels_vert if is_top else 0)) process_res_h = math.ceil(res_h / 64) * 64
mask = Image.new("RGB", (process_res_w, process_res_h), "white")
draw = ImageDraw.Draw(mask) img = Image.new("RGB", (process_res_w, process_res_h))
draw.rectangle(( img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
expand_pixels + mask_blur if is_left else 0, mask = Image.new("RGB", (process_res_w, process_res_h), "white")
expand_pixels + mask_blur if is_top else 0, draw = ImageDraw.Draw(mask)
mask.width - expand_pixels - mask_blur if is_right else res_w, draw.rectangle((
mask.height - expand_pixels - mask_blur if is_bottom else res_h, expand_pixels + mask_blur if is_left else 0,
), fill="black") expand_pixels + mask_blur if is_top else 0,
mask.width - expand_pixels - mask_blur if is_right else res_w,
np_image = (np.asarray(img) / 255.0).astype(np.float64) mask.height - expand_pixels - mask_blur if is_bottom else res_h,
np_mask = (np.asarray(mask) / 255.0).astype(np.float64) ), fill="black")
noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB") np_image = (np.asarray(img) / 255.0).astype(np.float64)
np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
target_width = min(process_width, init.width + pixels_horiz) if is_horiz else img.width noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
target_height = min(process_height, init.height + pixels_vert) if is_vert else img.height output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
crop_region = ( target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
0 if is_left else out.width - target_width, target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
0 if is_top else out.height - target_height, p.width = target_width if is_horiz else img.width
target_width if is_left else out.width, p.height = target_height if is_vert else img.height
target_height if is_top else out.height,
) crop_region = (
0 if is_left else output_images[n].width - target_width,
image_to_process = out.crop(crop_region) 0 if is_top else output_images[n].height - target_height,
mask = mask.crop(crop_region) target_width if is_left else output_images[n].width,
target_height if is_top else output_images[n].height,
p.width = target_width if is_horiz else img.width )
p.height = target_height if is_vert else img.height mask = mask.crop(crop_region)
p.init_images = [image_to_process] p.image_mask = mask
p.image_mask = mask
image_to_process = output_images[n].crop(crop_region)
images_to_process.append(image_to_process)
p.init_images = images_to_process
latent_mask = Image.new("RGB", (p.width, p.height), "white") latent_mask = Image.new("RGB", (p.width, p.height), "white")
draw = ImageDraw.Draw(latent_mask) draw = ImageDraw.Draw(latent_mask)
...@@ -232,31 +232,52 @@ class Script(scripts.Script): ...@@ -232,31 +232,52 @@ class Script(scripts.Script):
p.latent_mask = latent_mask p.latent_mask = latent_mask
proc = process_images(p) proc = process_images(p)
proc_img = proc.images[0]
if initial_seed_and_info[0] is None: if initial_seed_and_info[0] is None:
initial_seed_and_info[0] = proc.seed initial_seed_and_info[0] = proc.seed
initial_seed_and_info[1] = proc.info initial_seed_and_info[1] = proc.info
out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height)) for n in range(count):
out = out.crop((0, 0, res_w, res_h)) output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
return out output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
img = init_image return output_images
if left > 0: batch_count = p.n_iter
img = expand(img, left, is_left=True) batch_size = p.batch_size
if right > 0: p.n_iter = 1
img = expand(img, right, is_right=True) state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
if up > 0: all_processed_images = []
img = expand(img, up, is_top=True)
if down > 0: for i in range(batch_count):
img = expand(img, down, is_bottom=True) imgs = [init_img] * batch_size
state.job = f"Batch {i + 1} out of {batch_count}"
if left > 0:
imgs = expand(imgs, batch_size, left, is_left=True)
if right > 0:
imgs = expand(imgs, batch_size, right, is_right=True)
if up > 0:
imgs = expand(imgs, batch_size, up, is_top=True)
if down > 0:
imgs = expand(imgs, batch_size, down, is_bottom=True)
res = Processed(p, [img], initial_seed_and_info[0], initial_seed_and_info[1]) all_processed_images += imgs
all_images = all_processed_images
combined_grid_image = images.image_grid(all_processed_images)
unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
if opts.return_grid and not unwanted_grid_because_of_img_count:
all_images = [combined_grid_image] + all_processed_images
res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
if opts.samples_save: if opts.samples_save:
images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p) for img in all_processed_images:
images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
return res if opts.grid_save and not unwanted_grid_because_of_img_count:
images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
return res
...@@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs): ...@@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs):
if info is None: if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}") raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info) modules.sd_models.reload_model_weights(shared.sd_model, info)
p.sd_model = shared.sd_model
def confirm_checkpoints(p, xs): def confirm_checkpoints(p, xs):
......
...@@ -71,8 +71,9 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): ...@@ -71,8 +71,9 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
def initialize(): def initialize():
modules.scripts.load_scripts(os.path.join(script_path, "scripts")) modules.scripts.load_scripts()
if cmd_opts.ui_debug_mode: if cmd_opts.ui_debug_mode:
class enmpty(): class enmpty():
name = None name = None
...@@ -85,9 +86,7 @@ def initialize(): ...@@ -85,9 +86,7 @@ def initialize():
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers() modelloader.load_upscalers()
modules.sd_models.load_model()
shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
...@@ -124,7 +123,8 @@ def api_only(): ...@@ -124,7 +123,8 @@ def api_only():
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def webui(launch_api=False): def webui():
launch_api = cmd_opts.api
initialize() initialize()
while 1: while 1:
...@@ -150,7 +150,7 @@ def webui(launch_api=False): ...@@ -150,7 +150,7 @@ def webui(launch_api=False):
sd_samplers.set_samplers() sd_samplers.set_samplers()
print('Reloading Custom Scripts') print('Reloading Custom Scripts')
modules.scripts.reload_scripts(os.path.join(script_path, "scripts")) modules.scripts.reload_scripts()
print('Reloading modules: modules.ui') print('Reloading modules: modules.ui')
importlib.reload(modules.ui) importlib.reload(modules.ui)
print('Refreshing Model List') print('Refreshing Model List')
...@@ -164,4 +164,4 @@ if __name__ == "__main__": ...@@ -164,4 +164,4 @@ if __name__ == "__main__":
if cmd_opts.nowebui: if cmd_opts.nowebui:
api_only() api_only()
else: else:
webui(cmd_opts.api) webui()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment