Commit 07d1bd42 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into roy.add_simple_interrogate_api

parents 3f3d14af 6e4de5b4
* @AUTOMATIC1111 * @AUTOMATIC1111
/localizations/ar_AR.json @xmodar @blackneoo
/localizations/de_DE.json @LunixWasTaken # if you were managing a localization and were removed from this file, this is because
/localizations/es_ES.json @innovaciones # the intended way to do localizations now is via extensions. See:
/localizations/fr_FR.json @tumbly # https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions
/localizations/it_IT.json @EugenioBuffo # Make a repo with your localization and since you are still listed as a collaborator
/localizations/ja_JP.json @yuuki76 # you can add it to the wiki page yourself. This change is because some people complained
/localizations/ko_KR.json @36DB # the git commit log is cluttered with things unrelated to almost everyone and
/localizations/pt_BR.json @M-art-ucci # because I believe this is the best overall for the project to handle localizations almost
/localizations/ru_RU.json @kabachuha # entirely without my oversight.
/localizations/tr_TR.json @camenduru
/localizations/zh_CN.json @dtlnor @bgluminous
/localizations/zh_TW.json @benlisquare
...@@ -155,14 +155,15 @@ The documentation was moved from this README over to the project's [wiki](https: ...@@ -155,14 +155,15 @@ The documentation was moved from this README over to the project's [wiki](https:
- Swin2SR - https://github.com/mv-lab/swin2sr - Swin2SR - https://github.com/mv-lab/swin2sr
- LDSR - https://github.com/Hafiidz/latent-diffusion - LDSR - https://github.com/Hafiidz/latent-diffusion
- Ideas for optimizations - https://github.com/basujindal/stable-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion
- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. - Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing.
- InvokeAI, lstein - Cross Attention layer optimization - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) - Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion)
- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas).
- Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
- Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch - Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch
- xformers - https://github.com/facebookresearch/xformers - xformers - https://github.com/facebookresearch/xformers
- DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru
- Security advice - RyotaK
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
addEventListener('keydown', (event) => { addEventListener('keydown', (event) => {
let target = event.originalTarget || event.composedPath()[0]; let target = event.originalTarget || event.composedPath()[0];
if (!target.hasAttribute("placeholder")) return; if (!target.matches("#toprow textarea.gr-text-input[placeholder]")) return;
if (!target.placeholder.toLowerCase().includes("prompt")) return;
if (! (event.metaKey || event.ctrlKey)) return; if (! (event.metaKey || event.ctrlKey)) return;
......
...@@ -22,3 +22,14 @@ function extensions_check(){ ...@@ -22,3 +22,14 @@ function extensions_check(){
return [] return []
} }
function install_extension_from_index(button, url){
button.disabled = "disabled"
button.value = "Installing..."
textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url
textarea.dispatchEvent(new Event("input", { bubbles: true }))
gradioApp().querySelector('#install_extension_button').click()
}
...@@ -3,8 +3,21 @@ global_progressbars = {} ...@@ -3,8 +3,21 @@ global_progressbars = {}
galleries = {} galleries = {}
galleryObservers = {} galleryObservers = {}
// this tracks laumnches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running
timeoutIds = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
var progressbar = gradioApp().getElementById(id_progressbar) // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id
// every time you use gr.HTML(elem_id='xxx'), so we handle this here
var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar)
var progressbarParent
if(progressbar){
progressbarParent = gradioApp().querySelector("#"+id_progressbar)
} else{
progressbar = gradioApp().getElementById(id_progressbar)
progressbarParent = null
}
var skip = id_skip ? gradioApp().getElementById(id_skip) : null var skip = id_skip ? gradioApp().getElementById(id_skip) : null
var interrupt = gradioApp().getElementById(id_interrupt) var interrupt = gradioApp().getElementById(id_interrupt)
...@@ -26,18 +39,26 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip ...@@ -26,18 +39,26 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
global_progressbars[id_progressbar] = progressbar global_progressbars[id_progressbar] = progressbar
var mutationObserver = new MutationObserver(function(m){ var mutationObserver = new MutationObserver(function(m){
if(timeoutIds[id_part]) return;
preview = gradioApp().getElementById(id_preview) preview = gradioApp().getElementById(id_preview)
gallery = gradioApp().getElementById(id_gallery) gallery = gradioApp().getElementById(id_gallery)
if(preview != null && gallery != null){ if(preview != null && gallery != null){
preview.style.width = gallery.clientWidth + "px" preview.style.width = gallery.clientWidth + "px"
preview.style.height = gallery.clientHeight + "px" preview.style.height = gallery.clientHeight + "px"
if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px"
//only watch gallery if there is a generation process going on //only watch gallery if there is a generation process going on
check_gallery(id_gallery); check_gallery(id_gallery);
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
if(!progressDiv){ if(progressDiv){
timeoutIds[id_part] = window.setTimeout(function() {
timeoutIds[id_part] = null
requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt)
}, 500)
} else{
if (skip) { if (skip) {
skip.style.display = "none" skip.style.display = "none"
} }
...@@ -49,11 +70,8 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip ...@@ -49,11 +70,8 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip
galleries[id_gallery] = null; galleries[id_gallery] = null;
} }
} }
} }
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
}); });
mutationObserver.observe( progressbar, { childList:true, subtree:true }) mutationObserver.observe( progressbar, { childList:true, subtree:true })
} }
......
...@@ -208,4 +208,6 @@ function update_token_counter(button_id) { ...@@ -208,4 +208,6 @@ function update_token_counter(button_id) {
function restart_reload(){ function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>'; document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},2000) setTimeout(function(){location.reload()},2000)
return []
} }
...@@ -7,6 +7,7 @@ import shlex ...@@ -7,6 +7,7 @@ import shlex
import platform import platform
dir_repos = "repositories" dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
...@@ -16,11 +17,11 @@ def extract_arg(args, name): ...@@ -16,11 +17,11 @@ def extract_arg(args, name):
return [x for x in args if x != name], name in args return [x for x in args if x != name], name in args
def run(command, desc=None, errdesc=None): def run(command, desc=None, errdesc=None, custom_env=None):
if desc is not None: if desc is not None:
print(desc) print(desc)
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0: if result.returncode != 0:
...@@ -101,7 +102,25 @@ def version_check(commit): ...@@ -101,7 +102,25 @@ def version_check(commit):
else: else:
print("Not a git clone, can't perform version check.") print("Not a git clone, can't perform version check.")
except Exception as e: except Exception as e:
print("versipm check failed",e) print("version check failed", e)
def run_extensions_installers():
if not os.path.isdir(dir_extensions):
return
for dirname_extension in os.listdir(dir_extensions):
path_installer = os.path.join(dir_extensions, dirname_extension, "install.py")
if not os.path.isfile(path_installer):
continue
try:
env = os.environ.copy()
env['PYTHONPATH'] = os.path.abspath(".")
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {dirname_extension}", custom_env=env))
except Exception as e:
print(e, file=sys.stderr)
def prepare_enviroment(): def prepare_enviroment():
...@@ -123,7 +142,7 @@ def prepare_enviroment(): ...@@ -123,7 +142,7 @@ def prepare_enviroment():
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "60e5042ca0da89c14d1dd59d73883280f8fce991")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
...@@ -189,6 +208,8 @@ def prepare_enviroment(): ...@@ -189,6 +208,8 @@ def prepare_enviroment():
run_pip(f"install -r {requirements_file}", "requirements for Web UI") run_pip(f"install -r {requirements_file}", "requirements for Web UI")
run_extensions_installers()
if update_check: if update_check:
version_check(commit) version_check(commit)
...@@ -217,12 +238,15 @@ def tests(argv): ...@@ -217,12 +238,15 @@ def tests(argv):
proc.kill() proc.kill()
def start_webui(): def start():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}") print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
import webui import webui
if '--nowebui' in sys.argv:
webui.api_only()
else:
webui.webui() webui.webui()
if __name__ == "__main__": if __name__ == "__main__":
prepare_enviroment() prepare_enviroment()
start_webui() start()
...@@ -70,7 +70,7 @@ ...@@ -70,7 +70,7 @@
"None": "Nichts", "None": "Nichts",
"Prompt matrix": "Promptmatrix", "Prompt matrix": "Promptmatrix",
"Prompts from file or textbox": "Prompts aus Datei oder Textfeld", "Prompts from file or textbox": "Prompts aus Datei oder Textfeld",
"X/Y plot": "X/Y Graf", "X/Y plot": "X/Y Graph",
"Put variable parts at start of prompt": "Variable teile am start des Prompt setzen", "Put variable parts at start of prompt": "Variable teile am start des Prompt setzen",
"Iterate seed every line": "Iterate seed every line", "Iterate seed every line": "Iterate seed every line",
"List of prompt inputs": "List of prompt inputs", "List of prompt inputs": "List of prompt inputs",
......
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import base64
import io
import time import time
import uvicorn import uvicorn
from threading import Lock
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, FastAPI, HTTPException
import modules.shared as shared import modules.shared as shared
from modules import devices
from modules.api.models import * from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo from modules.extras import run_extras, run_pnginfo
from PIL import PngImagePlugin
from modules.sd_models import checkpoints_list
from modules.realesrgan_model import get_realesrgan_models
from typing import List
def upscaler_to_index(name: str): def upscaler_to_index(name: str):
try: try:
...@@ -29,8 +34,26 @@ def setUpscalers(req: dict): ...@@ -29,8 +34,26 @@ def setUpscalers(req: dict):
return reqDict return reqDict
def encode_pil_to_base64(image):
with io.BytesIO() as output_bytes:
# Copy any text-only metadata
use_metadata = False
metadata = PngImagePlugin.PngInfo()
for key, value in image.info.items():
if isinstance(key, str) and isinstance(value, str):
metadata.add_text(key, value)
use_metadata = True
image.save(
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None)
)
bytes_data = output_bytes.getvalue()
return base64.b64encode(bytes_data)
class Api: class Api:
def __init__(self, app, queue_lock): def __init__(self, app: FastAPI, queue_lock: Lock):
self.router = APIRouter() self.router = APIRouter()
self.app = app self.app = app
self.queue_lock = queue_lock self.queue_lock = queue_lock
...@@ -41,6 +64,19 @@ class Api: ...@@ -41,6 +64,19 @@ class Api:
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/interrogate", self.interrogateapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/options", self.get_config, methods=["GET"], response_model=OptionsModel)
self.app.add_api_route("/sdapi/v1/options", self.set_config, methods=["POST"])
self.app.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=FlagsModel)
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
self.app.add_api_route("/sdapi/v1/prompt-styles", self.get_promp_styles, methods=["GET"], response_model=List[PromptStyleItem])
self.app.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
self.app.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
...@@ -171,6 +207,8 @@ class Api: ...@@ -171,6 +207,8 @@ class Api:
progress = min(progress, 1) progress = min(progress, 1)
shared.state.set_current_image()
current_image = None current_image = None
if shared.state.current_image and not req.skip_current_image: if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image) current_image = encode_pil_to_base64(shared.state.current_image)
...@@ -190,6 +228,75 @@ class Api: ...@@ -190,6 +228,75 @@ class Api:
return InterrogateResponse(caption=processed) return InterrogateResponse(caption=processed)
def interruptapi(self):
shared.state.interrupt()
return {}
def get_config(self):
options = {}
for key in shared.opts.data.keys():
metadata = shared.opts.data_labels.get(key)
if(metadata is not None):
options.update({key: shared.opts.data.get(key, shared.opts.data_labels.get(key).default)})
else:
options.update({key: shared.opts.data.get(key, None)})
return options
def set_config(self, req: OptionsModel):
# currently req has all options fields even if you send a dict like { "send_seed": false }, which means it will
# overwrite all options with default values.
raise RuntimeError('Setting options via API is not supported')
reqDict = vars(req)
for o in reqDict:
setattr(shared.opts, o, reqDict[o])
shared.opts.save(shared.config_filename)
return
def get_cmd_flags(self):
return vars(shared.cmd_opts)
def get_samplers(self):
return [{"name":sampler[0], "aliases":sampler[2], "options":sampler[3]} for sampler in all_samplers]
def get_upscalers(self):
upscalers = []
for upscaler in shared.sd_upscalers:
u = upscaler.scaler
upscalers.append({"name":u.name, "model_name":u.model_name, "model_path":u.model_path, "model_url":u.model_url})
return upscalers
def get_sd_models(self):
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
def get_face_restorers(self):
return [{"name":x.name(), "cmd_dir": getattr(x, "cmd_dir", None)} for x in shared.face_restorers]
def get_realesrgan_models(self):
return [{"name":x.name,"path":x.data_path, "scale":x.scale} for x in get_realesrgan_models(None)]
def get_promp_styles(self):
styleList = []
for k in shared.prompt_styles.styles:
style = shared.prompt_styles.styles[k]
styleList.append({"name":style[0], "prompt": style[1], "negative_prompr": style[2]})
return styleList
def get_artists_categories(self):
return shared.artist_db.cats
def get_artists(self):
return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]
def launch(self, server_name, port): def launch(self, server_name, port):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port) uvicorn.run(self.app, host=server_name, port=port)
import inspect import inspect
from click import prompt
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from typing import Any, Optional from typing import Any, Optional
from typing_extensions import Literal from typing_extensions import Literal
from inflection import underscore from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
from modules.shared import sd_upscalers from modules.shared import sd_upscalers, opts, parser
from typing import Dict, List
API_NOT_ALLOWED = [ API_NOT_ALLOWED = [
"self", "self",
...@@ -110,12 +110,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( ...@@ -110,12 +110,12 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
).generate_model() ).generate_model()
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict parameters: dict
info: str info: str
class ImageToImageResponse(BaseModel): class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict parameters: dict
info: str info: str
...@@ -132,6 +132,7 @@ class ExtrasBaseRequest(BaseModel): ...@@ -132,6 +132,7 @@ class ExtrasBaseRequest(BaseModel):
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}") upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.") extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
upscale_first: bool = Field(default=False, title="Upscale first", description="Should the upscaler run before restoring faces?")
class ExtraBaseResponse(BaseModel): class ExtraBaseResponse(BaseModel):
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.") html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
...@@ -147,10 +148,10 @@ class FileData(BaseModel): ...@@ -147,10 +148,10 @@ class FileData(BaseModel):
name: str = Field(title="File name") name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest): class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings") imageList: List[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse): class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: list[str] = Field(title="Images", description="The generated images in base64 format.") images: List[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel): class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image") image: str = Field(title="Image", description="The base64 encoded PNG image")
...@@ -172,3 +173,69 @@ class InterrogateRequest(BaseModel): ...@@ -172,3 +173,69 @@ class InterrogateRequest(BaseModel):
class InterrogateResponse(BaseModel): class InterrogateResponse(BaseModel):
caption: str = Field(default=None, title="Caption", description="The generated caption for the image.") caption: str = Field(default=None, title="Caption", description="The generated caption for the image.")
fields = {}
for key, value in opts.data.items():
metadata = opts.data_labels.get(key)
optType = opts.typemap.get(type(value), type(value))
if (metadata is not None):
fields.update({key: (Optional[optType], Field(
default=metadata.default ,description=metadata.label))})
else:
fields.update({key: (Optional[optType], Field())})
OptionsModel = create_model("Options", **fields)
flags = {}
_options = vars(parser)['_option_string_actions']
for key in _options:
if(_options[key].dest != 'help'):
flag = _options[key]
_type = str
if _options[key].default is not None: _type = type(_options[key].default)
flags.update({flag.dest: (_type,Field(default=flag.default, description=flag.help))})
FlagsModel = create_model("Flags", **flags)
class SamplerItem(BaseModel):
name: str = Field(title="Name")
aliases: List[str] = Field(title="Aliases")
options: Dict[str, str] = Field(title="Options")
class UpscalerItem(BaseModel):
name: str = Field(title="Name")
model_name: Optional[str] = Field(title="Model Name")
model_path: Optional[str] = Field(title="Path")
model_url: Optional[str] = Field(title="URL")
class SDModelItem(BaseModel):
title: str = Field(title="Title")
model_name: str = Field(title="Model Name")
hash: str = Field(title="Hash")
filename: str = Field(title="Filename")
config: str = Field(title="Config file")
class HypernetworkItem(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
class FaceRestorerItem(BaseModel):
name: str = Field(title="Name")
cmd_dir: Optional[str] = Field(title="Path")
class RealesrganItem(BaseModel):
name: str = Field(title="Name")
path: Optional[str] = Field(title="Path")
scale: Optional[int] = Field(title="Scale")
class PromptStyleItem(BaseModel):
name: str = Field(title="Name")
prompt: Optional[str] = Field(title="Prompt")
negative_prompt: Optional[str] = Field(title="Negative Prompt")
class ArtistItem(BaseModel):
name: str = Field(title="Name")
score: float = Field(title="Score")
category: str = Field(title="Category")
...@@ -50,6 +50,7 @@ def mod2normal(state_dict): ...@@ -50,6 +50,7 @@ def mod2normal(state_dict):
def resrgan2normal(state_dict, nb=23): def resrgan2normal(state_dict, nb=23):
# this code is copied from https://github.com/victorca25/iNNfer # this code is copied from https://github.com/victorca25/iNNfer
if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict:
re8x = 0
crt_net = {} crt_net = {}
items = [] items = []
for k, v in state_dict.items(): for k, v in state_dict.items():
...@@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23): ...@@ -75,10 +76,18 @@ def resrgan2normal(state_dict, nb=23):
crt_net['model.3.bias'] = state_dict['conv_up1.bias'] crt_net['model.3.bias'] = state_dict['conv_up1.bias']
crt_net['model.6.weight'] = state_dict['conv_up2.weight'] crt_net['model.6.weight'] = state_dict['conv_up2.weight']
crt_net['model.6.bias'] = state_dict['conv_up2.bias'] crt_net['model.6.bias'] = state_dict['conv_up2.bias']
crt_net['model.8.weight'] = state_dict['conv_hr.weight']
crt_net['model.8.bias'] = state_dict['conv_hr.bias'] if 'conv_up3.weight' in state_dict:
crt_net['model.10.weight'] = state_dict['conv_last.weight'] # modification supporting: https://github.com/ai-forever/Real-ESRGAN/blob/main/RealESRGAN/rrdbnet_arch.py
crt_net['model.10.bias'] = state_dict['conv_last.bias'] re8x = 3
crt_net['model.9.weight'] = state_dict['conv_up3.weight']
crt_net['model.9.bias'] = state_dict['conv_up3.bias']
crt_net[f'model.{8+re8x}.weight'] = state_dict['conv_hr.weight']
crt_net[f'model.{8+re8x}.bias'] = state_dict['conv_hr.bias']
crt_net[f'model.{10+re8x}.weight'] = state_dict['conv_last.weight']
crt_net[f'model.{10+re8x}.bias'] = state_dict['conv_last.bias']
state_dict = crt_net state_dict = crt_net
return state_dict return state_dict
......
...@@ -34,8 +34,11 @@ class Extension: ...@@ -34,8 +34,11 @@ class Extension:
if repo is None or repo.bare: if repo is None or repo.bare:
self.remote = None self.remote = None
else: else:
try:
self.remote = next(repo.remote().urls, None) self.remote = next(repo.remote().urls, None)
self.status = 'unknown' self.status = 'unknown'
except Exception:
self.remote = None
def list_files(self, subdir, extension): def list_files(self, subdir, extension):
from modules import scripts from modules import scripts
...@@ -46,7 +49,7 @@ class Extension: ...@@ -46,7 +49,7 @@ class Extension:
res = [] res = []
for filename in sorted(os.listdir(dirpath)): for filename in sorted(os.listdir(dirpath)):
res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename))) res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
......
...@@ -136,10 +136,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -136,10 +136,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None blended_result: Image.Image = None
image_hash: str = hash(np.array(image.getdata()).tobytes())
for upscaler in params: for upscaler in params:
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
upscaling_resize_w, upscaling_resize_h, upscaling_crop) upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()), cache_key = LruCache.Key(image_hash=image_hash,
info_hash=hash(info), info_hash=hash(info),
args_hash=hash(upscale_args)) args_hash=hash(upscale_args))
cached_entry = cached_images.get(cache_key) cached_entry = cached_images.get(cache_key)
......
This diff is collapsed.
...@@ -9,7 +9,7 @@ from modules import devices, sd_hijack, shared ...@@ -9,7 +9,7 @@ from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
not_available = ["hardswish", "multiheadattention"] not_available = ["hardswish", "multiheadattention"]
keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) keys = list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name. # Remove illegal characters from name.
......
...@@ -510,6 +510,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -510,6 +510,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if extension.lower() == '.png': if extension.lower() == '.png':
pnginfo_data = PngImagePlugin.PngInfo() pnginfo_data = PngImagePlugin.PngInfo()
if opts.enable_pnginfo:
for k, v in params.pnginfo.items(): for k, v in params.pnginfo.items():
pnginfo_data.add_text(k, str(v)) pnginfo_data.add_text(k, str(v))
......
...@@ -55,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -55,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args):
filename = f"{left}-{n}{right}" filename = f"{left}-{n}{right}"
if not save_normally: if not save_normally:
os.makedirs(output_dir, exist_ok=True)
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
...@@ -80,6 +81,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -80,6 +81,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
mask = None mask = None
# Use the EXIF orientation of photos taken by smartphones. # Use the EXIF orientation of photos taken by smartphones.
if image is not None:
image = ImageOps.exif_transpose(image) image = ImageOps.exif_transpose(image)
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]' assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
...@@ -136,6 +138,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -136,6 +138,8 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
if processed is None: if processed is None:
processed = process_images(p) processed = process_images(p)
p.close()
shared.total_tqdm.clear() shared.total_tqdm.clear()
generation_info_js = processed.js() generation_info_js = processed.js()
......
...@@ -56,9 +56,9 @@ class InterrogateModels: ...@@ -56,9 +56,9 @@ class InterrogateModels:
import clip import clip
if self.running_on_cpu: if self.running_on_cpu:
model, preprocess = clip.load(clip_model_name, device="cpu") model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
else: else:
model, preprocess = clip.load(clip_model_name) model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
model.eval() model.eval()
model = model.to(devices.device_interrogate) model = model.to(devices.device_interrogate)
......
...@@ -101,8 +101,8 @@ class LDSR: ...@@ -101,8 +101,8 @@ class LDSR:
down_sample_rate = target_scale / 4 down_sample_rate = target_scale / 4
wd = width_og * down_sample_rate wd = width_og * down_sample_rate
hd = height_og * down_sample_rate hd = height_og * down_sample_rate
width_downsampled_pre = int(wd) width_downsampled_pre = int(np.ceil(wd))
height_downsampled_pre = int(hd) height_downsampled_pre = int(np.ceil(hd))
if down_sample_rate != 1: if down_sample_rate != 1:
print( print(
...@@ -110,7 +110,12 @@ class LDSR: ...@@ -110,7 +110,12 @@ class LDSR:
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS) im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else: else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)") print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
logs = self.run(model["model"], im_og, diffusion_steps, eta)
# pad width and height to multiples of 64, pads with the edge values of image to avoid artifacts
pad_w, pad_h = np.max(((2, 2), np.ceil(np.array(im_og.size) / 64).astype(int)), axis=0) * 64 - im_og.size
im_padded = Image.fromarray(np.pad(np.array(im_og), ((0, pad_h), (0, pad_w), (0, 0)), mode='edge'))
logs = self.run(model["model"], im_padded, diffusion_steps, eta)
sample = logs["sample"] sample = logs["sample"]
sample = sample.detach().cpu() sample = sample.detach().cpu()
...@@ -120,6 +125,9 @@ class LDSR: ...@@ -120,6 +125,9 @@ class LDSR:
sample = np.transpose(sample, (0, 2, 3, 1)) sample = np.transpose(sample, (0, 2, 3, 1))
a = Image.fromarray(sample[0]) a = Image.fromarray(sample[0])
# remove padding
a = a.crop((0, 0) + tuple(np.array(im_og.size) * 4))
del model del model
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
import sys import sys
import traceback import traceback
localizations = {} localizations = {}
...@@ -16,6 +17,11 @@ def list_localizations(dirname): ...@@ -16,6 +17,11 @@ def list_localizations(dirname):
localizations[fn] = os.path.join(dirname, file) localizations[fn] = os.path.join(dirname, file)
from modules import scripts
for file in scripts.list_scripts("localizations", ".json"):
fn, ext = os.path.splitext(file.filename)
localizations[fn] = file.path
def localization_js(current_localization_name): def localization_js(current_localization_name):
fn = localizations.get(current_localization_name, None) fn = localizations.get(current_localization_name, None)
......
...@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
# see below for register_forward_pre_hook; # see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods # useless here, and we just replace those methods
def first_stage_model_encode_wrap(self, encoder, x):
send_me_to_gpu(self, None)
return encoder(x)
def first_stage_model_decode_wrap(self, decoder, z): first_stage_model = sd_model.first_stage_model
send_me_to_gpu(self, None) first_stage_model_encode = sd_model.first_stage_model.encode
return decoder(z) first_stage_model_decode = sd_model.first_stage_model.decode
def first_stage_model_encode_wrap(x):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_encode(x)
def first_stage_model_decode_wrap(z):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then # remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU. # send the model to GPU. Then put modules back. the modules will be in CPU.
...@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
# register hooks for those the first two models # register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x) sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z) sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram: if use_medvram:
......
...@@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w ...@@ -49,7 +49,7 @@ def expand_crop_region(crop_region, processing_width, processing_height, image_w
ratio_processing = processing_width / processing_height ratio_processing = processing_width / processing_height
if ratio_crop_region > ratio_processing: if ratio_crop_region > ratio_processing:
desired_height = (x2 - x1) * ratio_processing desired_height = (x2 - x1) / ratio_processing
desired_height_diff = int(desired_height - (y2-y1)) desired_height_diff = int(desired_height - (y2-y1))
y1 -= desired_height_diff//2 y1 -= desired_height_diff//2
y2 += desired_height_diff - desired_height_diff//2 y2 += desired_height_diff - desired_height_diff//2
......
...@@ -85,6 +85,9 @@ def cleanup_models(): ...@@ -85,6 +85,9 @@ def cleanup_models():
src_path = os.path.join(root_path, "ESRGAN") src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN") dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path) move_files(src_path, dest_path)
src_path = os.path.join(models_path, "BSRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path, ".pth")
src_path = os.path.join(root_path, "gfpgan") src_path = os.path.join(root_path, "gfpgan")
dest_path = os.path.join(models_path, "GFPGAN") dest_path = os.path.join(models_path, "GFPGAN")
move_files(src_path, dest_path) move_files(src_path, dest_path)
......
...@@ -134,11 +134,7 @@ class StableDiffusionProcessing(): ...@@ -134,11 +134,7 @@ class StableDiffusionProcessing():
# Dummy zero conditioning if we're not using inpainting model. # Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call. # Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return torch.zeros( return x.new_zeros(x.shape[0], 5, 1, 1)
x.shape[0], 5, 1, 1,
dtype=x.dtype,
device=x.device
)
height = height or self.height height = height or self.height
width = width or self.width width = width or self.width
...@@ -156,11 +152,7 @@ class StableDiffusionProcessing(): ...@@ -156,11 +152,7 @@ class StableDiffusionProcessing():
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None): def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}: if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model. # Dummy zero conditioning if we're not using inpainting model.
return torch.zeros( return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
latent_image.shape[0], 5, 1, 1,
dtype=latent_image.dtype,
device=latent_image.device
)
# Handle the different mask inputs # Handle the different mask inputs
if image_mask is not None: if image_mask is not None:
...@@ -174,11 +166,11 @@ class StableDiffusionProcessing(): ...@@ -174,11 +166,11 @@ class StableDiffusionProcessing():
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask) conditioning_mask = torch.round(conditioning_mask)
else: else:
conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:]) conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
# Create another latent image, this time with a masked version of the original input. # Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter. # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
conditioning_mask = conditioning_mask.to(source_image.device) conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
conditioning_image = torch.lerp( conditioning_image = torch.lerp(
source_image, source_image,
source_image * (1.0 - conditioning_mask), source_image * (1.0 - conditioning_mask),
...@@ -199,9 +191,13 @@ class StableDiffusionProcessing(): ...@@ -199,9 +191,13 @@ class StableDiffusionProcessing():
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
pass pass
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError() raise NotImplementedError()
def close(self):
self.sd_model = None
self.sampler = None
class Processed: class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None): def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None):
...@@ -422,13 +418,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -422,13 +418,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try: try:
for k, v in p.override_settings.items(): for k, v in p.override_settings.items():
opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model, hypernet impossible
res = process_images_inner(p) res = process_images_inner(p)
finally: finally:
for k, v in stored_opts.items(): for k, v in stored_opts.items():
opts.data[k] = v setattr(opts, k, v)
return res return res
...@@ -505,6 +501,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -505,6 +501,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if len(prompts) == 0: if len(prompts) == 0:
break break
if p.scripts is not None:
p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
with devices.autocast(): with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
...@@ -517,7 +516,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -517,7 +516,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}" shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.autocast(): with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength) samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
samples_ddim = samples_ddim.to(devices.dtype_vae) samples_ddim = samples_ddim.to(devices.dtype_vae)
x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim) x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
...@@ -645,7 +644,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -645,7 +644,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr: if not self.enable_hr:
...@@ -658,9 +657,28 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -658,9 +657,28 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
"""saves image before applying hires fix, if enabled in options; takes as an arguyment either an image or batch with latent space images"""
def save_intermediate(image, index):
if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
return
if not isinstance(image, Image.Image):
image = sd_samplers.sample_to_image(image, index)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix: if opts.use_scale_latent_for_hires_fix:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
else:
image_conditioning = self.txt2img_image_conditioning(samples)
else: else:
decoded_samples = decode_first_stage(self.sd_model, samples) decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
...@@ -670,6 +688,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -670,6 +688,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
save_intermediate(image, i)
image = images.resize_image(0, image, self.width, self.height) image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0) image = np.moveaxis(image, 2, 0)
...@@ -681,14 +702,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -681,14 +702,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
shared.state.nextjob() shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
image_conditioning = self.txt2img_image_conditioning(x)
# GC now before running the next img2img to prevent running out of memory # GC now before running the next img2img to prevent running out of memory
x = None x = None
devices.torch_gc() devices.torch_gc()
...@@ -827,8 +848,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -827,8 +848,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask) self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
......
...@@ -23,16 +23,23 @@ def encode(*args): ...@@ -23,16 +23,23 @@ def encode(*args):
class RestrictedUnpickler(pickle.Unpickler): class RestrictedUnpickler(pickle.Unpickler):
extra_handler = None
def persistent_load(self, saved_id): def persistent_load(self, saved_id):
assert saved_id[0] == 'storage' assert saved_id[0] == 'storage'
return TypedStorage() return TypedStorage()
def find_class(self, module, name): def find_class(self, module, name):
if self.extra_handler is not None:
res = self.extra_handler(module, name)
if res is not None:
return res
if module == 'collections' and name == 'OrderedDict': if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name) return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
return getattr(torch._utils, name) return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']: if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
return getattr(torch, name) return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']: if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name) return getattr(torch.nn.modules.container, name)
...@@ -52,7 +59,7 @@ class RestrictedUnpickler(pickle.Unpickler): ...@@ -52,7 +59,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return set return set
# Forbid everything else. # Forbid everything else.
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden") raise Exception(f"global '{module}/{name}' is forbidden")
allowed_zip_names = ["archive/data.pkl", "archive/version"] allowed_zip_names = ["archive/data.pkl", "archive/version"]
...@@ -69,7 +76,7 @@ def check_zip_filenames(filename, names): ...@@ -69,7 +76,7 @@ def check_zip_filenames(filename, names):
raise Exception(f"bad file inside {filename}: {name}") raise Exception(f"bad file inside {filename}: {name}")
def check_pt(filename): def check_pt(filename, extra_handler):
try: try:
# new pytorch format is a zip file # new pytorch format is a zip file
...@@ -78,6 +85,7 @@ def check_pt(filename): ...@@ -78,6 +85,7 @@ def check_pt(filename):
with z.open('archive/data.pkl') as file: with z.open('archive/data.pkl') as file:
unpickler = RestrictedUnpickler(file) unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
unpickler.load() unpickler.load()
except zipfile.BadZipfile: except zipfile.BadZipfile:
...@@ -85,16 +93,42 @@ def check_pt(filename): ...@@ -85,16 +93,42 @@ def check_pt(filename):
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle # if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
with open(filename, "rb") as file: with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file) unpickler = RestrictedUnpickler(file)
unpickler.extra_handler = extra_handler
for i in range(5): for i in range(5):
unpickler.load() unpickler.load()
def load(filename, *args, **kwargs): def load(filename, *args, **kwargs):
return load_with_extra(filename, *args, **kwargs)
def load_with_extra(filename, extra_handler=None, *args, **kwargs):
"""
this functon is intended to be used by extensions that want to load models with
some extra classes in them that the usual unpickler would find suspicious.
Use the extra_handler argument to specify a function that takes module and field name as text,
and returns that field's value:
```python
def extra(module, name):
if module == 'collections' and name == 'OrderedDict':
return collections.OrderedDict
return None
safe.load_with_extra('model.pt', extra_handler=extra)
```
The alternative to this is just to use safe.unsafe_torch_load('model.pt'), which as the name implies is
definitely unsafe.
"""
from modules import shared from modules import shared
try: try:
if not shared.cmd_opts.disable_safe_unpickle: if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename) check_pt(filename, extra_handler)
except pickle.UnpicklingError: except pickle.UnpicklingError:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr) print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
......
...@@ -2,7 +2,10 @@ import sys ...@@ -2,7 +2,10 @@ import sys
import traceback import traceback
from collections import namedtuple from collections import namedtuple
import inspect import inspect
from typing import Optional
from fastapi import FastAPI
from gradio import Blocks
def report_exception(c, job): def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr) print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
...@@ -24,24 +27,50 @@ class ImageSaveParams: ...@@ -24,24 +27,50 @@ class ImageSaveParams:
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'""" """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
class CFGDenoiserParams:
def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
self.x = x
"""Latent image representation in the process of being denoised"""
self.image_cond = image_cond
"""Conditioning image"""
self.sigma = sigma
"""Current sigma noise step value"""
self.sampling_step = sampling_step
"""Current Sampling step number"""
self.total_sampling_steps = total_sampling_steps
"""Total number of sampling steps planned"""
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_model_loaded = [] callback_map = dict(
callbacks_ui_tabs = [] callbacks_app_started=[],
callbacks_ui_settings = [] callbacks_model_loaded=[],
callbacks_before_image_saved = [] callbacks_ui_tabs=[],
callbacks_image_saved = [] callbacks_ui_settings=[],
callbacks_before_image_saved=[],
callbacks_image_saved=[],
callbacks_cfg_denoiser=[]
)
def clear_callbacks(): def clear_callbacks():
callbacks_model_loaded.clear() for callback_list in callback_map.values():
callbacks_ui_tabs.clear() callback_list.clear()
callbacks_ui_settings.clear()
callbacks_before_image_saved.clear() def app_started_callback(demo: Optional[Blocks], app: FastAPI):
callbacks_image_saved.clear() for c in callback_map['callbacks_app_started']:
try:
c.callback(demo, app)
except Exception:
report_exception(c, 'app_started_callback')
def model_loaded_callback(sd_model): def model_loaded_callback(sd_model):
for c in callbacks_model_loaded: for c in callback_map['callbacks_model_loaded']:
try: try:
c.callback(sd_model) c.callback(sd_model)
except Exception: except Exception:
...@@ -51,7 +80,7 @@ def model_loaded_callback(sd_model): ...@@ -51,7 +80,7 @@ def model_loaded_callback(sd_model):
def ui_tabs_callback(): def ui_tabs_callback():
res = [] res = []
for c in callbacks_ui_tabs: for c in callback_map['callbacks_ui_tabs']:
try: try:
res += c.callback() or [] res += c.callback() or []
except Exception: except Exception:
...@@ -61,7 +90,7 @@ def ui_tabs_callback(): ...@@ -61,7 +90,7 @@ def ui_tabs_callback():
def ui_settings_callback(): def ui_settings_callback():
for c in callbacks_ui_settings: for c in callback_map['callbacks_ui_settings']:
try: try:
c.callback() c.callback()
except Exception: except Exception:
...@@ -69,7 +98,7 @@ def ui_settings_callback(): ...@@ -69,7 +98,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams): def before_image_saved_callback(params: ImageSaveParams):
for c in callbacks_image_saved: for c in callback_map['callbacks_before_image_saved']:
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
...@@ -77,13 +106,21 @@ def before_image_saved_callback(params: ImageSaveParams): ...@@ -77,13 +106,21 @@ def before_image_saved_callback(params: ImageSaveParams):
def image_saved_callback(params: ImageSaveParams): def image_saved_callback(params: ImageSaveParams):
for c in callbacks_image_saved: for c in callback_map['callbacks_image_saved']:
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
report_exception(c, 'image_saved_callback') report_exception(c, 'image_saved_callback')
def cfg_denoiser_callback(params: CFGDenoiserParams):
for c in callback_map['callbacks_cfg_denoiser']:
try:
c.callback(params)
except Exception:
report_exception(c, 'cfg_denoiser_callback')
def add_callback(callbacks, fun): def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__] stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file' filename = stack[0].filename if len(stack) > 0 else 'unknown file'
...@@ -91,10 +128,32 @@ def add_callback(callbacks, fun): ...@@ -91,10 +128,32 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun)) callbacks.append(ScriptCallback(filename, fun))
def remove_current_script_callbacks():
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
if filename == 'unknown file':
return
for callback_list in callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.script == filename]:
callback_list.remove(callback_to_remove)
def remove_callbacks_for_function(callback_func):
for callback_list in callback_map.values():
for callback_to_remove in [cb for cb in callback_list if cb.callback == callback_func]:
callback_list.remove(callback_to_remove)
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
add_callback(callback_map['callbacks_app_started'], callback)
def on_model_loaded(callback): def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is """register a function to be called when the stable diffusion model is created; the model is
passed as an argument""" passed as an argument"""
add_callback(callbacks_model_loaded, callback) add_callback(callback_map['callbacks_model_loaded'], callback)
def on_ui_tabs(callback): def on_ui_tabs(callback):
...@@ -107,13 +166,13 @@ def on_ui_tabs(callback): ...@@ -107,13 +166,13 @@ def on_ui_tabs(callback):
title is tab text displayed to user in the UI title is tab text displayed to user in the UI
elem_id is HTML id for the tab elem_id is HTML id for the tab
""" """
add_callback(callbacks_ui_tabs, callback) add_callback(callback_map['callbacks_ui_tabs'], callback)
def on_ui_settings(callback): def on_ui_settings(callback):
"""register a function to be called before UI settings are populated; add your settings """register a function to be called before UI settings are populated; add your settings
by using shared.opts.add_option(shared.OptionInfo(...)) """ by using shared.opts.add_option(shared.OptionInfo(...)) """
add_callback(callbacks_ui_settings, callback) add_callback(callback_map['callbacks_ui_settings'], callback)
def on_before_image_saved(callback): def on_before_image_saved(callback):
...@@ -121,7 +180,7 @@ def on_before_image_saved(callback): ...@@ -121,7 +180,7 @@ def on_before_image_saved(callback):
The callback is called with one argument: The callback is called with one argument:
- params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object. - params: ImageSaveParams - parameters the image is to be saved with. You can change fields in this object.
""" """
add_callback(callbacks_before_image_saved, callback) add_callback(callback_map['callbacks_before_image_saved'], callback)
def on_image_saved(callback): def on_image_saved(callback):
...@@ -129,4 +188,12 @@ def on_image_saved(callback): ...@@ -129,4 +188,12 @@ def on_image_saved(callback):
The callback is called with one argument: The callback is called with one argument:
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
""" """
add_callback(callbacks_image_saved, callback) add_callback(callback_map['callbacks_image_saved'], callback)
def on_cfg_denoiser(callback):
"""register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
The callback is called with one argument:
- params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
"""
add_callback(callback_map['callbacks_cfg_denoiser'], callback)
...@@ -3,7 +3,6 @@ import sys ...@@ -3,7 +3,6 @@ import sys
import traceback import traceback
from collections import namedtuple from collections import namedtuple
import modules.ui as ui
import gradio as gr import gradio as gr
from modules.processing import StableDiffusionProcessing from modules.processing import StableDiffusionProcessing
...@@ -18,6 +17,9 @@ class Script: ...@@ -18,6 +17,9 @@ class Script:
args_to = None args_to = None
alwayson = False alwayson = False
"""A gr.Group component that has all script's UI inside it"""
group = None
infotext_fields = None infotext_fields = None
"""if set in ui(), this is a list of pairs of gradio component + text; the text will be used when """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when
parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
...@@ -70,6 +72,19 @@ class Script: ...@@ -70,6 +72,19 @@ class Script:
pass pass
def process_batch(self, p, *args, **kwargs):
"""
Same as process(), but called for every batch.
**kwargs will have those items:
- batch_number - index of current batch, from 0 to number of batches-1
- prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things
- seeds - list of seeds for current batch
- subseeds - list of subseeds for current batch
"""
pass
def postprocess(self, p, processed, *args): def postprocess(self, p, processed, *args):
""" """
This function is called after processing ends for AlwaysVisible scripts. This function is called after processing ends for AlwaysVisible scripts.
...@@ -218,8 +233,6 @@ class ScriptRunner: ...@@ -218,8 +233,6 @@ class ScriptRunner:
for control in controls: for control in controls:
control.custom_script_source = os.path.basename(script.filename) control.custom_script_source = os.path.basename(script.filename)
if not script.alwayson:
control.visible = False
if script.infotext_fields is not None: if script.infotext_fields is not None:
self.infotext_fields += script.infotext_fields self.infotext_fields += script.infotext_fields
...@@ -229,40 +242,41 @@ class ScriptRunner: ...@@ -229,40 +242,41 @@ class ScriptRunner:
script.args_to = len(inputs) script.args_to = len(inputs)
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
with gr.Group(): with gr.Group() as group:
create_script_ui(script, inputs, inputs_alwayson) create_script_ui(script, inputs, inputs_alwayson)
script.group = group
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True dropdown.save_to_config = True
inputs[0] = dropdown inputs[0] = dropdown
for script in self.selectable_scripts: for script in self.selectable_scripts:
with gr.Group(visible=False) as group:
create_script_ui(script, inputs, inputs_alwayson) create_script_ui(script, inputs, inputs_alwayson)
script.group = group
def select_script(script_index): def select_script(script_index):
if 0 < script_index <= len(self.selectable_scripts): selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None
script = self.selectable_scripts[script_index-1]
args_from = script.args_from
args_to = script.args_to
else:
args_from = 0
args_to = 0
return [ui.gr_show(True if i == 0 else args_from <= i < args_to or is_alwayson) for i, is_alwayson in enumerate(inputs_alwayson)] return [gr.update(visible=selected_script == s) for s in self.selectable_scripts]
def init_field(title): def init_field(title):
"""called when an initial value is set from ui-config.json to show script's UI components"""
if title == 'None': if title == 'None':
return return
script_index = self.titles.index(title) script_index = self.titles.index(title)
script = self.selectable_scripts[script_index] self.selectable_scripts[script_index].group.visible = True
for i in range(script.args_from, script.args_to):
inputs[i].visible = True
dropdown.init_field = init_field dropdown.init_field = init_field
dropdown.change( dropdown.change(
fn=select_script, fn=select_script,
inputs=[dropdown], inputs=[dropdown],
outputs=inputs outputs=[script.group for script in self.selectable_scripts]
) )
return inputs return inputs
...@@ -294,6 +308,15 @@ class ScriptRunner: ...@@ -294,6 +308,15 @@ class ScriptRunner:
print(f"Error running process: {script.filename}", file=sys.stderr) print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
def process_batch(self, p, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.process_batch(p, *script_args, **kwargs)
except Exception:
print(f"Error running process_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def postprocess(self, p, processed): def postprocess(self, p, processed):
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
try: try:
......
...@@ -94,6 +94,10 @@ class StableDiffusionModelHijack: ...@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
self.layers = None
self.circular_enabled = False
self.clip = None
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
return return
......
import collections import collections
import os.path import os.path
import sys import sys
import gc
from collections import namedtuple from collections import namedtuple
import torch import torch
import re import re
...@@ -8,7 +9,7 @@ from omegaconf import OmegaConf ...@@ -8,7 +9,7 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices, script_callbacks from modules import shared, modelloader, devices, script_callbacks, sd_vae
from modules.paths import models_path from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
...@@ -158,13 +159,16 @@ def get_state_dict_from_checkpoint(pl_sd): ...@@ -158,13 +159,16 @@ def get_state_dict_from_checkpoint(pl_sd):
return pl_sd return pl_sd
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} def load_model_weights(model, checkpoint_info, vae_file="auto"):
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash sd_model_hash = checkpoint_info.hash
if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"):
sd_vae.restore_base_vae(model)
checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy()
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
if checkpoint_info not in checkpoints_loaded: if checkpoint_info not in checkpoints_loaded:
print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
...@@ -181,37 +185,36 @@ def load_model_weights(model, checkpoint_info): ...@@ -181,37 +185,36 @@ def load_model_weights(model, checkpoint_info):
model.to(memory_format=torch.channels_last) model.to(memory_format=torch.channels_last)
if not shared.cmd_opts.no_half: if not shared.cmd_opts.no_half:
vae = model.first_stage_model
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
model.half() model.half()
model.first_stage_model = vae
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
vae_file = os.path.splitext(checkpoint_file)[0] + ".vae.pt"
if not os.path.exists(vae_file) and shared.cmd_opts.vae_path is not None:
vae_file = shared.cmd_opts.vae_path
if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
else:
vae_name = sd_vae.get_filename(vae_file) if vae_file else None
vae_message = f" with {vae_name} VAE" if vae_name else ""
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info])
if shared.opts.sd_checkpoint_cache > 0: if shared.opts.sd_checkpoint_cache > 0:
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU checkpoints_loaded.popitem(last=False) # LRU
else:
print(f"Loading weights [{sd_model_hash}] from cache")
checkpoints_loaded.move_to_end(checkpoint_info)
model.load_state_dict(checkpoints_loaded[checkpoint_info])
model.sd_model_hash = sd_model_hash model.sd_model_hash = sd_model_hash
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
sd_vae.load_vae(model, vae_file)
def load_model(checkpoint_info=None): def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
...@@ -220,6 +223,12 @@ def load_model(checkpoint_info=None): ...@@ -220,6 +223,12 @@ def load_model(checkpoint_info=None):
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}") print(f"Loading config from: {checkpoint_info.config}")
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
gc.collect()
devices.torch_gc()
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info): if should_hijack_inpainting(checkpoint_info):
...@@ -233,6 +242,7 @@ def load_model(checkpoint_info=None): ...@@ -233,6 +242,7 @@ def load_model(checkpoint_info=None):
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack() do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
...@@ -252,14 +262,18 @@ def load_model(checkpoint_info=None): ...@@ -252,14 +262,18 @@ def load_model(checkpoint_info=None):
return sd_model return sd_model
def reload_model_weights(sd_model, info=None): def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if not sd_model:
sd_model = shared.sd_model
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info) load_model(checkpoint_info)
return shared.sd_model return shared.sd_model
......
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
from math import floor
import torch import torch
import tqdm import tqdm
from PIL import Image from PIL import Image
...@@ -11,6 +12,7 @@ from modules import prompt_parser, devices, processing, images ...@@ -11,6 +12,7 @@ from modules import prompt_parser, devices, processing, images
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
...@@ -22,11 +24,15 @@ samplers_k_diffusion = [ ...@@ -22,11 +24,15 @@ samplers_k_diffusion = [
('Heun', 'sample_heun', ['k_heun'], {}), ('Heun', 'sample_heun', ['k_heun'], {}),
('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}), ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {}),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}), ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {}),
('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {}),
('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {}),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {}),
('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}),
('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}), ('DPM2 Karras', 'sample_dpm_2', ['k_dpm_2_ka'], {'scheduler': 'karras'}),
('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}), ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras'}),
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
] ]
samplers_data_k_diffusion = [ samplers_data_k_diffusion = [
...@@ -91,8 +97,8 @@ def single_sample_to_image(sample): ...@@ -91,8 +97,8 @@ def single_sample_to_image(sample):
return Image.fromarray(x_sample) return Image.fromarray(x_sample)
def sample_to_image(samples): def sample_to_image(samples, index=0):
return single_sample_to_image(samples[0]) return single_sample_to_image(samples[index])
def samples_to_image_grid(samples): def samples_to_image_grid(samples):
...@@ -205,17 +211,22 @@ class VanillaStableDiffusionSampler: ...@@ -205,17 +211,22 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == floor(valid_step):
return int(valid_step) + 1
return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps) steps, t_enc = setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p) self.initialize(p)
# existing code fails with certain step counts, like 9
try:
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception:
self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x self.init_latent = x
...@@ -239,18 +250,14 @@ class VanillaStableDiffusionSampler: ...@@ -239,18 +250,14 @@ class VanillaStableDiffusionSampler:
self.last_latent = x self.last_latent = x
self.step = 0 self.step = 0
steps = steps or p.steps steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model # Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None: if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
# existing code fails with certain step counts, like 9
try:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
except Exception:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim return samples_ddim
...@@ -278,6 +285,12 @@ class CFGDenoiser(torch.nn.Module): ...@@ -278,6 +285,12 @@ class CFGDenoiser(torch.nn.Module):
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params)
x_in = denoiser_params.x
image_cond_in = denoiser_params.image_cond
sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]: if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond]) cond_in = torch.cat([tensor, uncond])
......
import torch
import os
from collections import namedtuple
from modules import shared, devices, script_callbacks
from modules.paths import models_path
import glob
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
vae_dir = "VAE"
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
default_vae_dict = {"auto": "auto", "None": "None"}
default_vae_list = ["auto", "None"]
default_vae_values = [default_vae_dict[x] for x in default_vae_list]
vae_dict = dict(default_vae_dict)
vae_list = list(default_vae_list)
first_load = True
base_vae = None
loaded_vae_file = None
checkpoint_info = None
def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
return base_vae
return None
def store_base_vae(model):
global base_vae, checkpoint_info
if checkpoint_info != model.sd_checkpoint_info:
base_vae = model.first_stage_model.state_dict().copy()
checkpoint_info = model.sd_checkpoint_info
def delete_base_vae():
global base_vae, checkpoint_info
base_vae = None
checkpoint_info = None
def restore_base_vae(model):
global base_vae, checkpoint_info
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
load_vae_dict(model, base_vae)
delete_base_vae()
def get_filename(filepath):
return os.path.splitext(os.path.basename(filepath))[0]
def refresh_vae_list(vae_path=vae_path, model_path=model_path):
global vae_dict, vae_list
res = {}
candidates = [
*glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
*glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
*glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True)
]
if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
candidates.append(shared.cmd_opts.vae_path)
for filepath in candidates:
name = get_filename(filepath)
res[name] = filepath
vae_list.clear()
vae_list.extend(default_vae_list)
vae_list.extend(list(res.keys()))
vae_dict.clear()
vae_dict.update(res)
vae_dict.update(default_vae_dict)
return vae_list
def resolve_vae(checkpoint_file, vae_file="auto"):
global first_load, vae_dict, vae_list
# if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file):
vae_file = "auto"
print("VAE provided as function argument doesn't exist")
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file)
else:
print("VAE provided as command line argument doesn't exist")
# else, we load from settings
if vae_file == "auto" and shared.opts.sd_vae is not None:
# if saved VAE settings isn't recognized, fallback to auto
vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
# if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto"
print("Selected VAE doesn't exist")
# vae-path cmd arg takes priority for auto
if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path
print("Using VAE provided as command line argument")
# if still not found, try look for ".vae.pt" beside model
model_path = os.path.splitext(checkpoint_file)[0]
if vae_file == "auto":
vae_file_try = model_path + ".vae.pt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
# if still not found, try look for ".vae.ckpt" beside model
if vae_file == "auto":
vae_file_try = model_path + ".vae.ckpt"
if os.path.isfile(vae_file_try):
vae_file = vae_file_try
print("Using VAE found beside selected model")
# No more fallbacks for auto
if vae_file == "auto":
vae_file = None
# Last check, just because
if vae_file and not os.path.exists(vae_file):
vae_file = None
return vae_file
def load_vae(model, vae_file=None):
global first_load, vae_dict, vae_list, loaded_vae_file
# save_settings = False
if vae_file:
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
load_vae_dict(model, vae_dict_1)
# If vae used is not in dict, update it
# It will be removed on refresh though
vae_opt = get_filename(vae_file)
if vae_opt not in vae_dict:
vae_dict[vae_opt] = vae_file
vae_list.append(vae_opt)
loaded_vae_file = vae_file
"""
# Save current VAE to VAE settings, maybe? will it work?
if save_settings:
if vae_file is None:
vae_opt = "None"
# shared.opts.sd_vae = vae_opt
"""
first_load = False
# don't call this from outside
def load_vae_dict(model, vae_dict_1=None):
if vae_dict_1:
store_base_vae(model)
model.first_stage_model.load_state_dict(vae_dict_1)
else:
restore_base_vae()
model.first_stage_model.to(devices.dtype_vae)
def reload_vae_weights(sd_model=None, vae_file="auto"):
from modules import lowvram, devices, sd_hijack
if not sd_model:
sd_model = shared.sd_model
checkpoint_info = sd_model.sd_checkpoint_info
checkpoint_file = checkpoint_info.filename
vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)
if loaded_vae_file == vae_file:
return
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu()
else:
sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_vae(sd_model, vae_file)
sd_hijack.model_hijack.hijack(sd_model)
script_callbacks.model_loaded_callback(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
print(f"VAE Weights loaded.")
return sd_model
This diff is collapsed.
...@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
unload = shared.opts.unload_models_when_training
if save_embedding_every > 0: if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings") embedding_dir = os.path.join(log_directory, "embeddings")
...@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
...@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if images_dir is not None and steps_done % create_image_every == 0: if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}' forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename) last_saved_image = os.path.join(images_dir, forced_filename)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
do_not_save_grid=True, do_not_save_grid=True,
...@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] image = processed.images[0]
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
shared.state.current_image = image shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
...@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename return embedding, filename
......
...@@ -25,7 +25,9 @@ def train_embedding(*args): ...@@ -25,7 +25,9 @@ def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
apply_optimizations = shared.opts.training_xattention_optimizations
try: try:
if not apply_optimizations:
sd_hijack.undo_optimizations() sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
...@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)} ...@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
except Exception: except Exception:
raise raise
finally: finally:
if not apply_optimizations:
sd_hijack.apply_optimizations() sd_hijack.apply_optimizations()
...@@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -47,6 +47,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
if processed is None: if processed is None:
processed = process_images(p) processed = process_images(p)
p.close()
shared.total_tqdm.clear() shared.total_tqdm.clear()
generation_info_js = processed.js() generation_info_js = processed.js()
......
...@@ -174,9 +174,9 @@ def save_pil_to_file(pil_image, dir=None): ...@@ -174,9 +174,9 @@ def save_pil_to_file(pil_image, dir=None):
gr.processing_utils.save_pil_to_file = save_pil_to_file gr.processing_utils.save_pil_to_file = save_pil_to_file
def wrap_gradio_call(func, extra_outputs=None): def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
def f(*args, extra_outputs_array=extra_outputs, **kwargs): def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats
if run_memmon: if run_memmon:
shared.mem_mon.monitor() shared.mem_mon.monitor()
t = time.perf_counter() t = time.perf_counter()
...@@ -203,11 +203,18 @@ def wrap_gradio_call(func, extra_outputs=None): ...@@ -203,11 +203,18 @@ def wrap_gradio_call(func, extra_outputs=None):
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"] res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
if not add_stats:
return tuple(res)
elapsed = time.perf_counter() - t elapsed = time.perf_counter() - t
elapsed_m = int(elapsed // 60) elapsed_m = int(elapsed // 60)
elapsed_s = elapsed % 60 elapsed_s = elapsed % 60
elapsed_text = f"{elapsed_s:.2f}s" elapsed_text = f"{elapsed_s:.2f}s"
if (elapsed_m > 0): if elapsed_m > 0:
elapsed_text = f"{elapsed_m}m "+elapsed_text elapsed_text = f"{elapsed_m}m "+elapsed_text
if run_memmon: if run_memmon:
...@@ -225,10 +232,6 @@ def wrap_gradio_call(func, extra_outputs=None): ...@@ -225,10 +232,6 @@ def wrap_gradio_call(func, extra_outputs=None):
# last item is always HTML # last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>" res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
return tuple(res) return tuple(res)
return f return f
...@@ -276,16 +279,8 @@ def check_progress_call(id_part): ...@@ -276,16 +279,8 @@ def check_progress_call(id_part):
image = gr_show(False) image = gr_show(False)
preview_visibility = gr_show(False) preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0: if opts.show_progress_every_n_steps != 0:
if shared.parallel_processing_allowed: shared.state.set_current_image()
if shared.state.sampling_step - shared.state.current_image_sampling_step >= opts.show_progress_every_n_steps and shared.state.current_latent is not None:
if opts.show_progress_grid:
shared.state.current_image = modules.sd_samplers.samples_to_image_grid(shared.state.current_latent)
else:
shared.state.current_image = modules.sd_samplers.sample_to_image(shared.state.current_latent)
shared.state.current_image_sampling_step = shared.state.sampling_step
image = shared.state.current_image image = shared.state.current_image
if image is None: if image is None:
...@@ -671,6 +666,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -671,6 +666,8 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img import modules.img2img
import modules.txt2img import modules.txt2img
reload_javascript()
parameters_copypaste.reset() parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
...@@ -1058,9 +1055,11 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1058,9 +1055,11 @@ def create_ui(wrap_gradio_gpu_call):
extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.") extras_batch_output_dir = gr.Textbox(label="Output directory", **shared.hide_dirs, placeholder="Leave blank to save images to the default path.")
show_extras_results = gr.Checkbox(label='Show result images', value=True) show_extras_results = gr.Checkbox(label='Show result images', value=True)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Tabs(elem_id="extras_resize_mode"): with gr.Tabs(elem_id="extras_resize_mode"):
with gr.TabItem('Scale by'): with gr.TabItem('Scale by'):
upscaling_resize = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Resize", value=2) upscaling_resize = gr.Slider(minimum=1.0, maximum=8.0, step=0.05, label="Resize", value=4)
with gr.TabItem('Scale to'): with gr.TabItem('Scale to'):
with gr.Group(): with gr.Group():
with gr.Row(): with gr.Row():
...@@ -1085,8 +1084,6 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1085,8 +1084,6 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Group(): with gr.Group():
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False) upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples) result_images, html_info_x, html_info = create_output_panel("extras", opts.outdir_extras_samples)
submit.click( submit.click(
...@@ -1188,8 +1185,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1188,8 +1185,8 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork", choices=modules.hypernetworks.ui.keys) new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys)
new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"]) new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
...@@ -1442,50 +1439,41 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1442,50 +1439,41 @@ def create_ui(wrap_gradio_gpu_call):
opts.reorder() opts.reorder()
def run_settings(*args): def run_settings(*args):
changed = 0 changed = []
assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
for key, value, comp in zip(opts.data_labels.keys(), args, components): for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp != dummy_component and not opts.same_type(value, opts.data_labels[key].default): assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}"
return f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}", opts.dumpjson()
for key, value, comp in zip(opts.data_labels.keys(), args, components): for key, value, comp in zip(opts.data_labels.keys(), args, components):
if comp == dummy_component: if comp == dummy_component:
continue continue
comp_args = opts.data_labels[key].component_args
if comp_args and isinstance(comp_args, dict) and comp_args.get('visible') is False:
continue
if cmd_opts.hide_ui_dir_config and key in restricted_opts:
continue
oldval = opts.data.get(key, None) oldval = opts.data.get(key, None)
opts.data[key] = value try:
setattr(opts, key, value)
except RuntimeError:
continue
if oldval != value: if oldval != value:
if opts.data_labels[key].onchange is not None: if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange() opts.data_labels[key].onchange()
changed += 1 changed.append(key)
try:
opts.save(shared.config_filename) opts.save(shared.config_filename)
except RuntimeError:
return f'{changed} settings changed.', opts.dumpjson() return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.'
return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.'
def run_settings_single(value, key): def run_settings_single(value, key):
assert not shared.cmd_opts.freeze_settings, "changing settings is disabled"
if not opts.same_type(value, opts.data_labels[key].default): if not opts.same_type(value, opts.data_labels[key].default):
return gr.update(visible=True), opts.dumpjson() return gr.update(visible=True), opts.dumpjson()
oldval = opts.data.get(key, None) oldval = opts.data.get(key, None)
if cmd_opts.hide_ui_dir_config and key in restricted_opts: try:
setattr(opts, key, value)
except Exception:
return gr.update(value=oldval), opts.dumpjson() return gr.update(value=oldval), opts.dumpjson()
opts.data[key] = value
if oldval != value: if oldval != value:
if opts.data_labels[key].onchange is not None: if opts.data_labels[key].onchange is not None:
opts.data_labels[key].onchange() opts.data_labels[key].onchange()
...@@ -1570,8 +1558,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1570,8 +1558,7 @@ def create_ui(wrap_gradio_gpu_call):
reload_script_bodies.click( reload_script_bodies.click(
fn=reload_scripts, fn=reload_scripts,
inputs=[], inputs=[],
outputs=[], outputs=[]
_js='function(){}'
) )
def request_restart(): def request_restart():
...@@ -1579,11 +1566,10 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1579,11 +1566,10 @@ def create_ui(wrap_gradio_gpu_call):
shared.state.need_restart = True shared.state.need_restart = True
restart_gradio.click( restart_gradio.click(
fn=request_restart, fn=request_restart,
_js='restart_reload',
inputs=[], inputs=[],
outputs=[], outputs=[],
_js='function(){restart_reload()}'
) )
if column is not None: if column is not None:
...@@ -1639,9 +1625,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1639,9 +1625,9 @@ def create_ui(wrap_gradio_gpu_call):
text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)
settings_submit.click( settings_submit.click(
fn=run_settings, fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]),
inputs=components, inputs=components,
outputs=[result, text_settings], outputs=[text_settings, result],
) )
for i, k, item in quicksettings_list: for i, k, item in quicksettings_list:
...@@ -1782,4 +1768,3 @@ def load_javascript(raw_response): ...@@ -1782,4 +1768,3 @@ def load_javascript(raw_response):
reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse)
reload_javascript()
...@@ -13,6 +13,9 @@ import html ...@@ -13,6 +13,9 @@ import html
from modules import extensions, shared, paths from modules import extensions, shared, paths
available_extensions = {"extensions": []}
def check_access(): def check_access():
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags" assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
...@@ -83,7 +86,7 @@ def extension_table(): ...@@ -83,7 +86,7 @@ def extension_table():
code += f""" code += f"""
<tr> <tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td> <td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td> <td><a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape(ext.remote or '')}</a></td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td> <td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr> </tr>
""" """
...@@ -96,6 +99,14 @@ def extension_table(): ...@@ -96,6 +99,14 @@ def extension_table():
return code return code
def normalize_git_url(url):
if url is None:
return ""
url = url.replace(".git", "")
return url
def install_extension_from_url(dirname, url): def install_extension_from_url(dirname, url):
check_access() check_access()
...@@ -103,14 +114,15 @@ def install_extension_from_url(dirname, url): ...@@ -103,14 +114,15 @@ def install_extension_from_url(dirname, url):
if dirname is None or dirname == "": if dirname is None or dirname == "":
*parts, last_part = url.split('/') *parts, last_part = url.split('/')
last_part = last_part.replace(".git", "") last_part = normalize_git_url(last_part)
dirname = last_part dirname = last_part
target_dir = os.path.join(extensions.extensions_dir, dirname) target_dir = os.path.join(extensions.extensions_dir, dirname)
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed' normalized_url = normalize_git_url(url)
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
tmpdir = os.path.join(paths.script_path, "tmp", dirname) tmpdir = os.path.join(paths.script_path, "tmp", dirname)
...@@ -128,18 +140,104 @@ def install_extension_from_url(dirname, url): ...@@ -128,18 +140,104 @@ def install_extension_from_url(dirname, url):
shutil.rmtree(tmpdir, True) shutil.rmtree(tmpdir, True)
def install_extension_from_index(url, hide_tags):
ext_table, message = install_extension_from_url(None, url)
code, _ = refresh_available_extensions_from_data(hide_tags)
return code, ext_table, message
def refresh_available_extensions(url, hide_tags):
global available_extensions
import urllib.request
with urllib.request.urlopen(url) as response:
text = response.read()
available_extensions = json.loads(text)
code, tags = refresh_available_extensions_from_data(hide_tags)
return url, code, gr.CheckboxGroup.update(choices=tags), ''
def refresh_available_extensions_for_tags(hide_tags):
code, _ = refresh_available_extensions_from_data(hide_tags)
return code, ''
def refresh_available_extensions_from_data(hide_tags):
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
tags = available_extensions.get("tags", {})
tags_to_hide = set(hide_tags)
hidden = 0
code = f"""<!-- {time.time()} -->
<table id="available_extensions">
<thead>
<tr>
<th>Extension</th>
<th>Description</th>
<th>Action</th>
</tr>
</thead>
<tbody>
"""
for ext in extlist:
name = ext.get("name", "noname")
url = ext.get("url", None)
description = ext.get("description", "")
extension_tags = ext.get("tags", [])
if url is None:
continue
if len([x for x in extension_tags if x in tags_to_hide]) > 0:
hidden += 1
continue
existing = installed_extension_urls.get(normalize_git_url(url), None)
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
tags_text = ", ".join([f"<span class='extension-tag' title='{tags.get(x, '')}'>{x}</span>" for x in extension_tags])
code += f"""
<tr>
<td><a href="{html.escape(url)}" target="_blank">{html.escape(name)}</a><br />{tags_text}</td>
<td>{html.escape(description)}</td>
<td>{install_code}</td>
</tr>
"""
code += """
</tbody>
</table>
"""
if hidden > 0:
code += f"<p>Extension hidden: {hidden}</p>"
return code, list(tags)
def create_ui(): def create_ui():
import modules.ui import modules.ui
with gr.Blocks(analytics_enabled=False) as ui: with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions") as tabs: with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"): with gr.TabItem("Installed"):
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False)
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False)
with gr.Row(): with gr.Row():
apply = gr.Button(value="Apply and restart UI", variant="primary") apply = gr.Button(value="Apply and restart UI", variant="primary")
check = gr.Button(value="Check for updates") check = gr.Button(value="Check for updates")
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
extensions_table = gr.HTML(lambda: extension_table()) extensions_table = gr.HTML(lambda: extension_table())
...@@ -157,16 +255,47 @@ def create_ui(): ...@@ -157,16 +255,47 @@ def create_ui():
outputs=[extensions_table], outputs=[extensions_table],
) )
with gr.TabItem("Available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
with gr.Row():
hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"])
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]),
inputs=[available_extensions_index, hide_tags],
outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
inputs=[extension_to_install, hide_tags],
outputs=[available_extensions_table, extensions_table, install_result],
)
hide_tags.change(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]),
inputs=[hide_tags],
outputs=[available_extensions_table, install_result]
)
with gr.TabItem("Install from URL"): with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository") install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
intall_button = gr.Button(value="Install", variant="primary") install_button = gr.Button(value="Install", variant="primary")
intall_result = gr.HTML(elem_id="extension_install_result") install_result = gr.HTML(elem_id="extension_install_result")
intall_button.click( install_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]), fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
inputs=[install_dirname, install_url], inputs=[install_dirname, install_url],
outputs=[extensions_table, intall_result], outputs=[extensions_table, install_result],
) )
return ui return ui
...@@ -10,6 +10,7 @@ import modules.shared ...@@ -10,6 +10,7 @@ import modules.shared
from modules import modelloader, shared from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
NEAREST = (Image.Resampling.NEAREST if hasattr(Image, 'Resampling') else Image.NEAREST)
from modules.paths import models_path from modules.paths import models_path
...@@ -56,10 +57,18 @@ class Upscaler: ...@@ -56,10 +57,18 @@ class Upscaler:
self.scale = scale self.scale = scale
dest_w = img.width * scale dest_w = img.width * scale
dest_h = img.height * scale dest_h = img.height * scale
for i in range(3): for i in range(3):
shape = (img.width, img.height)
img = self.do_upscale(img, selected_model)
if shape == (img.width, img.height):
break
if img.width >= dest_w and img.height >= dest_h: if img.width >= dest_w and img.height >= dest_h:
break break
img = self.do_upscale(img, selected_model)
if img.width != dest_w or img.height != dest_h: if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS) img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
...@@ -120,3 +129,17 @@ class UpscalerLanczos(Upscaler): ...@@ -120,3 +129,17 @@ class UpscalerLanczos(Upscaler):
self.name = "Lanczos" self.name = "Lanczos"
self.scalers = [UpscalerData("Lanczos", None, self)] self.scalers = [UpscalerData("Lanczos", None, self)]
class UpscalerNearest(Upscaler):
scalers = []
def do_upscale(self, img, selected_model=None):
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=NEAREST)
def load_model(self, _):
pass
def __init__(self, dirname=None):
super().__init__(False)
self.name = "Nearest"
self.scalers = [UpscalerData("Nearest", None, self)]
\ No newline at end of file
...@@ -4,7 +4,7 @@ fairscale==0.4.4 ...@@ -4,7 +4,7 @@ fairscale==0.4.4
fonts fonts
font-roboto font-roboto
gfpgan gfpgan
gradio==3.5 gradio==3.8
invisible-watermark invisible-watermark
numpy numpy
omegaconf omegaconf
...@@ -12,7 +12,7 @@ opencv-python ...@@ -12,7 +12,7 @@ opencv-python
requests requests
piexif piexif
Pillow Pillow
pytorch_lightning pytorch_lightning==1.7.7
realesrgan realesrgan
scikit-image>=0.19 scikit-image>=0.19
timm==0.4.12 timm==0.4.12
......
...@@ -2,7 +2,7 @@ transformers==4.19.2 ...@@ -2,7 +2,7 @@ transformers==4.19.2
diffusers==0.3.0 diffusers==0.3.0
basicsr==1.4.2 basicsr==1.4.2
gfpgan==1.3.8 gfpgan==1.3.8
gradio==3.5 gradio==3.8
numpy==1.23.3 numpy==1.23.3
Pillow==9.2.0 Pillow==9.2.0
realesrgan==0.3.0 realesrgan==0.3.0
......
...@@ -14,7 +14,7 @@ class Script(scripts.Script): ...@@ -14,7 +14,7 @@ class Script(scripts.Script):
return cmd_opts.allow_code return cmd_opts.allow_code
def ui(self, is_img2img): def ui(self, is_img2img):
code = gr.Textbox(label="Python code", visible=False, lines=1) code = gr.Textbox(label="Python code", lines=1)
return [code] return [code]
......
...@@ -166,8 +166,7 @@ class Script(scripts.Script): ...@@ -166,8 +166,7 @@ class Script(scripts.Script):
if override_strength: if override_strength:
p.denoising_strength = 1.0 p.denoising_strength = 1.0
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
lat = (p.init_latent.cpu().numpy() * 10).astype(int) lat = (p.init_latent.cpu().numpy() * 10).astype(int)
same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \ same_params = self.cache is not None and self.cache.cfg_scale == cfg and self.cache.steps == st \
......
...@@ -132,7 +132,7 @@ class Script(scripts.Script): ...@@ -132,7 +132,7 @@ class Script(scripts.Script):
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8</p>") info = gr.HTML("<p style=\"margin-bottom:0.75em\">Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8</p>")
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, visible=False) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8)
direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'])
noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0) noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0)
color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05) color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05)
......
...@@ -22,8 +22,8 @@ class Script(scripts.Script): ...@@ -22,8 +22,8 @@ class Script(scripts.Script):
return None return None
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False) mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4)
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False) inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index")
direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'])
return [pixels, mask_blur, inpainting_fill, direction] return [pixels, mask_blur, inpainting_fill, direction]
......
This diff is collapsed.
...@@ -18,8 +18,8 @@ class Script(scripts.Script): ...@@ -18,8 +18,8 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image to twice the dimensions; use width and height sliders to set tile size</p>") info = gr.HTML("<p style=\"margin-bottom:0.75em\">Will upscale the image to twice the dimensions; use width and height sliders to set tile size</p>")
overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False) overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64)
upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", visible=False) upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index")
return [info, overlap, upscaler_index] return [info, overlap, upscaler_index]
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment