Commit 70a01cd4 authored by AUTOMATIC1111's avatar AUTOMATIC1111

Merge branch 'dev' into refiner

parents 1aefb502 070b034c
...@@ -195,6 +195,15 @@ def load_network(name, network_on_disk): ...@@ -195,6 +195,15 @@ def load_network(name, network_on_disk):
return net return net
def purge_networks_from_memory():
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
name = next(iter(networks_in_memory))
networks_in_memory.pop(name, None)
devices.torch_gc()
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
already_loaded = {} already_loaded = {}
...@@ -212,15 +221,19 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No ...@@ -212,15 +221,19 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
failed_to_load_networks = [] failed_to_load_networks = []
for i, name in enumerate(names): for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
net = already_loaded.get(name, None) net = already_loaded.get(name, None)
network_on_disk = networks_on_disk[i]
if network_on_disk is not None: if network_on_disk is not None:
if net is None:
net = networks_in_memory.get(name)
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
try: try:
net = load_network(name, network_on_disk) net = load_network(name, network_on_disk)
networks_in_memory.pop(name, None)
networks_in_memory[name] = net
except Exception as e: except Exception as e:
errors.display(e, f"loading network {network_on_disk.filename}") errors.display(e, f"loading network {network_on_disk.filename}")
continue continue
...@@ -242,6 +255,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No ...@@ -242,6 +255,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No
if failed_to_load_networks: if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
purge_networks_from_memory()
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None) weights_backup = getattr(self, "network_weights_backup", None)
...@@ -462,6 +477,7 @@ def infotext_pasted(infotext, params): ...@@ -462,6 +477,7 @@ def infotext_pasted(infotext, params):
available_networks = {} available_networks = {}
available_network_aliases = {} available_network_aliases = {}
loaded_networks = [] loaded_networks = []
networks_in_memory = {}
available_network_hash_lookup = {} available_network_hash_lookup = {}
forbidden_network_aliases = {} forbidden_network_aliases = {}
......
...@@ -65,6 +65,7 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra ...@@ -65,6 +65,7 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
"lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
})) }))
...@@ -121,3 +122,5 @@ def infotext_pasted(infotext, d): ...@@ -121,3 +122,5 @@ def infotext_pasted(infotext, d):
script_callbacks.on_infotext_pasted(infotext_pasted) script_callbacks.on_infotext_pasted(infotext_pasted)
shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
...@@ -42,6 +42,11 @@ onUiLoaded(async() => { ...@@ -42,6 +42,11 @@ onUiLoaded(async() => {
} }
} }
// Detect whether the element has a horizontal scroll bar
function hasHorizontalScrollbar(element) {
return element.scrollWidth > element.clientWidth;
}
// Function for defining the "Ctrl", "Shift" and "Alt" keys // Function for defining the "Ctrl", "Shift" and "Alt" keys
function isModifierKey(event, key) { function isModifierKey(event, key) {
switch (key) { switch (key) {
...@@ -201,7 +206,8 @@ onUiLoaded(async() => { ...@@ -201,7 +206,8 @@ onUiLoaded(async() => {
canvas_hotkey_overlap: "KeyO", canvas_hotkey_overlap: "KeyO",
canvas_disabled_functions: [], canvas_disabled_functions: [],
canvas_show_tooltip: true, canvas_show_tooltip: true,
canvas_blur_prompt: false canvas_auto_expand: true,
canvas_blur_prompt: false,
}; };
const functionMap = { const functionMap = {
...@@ -648,8 +654,32 @@ onUiLoaded(async() => { ...@@ -648,8 +654,32 @@ onUiLoaded(async() => {
mouseY = e.offsetY; mouseY = e.offsetY;
} }
// Simulation of the function to put a long image into the screen.
// We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element.
// We hide the image and show it to the user when it is ready.
function autoExpand(e) {
const canvas = document.querySelector(`${elemId} canvas[key="interface"]`);
const isMainTab = activeElement === elementIDs.inpaint || activeElement === elementIDs.inpaintSketch || activeElement === elementIDs.sketch;
if (canvas && isMainTab) {
if (hasHorizontalScrollbar(targetElement)) {
targetElement.style.visibility = "hidden";
setTimeout(() => {
fitToScreen();
resetZoom();
targetElement.style.visibility = "visible";
}, 10);
}
}
}
targetElement.addEventListener("mousemove", getMousePosition); targetElement.addEventListener("mousemove", getMousePosition);
// Apply auto expand if enabled
if (hotkeysConfig.canvas_auto_expand) {
targetElement.addEventListener("mousemove", autoExpand);
}
// Handle events only inside the targetElement // Handle events only inside the targetElement
let isKeyDownHandlerAttached = false; let isKeyDownHandlerAttached = false;
......
...@@ -9,6 +9,7 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas ...@@ -9,6 +9,7 @@ shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas
"canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"),
"canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"), "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"),
"canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"),
"canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"),
"canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"), "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"),
"canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}), "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}),
})) }))
var observerAccordionOpen = new MutationObserver(function(mutations) {
mutations.forEach(function(mutationRecord) {
var elem = mutationRecord.target;
var open = elem.classList.contains('open');
var accordion = elem.parentNode;
accordion.classList.toggle('input-accordion-open', open);
var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input");
checkbox.checked = open;
updateInput(checkbox);
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
if (extra) {
extra.style.display = open ? "" : "none";
}
});
});
function inputAccordionChecked(id, checked) {
var label = gradioApp().querySelector('#' + id + " .label-wrap");
if (label.classList.contains('open') != checked) {
label.click();
}
}
onUiLoaded(function() {
for (var accordion of gradioApp().querySelectorAll('.input-accordion')) {
var labelWrap = accordion.querySelector('.label-wrap');
observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']});
var extra = gradioApp().querySelector('#' + accordion.id + "-extra");
if (extra) {
labelWrap.insertBefore(extra, labelWrap.lastElementChild);
}
}
});
...@@ -16,6 +16,7 @@ parser.add_argument("--test-server", action='store_true', help="launch.py argume ...@@ -16,6 +16,7 @@ parser.add_argument("--test-server", action='store_true', help="launch.py argume
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup") parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation") parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None)
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint") parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored") parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
......
...@@ -3,7 +3,7 @@ import contextlib ...@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache from functools import lru_cache
import torch import torch
from modules import errors, rng_philox from modules import errors, shared
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
...@@ -17,8 +17,6 @@ def has_mps() -> bool: ...@@ -17,8 +17,6 @@ def has_mps() -> bool:
def get_cuda_device_string(): def get_cuda_device_string():
from modules import shared
if shared.cmd_opts.device_id is not None: if shared.cmd_opts.device_id is not None:
return f"cuda:{shared.cmd_opts.device_id}" return f"cuda:{shared.cmd_opts.device_id}"
...@@ -40,8 +38,6 @@ def get_optimal_device(): ...@@ -40,8 +38,6 @@ def get_optimal_device():
def get_device_for(task): def get_device_for(task):
from modules import shared
if task in shared.cmd_opts.use_cpu: if task in shared.cmd_opts.use_cpu:
return cpu return cpu
...@@ -96,87 +92,7 @@ def cond_cast_float(input): ...@@ -96,87 +92,7 @@ def cond_cast_float(input):
nv_rng = None nv_rng = None
def randn(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
from modules.shared import opts
manual_seed(seed)
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_local(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
from modules.shared import opts
if opts.randn_source == "NV":
rng = rng_philox.Generator(seed)
return torch.asarray(rng.randn(shape), device=device)
local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
local_generator = torch.Generator(local_device).manual_seed(int(seed))
return torch.randn(shape, device=local_device, generator=local_generator).to(device)
def randn_like(x):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
from modules.shared import opts
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=cpu).to(x.device)
return torch.randn_like(x)
def randn_without_seed(shape):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
from modules.shared import opts
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def manual_seed(seed):
"""Set up a global random number generator using the specified seed."""
from modules.shared import opts
if opts.randn_source == "NV":
global nv_rng
nv_rng = rng_philox.Generator(seed)
return
torch.manual_seed(seed)
def autocast(disable=False): def autocast(disable=False):
from modules import shared
if disable: if disable:
return contextlib.nullcontext() return contextlib.nullcontext()
...@@ -195,8 +111,6 @@ class NansException(Exception): ...@@ -195,8 +111,6 @@ class NansException(Exception):
def test_for_nans(x, where): def test_for_nans(x, where):
from modules import shared
if shared.cmd_opts.disable_nan_check: if shared.cmd_opts.disable_nan_check:
return return
...@@ -236,3 +150,4 @@ def first_time_calculation(): ...@@ -236,3 +150,4 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype) x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x) conv2d(x)
import os import os
import threading import threading
from modules import shared, errors, cache from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
...@@ -90,8 +90,6 @@ class Extension: ...@@ -90,8 +90,6 @@ class Extension:
self.have_info_from_repo = True self.have_info_from_repo = True
def list_files(self, subdir, extension): def list_files(self, subdir, extension):
from modules import scripts
dirpath = os.path.join(self.path, subdir) dirpath = os.path.join(self.path, subdir)
if not os.path.isdir(dirpath): if not os.path.isdir(dirpath):
return [] return []
......
...@@ -6,7 +6,7 @@ import re ...@@ -6,7 +6,7 @@ import re
import gradio as gr import gradio as gr
from modules.paths import data_path from modules.paths import data_path
from modules import shared, ui_tempdir, script_callbacks from modules import shared, ui_tempdir, script_callbacks, processing
from PIL import Image from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)'
...@@ -198,7 +198,6 @@ def restore_old_hires_fix_params(res): ...@@ -198,7 +198,6 @@ def restore_old_hires_fix_params(res):
height = int(res.get("Size-2", 512)) height = int(res.get("Size-2", 512))
if firstpass_width == 0 or firstpass_height == 0: if firstpass_width == 0 or firstpass_height == 0:
from modules import processing
firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height) firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
res['Size-1'] = firstpass_width res['Size-1'] = firstpass_width
...@@ -317,36 +316,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -317,36 +316,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
infotext_to_setting_name_mapping = [ infotext_to_setting_name_mapping = [
('Clip skip', 'CLIP_stop_at_last_layers', ),
]
"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
Example content:
infotext_to_setting_name_mapping = [
('Conditional mask weight', 'inpainting_mask_weight'), ('Conditional mask weight', 'inpainting_mask_weight'),
('Model hash', 'sd_model_checkpoint'), ('Model hash', 'sd_model_checkpoint'),
('ENSD', 'eta_noise_seed_delta'), ('ENSD', 'eta_noise_seed_delta'),
('Schedule type', 'k_sched_type'), ('Schedule type', 'k_sched_type'),
('Schedule max sigma', 'sigma_max'),
('Schedule min sigma', 'sigma_min'),
('Schedule rho', 'rho'),
('Noise multiplier', 'initial_noise_multiplier'),
('Eta', 'eta_ancestral'),
('Eta DDIM', 'eta_ddim'),
('Sigma churn', 's_churn'),
('Sigma tmin', 's_tmin'),
('Sigma tmax', 's_tmax'),
('Sigma noise', 's_noise'),
('Discard penultimate sigma', 'always_discard_next_to_last_sigma'),
('UniPC variant', 'uni_pc_variant'),
('UniPC skip type', 'uni_pc_skip_type'),
('UniPC order', 'uni_pc_order'),
('UniPC lower order final', 'uni_pc_lower_order_final'),
('Token merging ratio', 'token_merging_ratio'),
('Token merging ratio hr', 'token_merging_ratio_hr'),
('RNG', 'randn_source'),
('NGMS', 's_min_uncond'),
('Pad conds', 'pad_cond_uncond'),
('VAE Encoder', 'sd_vae_encode_method'),
('VAE Decoder', 'sd_vae_decode_method'),
('Refiner', 'sd_refiner_checkpoint'),
('Refiner switch at', 'sd_refiner_switch_at'),
] ]
"""
def create_override_settings_dict(text_pairs): def create_override_settings_dict(text_pairs):
...@@ -367,7 +348,8 @@ def create_override_settings_dict(text_pairs): ...@@ -367,7 +348,8 @@ def create_override_settings_dict(text_pairs):
params[k] = v.strip() params[k] = v.strip()
for param_name, setting_name in infotext_to_setting_name_mapping: mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
value = params.get(param_name, None) value = params.get(param_name, None)
if value is None: if value is None:
...@@ -421,7 +403,8 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, ...@@ -421,7 +403,8 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
def paste_settings(params): def paste_settings(params):
vals = {} vals = {}
for param_name, setting_name in infotext_to_setting_name_mapping: mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
if param_name in already_handled_fields: if param_name in already_handled_fields:
continue continue
......
import gradio as gr import gradio as gr
from modules import scripts from modules import scripts, ui_tempdir
def add_classes_to_gradio_component(comp): def add_classes_to_gradio_component(comp):
""" """
...@@ -58,3 +58,5 @@ original_BlockContext_init = gr.blocks.BlockContext.__init__ ...@@ -58,3 +58,5 @@ original_BlockContext_init = gr.blocks.BlockContext.__init__
gr.components.IOComponent.__init__ = IOComponent_init gr.components.IOComponent.__init__ = IOComponent_init
gr.blocks.Block.get_config = Block_get_config gr.blocks.Block.get_config = Block_get_config
gr.blocks.BlockContext.__init__ = BlockContext_init gr.blocks.BlockContext.__init__ = BlockContext_init
ui_tempdir.install_ui_tempdir_override()
...@@ -21,8 +21,6 @@ from modules import sd_samplers, shared, script_callbacks, errors ...@@ -21,8 +21,6 @@ from modules import sd_samplers, shared, script_callbacks, errors
from modules.paths_internal import roboto_ttf_file from modules.paths_internal import roboto_ttf_file
from modules.shared import opts from modules.shared import opts
import modules.sd_vae as sd_vae
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
...@@ -342,16 +340,6 @@ def sanitize_filename_part(text, replace_spaces=True): ...@@ -342,16 +340,6 @@ def sanitize_filename_part(text, replace_spaces=True):
class FilenameGenerator: class FilenameGenerator:
def get_vae_filename(self): #get the name of the VAE file.
if sd_vae.loaded_vae_file is None:
return "NoneType"
file_name = os.path.basename(sd_vae.loaded_vae_file)
split_file_name = file_name.split('.')
if len(split_file_name) > 1 and split_file_name[0] == '':
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
else:
return split_file_name[0]
replacements = { replacements = {
'seed': lambda self: self.seed if self.seed is not None else '', 'seed': lambda self: self.seed if self.seed is not None else '',
'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0], 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0],
...@@ -391,6 +379,22 @@ class FilenameGenerator: ...@@ -391,6 +379,22 @@ class FilenameGenerator:
self.image = image self.image = image
self.zip = zip self.zip = zip
def get_vae_filename(self):
"""Get the name of the VAE file."""
import modules.sd_vae as sd_vae
if sd_vae.loaded_vae_file is None:
return "NoneType"
file_name = os.path.basename(sd_vae.loaded_vae_file)
split_file_name = file_name.split('.')
if len(split_file_name) > 1 and split_file_name[0] == '':
return split_file_name[1] # if the first character of the filename is "." then [1] is obtained.
else:
return split_file_name[0]
def hasprompt(self, *args): def hasprompt(self, *args):
lower = self.prompt.lower() lower = self.prompt.lower()
if self.p is None or self.prompt is None: if self.p is None or self.prompt is None:
......
...@@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal ...@@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
process_images(p) process_images(p)
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
is_batch = mode == 5 is_batch = mode == 5
...@@ -179,8 +179,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s ...@@ -179,8 +179,6 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
width=width, width=width,
height=height, height=height,
restore_faces=restore_faces,
tiling=tiling,
init_images=[image], init_images=[image],
mask=mask, mask=mask,
mask_blur=mask_blur, mask_blur=mask_blur,
......
import importlib
import logging
import sys
import warnings
from threading import Thread
from modules.timer import startup_timer
def imports():
logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh...
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
import torch # noqa: F401
startup_timer.record("import torch")
import pytorch_lightning # noqa: F401
startup_timer.record("import torch")
warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")
import gradio # noqa: F401
startup_timer.record("import gradio")
from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer.record("setup paths")
import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
import sgm.modules.encoders.modules # noqa: F401
startup_timer.record("import sgm")
from modules import shared_init
shared_init.initialize()
startup_timer.record("initialize shared")
from modules import processing, gradio_extensons, ui # noqa: F401
startup_timer.record("other imports")
def check_versions():
from modules.shared_cmd_options import cmd_opts
if not cmd_opts.skip_version_check:
from modules import errors
errors.check_versions()
def initialize():
from modules import initialize_util
initialize_util.fix_torch_version()
initialize_util.fix_asyncio_event_loop_policy()
initialize_util.validate_tls_options()
initialize_util.configure_sigint_handler()
initialize_util.configure_opts_onchange()
from modules import modelloader
modelloader.cleanup_models()
from modules import sd_models
sd_models.setup_model()
startup_timer.record("setup SD model")
from modules.shared_cmd_options import cmd_opts
from modules import codeformer_model
warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor")
codeformer_model.setup_model(cmd_opts.codeformer_models_path)
startup_timer.record("setup codeformer")
from modules import gfpgan_model
gfpgan_model.setup_model(cmd_opts.gfpgan_models_path)
startup_timer.record("setup gfpgan")
initialize_rest(reload_script_modules=False)
def initialize_rest(*, reload_script_modules=False):
"""
Called both from initialize() and when reloading the webui.
"""
from modules.shared_cmd_options import cmd_opts
from modules import sd_samplers
sd_samplers.set_samplers()
startup_timer.record("set samplers")
from modules import extensions
extensions.list_extensions()
startup_timer.record("list extensions")
from modules import initialize_util
initialize_util.restore_config_state_file()
startup_timer.record("restore config state file")
from modules import shared, upscaler, scripts
if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
scripts.load_scripts()
return
from modules import sd_models
sd_models.list_models()
startup_timer.record("list SD models")
from modules import localization
localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list localizations")
with startup_timer.subcategory("load scripts"):
scripts.load_scripts()
if reload_script_modules:
for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
importlib.reload(module)
startup_timer.record("reload script modules")
from modules import modelloader
modelloader.load_upscalers()
startup_timer.record("load upscalers")
from modules import sd_vae
sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
from modules import textual_inversion
textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
from modules import script_callbacks, sd_hijack_optimizations, sd_hijack
script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers)
sd_hijack.list_optimizers()
startup_timer.record("scripts list_optimizers")
from modules import sd_unet
sd_unet.list_unets()
startup_timer.record("scripts list_unets")
def load_model():
"""
Accesses shared.sd_model property to load model.
After it's available, if it has been loaded before this access by some extension,
its optimization may be None because the list of optimizaers has neet been filled
by that time, so we apply optimization again.
"""
shared.sd_model # noqa: B018
if sd_hijack.current_optimizer is None:
sd_hijack.apply_optimizations()
from modules import devices
devices.first_time_calculation()
Thread(target=load_model).start()
from modules import shared_items
shared_items.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
from modules import ui_extra_networks
ui_extra_networks.initialize()
ui_extra_networks.register_default_pages()
from modules import extra_networks
extra_networks.initialize()
extra_networks.register_default_extra_networks()
startup_timer.record("initialize extra networks")
import json
import os
import signal
import sys
import re
from modules.timer import startup_timer
def gradio_server_name():
from modules.shared_cmd_options import cmd_opts
if cmd_opts.server_name:
return cmd_opts.server_name
else:
return "0.0.0.0" if cmd_opts.listen else None
def fix_torch_version():
import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
def fix_asyncio_event_loop_policy():
"""
The default `asyncio` event loop policy only automatically creates
event loops in the main threads. Other threads must create event
loops explicitly or `asyncio.get_event_loop` (and therefore
`.IOLoop.current`) will fail. Installing this policy allows event
loops to be created automatically on any thread, matching the
behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2).
"""
import asyncio
if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
# "Any thread" and "selector" should be orthogonal, but there's not a clean
# interface for composing policies so pick the right base.
_BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore
else:
_BasePolicy = asyncio.DefaultEventLoopPolicy
class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore
"""Event loop policy that allows loop creation on any thread.
Usage::
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
"""
def get_event_loop(self) -> asyncio.AbstractEventLoop:
try:
return super().get_event_loop()
except (RuntimeError, AssertionError):
# This was an AssertionError in python 3.4.2 (which ships with debian jessie)
# and changed to a RuntimeError in 3.4.3.
# "There is no current event loop in thread %r"
loop = self.new_event_loop()
self.set_event_loop(loop)
return loop
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
def restore_config_state_file():
from modules import shared, config_states
config_state_file = shared.opts.restore_config_state_file
if config_state_file == "":
return
shared.opts.restore_config_state_file = ""
shared.opts.save(shared.config_filename)
if os.path.isfile(config_state_file):
print(f"*** About to restore extension state from file: {config_state_file}")
with open(config_state_file, "r", encoding="utf-8") as f:
config_state = json.load(f)
config_states.restore_extension_config(config_state)
startup_timer.record("restore extension config")
elif config_state_file:
print(f"!!! Config state backup not found: {config_state_file}")
def validate_tls_options():
from modules.shared_cmd_options import cmd_opts
if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile):
return
try:
if not os.path.exists(cmd_opts.tls_keyfile):
print("Invalid path to TLS keyfile given")
if not os.path.exists(cmd_opts.tls_certfile):
print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
except TypeError:
cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
print("TLS setup invalid, running webui without TLS")
else:
print("Running with TLS")
startup_timer.record("TLS")
def get_gradio_auth_creds():
"""
Convert the gradio_auth and gradio_auth_path commandline arguments into
an iterable of (username, password) tuples.
"""
from modules.shared_cmd_options import cmd_opts
def process_credential_line(s):
s = s.strip()
if not s:
return None
return tuple(s.split(':', 1))
if cmd_opts.gradio_auth:
for cred in cmd_opts.gradio_auth.split(','):
cred = process_credential_line(cred)
if cred:
yield cred
if cmd_opts.gradio_auth_path:
with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file:
for line in file.readlines():
for cred in line.strip().split(','):
cred = process_credential_line(cred)
if cred:
yield cred
def configure_sigint_handler():
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
os._exit(0)
if not os.environ.get("COVERAGE_RUN"):
# Don't install the immediate-quit handler when running under coverage,
# as then the coverage report won't be generated.
signal.signal(signal.SIGINT, sigint_handler)
def configure_opts_onchange():
from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack
from modules.call_queue import wrap_queued_call
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
startup_timer.record("opts onchange")
def setup_middleware(app):
from starlette.middleware.gzip import GZipMiddleware
app.middleware_stack = None # reset current middleware to allow modifying user provided list
app.add_middleware(GZipMiddleware, minimum_size=1000)
configure_cors_middleware(app)
app.build_middleware_stack() # rebuild middleware stack on-the-fly
def configure_cors_middleware(app):
from starlette.middleware.cors import CORSMiddleware
from modules.shared_cmd_options import cmd_opts
cors_options = {
"allow_methods": ["*"],
"allow_headers": ["*"],
"allow_credentials": True,
}
if cmd_opts.cors_allow_origins:
cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',')
if cmd_opts.cors_allow_origins_regex:
cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex
app.add_middleware(CORSMiddleware, **cors_options)
# this scripts installs necessary requirements and launches main program in webui.py # this scripts installs necessary requirements and launches main program in webui.py
import logging
import re import re
import subprocess import subprocess
import os import os
...@@ -11,8 +12,10 @@ from functools import lru_cache ...@@ -11,8 +12,10 @@ from functools import lru_cache
from modules import cmd_args, errors from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir from modules.paths_internal import script_path, extensions_dir
from modules.timer import startup_timer from modules.timer import startup_timer
from modules import logging_config
args, _ = cmd_args.parser.parse_known_args() args, _ = cmd_args.parser.parse_known_args()
logging_config.setup_logging(args.loglevel)
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
...@@ -249,6 +252,8 @@ def run_extensions_installers(settings_file): ...@@ -249,6 +252,8 @@ def run_extensions_installers(settings_file):
with startup_timer.subcategory("run extensions installers"): with startup_timer.subcategory("run extensions installers"):
for dirname_extension in list_extensions(settings_file): for dirname_extension in list_extensions(settings_file):
logging.debug(f"Installing {dirname_extension}")
path = os.path.join(extensions_dir, dirname_extension) path = os.path.join(extensions_dir, dirname_extension)
if os.path.isdir(path): if os.path.isdir(path):
......
import json import json
import os import os
from modules import errors from modules import errors, scripts
localizations = {} localizations = {}
...@@ -16,7 +16,6 @@ def list_localizations(dirname): ...@@ -16,7 +16,6 @@ def list_localizations(dirname):
localizations[fn] = os.path.join(dirname, file) localizations[fn] = os.path.join(dirname, file)
from modules import scripts
for file in scripts.list_scripts("localizations", ".json"): for file in scripts.list_scripts("localizations", ".json"):
fn, ext = os.path.splitext(file.filename) fn, ext = os.path.splitext(file.filename)
localizations[fn] = file.path localizations[fn] = file.path
......
import os
import logging
def setup_logging(loglevel):
if loglevel is None:
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")
if loglevel:
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
)
...@@ -4,6 +4,7 @@ import torch ...@@ -4,6 +4,7 @@ import torch
import platform import platform
from modules.sd_hijack_utils import CondFunc from modules.sd_hijack_utils import CondFunc
from packaging import version from packaging import version
from modules import shared
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -30,8 +31,7 @@ has_mps = check_for_mps() ...@@ -30,8 +31,7 @@ has_mps = check_for_mps()
def torch_mps_gc() -> None: def torch_mps_gc() -> None:
try: try:
from modules.shared import state if shared.state.current_latent is not None:
if state.current_latent is not None:
log.debug("`current_latent` is set, skipping MPS garbage collection") log.debug("`current_latent` is set, skipping MPS garbage collection")
return return
from torch.mps import empty_cache from torch.mps import empty_cache
......
import json
import sys
import gradio as gr
from modules import errors
from modules.shared_cmd_options import cmd_opts
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
self.onchange = onchange
self.section = section
self.refresh = refresh
self.do_not_save = False
self.comment_before = comment_before
"""HTML text that will be added after label in UI"""
self.comment_after = comment_after
"""HTML text that will be added before label in UI"""
self.infotext = infotext
def link(self, label, url):
self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]"
return self
def js(self, label, js_func):
self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]"
return self
def info(self, info):
self.comment_after += f"<span class='info'>({info})</span>"
return self
def html(self, html):
self.comment_after += html
return self
def needs_restart(self):
self.comment_after += " <span class='info'>(requires restart)</span>"
return self
def needs_reload_ui(self):
self.comment_after += " <span class='info'>(requires Reload UI)</span>"
return self
class OptionHTML(OptionInfo):
def __init__(self, text):
super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs))
self.do_not_save = True
def options_section(section_identifier, options_dict):
for v in options_dict.values():
v.section = section_identifier
return options_dict
options_builtin_fields = {"data_labels", "data", "restricted_opts", "typemap"}
class Options:
typemap = {int: float}
def __init__(self, data_labels, restricted_opts):
self.data_labels = data_labels
self.data = {k: v.default for k, v in self.data_labels.items()}
self.restricted_opts = restricted_opts
def __setattr__(self, key, value):
if key in options_builtin_fields:
return super(Options, self).__setattr__(key, value)
if self.data is not None:
if key in self.data or key in self.data_labels:
assert not cmd_opts.freeze_settings, "changing settings is disabled"
info = self.data_labels.get(key, None)
if info.do_not_save:
return
comp_args = info.component_args if info else None
if isinstance(comp_args, dict) and comp_args.get('visible', True) is False:
raise RuntimeError(f"not possible to set {key} because it is restricted")
if cmd_opts.hide_ui_dir_config and key in self.restricted_opts:
raise RuntimeError(f"not possible to set {key} because it is restricted")
self.data[key] = value
return
return super(Options, self).__setattr__(key, value)
def __getattr__(self, item):
if item in options_builtin_fields:
return super(Options, self).__getattribute__(item)
if self.data is not None:
if item in self.data:
return self.data[item]
if item in self.data_labels:
return self.data_labels[item].default
return super(Options, self).__getattribute__(item)
def set(self, key, value):
"""sets an option and calls its onchange callback, returning True if the option changed and False otherwise"""
oldval = self.data.get(key, None)
if oldval == value:
return False
if self.data_labels[key].do_not_save:
return False
try:
setattr(self, key, value)
except RuntimeError:
return False
if self.data_labels[key].onchange is not None:
try:
self.data_labels[key].onchange()
except Exception as e:
errors.display(e, f"changing setting {key} to {value}")
setattr(self, key, oldval)
return False
return True
def get_default(self, key):
"""returns the default value for the key"""
data_label = self.data_labels.get(key)
if data_label is None:
return None
return data_label.default
def save(self, filename):
assert not cmd_opts.freeze_settings, "saving settings is disabled"
with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file, indent=4)
def same_type(self, x, y):
if x is None or y is None:
return True
type_x = self.typemap.get(type(x), type(x))
type_y = self.typemap.get(type(y), type(y))
return type_x == type_y
def load(self, filename):
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
# 1.6.0 VAE defaults
if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
# 1.1.1 quicksettings list migration
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
# 1.4.0 ui_reorder
if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data:
self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')]
bad_settings = 0
for k, v in self.data.items():
info = self.data_labels.get(k, None)
if info is not None and not self.same_type(info.default, v):
print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
bad_settings += 1
if bad_settings > 0:
print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)
def onchange(self, key, func, call=True):
item = self.data_labels.get(key)
item.onchange = func
if call:
func()
def dumpjson(self):
d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()}
d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None}
d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None}
return json.dumps(d)
def add_option(self, key, info):
self.data_labels[key] = info
def reorder(self):
"""reorder settings so that all items related to section always go together"""
section_ids = {}
settings_items = self.data_labels.items()
for _, item in settings_items:
if item.section not in section_ids:
section_ids[item.section] = len(section_ids)
self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section]))
def cast_value(self, key, value):
"""casts an arbitrary to the same type as this setting's value with key
Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str)
"""
if value is None:
return None
default_value = self.data_labels[key].default
if default_value is None:
default_value = getattr(self, key, None)
if default_value is None:
return None
expected_type = type(default_value)
if expected_type == bool and value == "False":
value = False
else:
value = expected_type(value)
return value
This diff is collapsed.
import torch
from modules import devices, rng_philox, shared
def randn(seed, shape, generator=None):
"""Generate a tensor with random numbers from a normal distribution using seed.
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
manual_seed(seed)
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
return torch.randn(shape, device=devices.device, generator=generator)
def randn_local(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
if shared.opts.randn_source == "NV":
rng = rng_philox.Generator(seed)
return torch.asarray(rng.randn(shape), device=devices.device)
local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
local_generator = torch.Generator(local_device).manual_seed(int(seed))
return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
def randn_like(x):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
if shared.opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
return torch.randn_like(x)
def randn_without_seed(shape, generator=None):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
return torch.randn(shape, device=devices.device, generator=generator)
def manual_seed(seed):
"""Set up a global random number generator using the specified seed."""
if shared.opts.randn_source == "NV":
global nv_rng
nv_rng = rng_philox.Generator(seed)
return
torch.manual_seed(seed)
def create_generator(seed):
if shared.opts.randn_source == "NV":
return rng_philox.Generator(seed)
device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
generator = torch.Generator(device).manual_seed(int(seed))
return generator
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
dot = (low_norm*high_norm).sum(1)
if dot.mean() > 0.9995:
return low * val + high * (1 - val)
omega = torch.acos(dot)
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
class ImageRNG:
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
self.shape = shape
self.seeds = seeds
self.subseeds = subseeds
self.subseed_strength = subseed_strength
self.seed_resize_from_h = seed_resize_from_h
self.seed_resize_from_w = seed_resize_from_w
self.generators = [create_generator(seed) for seed in seeds]
self.is_first = True
def first(self):
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
xs = []
for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
subnoise = None
if self.subseeds is not None and self.subseed_strength != 0:
subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
subnoise = randn(subseed, noise_shape)
if noise_shape != self.shape:
noise = randn(seed, noise_shape)
else:
noise = randn(seed, self.shape, generator=generator)
if subnoise is not None:
noise = slerp(self.subseed_strength, noise, subnoise)
if noise_shape != self.shape:
x = randn(seed, self.shape, generator=generator)
dx = (self.shape[2] - noise_shape[2]) // 2
dy = (self.shape[1] - noise_shape[1]) // 2
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
tx = 0 if dx < 0 else dx
ty = 0 if dy < 0 else dy
dx = max(-dx, 0)
dy = max(-dy, 0)
x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
noise = x
xs.append(noise)
eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
if eta_noise_seed_delta:
self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
return torch.stack(xs).to(shared.device)
def next(self):
if self.is_first:
self.is_first = False
return self.first()
xs = []
for generator in self.generators:
x = randn_without_seed(self.shape, generator=generator)
xs.append(x)
return torch.stack(xs).to(shared.device)
devices.randn = randn
devices.randn_local = randn_local
devices.randn_like = randn_like
devices.randn_without_seed = randn_without_seed
devices.manual_seed = manual_seed
...@@ -14,7 +14,7 @@ import ldm.modules.midas as midas ...@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
from modules.timer import Timer from modules.timer import Timer
import tomesd import tomesd
...@@ -68,7 +68,9 @@ class CheckpointInfo: ...@@ -68,7 +68,9 @@ class CheckpointInfo:
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]' self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]']
if self.shorthash:
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
def register(self): def register(self):
checkpoints_list[self.title] = self checkpoints_list[self.title] = self
...@@ -80,10 +82,14 @@ class CheckpointInfo: ...@@ -80,10 +82,14 @@ class CheckpointInfo:
if self.sha256 is None: if self.sha256 is None:
return return
self.shorthash = self.sha256[0:10] shorthash = self.sha256[0:10]
if self.shorthash == self.sha256[0:10]:
return self.shorthash
self.shorthash = shorthash
if self.shorthash not in self.ids: if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]']
checkpoints_list.pop(self.title, None) checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]' self.title = f'{self.name} [{self.shorthash}]'
...@@ -489,7 +495,6 @@ model_data = SdModelData() ...@@ -489,7 +495,6 @@ model_data = SdModelData()
def get_empty_cond(sd_model): def get_empty_cond(sd_model):
from modules import extra_networks, processing
p = processing.StableDiffusionProcessingTxt2Img() p = processing.StableDiffusionProcessingTxt2Img()
extra_networks.activate(p, {}) extra_networks.activate(p, {})
...@@ -502,8 +507,6 @@ def get_empty_cond(sd_model): ...@@ -502,8 +507,6 @@ def get_empty_cond(sd_model):
def send_model_to_cpu(m): def send_model_to_cpu(m):
from modules import lowvram
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()
else: else:
...@@ -513,8 +516,6 @@ def send_model_to_cpu(m): ...@@ -513,8 +516,6 @@ def send_model_to_cpu(m):
def send_model_to_device(m): def send_model_to_device(m):
from modules import lowvram
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
else: else:
...@@ -639,6 +640,8 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): ...@@ -639,6 +640,8 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
timer.record("send model to device") timer.record("send model to device")
model_data.set_sd_model(already_loaded) model_data.set_sd_model(already_loaded)
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
return model_data.sd_model return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
...@@ -658,7 +661,6 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): ...@@ -658,7 +661,6 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
def reload_model_weights(sd_model=None, info=None): def reload_model_weights(sd_model=None, info=None):
from modules import devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
timer = Timer() timer = Timer()
...@@ -721,7 +723,6 @@ def reload_model_weights(sd_model=None, info=None): ...@@ -721,7 +723,6 @@ def reload_model_weights(sd_model=None, info=None):
def unload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None):
from modules import devices, sd_hijack
timer = Timer() timer = Timer()
if model_data.sd_model: if model_data.sd_model:
......
...@@ -2,7 +2,7 @@ import os ...@@ -2,7 +2,7 @@ import os
import torch import torch
from modules import shared, paths, sd_disable_initialization from modules import shared, paths, sd_disable_initialization, devices
sd_configs_path = shared.sd_configs_path sd_configs_path = shared.sd_configs_path
sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")
...@@ -29,7 +29,6 @@ def is_using_v_parameterization_for_sd2(state_dict): ...@@ -29,7 +29,6 @@ def is_using_v_parameterization_for_sd2(state_dict):
""" """
import ldm.modules.diffusionmodules.openaimodel import ldm.modules.diffusionmodules.openaimodel
from modules import devices
device = devices.cpu device = devices.cpu
......
import inspect import inspect
from collections import namedtuple, deque from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
...@@ -161,10 +161,15 @@ def apply_refiner(sampler): ...@@ -161,10 +161,15 @@ def apply_refiner(sampler):
class TorchHijack: class TorchHijack:
def __init__(self, sampler_noises): """This is here to replace torch.randn_like of k-diffusion.
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation. k-diffusion has random_sampler argument for most samplers, but not for all, so
self.sampler_noises = deque(sampler_noises) this is needed to properly replace every use of torch.randn_like.
We need to replace to make images generated in batches to be same as images generated individually."""
def __init__(self, p):
self.rng = p.rng
def __getattr__(self, item): def __getattr__(self, item):
if item == 'randn_like': if item == 'randn_like':
...@@ -176,12 +181,7 @@ class TorchHijack: ...@@ -176,12 +181,7 @@ class TorchHijack:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x): def randn_like(self, x):
if self.sampler_noises: return self.rng.next()
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise
return devices.randn_like(x)
class Sampler: class Sampler:
...@@ -248,7 +248,7 @@ class Sampler: ...@@ -248,7 +248,7 @@ class Sampler:
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) k_diffusion.sampling.torch = TorchHijack(p)
extra_params_kwargs = {} extra_params_kwargs = {}
for param_name in self.extra_params: for param_name in self.extra_params:
......
import torch import torch
import inspect import inspect
import k_diffusion.sampling import k_diffusion.sampling
from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser from modules import sd_samplers_common, sd_samplers_extra
from modules.sd_samplers_cfg_denoiser import CFGDenoiser
from modules.shared import opts from modules.shared import opts
import modules.shared as shared import modules.shared as shared
......
import torch import torch
import inspect import inspect
import sys
from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
from modules.sd_samplers_cfg_denoiser import CFGDenoiser from modules.sd_samplers_cfg_denoiser import CFGDenoiser
...@@ -152,3 +153,6 @@ class CompVisSampler(sd_samplers_common.Sampler): ...@@ -152,3 +153,6 @@ class CompVisSampler(sd_samplers_common.Sampler):
return samples return samples
sys.modules['modules.sd_samplers_compvis'] = sys.modules[__name__]
VanillaStableDiffusionSampler = CompVisSampler # temp. compatibility with older extensions
...@@ -2,7 +2,8 @@ import os ...@@ -2,7 +2,8 @@ import os
import collections import collections
from dataclasses import dataclass from dataclasses import dataclass
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes
import glob import glob
from copy import deepcopy from copy import deepcopy
...@@ -19,6 +20,20 @@ checkpoint_info = None ...@@ -19,6 +20,20 @@ checkpoint_info = None
checkpoints_loaded = collections.OrderedDict() checkpoints_loaded = collections.OrderedDict()
def get_loaded_vae_name():
if loaded_vae_file is None:
return None
return os.path.basename(loaded_vae_file)
def get_loaded_vae_hash():
if loaded_vae_file is None:
return None
return hashes.sha256(loaded_vae_file, 'vae')[0:10]
def get_base_vae(model): def get_base_vae(model):
if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
return base_vae return base_vae
...@@ -231,8 +246,6 @@ unspecified = object() ...@@ -231,8 +246,6 @@ unspecified = object()
def reload_vae_weights(sd_model=None, vae_file=unspecified): def reload_vae_weights(sd_model=None, vae_file=unspecified):
from modules import lowvram, devices, sd_hijack
if not sd_model: if not sd_model:
sd_model = shared.sd_model sd_model = shared.sd_model
......
This diff is collapsed.
import os
import launch
from modules import cmd_args, script_loading
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
parser = cmd_args.parser
script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
script_loading.preload_extensions(extensions_builtin_dir, parser)
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
cmd_opts = parser.parse_args()
else:
cmd_opts, _ = parser.parse_known_args()
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
import os
import gradio as gr
from modules import errors, shared
from modules.paths_internal import script_path
# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json
gradio_hf_hub_themes = [
"gradio/base",
"gradio/glass",
"gradio/monochrome",
"gradio/seafoam",
"gradio/soft",
"gradio/dracula_test",
"abidlabs/dracula_test",
"abidlabs/Lime",
"abidlabs/pakistan",
"Ama434/neutral-barlow",
"dawood/microsoft_windows",
"finlaymacklon/smooth_slate",
"Franklisi/darkmode",
"freddyaboulton/dracula_revamped",
"freddyaboulton/test-blue",
"gstaff/xkcd",
"Insuz/Mocha",
"Insuz/SimpleIndigo",
"JohnSmith9982/small_and_pretty",
"nota-ai/theme",
"nuttea/Softblue",
"ParityError/Anime",
"reilnuud/polite",
"remilia/Ghostly",
"rottenlittlecreature/Moon_Goblin",
"step-3-profit/Midnight-Deep",
"Taithrah/Minimal",
"ysharma/huggingface",
"ysharma/steampunk"
]
def reload_gradio_theme(theme_name=None):
if not theme_name:
theme_name = shared.opts.gradio_theme
default_theme_args = dict(
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'],
)
if theme_name == "Default":
shared.gradio_theme = gr.themes.Default(**default_theme_args)
else:
try:
theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes')
theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json')
if shared.opts.gradio_themes_cache and os.path.exists(theme_cache_path):
shared.gradio_theme = gr.themes.ThemeClass.load(theme_cache_path)
else:
os.makedirs(theme_cache_dir, exist_ok=True)
shared.gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
shared.gradio_theme.dump(theme_cache_path)
except Exception as e:
errors.display(e, "changing gradio theme")
shared.gradio_theme = gr.themes.Default(**default_theme_args)
import os
import torch
from modules import shared
from modules.shared import cmd_opts
def initialize():
"""Initializes fields inside the shared module in a controlled manner.
Should be called early because some other modules you can import mingt need these fields to be already set.
"""
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
from modules import options, shared_options
shared.options_templates = shared_options.options_templates
shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts)
shared.restricted_opts = shared_options.restricted_opts
if os.path.exists(shared.config_filename):
shared.opts.load(shared.config_filename)
from modules import devices
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16
shared.device = devices.device
shared.weight_load_location = None if cmd_opts.lowram else "cpu"
from modules import shared_state
shared.state = shared_state.State()
from modules import styles
shared.prompt_styles = styles.StyleDatabase(shared.styles_filename)
from modules import interrogate
shared.interrogator = interrogate.InterrogateModels("interrogate")
from modules import shared_total_tqdm
shared.total_tqdm = shared_total_tqdm.TotalTQDM()
from modules import memmon, devices
shared.mem_mon = memmon.MemUsageMonitor("MemMon", devices.device, shared.opts)
shared.mem_mon.start()
import sys
from modules.shared_cmd_options import cmd_opts
def realesrgan_models_names(): def realesrgan_models_names():
...@@ -41,6 +44,28 @@ def refresh_unet_list(): ...@@ -41,6 +44,28 @@ def refresh_unet_list():
modules.sd_unet.list_unets() modules.sd_unet.list_unets()
def list_checkpoint_tiles():
import modules.sd_models
return modules.sd_models.checkpoint_tiles()
def refresh_checkpoints():
import modules.sd_models
return modules.sd_models.list_models()
def list_samplers():
import modules.sd_samplers
return modules.sd_samplers.all_samplers
def reload_hypernetworks():
from modules.hypernetworks import hypernetwork
from modules import shared
shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
ui_reorder_categories_builtin_items = [ ui_reorder_categories_builtin_items = [
"inpaint", "inpaint",
"sampler", "sampler",
...@@ -67,3 +92,27 @@ def ui_reorder_categories(): ...@@ -67,3 +92,27 @@ def ui_reorder_categories():
yield from sections yield from sections
yield "scripts" yield "scripts"
class Shared(sys.modules[__name__].__class__):
"""
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
at program startup.
"""
sd_model_val = None
@property
def sd_model(self):
import modules.sd_models
return modules.sd_models.model_data.get_sd_model()
@sd_model.setter
def sd_model(self, value):
import modules.sd_models
modules.sd_models.model_data.set_sd_model(value)
sys.modules['modules.shared'].__class__ = Shared
This diff is collapsed.
import datetime
import logging
import threading
import time
from modules import errors, shared, devices
from typing import Optional
log = logging.getLogger(__name__)
class State:
skipped = False
interrupted = False
job = ""
job_no = 0
job_count = 0
processing_has_refined_job_count = False
job_timestamp = '0'
sampling_step = 0
sampling_steps = 0
current_latent = None
current_image = None
current_image_sampling_step = 0
id_live_preview = 0
textinfo = None
time_start = None
server_start = None
_server_command_signal = threading.Event()
_server_command: Optional[str] = None
def __init__(self):
self.server_start = time.time()
@property
def need_restart(self) -> bool:
# Compatibility getter for need_restart.
return self.server_command == "restart"
@need_restart.setter
def need_restart(self, value: bool) -> None:
# Compatibility setter for need_restart.
if value:
self.server_command = "restart"
@property
def server_command(self):
return self._server_command
@server_command.setter
def server_command(self, value: Optional[str]) -> None:
"""
Set the server command to `value` and signal that it's been set.
"""
self._server_command = value
self._server_command_signal.set()
def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]:
"""
Wait for server command to get set; return and clear the value and signal.
"""
if self._server_command_signal.wait(timeout):
self._server_command_signal.clear()
req = self._server_command
self._server_command = None
return req
return None
def request_restart(self) -> None:
self.interrupt()
self.server_command = "restart"
log.info("Received restart request")
def skip(self):
self.skipped = True
log.info("Received skip request")
def interrupt(self):
self.interrupted = True
log.info("Received interrupt request")
def nextjob(self):
if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1:
self.do_set_current_image()
self.job_no += 1
self.sampling_step = 0
self.current_image_sampling_step = 0
def dict(self):
obj = {
"skipped": self.skipped,
"interrupted": self.interrupted,
"job": self.job,
"job_count": self.job_count,
"job_timestamp": self.job_timestamp,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
}
return obj
def begin(self, job: str = "(unknown)"):
self.sampling_step = 0
self.job_count = -1
self.processing_has_refined_job_count = False
self.job_no = 0
self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
self.current_latent = None
self.current_image = None
self.current_image_sampling_step = 0
self.id_live_preview = 0
self.skipped = False
self.interrupted = False
self.textinfo = None
self.time_start = time.time()
self.job = job
devices.torch_gc()
log.info("Starting job %s", job)
def end(self):
duration = time.time() - self.time_start
log.info("Ending job %s (%.2f seconds)", self.job, duration)
self.job = ""
self.job_count = 0
devices.torch_gc()
def set_current_image(self):
"""sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this"""
if not shared.parallel_processing_allowed:
return
if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1:
self.do_set_current_image()
def do_set_current_image(self):
if self.current_latent is None:
return
import modules.sd_samplers
try:
if shared.opts.show_progress_grid:
self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent))
else:
self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent))
self.current_image_sampling_step = self.sampling_step
except Exception:
# when switching models during genration, VAE would be on CPU, so creating an image will fail.
# we silently ignore this error
errors.record_exception()
def assign_current_image(self, image):
self.current_image = image
self.id_live_preview += 1
import tqdm
from modules import shared
class TotalTQDM:
def __init__(self):
self._tqdm = None
def reset(self):
self._tqdm = tqdm.tqdm(
desc="Total progress",
total=shared.state.job_count * shared.state.sampling_steps,
position=1,
file=shared.progress_print_out
)
def update(self):
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.update()
def updateTotal(self, new_total):
if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.total = new_total
def clear(self):
if self._tqdm is not None:
self._tqdm.refresh()
self._tqdm.close()
self._tqdm = None
...@@ -10,7 +10,7 @@ import psutil ...@@ -10,7 +10,7 @@ import psutil
import re import re
import launch import launch
from modules import paths_internal, timer from modules import paths_internal, timer, shared, extensions, errors
checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY" checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY"
environment_whitelist = { environment_whitelist = {
...@@ -115,8 +115,6 @@ def format_exception(e, tb): ...@@ -115,8 +115,6 @@ def format_exception(e, tb):
def get_exceptions(): def get_exceptions():
try: try:
from modules import errors
return list(reversed(errors.exception_records)) return list(reversed(errors.exception_records))
except Exception as e: except Exception as e:
return str(e) return str(e)
...@@ -142,8 +140,6 @@ def get_torch_sysinfo(): ...@@ -142,8 +140,6 @@ def get_torch_sysinfo():
def get_extensions(*, enabled): def get_extensions(*, enabled):
try: try:
from modules import extensions
def to_json(x: extensions.Extension): def to_json(x: extensions.Extension):
return { return {
"name": x.name, "name": x.name,
...@@ -160,7 +156,6 @@ def get_extensions(*, enabled): ...@@ -160,7 +156,6 @@ def get_extensions(*, enabled):
def get_config(): def get_config():
try: try:
from modules import shared
return shared.opts.data return shared.opts.data
except Exception as e: except Exception as e:
return str(e) return str(e)
...@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html ...@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import gradio as gr import gradio as gr
def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):
override_settings = create_override_settings_dict(override_settings_texts) override_settings = create_override_settings_dict(override_settings_texts)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
...@@ -32,8 +32,6 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step ...@@ -32,8 +32,6 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
cfg_scale=cfg_scale, cfg_scale=cfg_scale,
width=width, width=width,
height=height, height=height,
restore_faces=restore_faces,
tiling=tiling,
enable_hr=enable_hr, enable_hr=enable_hr,
denoising_strength=denoising_strength if enable_hr else None, denoising_strength=denoising_strength if enable_hr else None,
hr_scale=hr_scale, hr_scale=hr_scale,
...@@ -42,7 +40,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step ...@@ -42,7 +40,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step
hr_resize_x=hr_resize_x, hr_resize_x=hr_resize_x,
hr_resize_y=hr_resize_y, hr_resize_y=hr_resize_y,
hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,
hr_sampler_name=hr_sampler_name, hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name,
hr_prompt=hr_prompt, hr_prompt=hr_prompt,
hr_negative_prompt=hr_negative_prompt, hr_negative_prompt=hr_negative_prompt,
override_settings=override_settings, override_settings=override_settings,
......
...@@ -13,8 +13,8 @@ from PIL import Image, PngImagePlugin # noqa: F401 ...@@ -13,8 +13,8 @@ from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules import gradio_extensons # noqa: F401 from modules import gradio_extensons # noqa: F401
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers, processing, ui_extra_networks
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion
from modules.paths import script_path from modules.paths import script_path
from modules.ui_common import create_refresh_button from modules.ui_common import create_refresh_button
from modules.ui_gradio_extensions import reload_javascript from modules.ui_gradio_extensions import reload_javascript
...@@ -78,7 +78,6 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 ...@@ -78,7 +78,6 @@ extra_networks_symbol = '\U0001F3B4' # 🎴
switch_values_symbol = '\U000021C5' # ⇅ switch_values_symbol = '\U000021C5' # ⇅
restore_progress_symbol = '\U0001F300' # 🌀 restore_progress_symbol = '\U0001F300' # 🌀
detect_image_size_symbol = '\U0001F4D0' # 📐 detect_image_size_symbol = '\U0001F4D0' # 📐
up_down_symbol = '\u2195\ufe0f' # ↕️
plaintext_to_html = ui_common.plaintext_to_html plaintext_to_html = ui_common.plaintext_to_html
...@@ -91,17 +90,13 @@ def send_gradio_gallery_to_image(x): ...@@ -91,17 +90,13 @@ def send_gradio_gallery_to_image(x):
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
from modules import processing, devices
if not enable: if not enable:
return "" return ""
p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y)
p.calculate_target_resolution()
with devices.autocast(): return f"from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
p.init([""], [0], [0])
return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>"
def resize_from_to_html(width, height, scale_by): def resize_from_to_html(width, height, scale_by):
...@@ -149,7 +144,11 @@ def interrogate_deepbooru(image): ...@@ -149,7 +144,11 @@ def interrogate_deepbooru(image):
def create_seed_inputs(target_interface): def create_seed_inputs(target_interface):
with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"): with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):
seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed") if cmd_opts.use_textbox_seed:
seed = gr.Textbox(label='Seed', value="", elem_id=f"{target_interface}_seed")
else:
seed = gr.Number(label='Seed', value=-1, elem_id=f"{target_interface}_seed", precision=0)
random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed') random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed')
reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed') reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')
...@@ -160,7 +159,7 @@ def create_seed_inputs(target_interface): ...@@ -160,7 +159,7 @@ def create_seed_inputs(target_interface):
with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1: with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:
seed_extras.append(seed_extra_row_1) seed_extras.append(seed_extra_row_1)
subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed") subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed", precision=0)
random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed") random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")
reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed") reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")
subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength") subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")
...@@ -437,13 +436,13 @@ def create_ui(): ...@@ -437,13 +436,13 @@ def create_ui():
elif category == "checkboxes": elif category == "checkboxes":
with FormRow(elem_classes="checkboxes-row", variant="compact"): with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") pass
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling")
enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr")
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False)
elif category == "hires_fix": elif category == "hires_fix":
with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: with InputAccordion(False, label="Hires. fix") as enable_hr:
with enable_hr.extra():
hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0)
with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"):
hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode)
hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps")
...@@ -520,8 +519,6 @@ def create_ui(): ...@@ -520,8 +519,6 @@ def create_ui():
toprow.ui_styles.dropdown, toprow.ui_styles.dropdown,
steps, steps,
sampler_name, sampler_name,
restore_faces,
tiling,
batch_count, batch_count,
batch_size, batch_size,
cfg_scale, cfg_scale,
...@@ -571,19 +568,11 @@ def create_ui(): ...@@ -571,19 +568,11 @@ def create_ui():
show_progress=False, show_progress=False,
) )
enable_hr.change(
fn=lambda x: gr_show(x),
inputs=[enable_hr],
outputs=[hr_options],
show_progress = False,
)
txt2img_paste_fields = [ txt2img_paste_fields = [
(toprow.prompt, "Prompt"), (toprow.prompt, "Prompt"),
(toprow.negative_prompt, "Negative prompt"), (toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"), (steps, "Steps"),
(sampler_name, "Sampler"), (sampler_name, "Sampler"),
(restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"), (cfg_scale, "CFG scale"),
(seed, "Seed"), (seed, "Seed"),
(width, "Size-1"), (width, "Size-1"),
...@@ -597,7 +586,6 @@ def create_ui(): ...@@ -597,7 +586,6 @@ def create_ui():
(toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)), (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d))),
(hr_scale, "Hires upscale"), (hr_scale, "Hires upscale"),
(hr_upscaler, "Hires upscaler"), (hr_upscaler, "Hires upscaler"),
(hr_second_pass_steps, "Hires steps"), (hr_second_pass_steps, "Hires steps"),
...@@ -630,7 +618,6 @@ def create_ui(): ...@@ -630,7 +618,6 @@ def create_ui():
toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
from modules import ui_extra_networks
extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
...@@ -805,8 +792,7 @@ def create_ui(): ...@@ -805,8 +792,7 @@ def create_ui():
elif category == "checkboxes": elif category == "checkboxes":
with FormRow(elem_classes="checkboxes-row", variant="compact"): with FormRow(elem_classes="checkboxes-row", variant="compact"):
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="img2img_restore_faces") pass
tiling = gr.Checkbox(label='Tiling', value=False, elem_id="img2img_tiling")
elif category == "batch": elif category == "batch":
if not opts.dimensions_and_batch_together: if not opts.dimensions_and_batch_together:
...@@ -879,8 +865,6 @@ def create_ui(): ...@@ -879,8 +865,6 @@ def create_ui():
mask_blur, mask_blur,
mask_alpha, mask_alpha,
inpainting_fill, inpainting_fill,
restore_faces,
tiling,
batch_count, batch_count,
batch_size, batch_size,
cfg_scale, cfg_scale,
...@@ -972,7 +956,6 @@ def create_ui(): ...@@ -972,7 +956,6 @@ def create_ui():
(toprow.negative_prompt, "Negative prompt"), (toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"), (steps, "Steps"),
(sampler_name, "Sampler"), (sampler_name, "Sampler"),
(restore_faces, "Face restoration"),
(cfg_scale, "CFG scale"), (cfg_scale, "CFG scale"),
(image_cfg_scale, "Image CFG scale"), (image_cfg_scale, "Image CFG scale"),
(seed, "Seed"), (seed, "Seed"),
...@@ -995,7 +978,6 @@ def create_ui(): ...@@ -995,7 +978,6 @@ def create_ui():
paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
)) ))
from modules import ui_extra_networks
extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img') extra_networks_ui_img2img = ui_extra_networks.create_ui(img2img_interface, [img2img_generation_tab], 'img2img')
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
......
...@@ -11,7 +11,7 @@ from modules import call_queue, shared ...@@ -11,7 +11,7 @@ from modules import call_queue, shared
from modules.generation_parameters_copypaste import image_from_url_text from modules.generation_parameters_copypaste import image_from_url_text
import modules.images import modules.images
from modules.ui_components import ToolButton from modules.ui_components import ToolButton
import modules.generation_parameters_copypaste as parameters_copypaste
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄 refresh_symbol = '\U0001f504' # 🔄
...@@ -105,8 +105,6 @@ def save_files(js_data, images, do_make_zip, index): ...@@ -105,8 +105,6 @@ def save_files(js_data, images, do_make_zip, index):
def create_output_panel(tabname, outdir): def create_output_panel(tabname, outdir):
from modules import shared
import modules.generation_parameters_copypaste as parameters_copypaste
def open_folder(f): def open_folder(f):
if not os.path.exists(f): if not os.path.exists(f):
......
...@@ -72,3 +72,52 @@ class DropdownEditable(FormComponent, gr.Dropdown): ...@@ -72,3 +72,52 @@ class DropdownEditable(FormComponent, gr.Dropdown):
def get_block_name(self): def get_block_name(self):
return "dropdown" return "dropdown"
class InputAccordion(gr.Checkbox):
"""A gr.Accordion that can be used as an input - returns True if open, False if closed.
Actaully just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox.
"""
global_index = 0
def __init__(self, value, **kwargs):
self.accordion_id = kwargs.get('elem_id')
if self.accordion_id is None:
self.accordion_id = f"input-accordion-{InputAccordion.global_index}"
InputAccordion.global_index += 1
kwargs['elem_id'] = self.accordion_id + "-checkbox"
kwargs['visible'] = False
super().__init__(value, **kwargs)
self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self])
self.accordion = gr.Accordion(kwargs.get('label', 'Accordion'), open=value, elem_id=self.accordion_id, elem_classes=['input-accordion'])
def extra(self):
"""Allows you to put something into the label of the accordion.
Use it like this:
```
with InputAccordion(False, label="Accordion") as acc:
with acc.extra():
FormHTML(value="hello", min_width=0)
...
```
"""
return gr.Column(elem_id=self.accordion_id + '-extra', elem_classes='input-accordion-extra', min_width=0)
def __enter__(self):
self.accordion.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.accordion.__exit__(exc_type, exc_val, exc_tb)
def get_block_name(self):
return "checkbox"
...@@ -4,7 +4,6 @@ from pathlib import Path ...@@ -4,7 +4,6 @@ from pathlib import Path
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
from modules.images import read_info_from_image, save_image_with_geninfo from modules.images import read_info_from_image, save_image_with_geninfo
from modules.ui import up_down_symbol
import gradio as gr import gradio as gr
import json import json
import html import html
...@@ -348,6 +347,8 @@ def pages_in_preferred_order(pages): ...@@ -348,6 +347,8 @@ def pages_in_preferred_order(pages):
def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
from modules.ui import switch_values_symbol
ui = ExtraNetworksUi() ui = ExtraNetworksUi()
ui.pages = [] ui.pages = []
ui.pages_contents = [] ui.pages_contents = []
...@@ -373,7 +374,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ...@@ -373,7 +374,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname):
edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True)
dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order")
button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False) button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False)
checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False)
......
...@@ -36,8 +36,8 @@ class UserMetadataEditor: ...@@ -36,8 +36,8 @@ class UserMetadataEditor:
item = self.page.items.get(name, {}) item = self.page.items.get(name, {})
user_metadata = item.get('user_metadata', None) user_metadata = item.get('user_metadata', None)
if user_metadata is None: if not user_metadata:
user_metadata = {} user_metadata = {'description': item.get('description', '')}
item['user_metadata'] = user_metadata item['user_metadata'] = user_metadata
return user_metadata return user_metadata
......
...@@ -8,7 +8,7 @@ from modules.ui_components import ToolButton ...@@ -8,7 +8,7 @@ from modules.ui_components import ToolButton
class UiLoadsave: class UiLoadsave:
"""allows saving and restorig default values for gradio components""" """allows saving and restoring default values for gradio components"""
def __init__(self, filename): def __init__(self, filename):
self.filename = filename self.filename = filename
...@@ -48,6 +48,11 @@ class UiLoadsave: ...@@ -48,6 +48,11 @@ class UiLoadsave:
elif condition and not condition(saved_value): elif condition and not condition(saved_value):
pass pass
else: else:
if isinstance(x, gr.Textbox) and field == 'value': # due to an undersirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies
saved_value = str(saved_value)
elif isinstance(x, gr.Number) and field == 'value':
saved_value = float(saved_value)
setattr(obj, field, saved_value) setattr(obj, field, saved_value)
if init_field is not None: if init_field is not None:
init_field(saved_value) init_field(saved_value)
......
...@@ -57,8 +57,9 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"): ...@@ -57,8 +57,9 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"):
return file_obj.name return file_obj.name
# override save to file function so that it also writes PNG info def install_ui_tempdir_override():
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file """override save to file function so that it also writes PNG info"""
gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file
def on_tmpdir_changed(): def on_tmpdir_changed():
......
import os
import re
from modules import shared
from modules.paths_internal import script_path
def natural_sort_key(s, regex=re.compile('([0-9]+)')):
return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)]
def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
def html_path(filename):
return os.path.join(script_path, "html", filename)
def html(filename):
path = html_path(filename)
if os.path.exists(path):
with open(path, encoding="utf8") as file:
return file.read()
return ""
def walk_files(path, allowed_extensions=None):
if not os.path.exists(path):
return
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
items = list(os.walk(path, followlinks=True))
items = sorted(items, key=lambda x: natural_sort_key(x[0]))
for root, _, files in items:
for filename in sorted(files, key=natural_sort_key):
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
continue
if not shared.opts.list_hidden_files and ("/." in root or "\\." in root):
continue
yield os.path.join(root, filename)
def ldm_print(*args, **kwargs):
if shared.opts.hide_ldm_prints:
return
print(*args, **kwargs)
...@@ -6,6 +6,7 @@ basicsr ...@@ -6,6 +6,7 @@ basicsr
blendmodes blendmodes
clean-fid clean-fid
einops einops
fastapi>=0.90.1
gfpgan gfpgan
gradio==3.39.0 gradio==3.39.0
inflection inflection
......
...@@ -43,13 +43,15 @@ div.form{ ...@@ -43,13 +43,15 @@ div.form{
.block.gradio-radio, .block.gradio-radio,
.block.gradio-checkboxgroup, .block.gradio-checkboxgroup,
.block.gradio-number, .block.gradio-number,
.block.gradio-colorpicker, .block.gradio-colorpicker {
div.gradio-group
{
border-width: 0 !important; border-width: 0 !important;
box-shadow: none !important; box-shadow: none !important;
} }
div.gradio-group, div.styler{
border-width: 0 !important;
background: none;
}
.gap.compact{ .gap.compact{
padding: 0; padding: 0;
gap: 0.2em 0; gap: 0.2em 0;
...@@ -135,12 +137,8 @@ a{ ...@@ -135,12 +137,8 @@ a{
cursor: pointer; cursor: pointer;
} }
div.styler{ /* gradio 3.39 puts a lot of overflow: hidden all over the place for an unknown reqasaon. */
border: none; .block.gradio-textbox, div.gradio-group, div.gradio-group div, div.gradio-dropdown{
background: var(--background-fill-primary);
}
.block.gradio-textbox{
overflow: visible !important; overflow: visible !important;
} }
...@@ -194,6 +192,13 @@ button.custom-button{ ...@@ -194,6 +192,13 @@ button.custom-button{
text-align: center; text-align: center;
} }
div.gradio-accordion {
border: 1px solid var(--block-border-color) !important;
border-radius: 8px !important;
margin: 2px 0;
padding: 8px 8px;
}
/* txt2img/img2img specific */ /* txt2img/img2img specific */
...@@ -324,12 +329,6 @@ button.custom-button{ ...@@ -324,12 +329,6 @@ button.custom-button{
border-radius: 0 0.5rem 0.5rem 0; border-radius: 0 0.5rem 0.5rem 0;
} }
#txtimg_hr_finalres{
min-height: 0 !important;
padding: .625rem .75rem;
margin-left: -0.75em
}
#img2img_scale_resolution_preview.block{ #img2img_scale_resolution_preview.block{
display: flex; display: flex;
align-items: end; align-items: end;
...@@ -1011,3 +1010,12 @@ div.block.gradio-box.popup-dialog, .popup-dialog { ...@@ -1011,3 +1010,12 @@ div.block.gradio-box.popup-dialog, .popup-dialog {
div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{ div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{
margin-top: 1em; margin-top: 1em;
} }
div.block.input-accordion{
margin-bottom: 0.4em;
}
.input-accordion-extra{
flex: 0 0 auto !important;
margin: 0 0.5em 0 auto;
}
import os import os
import pytest import pytest
from PIL import Image import base64
from gradio.processing_utils import encode_pil_to_base64
test_files_path = os.path.dirname(__file__) + "/test_files" test_files_path = os.path.dirname(__file__) + "/test_files"
def file_to_base64(filename):
with open(filename, "rb") as file:
data = file.read()
base64_str = str(base64.b64encode(data), "utf-8")
return "data:image/png;base64," + base64_str
@pytest.fixture(scope="session") # session so we don't read this over and over @pytest.fixture(scope="session") # session so we don't read this over and over
def img2img_basic_image_base64() -> str: def img2img_basic_image_base64() -> str:
return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "img2img_basic.png"))) return file_to_base64(os.path.join(test_files_path, "img2img_basic.png"))
@pytest.fixture(scope="session") # session so we don't read this over and over @pytest.fixture(scope="session") # session so we don't read this over and over
def mask_basic_image_base64() -> str: def mask_basic_image_base64() -> str:
return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "mask_basic.png"))) return file_to_base64(os.path.join(test_files_path, "mask_basic.png"))
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment