Commit 51e0a839 authored by evshiron's avatar evshiron

Merge branch 'master' into fix/progress-api

parents 1a4ff2de 55688c48
function extensions_apply(_, _){
disable = []
update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7))
if(x.name.startsWith("update_") && x.checked)
update.push(x.name.substr(7))
})
restart_reload()
return [JSON.stringify(disable), JSON.stringify(update)]
}
function extensions_check(){
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
x.innerHTML = "Loading..."
})
return []
}
function install_extension_from_index(button, url){
button.disabled = "disabled"
button.value = "Installing..."
textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url
textarea.dispatchEvent(new Event("input", { bubbles: true }))
gradioApp().querySelector('#install_extension_button').click()
}
...@@ -7,6 +7,7 @@ import shlex ...@@ -7,6 +7,7 @@ import shlex
import platform import platform
dir_repos = "repositories" dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "") index_url = os.environ.get('INDEX_URL', "")
...@@ -16,11 +17,11 @@ def extract_arg(args, name): ...@@ -16,11 +17,11 @@ def extract_arg(args, name):
return [x for x in args if x != name], name in args return [x for x in args if x != name], name in args
def run(command, desc=None, errdesc=None): def run(command, desc=None, errdesc=None, custom_env=None):
if desc is not None: if desc is not None:
print(desc) print(desc)
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
if result.returncode != 0: if result.returncode != 0:
...@@ -101,9 +102,27 @@ def version_check(commit): ...@@ -101,9 +102,27 @@ def version_check(commit):
else: else:
print("Not a git clone, can't perform version check.") print("Not a git clone, can't perform version check.")
except Exception as e: except Exception as e:
print("versipm check failed",e) print("version check failed", e)
def run_extensions_installers():
if not os.path.isdir(dir_extensions):
return
for dirname_extension in os.listdir(dir_extensions):
path_installer = os.path.join(dir_extensions, dirname_extension, "install.py")
if not os.path.isfile(path_installer):
continue
try:
env = os.environ.copy()
env['PYTHONPATH'] = os.path.abspath(".")
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {dirname_extension}", custom_env=env))
except Exception as e:
print(e, file=sys.stderr)
def prepare_enviroment(): def prepare_enviroment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
...@@ -189,6 +208,8 @@ def prepare_enviroment(): ...@@ -189,6 +208,8 @@ def prepare_enviroment():
run_pip(f"install -r {requirements_file}", "requirements for Web UI") run_pip(f"install -r {requirements_file}", "requirements for Web UI")
run_extensions_installers()
if update_check: if update_check:
version_check(commit) version_check(commit)
......
...@@ -70,7 +70,7 @@ ...@@ -70,7 +70,7 @@
"None": "Nichts", "None": "Nichts",
"Prompt matrix": "Promptmatrix", "Prompt matrix": "Promptmatrix",
"Prompts from file or textbox": "Prompts aus Datei oder Textfeld", "Prompts from file or textbox": "Prompts aus Datei oder Textfeld",
"X/Y plot": "X/Y Graf", "X/Y plot": "X/Y Graph",
"Put variable parts at start of prompt": "Variable teile am start des Prompt setzen", "Put variable parts at start of prompt": "Variable teile am start des Prompt setzen",
"Iterate seed every line": "Iterate seed every line", "Iterate seed every line": "Iterate seed every line",
"List of prompt inputs": "List of prompt inputs", "List of prompt inputs": "List of prompt inputs",
...@@ -455,4 +455,4 @@ ...@@ -455,4 +455,4 @@
"Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.": "Gilt nur für Inpainting-Modelle. Legt fest, wie stark das Originalbild für Inpainting und img2img maskiert werden soll. 1.0 bedeutet vollständig maskiert, was das Standardverhalten ist. 0.0 bedeutet eine vollständig unmaskierte Konditionierung. Niedrigere Werte tragen dazu bei, die Gesamtkomposition des Bildes zu erhalten, sind aber bei großen Änderungen problematisch.", "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.": "Gilt nur für Inpainting-Modelle. Legt fest, wie stark das Originalbild für Inpainting und img2img maskiert werden soll. 1.0 bedeutet vollständig maskiert, was das Standardverhalten ist. 0.0 bedeutet eine vollständig unmaskierte Konditionierung. Niedrigere Werte tragen dazu bei, die Gesamtkomposition des Bildes zu erhalten, sind aber bei großen Änderungen problematisch.",
"List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.": "Liste von Einstellungsnamen, getrennt durch Kommas, für Einstellungen, die in der Schnellzugriffsleiste oben erscheinen sollen, anstatt in dem üblichen Einstellungs-Tab. Siehe modules/shared.py für Einstellungsnamen. Erfordert einen Neustart zur Anwendung.", "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply.": "Liste von Einstellungsnamen, getrennt durch Kommas, für Einstellungen, die in der Schnellzugriffsleiste oben erscheinen sollen, anstatt in dem üblichen Einstellungs-Tab. Siehe modules/shared.py für Einstellungsnamen. Erfordert einen Neustart zur Anwendung.",
"If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.": "Wenn dieser Wert ungleich Null ist, wird er zum Seed addiert und zur Initialisierung des RNG für Noise bei der Verwendung von Samplern mit Eta verwendet. Dies kann verwendet werden, um noch mehr Variationen von Bildern zu erzeugen, oder um Bilder von anderer Software zu erzeugen, wenn Sie wissen, was Sie tun." "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.": "Wenn dieser Wert ungleich Null ist, wird er zum Seed addiert und zur Initialisierung des RNG für Noise bei der Verwendung von Samplern mit Eta verwendet. Dies kann verwendet werden, um noch mehr Variationen von Bildern zu erzeugen, oder um Bilder von anderer Software zu erzeugen, wenn Sie wissen, was Sie tun."
} }
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
import base64
import io
import time import time
import uvicorn import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image from gradio.processing_utils import decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared import modules.shared as shared
from modules.api.models import * from modules.api.models import *
...@@ -28,6 +30,12 @@ def setUpscalers(req: dict): ...@@ -28,6 +30,12 @@ def setUpscalers(req: dict):
return reqDict return reqDict
def encode_pil_to_base64(image):
buffer = io.BytesIO()
image.save(buffer, format="png")
return base64.b64encode(buffer.getvalue())
class Api: class Api:
def __init__(self, app, queue_lock): def __init__(self, app, queue_lock):
self.router = APIRouter() self.router = APIRouter()
...@@ -39,6 +47,7 @@ class Api: ...@@ -39,6 +47,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
self.app.add_api_route("/sdapi/v1/interrupt", self.interruptapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
...@@ -185,6 +194,11 @@ class Api: ...@@ -185,6 +194,11 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image) return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def interruptapi(self):
shared.state.interrupt()
return {}
def launch(self, server_name, port): def launch(self, server_name, port):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port) uvicorn.run(self.app, host=server_name, port=port)
import os
import sys
import traceback
import git
from modules import paths, shared
extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions")
def active():
return [x for x in extensions if x.enabled]
class Extension:
def __init__(self, name, path, enabled=True):
self.name = name
self.path = path
self.enabled = enabled
self.status = ''
self.can_update = False
repo = None
try:
if os.path.exists(os.path.join(path, ".git")):
repo = git.Repo(path)
except Exception:
print(f"Error reading github repository info from {path}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if repo is None or repo.bare:
self.remote = None
else:
self.remote = next(repo.remote().urls, None)
self.status = 'unknown'
def list_files(self, subdir, extension):
from modules import scripts
dirpath = os.path.join(self.path, subdir)
if not os.path.isdir(dirpath):
return []
res = []
for filename in sorted(os.listdir(dirpath)):
res.append(scripts.ScriptFile(self.path, filename, os.path.join(dirpath, filename)))
res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
return res
def check_updates(self):
repo = git.Repo(self.path)
for fetch in repo.remote().fetch("--dry-run"):
if fetch.flags != fetch.HEAD_UPTODATE:
self.can_update = True
self.status = "behind"
return
self.can_update = False
self.status = "latest"
def pull(self):
repo = git.Repo(self.path)
repo.remotes.origin.pull()
def list_extensions():
extensions.clear()
if not os.path.isdir(extensions_dir):
return
for dirname in sorted(os.listdir(extensions_dir)):
path = os.path.join(extensions_dir, dirname)
if not os.path.isdir(path):
continue
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
extensions.append(extension)
...@@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -141,7 +141,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
upscaling_resize_w, upscaling_resize_h, upscaling_crop) upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()), cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
info_hash=hash(info), info_hash=hash(info),
args_hash=hash(upscale_args)) args_hash=hash((upscale_args, upscale_first)))
cached_entry = cached_images.get(cache_key) cached_entry = cached_images.get(cache_key)
if cached_entry is None: if cached_entry is None:
res = upscale(image, *upscale_args) res = upscale(image, *upscale_args)
......
...@@ -17,6 +17,11 @@ paste_fields = {} ...@@ -17,6 +17,11 @@ paste_fields = {}
bind_list = [] bind_list = []
def reset():
paste_fields.clear()
bind_list.clear()
def quote(text): def quote(text):
if ',' not in str(text): if ',' not in str(text):
return text return text
......
...@@ -510,8 +510,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -510,8 +510,9 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if extension.lower() == '.png': if extension.lower() == '.png':
pnginfo_data = PngImagePlugin.PngInfo() pnginfo_data = PngImagePlugin.PngInfo()
for k, v in params.pnginfo.items(): if opts.enable_pnginfo:
pnginfo_data.add_text(k, str(v)) for k, v in params.pnginfo.items():
pnginfo_data.add_text(k, str(v))
image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data) image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)
......
...@@ -55,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -55,6 +55,7 @@ def process_batch(p, input_dir, output_dir, args):
filename = f"{left}-{n}{right}" filename = f"{left}-{n}{right}"
if not save_normally: if not save_normally:
os.makedirs(output_dir, exist_ok=True)
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
......
...@@ -56,9 +56,9 @@ class InterrogateModels: ...@@ -56,9 +56,9 @@ class InterrogateModels:
import clip import clip
if self.running_on_cpu: if self.running_on_cpu:
model, preprocess = clip.load(clip_model_name, device="cpu") model, preprocess = clip.load(clip_model_name, device="cpu", download_root=shared.cmd_opts.clip_models_path)
else: else:
model, preprocess = clip.load(clip_model_name) model, preprocess = clip.load(clip_model_name, download_root=shared.cmd_opts.clip_models_path)
model.eval() model.eval()
model = model.to(devices.device_interrogate) model = model.to(devices.device_interrogate)
......
...@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
# see below for register_forward_pre_hook; # see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods # useless here, and we just replace those methods
def first_stage_model_encode_wrap(self, encoder, x):
send_me_to_gpu(self, None)
return encoder(x)
def first_stage_model_decode_wrap(self, decoder, z): first_stage_model = sd_model.first_stage_model
send_me_to_gpu(self, None) first_stage_model_encode = sd_model.first_stage_model.encode
return decoder(z) first_stage_model_decode = sd_model.first_stage_model.decode
def first_stage_model_encode_wrap(x):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_encode(x)
def first_stage_model_decode_wrap(z):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then # remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU. # send the model to GPU. Then put modules back. the modules will be in CPU.
...@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
# register hooks for those the first two models # register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x) sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z) sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram: if use_medvram:
......
...@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess(p, res) p.scripts.postprocess(p, res)
p.sd_model = None
p.sampler = None
return res return res
......
...@@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler): ...@@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler):
return getattr(collections, name) return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']: if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
return getattr(torch._utils, name) return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']: if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
return getattr(torch, name) return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']: if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name) return getattr(torch.nn.modules.container, name)
......
...@@ -3,6 +3,8 @@ import traceback ...@@ -3,6 +3,8 @@ import traceback
from collections import namedtuple from collections import namedtuple
import inspect import inspect
from fastapi import FastAPI
from gradio import Blocks
def report_exception(c, job): def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr) print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
...@@ -25,6 +27,7 @@ class ImageSaveParams: ...@@ -25,6 +27,7 @@ class ImageSaveParams:
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_app_started = []
callbacks_model_loaded = [] callbacks_model_loaded = []
callbacks_ui_tabs = [] callbacks_ui_tabs = []
callbacks_ui_settings = [] callbacks_ui_settings = []
...@@ -40,6 +43,14 @@ def clear_callbacks(): ...@@ -40,6 +43,14 @@ def clear_callbacks():
callbacks_image_saved.clear() callbacks_image_saved.clear()
def app_started_callback(demo: Blocks, app: FastAPI):
for c in callbacks_app_started:
try:
c.callback(demo, app)
except Exception:
report_exception(c, 'app_started_callback')
def model_loaded_callback(sd_model): def model_loaded_callback(sd_model):
for c in callbacks_model_loaded: for c in callbacks_model_loaded:
try: try:
...@@ -69,7 +80,7 @@ def ui_settings_callback(): ...@@ -69,7 +80,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams): def before_image_saved_callback(params: ImageSaveParams):
for c in callbacks_image_saved: for c in callbacks_before_image_saved:
try: try:
c.callback(params) c.callback(params)
except Exception: except Exception:
...@@ -91,6 +102,12 @@ def add_callback(callbacks, fun): ...@@ -91,6 +102,12 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun)) callbacks.append(ScriptCallback(filename, fun))
def on_app_started(callback):
"""register a function to be called when the webui started, the gradio `Block` component and
fastapi `FastAPI` object are passed as the arguments"""
add_callback(callbacks_app_started, callback)
def on_model_loaded(callback): def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is """register a function to be called when the stable diffusion model is created; the model is
passed as an argument""" passed as an argument"""
......
...@@ -7,7 +7,7 @@ import modules.ui as ui ...@@ -7,7 +7,7 @@ import modules.ui as ui
import gradio as gr import gradio as gr
from modules.processing import StableDiffusionProcessing from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks from modules import shared, paths, script_callbacks, extensions
AlwaysVisible = object() AlwaysVisible = object()
...@@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension): ...@@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension):
for filename in sorted(os.listdir(basedir)): for filename in sorted(os.listdir(basedir)):
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
extdir = os.path.join(paths.script_path, "extensions") for ext in extensions.active():
if os.path.exists(extdir): scripts_list += ext.list_files(scriptdirname, extension)
for dirname in sorted(os.listdir(extdir)):
dirpath = os.path.join(extdir, dirname)
scriptdirpath = os.path.join(dirpath, scriptdirname)
if not os.path.isdir(scriptdirpath):
continue
for filename in sorted(os.listdir(scriptdirpath)):
scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename)))
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
...@@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension): ...@@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension):
def list_files_with_name(filename): def list_files_with_name(filename):
res = [] res = []
dirs = [paths.script_path] dirs = [paths.script_path] + [ext.path for ext in extensions.active()]
extdir = os.path.join(paths.script_path, "extensions")
if os.path.exists(extdir):
dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))]
for dirpath in dirs: for dirpath in dirs:
if not os.path.isdir(dirpath): if not os.path.isdir(dirpath):
......
...@@ -94,6 +94,10 @@ class StableDiffusionModelHijack: ...@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
self.layers = None
self.circular_enabled = False
self.clip = None
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
return return
......
import collections import collections
import os.path import os.path
import sys import sys
import gc
from collections import namedtuple from collections import namedtuple
import torch import torch
import re import re
...@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None): ...@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None):
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}") print(f"Loading config from: {checkpoint_info.config}")
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
gc.collect()
devices.torch_gc()
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info): if should_hijack_inpainting(checkpoint_info):
...@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None): ...@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None):
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack() do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
...@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None): ...@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None):
return sd_model return sd_model
def reload_model_weights(sd_model, info=None): def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if not sd_model:
sd_model = shared.sd_model
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info) load_model(checkpoint_info)
return shared.sd_model return shared.sd_model
......
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
from math import floor
import torch import torch
import tqdm import tqdm
from PIL import Image from PIL import Image
...@@ -205,17 +206,22 @@ class VanillaStableDiffusionSampler: ...@@ -205,17 +206,22 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None
def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
if valid_step == floor(valid_step):
return int(valid_step) + 1
return num_steps
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps) steps, t_enc = setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p) self.initialize(p)
# existing code fails with certain step counts, like 9 self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
try:
self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception:
self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x self.init_latent = x
...@@ -239,18 +245,14 @@ class VanillaStableDiffusionSampler: ...@@ -239,18 +245,14 @@ class VanillaStableDiffusionSampler:
self.last_latent = x self.last_latent = x
self.step = 0 self.step = 0
steps = steps or p.steps steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model # Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None: if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
# existing code fails with certain step counts, like 9 samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
try:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
except Exception:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim return samples_ddim
......
...@@ -41,7 +41,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion ...@@ -41,7 +41,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site")
parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None)
parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us")
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer'))
...@@ -52,6 +52,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director ...@@ -52,6 +52,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET')) parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR')) parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator") parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
...@@ -98,6 +99,8 @@ restricted_opts = { ...@@ -98,6 +99,8 @@ restricted_opts = {
"outdir_save", "outdir_save",
} }
cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'])
...@@ -134,6 +137,7 @@ class State: ...@@ -134,6 +137,7 @@ class State:
current_image_sampling_step = 0 current_image_sampling_step = 0
textinfo = None textinfo = None
time_start = None time_start = None
need_restart = False
def skip(self): def skip(self):
self.skipped = True self.skipped = True
...@@ -288,11 +292,12 @@ options_templates.update(options_section(('system', "System"), { ...@@ -288,11 +292,12 @@ options_templates.update(options_section(('system', "System"), {
})) }))
options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."), "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"), "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"), "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
"training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
"training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
...@@ -357,6 +362,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" ...@@ -357,6 +362,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters"
'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}),
})) }))
options_templates.update(options_section((None, "Hidden options"), {
"disabled_extensions": OptionInfo([], "Disable those extensions"),
}))
options_templates.update()
class Options: class Options:
data = None data = None
...@@ -368,8 +379,9 @@ class Options: ...@@ -368,8 +379,9 @@ class Options:
def __setattr__(self, key, value): def __setattr__(self, key, value):
if self.data is not None: if self.data is not None:
if key in self.data: if key in self.data or key in self.data_labels:
self.data[key] = value self.data[key] = value
return
return super(Options, self).__setattr__(key, value) return super(Options, self).__setattr__(key, value)
......
...@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
unload = shared.opts.unload_models_when_training
if save_embedding_every > 0: if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings") embedding_dir = os.path.join(log_directory, "embeddings")
...@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
...@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if images_dir is not None and steps_done % create_image_every == 0: if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}' forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename) last_saved_image = os.path.join(images_dir, forced_filename)
shared.sd_model.first_stage_model.to(devices.device)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
do_not_save_grid=True, do_not_save_grid=True,
...@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
processed = processing.process_images(p) processed = processing.process_images(p)
image = processed.images[0] image = processed.images[0]
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
shared.state.current_image = image shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
...@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True) save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename return embedding, filename
......
...@@ -25,8 +25,10 @@ def train_embedding(*args): ...@@ -25,8 +25,10 @@ def train_embedding(*args):
assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
apply_optimizations = shared.opts.training_xattention_optimizations
try: try:
sd_hijack.undo_optimizations() if not apply_optimizations:
sd_hijack.undo_optimizations()
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
...@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)} ...@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
except Exception: except Exception:
raise raise
finally: finally:
sd_hijack.apply_optimizations() if not apply_optimizations:
sd_hijack.apply_optimizations()
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
from modules import sd_hijack, sd_models, localization, script_callbacks from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions
from modules.paths import script_path from modules.paths import script_path
from modules.shared import opts, cmd_opts, restricted_opts from modules.shared import opts, cmd_opts, restricted_opts
...@@ -671,6 +671,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -671,6 +671,7 @@ def create_ui(wrap_gradio_gpu_call):
import modules.img2img import modules.img2img
import modules.txt2img import modules.txt2img
parameters_copypaste.reset()
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
...@@ -1511,8 +1512,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1511,8 +1512,9 @@ def create_ui(wrap_gradio_gpu_call):
column = None column = None
with gr.Row(elem_id="settings").style(equal_height=False): with gr.Row(elem_id="settings").style(equal_height=False):
for i, (k, item) in enumerate(opts.data_labels.items()): for i, (k, item) in enumerate(opts.data_labels.items()):
section_must_be_skipped = item.section[0] is None
if previous_section != item.section: if previous_section != item.section and not section_must_be_skipped:
if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None):
if column is not None: if column is not None:
column.__exit__() column.__exit__()
...@@ -1531,6 +1533,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1531,6 +1533,8 @@ def create_ui(wrap_gradio_gpu_call):
if k in quicksettings_names and not shared.cmd_opts.freeze_settings: if k in quicksettings_names and not shared.cmd_opts.freeze_settings:
quicksettings_list.append((i, k, item)) quicksettings_list.append((i, k, item))
components.append(dummy_component) components.append(dummy_component)
elif section_must_be_skipped:
components.append(dummy_component)
else: else:
component = create_setting_component(k) component = create_setting_component(k)
component_dict[k] = component component_dict[k] = component
...@@ -1572,9 +1576,10 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1572,9 +1576,10 @@ def create_ui(wrap_gradio_gpu_call):
def request_restart(): def request_restart():
shared.state.interrupt() shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True shared.state.need_restart = True
restart_gradio.click( restart_gradio.click(
fn=request_restart, fn=request_restart,
inputs=[], inputs=[],
outputs=[], outputs=[],
...@@ -1612,14 +1617,15 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1612,14 +1617,15 @@ def create_ui(wrap_gradio_gpu_call):
interfaces += script_callbacks.ui_tabs_callback() interfaces += script_callbacks.ui_tabs_callback()
interfaces += [(settings_interface, "Settings", "settings")] interfaces += [(settings_interface, "Settings", "settings")]
extensions_interface = ui_extensions.create_ui()
interfaces += [(extensions_interface, "Extensions", "extensions")]
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"): with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list: for i, k, item in quicksettings_list:
component = create_setting_component(k, is_quicksettings=True) component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component component_dict[k] = component
settings_interface.gradio_ref = demo
parameters_copypaste.integrate_settings_paste_fields(component_dict) parameters_copypaste.integrate_settings_paste_fields(component_dict)
parameters_copypaste.run_bind() parameters_copypaste.run_bind()
......
import json
import os.path
import shutil
import sys
import time
import traceback
import git
import gradio as gr
import html
from modules import extensions, shared, paths
available_extensions = {"extensions": []}
def check_access():
assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags"
def apply_and_restart(disable_list, update_list):
check_access()
disabled = json.loads(disable_list)
assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}"
update = json.loads(update_list)
assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}"
update = set(update)
for ext in extensions.extensions:
if ext.name not in update:
continue
try:
ext.pull()
except Exception:
print(f"Error pulling updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
shared.opts.disabled_extensions = disabled
shared.opts.save(shared.config_filename)
shared.state.interrupt()
shared.state.need_restart = True
def check_updates():
check_access()
for ext in extensions.extensions:
if ext.remote is None:
continue
try:
ext.check_updates()
except Exception:
print(f"Error checking updates for {ext.name}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return extension_table()
def extension_table():
code = f"""<!-- {time.time()} -->
<table id="extensions">
<thead>
<tr>
<th><abbr title="Use checkbox to enable the extension; it will be enabled or disabled when you click apply button">Extension</abbr></th>
<th>URL</th>
<th><abbr title="Use checkbox to mark the extension for update; it will be updated when you click apply button">Update</abbr></th>
</tr>
</thead>
<tbody>
"""
for ext in extensions.extensions:
if ext.can_update:
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
else:
ext_status = ext.status
code += f"""
<tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td><a href="{html.escape(ext.remote or '')}">{html.escape(ext.remote or '')}</a></td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr>
"""
code += """
</tbody>
</table>
"""
return code
def normalize_git_url(url):
if url is None:
return ""
url = url.replace(".git", "")
return url
def install_extension_from_url(dirname, url):
check_access()
assert url, 'No URL specified'
if dirname is None or dirname == "":
*parts, last_part = url.split('/')
last_part = normalize_git_url(last_part)
dirname = last_part
target_dir = os.path.join(extensions.extensions_dir, dirname)
assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}'
normalized_url = normalize_git_url(url)
assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed'
tmpdir = os.path.join(paths.script_path, "tmp", dirname)
try:
shutil.rmtree(tmpdir, True)
repo = git.Repo.clone_from(url, tmpdir)
repo.remote().fetch()
os.rename(tmpdir, target_dir)
extensions.list_extensions()
return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")]
finally:
shutil.rmtree(tmpdir, True)
def install_extension_from_index(url):
ext_table, message = install_extension_from_url(None, url)
return refresh_available_extensions_from_data(), ext_table, message
def refresh_available_extensions(url):
global available_extensions
import urllib.request
with urllib.request.urlopen(url) as response:
text = response.read()
available_extensions = json.loads(text)
return url, refresh_available_extensions_from_data(), ''
def refresh_available_extensions_from_data():
extlist = available_extensions["extensions"]
installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions}
code = f"""<!-- {time.time()} -->
<table id="available_extensions">
<thead>
<tr>
<th>Extension</th>
<th>Description</th>
<th>Action</th>
</tr>
</thead>
<tbody>
"""
for ext in extlist:
name = ext.get("name", "noname")
url = ext.get("url", None)
description = ext.get("description", "")
if url is None:
continue
existing = installed_extension_urls.get(normalize_git_url(url), None)
install_code = f"""<input onclick="install_extension_from_index(this, '{html.escape(url)}')" type="button" value="{"Install" if not existing else "Installed"}" {"disabled=disabled" if existing else ""} class="gr-button gr-button-lg gr-button-secondary">"""
code += f"""
<tr>
<td><a href="{html.escape(url)}">{html.escape(name)}</a></td>
<td>{html.escape(description)}</td>
<td>{install_code}</td>
</tr>
"""
code += """
</tbody>
</table>
"""
return code
def create_ui():
import modules.ui
with gr.Blocks(analytics_enabled=False) as ui:
with gr.Tabs(elem_id="tabs_extensions") as tabs:
with gr.TabItem("Installed"):
with gr.Row():
apply = gr.Button(value="Apply and restart UI", variant="primary")
check = gr.Button(value="Check for updates")
extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)
extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False)
extensions_table = gr.HTML(lambda: extension_table())
apply.click(
fn=apply_and_restart,
_js="extensions_apply",
inputs=[extensions_disabled_list, extensions_update_list],
outputs=[],
)
check.click(
fn=check_updates,
_js="extensions_check",
inputs=[],
outputs=[extensions_table],
)
with gr.TabItem("Available"):
with gr.Row():
refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")
available_extensions_index = gr.Text(value="https://raw.githubusercontent.com/wiki/AUTOMATIC1111/stable-diffusion-webui/Extensions-index.md", label="Extension index URL").style(container=False)
extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)
install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)
install_result = gr.HTML()
available_extensions_table = gr.HTML()
refresh_available_extensions_button.click(
fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]),
inputs=[available_extensions_index],
outputs=[available_extensions_index, available_extensions_table, install_result],
)
install_extension_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]),
inputs=[extension_to_install],
outputs=[available_extensions_table, extensions_table, install_result],
)
with gr.TabItem("Install from URL"):
install_url = gr.Text(label="URL for extension's git repository")
install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto")
install_button = gr.Button(value="Install", variant="primary")
install_result = gr.HTML(elem_id="extension_install_result")
install_button.click(
fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]),
inputs=[install_dirname, install_url],
outputs=[extensions_table, install_result],
)
return ui
...@@ -12,7 +12,7 @@ opencv-python ...@@ -12,7 +12,7 @@ opencv-python
requests requests
piexif piexif
Pillow Pillow
pytorch_lightning pytorch_lightning==1.7.7
realesrgan realesrgan
scikit-image>=0.19 scikit-image>=0.19
timm==0.4.12 timm==0.4.12
...@@ -26,3 +26,4 @@ torchdiffeq ...@@ -26,3 +26,4 @@ torchdiffeq
kornia kornia
lark lark
inflection inflection
GitPython
...@@ -23,3 +23,4 @@ torchdiffeq==0.2.3 ...@@ -23,3 +23,4 @@ torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2 lark==1.1.2
inflection==0.5.1 inflection==0.5.1
GitPython==3.1.27
...@@ -96,6 +96,7 @@ class Script(scripts.Script): ...@@ -96,6 +96,7 @@ class Script(scripts.Script):
def ui(self, is_img2img): def ui(self, is_img2img):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False) checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False)
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False)
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1) prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1)
file = gr.File(label="Upload prompt inputs", type='bytes') file = gr.File(label="Upload prompt inputs", type='bytes')
...@@ -106,9 +107,9 @@ class Script(scripts.Script): ...@@ -106,9 +107,9 @@ class Script(scripts.Script):
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may # We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed. # be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt]) prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
return [checkbox_iterate, file, prompt_txt] return [checkbox_iterate, checkbox_iterate_batch, file, prompt_txt]
def run(self, p, checkbox_iterate, file, prompt_txt: str): def run(self, p, checkbox_iterate, checkbox_iterate_batch, file, prompt_txt: str):
lines = [x.strip() for x in prompt_txt.splitlines()] lines = [x.strip() for x in prompt_txt.splitlines()]
lines = [x for x in lines if len(x) > 0] lines = [x for x in lines if len(x) > 0]
...@@ -137,7 +138,7 @@ class Script(scripts.Script): ...@@ -137,7 +138,7 @@ class Script(scripts.Script):
jobs.append(args) jobs.append(args)
print(f"Will process {len(lines)} lines in {job_count} jobs.") print(f"Will process {len(lines)} lines in {job_count} jobs.")
if (checkbox_iterate and p.seed == -1): if (checkbox_iterate or checkbox_iterate_batch) and p.seed == -1:
p.seed = int(random.randrange(4294967294)) p.seed = int(random.randrange(4294967294))
state.job_count = job_count state.job_count = job_count
...@@ -153,7 +154,7 @@ class Script(scripts.Script): ...@@ -153,7 +154,7 @@ class Script(scripts.Script):
proc = process_images(copy_p) proc = process_images(copy_p)
images += proc.images images += proc.images
if (checkbox_iterate): if checkbox_iterate:
p.seed = p.seed + (p.batch_size * p.n_iter) p.seed = p.seed + (p.batch_size * p.n_iter)
......
...@@ -530,6 +530,29 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h ...@@ -530,6 +530,29 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h
min-height: 480px !important; min-height: 480px !important;
} }
/* Extensions */
#tab_extensions table{
border-collapse: collapse;
}
#tab_extensions table td, #tab_extensions table th{
border: 1px solid #ccc;
padding: 0.25em 0.5em;
}
#tab_extensions table input[type="checkbox"]{
margin-right: 0.5em;
}
#tab_extensions button{
max-width: 16em;
}
#tab_extensions input[disabled="disabled"]{
opacity: 0.5;
}
/* The following handles localization for right-to-left (RTL) languages like Arabic. /* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js. The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running If you change anything above, you need to make sure it is RTL compliant by just running
...@@ -607,4 +630,4 @@ Then, you will need to add the RTL counterpart only if needed in the rtl section ...@@ -607,4 +630,4 @@ Then, you will need to add the RTL counterpart only if needed in the rtl section
right: unset; right: unset;
left: 0.5em; left: 0.5em;
} }
} }
\ No newline at end of file
...@@ -9,7 +9,7 @@ from fastapi.middleware.gzip import GZipMiddleware ...@@ -9,7 +9,7 @@ from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
from modules import devices, sd_samplers, upscaler from modules import devices, sd_samplers, upscaler, extensions
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.extras import modules.extras
import modules.face_restoration import modules.face_restoration
...@@ -23,6 +23,7 @@ import modules.sd_hijack ...@@ -23,6 +23,7 @@ import modules.sd_hijack
import modules.sd_models import modules.sd_models
import modules.shared as shared import modules.shared as shared
import modules.txt2img import modules.txt2img
import modules.script_callbacks
import modules.ui import modules.ui
from modules import devices from modules import devices
...@@ -60,6 +61,8 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): ...@@ -60,6 +61,8 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
def initialize(): def initialize():
extensions.list_extensions()
if cmd_opts.ui_debug_mode: if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
modules.scripts.load_scripts() modules.scripts.load_scripts()
...@@ -75,7 +78,7 @@ def initialize(): ...@@ -75,7 +78,7 @@ def initialize():
modules.scripts.load_scripts() modules.scripts.load_scripts()
modules.sd_models.load_model() modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
...@@ -92,15 +95,18 @@ def create_api(app): ...@@ -92,15 +95,18 @@ def create_api(app):
api = Api(app, queue_lock) api = Api(app, queue_lock)
return api return api
def wait_on_server(demo=None): def wait_on_server(demo=None):
while 1: while 1:
time.sleep(0.5) time.sleep(0.5)
if demo and getattr(demo, 'do_restart', False): if shared.state.need_restart:
shared.state.need_restart = False
time.sleep(0.5) time.sleep(0.5)
demo.close() demo.close()
time.sleep(0.5) time.sleep(0.5)
break break
def api_only(): def api_only():
initialize() initialize()
...@@ -132,14 +138,18 @@ def webui(): ...@@ -132,14 +138,18 @@ def webui():
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
if (launch_api): if launch_api:
create_api(app) create_api(app)
modules.script_callbacks.app_started_callback(demo, app)
wait_on_server(demo) wait_on_server(demo)
sd_samplers.set_samplers() sd_samplers.set_samplers()
print('Reloading Custom Scripts') print('Reloading extensions')
extensions.list_extensions()
print('Reloading custom scripts')
modules.scripts.reload_scripts() modules.scripts.reload_scripts()
print('Reloading modules: modules.ui') print('Reloading modules: modules.ui')
importlib.reload(modules.ui) importlib.reload(modules.ui)
...@@ -148,8 +158,6 @@ def webui(): ...@@ -148,8 +158,6 @@ def webui():
print('Restarting Gradio') print('Restarting Gradio')
task = []
if __name__ == "__main__": if __name__ == "__main__":
if cmd_opts.nowebui: if cmd_opts.nowebui:
api_only() api_only()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment