Commit 56c3f94b authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'dev' into dev

parents 952effa8 073c0ebb
...@@ -41,6 +41,7 @@ jobs: ...@@ -41,6 +41,7 @@ jobs:
--skip-prepare-environment --skip-prepare-environment
--skip-torch-cuda-test --skip-torch-cuda-test
--test-server --test-server
--do-not-download-clip
--no-half --no-half
--disable-opt-split-attention --disable-opt-split-attention
--use-cpu all --use-cpu all
......
## 1.5.1
### Minor:
* support parsing text encoder blocks in some new LoRAs
* delete scale checker script due to user demand
### Extensions and API:
* add postprocess_batch_list script callback
### Bug Fixes:
* fix TI training for SD1
* fix reload altclip model error
* prepend the pythonpath instead of overriding it
* fix typo in SD_WEBUI_RESTARTING
* if txt2img/img2img raises an exception, finally call state.end()
* fix composable diffusion weight parsing
* restyle Startup profile for black users
* fix webui not launching with --nowebui
* catch exception for non git extensions
* fix some options missing from /sdapi/v1/options
* fix for extension update status always saying "unknown"
* fix display of extra network cards that have `<>` in the name
* update lora extension to work with python 3.8
## 1.5.0 ## 1.5.0
### Features: ### Features:
* SD XL support * SD XL support
* user metadata system for custom networks * user metadata system for custom networks
* extended Lora metadata editor: set activation text, default weight, view tags, training info * extended Lora metadata editor: set activation text, default weight, view tags, training info
* Lora extension rework to include other types of networks (all that were previously handled by LyCORIS extension)
* show github stars for extenstions * show github stars for extenstions
* img2img batch mode can read extra stuff from png info * img2img batch mode can read extra stuff from png info
* img2img batch works with subdirectories * img2img batch works with subdirectories
...@@ -11,6 +37,9 @@ ...@@ -11,6 +37,9 @@
* restyle time taken/VRAM display * restyle time taken/VRAM display
* add textual inversion hashes to infotext * add textual inversion hashes to infotext
* optimization: cache git extension repo information * optimization: cache git extension repo information
* move generate button next to the generated picture for mobile clients
* hide cards for networks of incompatible Stable Diffusion version in Lora extra networks interface
* skip installing packages with pip if they all are already installed - startup speedup of about 2 seconds
### Minor: ### Minor:
* checkbox to check/uncheck all extensions in the Installed tab * checkbox to check/uncheck all extensions in the Installed tab
...@@ -25,6 +54,8 @@ ...@@ -25,6 +54,8 @@
* speedup extra networks listing * speedup extra networks listing
* added `[none]` filename token. * added `[none]` filename token.
* removed thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs) * removed thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs)
* add always_discard_next_to_last_sigma option to XYZ plot
* automatically switch to 32-bit float VAE if the generated picture has NaNs without the need for `--no-half-vae` commandline flag.
### Extensions and API: ### Extensions and API:
* api endpoints: /sdapi/v1/server-kill, /sdapi/v1/server-restart, /sdapi/v1/server-stop * api endpoints: /sdapi/v1/server-kill, /sdapi/v1/server-restart, /sdapi/v1/server-stop
...@@ -53,9 +84,8 @@ ...@@ -53,9 +84,8 @@
* fix: check fill size none zero when resize (fixes #11425) * fix: check fill size none zero when resize (fixes #11425)
* use submit and blur for quick settings textbox * use submit and blur for quick settings textbox
* save img2img batch with images.save_image() * save img2img batch with images.save_image()
* * prevent running preload.py for disabled extensions
* fix: previously, model name was added together with directory name to infotext and to [model_name] filename pattern; directory name is now not included
## 1.4.1 ## 1.4.1
......
...@@ -88,7 +88,7 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -88,7 +88,7 @@ A browser interface based on Gradio library for Stable Diffusion.
- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions
- Now without any bad letters! - Now without any bad letters!
- Load checkpoints in safetensors format - Load checkpoints in safetensors format
- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64 - Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64
- Now with a license! - Now with a license!
- Reorder elements in the UI from settings screen - Reorder elements in the UI from settings screen
...@@ -168,5 +168,7 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al ...@@ -168,5 +168,7 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- Security advice - RyotaK - Security advice - RyotaK
- UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- LyCORIS - KohakuBlueleaf
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You) - (You)
from modules import extra_networks, shared from modules import extra_networks, shared
import lora import networks
class ExtraNetworkLora(extra_networks.ExtraNetwork): class ExtraNetworkLora(extra_networks.ExtraNetwork):
...@@ -9,24 +9,38 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -9,24 +9,38 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
def activate(self, p, params_list): def activate(self, p, params_list):
additional = shared.opts.sd_lora additional = shared.opts.sd_lora
if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional): if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = [] names = []
multipliers = [] te_multipliers = []
unet_multipliers = []
dyn_dims = []
for params in params_list: for params in params_list:
assert params.items assert params.items
names.append(params.items[0]) names.append(params.positional[0])
multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
lora.load_loras(names, multipliers) te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
te_multiplier = float(params.named.get("te", te_multiplier))
unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier
unet_multiplier = float(params.named.get("unet", unet_multiplier))
dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
te_multipliers.append(te_multiplier)
unet_multipliers.append(unet_multiplier)
dyn_dims.append(dyn_dim)
networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
if shared.opts.lora_add_hashes_to_infotext: if shared.opts.lora_add_hashes_to_infotext:
lora_hashes = [] network_hashes = []
for item in lora.loaded_loras: for item in networks.loaded_networks:
shorthash = item.lora_on_disk.shorthash shorthash = item.network_on_disk.shorthash
if not shorthash: if not shorthash:
continue continue
...@@ -36,10 +50,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): ...@@ -36,10 +50,10 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
alias = alias.replace(":", "").replace(",", "") alias = alias.replace(":", "").replace(",", "")
lora_hashes.append(f"{alias}: {shorthash}") network_hashes.append(f"{alias}: {shorthash}")
if lora_hashes: if network_hashes:
p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes) p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
def deactivate(self, p): def deactivate(self, p):
pass pass
This diff is collapsed.
import torch
def make_weight_cp(t, wa, wb):
temp = torch.einsum('i j k l, j r -> i r k l', t, wb)
return torch.einsum('i j k l, i r -> r j k l', temp, wa)
def rebuild_conventional(up, down, shape, dyn_dim=None):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
if dyn_dim is not None:
up = up[:, :dyn_dim]
down = down[:dyn_dim, :]
return (up @ down).reshape(shape)
def rebuild_cp_decomposition(up, down, mid):
up = up.reshape(up.size(0), -1)
down = down.reshape(down.size(0), -1)
return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)
from __future__ import annotations
import os
from collections import namedtuple
import enum
from modules import sd_models, cache, errors, hashes, shared
NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
class SdVersion(enum.Enum):
Unknown = 1
SD1 = 2
SD2 = 3
SDXL = 4
class NetworkOnDisk:
def __init__(self, name, filename):
self.name = name
self.filename = filename
self.metadata = {}
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
def read_metadata():
metadata = sd_models.read_metadata_from_safetensors(filename)
metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text
return metadata
if self.is_safetensors:
try:
self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading lora {filename}")
if self.metadata:
m = {}
for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):
m[k] = v
self.metadata = m
self.alias = self.metadata.get('ss_output_name', self.name)
self.hash = None
self.shorthash = None
self.set_hash(
self.metadata.get('sshs_model_hash') or
hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or
''
)
self.sd_version = self.detect_version()
def detect_version(self):
if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):
return SdVersion.SDXL
elif str(self.metadata.get('ss_v2', "")) == "True":
return SdVersion.SD2
elif len(self.metadata):
return SdVersion.SD1
return SdVersion.Unknown
def set_hash(self, v):
self.hash = v
self.shorthash = self.hash[0:12]
if self.shorthash:
import networks
networks.available_network_hash_lookup[self.shorthash] = self
def read_hash(self):
if not self.hash:
self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')
def get_alias(self):
import networks
if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:
return self.name
else:
return self.alias
class Network: # LoraModule
def __init__(self, name, network_on_disk: NetworkOnDisk):
self.name = name
self.network_on_disk = network_on_disk
self.te_multiplier = 1.0
self.unet_multiplier = 1.0
self.dyn_dim = None
self.modules = {}
self.mtime = None
self.mentioned_name = None
"""the text that was used to add the network to prompt - can be either name or an alias"""
class ModuleType:
def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:
return None
class NetworkModule:
def __init__(self, net: Network, weights: NetworkWeights):
self.network = net
self.network_key = weights.network_key
self.sd_key = weights.sd_key
self.sd_module = weights.sd_module
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.dim = None
self.bias = weights.w.get("bias")
self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None
self.scale = weights.w["scale"].item() if "scale" in weights.w else None
def multiplier(self):
if 'transformer' in self.sd_key[:20]:
return self.network.te_multiplier
else:
return self.network.unet_multiplier
def calc_scale(self):
if self.scale is not None:
return self.scale
if self.dim is not None and self.alpha is not None:
return self.alpha / self.dim
return 1.0
def finalize_updown(self, updown, orig_weight, output_shape):
if self.bias is not None:
updown = updown.reshape(self.bias.shape)
updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)
updown = updown.reshape(output_shape)
if len(output_shape) == 4:
updown = updown.reshape(output_shape)
if orig_weight.size().numel() == updown.size().numel():
updown = updown.reshape(orig_weight.shape)
return updown * self.calc_scale() * self.multiplier()
def calc_updown(self, target):
raise NotImplementedError()
def forward(self, x, y):
raise NotImplementedError()
import network
class ModuleTypeFull(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["diff"]):
return NetworkModuleFull(net, weights)
return None
class NetworkModuleFull(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.weight = weights.w.get("diff")
def calc_updown(self, orig_weight):
output_shape = self.weight.shape
updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)
return self.finalize_updown(updown, orig_weight, output_shape)
import lyco_helpers
import network
class ModuleTypeHada(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):
return NetworkModuleHada(net, weights)
return None
class NetworkModuleHada(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
if hasattr(self.sd_module, 'weight'):
self.shape = self.sd_module.weight.shape
self.w1a = weights.w["hada_w1_a"]
self.w1b = weights.w["hada_w1_b"]
self.dim = self.w1b.shape[0]
self.w2a = weights.w["hada_w2_a"]
self.w2b = weights.w["hada_w2_b"]
self.t1 = weights.w.get("hada_t1")
self.t2 = weights.w.get("hada_t2")
def calc_updown(self, orig_weight):
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
output_shape = [w1a.size(0), w1b.size(1)]
if self.t1 is not None:
output_shape = [w1a.size(1), w1b.size(1)]
t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)
updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)
output_shape += t1.shape[2:]
else:
if len(w1b.shape) == 4:
output_shape += w1b.shape[2:]
updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)
if self.t2 is not None:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
else:
updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)
updown = updown1 * updown2
return self.finalize_updown(updown, orig_weight, output_shape)
import network
class ModuleTypeIa3(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["weight"]):
return NetworkModuleIa3(net, weights)
return None
class NetworkModuleIa3(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.w = weights.w["weight"]
self.on_input = weights.w["on_input"].item()
def calc_updown(self, orig_weight):
w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)
output_shape = [w.size(0), orig_weight.size(1)]
if self.on_input:
output_shape.reverse()
else:
w = w.reshape(-1, 1)
updown = orig_weight * w
return self.finalize_updown(updown, orig_weight, output_shape)
import torch
import lyco_helpers
import network
class ModuleTypeLokr(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)
has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)
if has_1 and has_2:
return NetworkModuleLokr(net, weights)
return None
def make_kron(orig_shape, w1, w2):
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous()
return torch.kron(w1, w2).reshape(orig_shape)
class NetworkModuleLokr(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.w1 = weights.w.get("lokr_w1")
self.w1a = weights.w.get("lokr_w1_a")
self.w1b = weights.w.get("lokr_w1_b")
self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim
self.w2 = weights.w.get("lokr_w2")
self.w2a = weights.w.get("lokr_w2_a")
self.w2b = weights.w.get("lokr_w2_b")
self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim
self.t2 = weights.w.get("lokr_t2")
def calc_updown(self, orig_weight):
if self.w1 is not None:
w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)
else:
w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)
w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)
w1 = w1a @ w1b
if self.w2 is not None:
w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)
elif self.t2 is None:
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w2 = w2a @ w2b
else:
t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)
w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)
w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)
w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)
output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]
if len(orig_weight.shape) == 4:
output_shape = orig_weight.shape
updown = make_kron(output_shape, w1, w2)
return self.finalize_updown(updown, orig_weight, output_shape)
import torch
import lyco_helpers
import network
from modules import devices
class ModuleTypeLora(network.ModuleType):
def create_module(self, net: network.Network, weights: network.NetworkWeights):
if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):
return NetworkModuleLora(net, weights)
return None
class NetworkModuleLora(network.NetworkModule):
def __init__(self, net: network.Network, weights: network.NetworkWeights):
super().__init__(net, weights)
self.up_model = self.create_module(weights.w, "lora_up.weight")
self.down_model = self.create_module(weights.w, "lora_down.weight")
self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)
self.dim = weights.w["lora_down.weight"].shape[0]
def create_module(self, weights, key, none_ok=False):
weight = weights.get(key)
if weight is None and none_ok:
return None
is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]
is_conv = type(self.sd_module) in [torch.nn.Conv2d]
if is_linear:
weight = weight.reshape(weight.shape[0], -1)
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif is_conv and key == "lora_down.weight" or key == "dyn_up":
if len(weight.shape) == 2:
weight = weight.reshape(weight.shape[0], -1, 1, 1)
if weight.shape[2] != 1 or weight.shape[3] != 1:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
else:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif is_conv and key == "lora_mid.weight":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)
elif is_conv and key == "lora_up.weight" or key == "dyn_down":
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')
with torch.no_grad():
if weight.shape != module.weight.shape:
weight = weight.reshape(module.weight.shape)
module.weight.copy_(weight)
module.to(device=devices.cpu, dtype=devices.dtype)
module.weight.requires_grad_(False)
return module
def calc_updown(self, orig_weight):
up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
output_shape = [up.size(0), down.size(1)]
if self.mid_model is not None:
# cp-decomposition
mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)
updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)
output_shape += mid.shape[2:]
else:
if len(down.shape) == 4:
output_shape += down.shape[2:]
updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)
return self.finalize_updown(updown, orig_weight, output_shape)
def forward(self, x, y):
self.up_model.to(device=devices.device)
self.down_model.to(device=devices.device)
return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()
This diff is collapsed.
...@@ -4,3 +4,4 @@ from modules import paths ...@@ -4,3 +4,4 @@ from modules import paths
def preload(parser): def preload(parser):
parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))
parser.add_argument("--lyco-dir-backcompat", type=str, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))
...@@ -4,69 +4,76 @@ import torch ...@@ -4,69 +4,76 @@ import torch
import gradio as gr import gradio as gr
from fastapi import FastAPI from fastapi import FastAPI
import lora import network
import networks
import lora # noqa:F401
import extra_networks_lora import extra_networks_lora
import ui_extra_networks_lora import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload(): def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
def before_ui(): def before_ui():
ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())
extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())
extra_network = extra_networks_lora.ExtraNetworkLora()
extra_networks.register_extra_network(extra_network)
extra_networks.register_extra_network_alias(extra_network, "lyco")
if not hasattr(torch.nn, 'Linear_forward_before_lora'):
torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): if not hasattr(torch.nn, 'Linear_forward_before_network'):
torch.nn.Linear_load_state_dict_before_lora = torch.nn.Linear._load_from_state_dict torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'): if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'): if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
torch.nn.Linear.forward = lora.lora_Linear_forward if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload) script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui) script_callbacks.on_before_ui(before_ui)
script_callbacks.on_infotext_pasted(lora.infotext_pasted) script_callbacks.on_infotext_pasted(networks.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras), "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks),
"lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),
"lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
})) }))
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), { shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), "lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
})) }))
def create_lora_json(obj: lora.LoraOnDisk): def create_lora_json(obj: network.NetworkOnDisk):
return { return {
"name": obj.name, "name": obj.name,
"alias": obj.alias, "alias": obj.alias,
...@@ -75,17 +82,17 @@ def create_lora_json(obj: lora.LoraOnDisk): ...@@ -75,17 +82,17 @@ def create_lora_json(obj: lora.LoraOnDisk):
} }
def api_loras(_: gr.Blocks, app: FastAPI): def api_networks(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras") @app.get("/sdapi/v1/loras")
async def get_loras(): async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()] return [create_lora_json(obj) for obj in networks.available_networks.values()]
@app.post("/sdapi/v1/refresh-loras") @app.post("/sdapi/v1/refresh-loras")
async def refresh_loras(): async def refresh_loras():
return lora.list_available_loras() return networks.list_available_networks()
script_callbacks.on_app_started(api_loras) script_callbacks.on_app_started(api_networks)
re_lora = re.compile("<lora:([^:]+):") re_lora = re.compile("<lora:([^:]+):")
...@@ -98,19 +105,19 @@ def infotext_pasted(infotext, d): ...@@ -98,19 +105,19 @@ def infotext_pasted(infotext, d):
hashes = [x.strip().split(':', 1) for x in hashes.split(",")] hashes = [x.strip().split(':', 1) for x in hashes.split(",")]
hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes} hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}
def lora_replacement(m): def network_replacement(m):
alias = m.group(1) alias = m.group(1)
shorthash = hashes.get(alias) shorthash = hashes.get(alias)
if shorthash is None: if shorthash is None:
return m.group(0) return m.group(0)
lora_on_disk = lora.available_lora_hash_lookup.get(shorthash) network_on_disk = networks.available_network_hash_lookup.get(shorthash)
if lora_on_disk is None: if network_on_disk is None:
return m.group(0) return m.group(0)
return f'<lora:{lora_on_disk.get_alias()}:' return f'<lora:{network_on_disk.get_alias()}:'
d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"]) d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])
script_callbacks.on_infotext_pasted(infotext_pasted) script_callbacks.on_infotext_pasted(infotext_pasted)
import datetime
import html import html
import random import random
...@@ -46,14 +47,17 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -46,14 +47,17 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
def __init__(self, ui, tabname, page): def __init__(self, ui, tabname, page):
super().__init__(ui, tabname, page) super().__init__(ui, tabname, page)
self.select_sd_version = None
self.taginfo = None self.taginfo = None
self.edit_activation_text = None self.edit_activation_text = None
self.slider_preferred_weight = None self.slider_preferred_weight = None
self.edit_notes = None self.edit_notes = None
def save_lora_user_metadata(self, name, desc, activation_text, preferred_weight, notes): def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes):
user_metadata = self.get_user_metadata(name) user_metadata = self.get_user_metadata(name)
user_metadata["description"] = desc user_metadata["description"] = desc
user_metadata["sd version"] = sd_version
user_metadata["activation text"] = activation_text user_metadata["activation text"] = activation_text
user_metadata["preferred weight"] = preferred_weight user_metadata["preferred weight"] = preferred_weight
user_metadata["notes"] = notes user_metadata["notes"] = notes
...@@ -68,6 +72,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -68,6 +72,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
keys = { keys = {
'ss_sd_model_name': "Model:", 'ss_sd_model_name': "Model:",
'ss_clip_skip': "Clip skip:", 'ss_clip_skip': "Clip skip:",
'ss_network_module': "Kohya module:",
} }
for key, label in keys.items(): for key, label in keys.items():
...@@ -75,6 +80,10 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -75,6 +80,10 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
if value is not None and str(value) != "None": if value is not None and str(value) != "None":
table.append((label, html.escape(value))) table.append((label, html.escape(value)))
ss_training_started_at = metadata.get('ss_training_started_at')
if ss_training_started_at:
table.append(("Date trained:", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M')))
ss_bucket_info = metadata.get("ss_bucket_info") ss_bucket_info = metadata.get("ss_bucket_info")
if ss_bucket_info and "buckets" in ss_bucket_info: if ss_bucket_info and "buckets" in ss_bucket_info:
resolutions = {} resolutions = {}
...@@ -112,11 +121,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -112,11 +121,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]] gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]
return [ return [
*values[0:4], *values[0:5],
item.get("sd_version", "Unknown"),
gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),
user_metadata.get('activation text', ''), user_metadata.get('activation text', ''),
float(user_metadata.get('preferred weight', 0.0)), float(user_metadata.get('preferred weight', 0.0)),
user_metadata.get('notes', ''),
gr.update(visible=True if tags else False), gr.update(visible=True if tags else False),
gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),
] ]
...@@ -141,10 +150,15 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -141,10 +150,15 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
return ", ".join(sorted(res)) return ", ".join(sorted(res))
def create_extra_default_items_in_left_column(self):
# this would be a lot better as gr.Radio but I can't make it work
self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)
def create_editor(self): def create_editor(self):
self.create_default_editor_elems() self.create_default_editor_elems()
self.taginfo = gr.HighlightedText(label="Tags") self.taginfo = gr.HighlightedText(label="Training dataset tags")
self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")
self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)
...@@ -153,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -153,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
with gr.Column(scale=1, min_width=120): with gr.Column(scale=1, min_width=120):
generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg") generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
self.edit_notes = gr.TextArea(label='Notes', lines=4) self.edit_notes = gr.TextArea(label='Notes', lines=4)
...@@ -178,10 +192,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -178,10 +192,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
self.edit_description, self.edit_description,
self.html_filedata, self.html_filedata,
self.html_preview, self.html_preview,
self.edit_notes,
self.select_sd_version,
self.taginfo, self.taginfo,
self.edit_activation_text, self.edit_activation_text,
self.slider_preferred_weight, self.slider_preferred_weight,
self.edit_notes,
row_random_prompt, row_random_prompt,
random_prompt, random_prompt,
] ]
...@@ -192,6 +207,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) ...@@ -192,6 +207,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
edited_components = [ edited_components = [
self.edit_description, self.edit_description,
self.select_sd_version,
self.edit_activation_text, self.edit_activation_text,
self.slider_preferred_weight, self.slider_preferred_weight,
self.edit_notes, self.edit_notes,
......
import os import os
import lora
import network
import networks
from modules import shared, ui_extra_networks from modules import shared, ui_extra_networks
from modules.ui_extra_networks import quote_js from modules.ui_extra_networks import quote_js
...@@ -11,16 +13,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -11,16 +13,15 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
super().__init__('Lora') super().__init__('Lora')
def refresh(self): def refresh(self):
lora.list_available_loras() networks.list_available_networks()
def create_item(self, name, index=None): def create_item(self, name, index=None, enable_filter=True):
lora_on_disk = lora.available_loras.get(name) lora_on_disk = networks.available_networks.get(name)
path, ext = os.path.splitext(lora_on_disk.filename) path, ext = os.path.splitext(lora_on_disk.filename)
alias = lora_on_disk.get_alias() alias = lora_on_disk.get_alias()
# in 1.5 filename changes to be full filename instead of path without extension, and metadata is dict instead of json string
item = { item = {
"name": name, "name": name,
"filename": lora_on_disk.filename, "filename": lora_on_disk.filename,
...@@ -30,6 +31,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -30,6 +31,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"local_preview": f"{path}.{shared.opts.samples_format}", "local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": lora_on_disk.metadata, "metadata": lora_on_disk.metadata,
"sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},
"sd_version": lora_on_disk.sd_version.name,
} }
self.read_user_metadata(item) self.read_user_metadata(item)
...@@ -40,15 +42,37 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): ...@@ -40,15 +42,37 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
if activation_text: if activation_text:
item["prompt"] += " + " + quote_js(" " + activation_text) item["prompt"] += " + " + quote_js(" " + activation_text)
sd_version = item["user_metadata"].get("sd version")
if sd_version in network.SdVersion.__members__:
item["sd_version"] = sd_version
sd_version = network.SdVersion[sd_version]
else:
sd_version = lora_on_disk.sd_version
if shared.opts.lora_show_all or not enable_filter:
pass
elif sd_version == network.SdVersion.Unknown:
model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1
if model_version.name in shared.opts.lora_hide_unknown_for_versions:
return None
elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:
return None
elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:
return None
elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1:
return None
return item return item
def list_items(self): def list_items(self):
for index, name in enumerate(lora.available_loras): for index, name in enumerate(networks.available_networks):
item = self.create_item(name, index) item = self.create_item(name, index)
yield item
if item is not None:
yield item
def allowed_directories_for_previews(self): def allowed_directories_for_previews(self):
return [shared.cmd_opts.lora_dir] return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir_backcompat]
def create_user_metadata_editor(self, ui, tabname): def create_user_metadata_editor(self, ui, tabname):
return LoraUserMetadataEditor(ui, tabname, self) return LoraUserMetadataEditor(ui, tabname, self)
<div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}> <div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}>
{background_image} {background_image}
<div class="button-row"> <div class="button-row">
{edit_button}
{metadata_button} {metadata_button}
{edit_button}
</div> </div>
<div class='actions'> <div class='actions'>
<div class='additional'> <div class='additional'>
......
(function() {
var ignore = localStorage.getItem("bad-scale-ignore-it") == "ignore-it";
function getScale() {
var ratio = 0,
screen = window.screen,
ua = navigator.userAgent.toLowerCase();
if (window.devicePixelRatio !== undefined) {
ratio = window.devicePixelRatio;
} else if (~ua.indexOf('msie')) {
if (screen.deviceXDPI && screen.logicalXDPI) {
ratio = screen.deviceXDPI / screen.logicalXDPI;
}
} else if (window.outerWidth !== undefined && window.innerWidth !== undefined) {
ratio = window.outerWidth / window.innerWidth;
}
return ratio == 0 ? 0 : Math.round(ratio * 100);
}
var showing = false;
var div = document.createElement("div");
div.style.position = "fixed";
div.style.top = "0px";
div.style.left = "0px";
div.style.width = "100vw";
div.style.backgroundColor = "firebrick";
div.style.textAlign = "center";
div.style.zIndex = 99;
var b = document.createElement("b");
b.innerHTML = 'Bad Scale: ??% ';
div.appendChild(b);
var note1 = document.createElement("p");
note1.innerHTML = "Change your browser or your computer settings!";
note1.title = 'Just make sure "computer-scale" * "browser-scale" = 100% ,\n' +
"you can keep your computer-scale and only change this page's scale,\n" +
"for example: your computer-scale is 125%, just use [\"CTRL\"+\"-\"] to make your browser-scale of this page to 80%.";
div.appendChild(note1);
var note2 = document.createElement("p");
note2.innerHTML = " Otherwise, it will cause this page to not function properly!";
note2.title = "When you click \"Copy image to: [inpaint sketch]\" in some img2img's tab,\n" +
"if scale<100% the canvas will be invisible,\n" +
"else if scale>100% this page will take large amount of memory and CPU performance.";
div.appendChild(note2);
var btn = document.createElement("button");
btn.innerHTML = "Click here to ignore";
div.appendChild(btn);
function tryShowTopBar(scale) {
if (showing) return;
b.innerHTML = 'Bad Scale: ' + scale + '% ';
var updateScaleTimer = setInterval(function() {
var newScale = getScale();
b.innerHTML = 'Bad Scale: ' + newScale + '% ';
if (newScale == 100) {
var p = div.parentNode;
if (p != null) p.removeChild(div);
showing = false;
clearInterval(updateScaleTimer);
check();
}
}, 999);
btn.onclick = function() {
clearInterval(updateScaleTimer);
var p = div.parentNode;
if (p != null) p.removeChild(div);
ignore = true;
showing = false;
localStorage.setItem("bad-scale-ignore-it", "ignore-it");
};
document.body.appendChild(div);
}
function check() {
if (!ignore) {
var timer = setInterval(function() {
var scale = getScale();
if (scale != 100 && !ignore) {
tryShowTopBar(scale);
clearInterval(timer);
}
if (ignore) {
clearInterval(timer);
}
}, 999);
}
}
if (document.readyState != "complete") {
document.onreadystatechange = function() {
if (document.readyState != "complete") check();
};
} else {
check();
}
})();
...@@ -213,7 +213,7 @@ function popup(contents) { ...@@ -213,7 +213,7 @@ function popup(contents) {
globalPopupInner.classList.add('global-popup-inner'); globalPopupInner.classList.add('global-popup-inner');
globalPopup.appendChild(globalPopupInner); globalPopup.appendChild(globalPopupInner);
gradioApp().appendChild(globalPopup); gradioApp().querySelector('.main').appendChild(globalPopup);
} }
globalPopupInner.innerHTML = ''; globalPopupInner.innerHTML = '';
......
...@@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) { ...@@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) {
tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000); tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
} }
}); });
onUiLoaded(function() {
for (var comp of window.gradio_config.components) {
if (comp.props.webui_tooltip && comp.props.elem_id) {
var elem = gradioApp().getElementById(comp.props.elem_id);
if (elem) {
elem.title = comp.props.webui_tooltip;
}
}
}
});
...@@ -152,7 +152,11 @@ function submit() { ...@@ -152,7 +152,11 @@ function submit() {
showSubmitButtons('txt2img', false); showSubmitButtons('txt2img', false);
var id = randomId(); var id = randomId();
localStorage.setItem("txt2img_task_id", id); try {
localStorage.setItem("txt2img_task_id", id);
} catch (e) {
console.warn(`Failed to save txt2img task id to localStorage: ${e}`);
}
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
showSubmitButtons('txt2img', true); showSubmitButtons('txt2img', true);
...@@ -171,7 +175,11 @@ function submit_img2img() { ...@@ -171,7 +175,11 @@ function submit_img2img() {
showSubmitButtons('img2img', false); showSubmitButtons('img2img', false);
var id = randomId(); var id = randomId();
localStorage.setItem("img2img_task_id", id); try {
localStorage.setItem("img2img_task_id", id);
} catch (e) {
console.warn(`Failed to save img2img task id to localStorage: ${e}`);
}
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
showSubmitButtons('img2img', true); showSubmitButtons('img2img', true);
...@@ -191,8 +199,6 @@ function restoreProgressTxt2img() { ...@@ -191,8 +199,6 @@ function restoreProgressTxt2img() {
showRestoreProgressButton("txt2img", false); showRestoreProgressButton("txt2img", false);
var id = localStorage.getItem("txt2img_task_id"); var id = localStorage.getItem("txt2img_task_id");
id = localStorage.getItem("txt2img_task_id");
if (id) { if (id) {
requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
showSubmitButtons('txt2img', true); showSubmitButtons('txt2img', true);
......
from modules import launch_utils from modules import launch_utils
args = launch_utils.args args = launch_utils.args
python = launch_utils.python python = launch_utils.python
git = launch_utils.git git = launch_utils.git
...@@ -18,6 +17,7 @@ run_pip = launch_utils.run_pip ...@@ -18,6 +17,7 @@ run_pip = launch_utils.run_pip
check_run_python = launch_utils.check_run_python check_run_python = launch_utils.check_run_python
git_clone = launch_utils.git_clone git_clone = launch_utils.git_clone
git_pull_recursive = launch_utils.git_pull_recursive git_pull_recursive = launch_utils.git_pull_recursive
list_extensions = launch_utils.list_extensions
run_extension_installer = launch_utils.run_extension_installer run_extension_installer = launch_utils.run_extension_installer
prepare_environment = launch_utils.prepare_environment prepare_environment = launch_utils.prepare_environment
configure_for_tests = launch_utils.configure_for_tests configure_for_tests = launch_utils.configure_for_tests
...@@ -25,8 +25,11 @@ start = launch_utils.start ...@@ -25,8 +25,11 @@ start = launch_utils.start
def main(): def main():
if not args.skip_prepare_environment: launch_utils.startup_timer.record("initial startup")
prepare_environment()
with launch_utils.startup_timer.subcategory("prepare environment"):
if not args.skip_prepare_environment:
prepare_environment()
if args.test_server: if args.test_server:
configure_for_tests() configure_for_tests()
......
...@@ -15,7 +15,7 @@ from fastapi.encoders import jsonable_encoder ...@@ -15,7 +15,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest from secrets import compare_digest
import modules.shared as shared import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
from modules.api import models from modules.api import models
from modules.shared import opts from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
...@@ -197,6 +197,7 @@ class Api: ...@@ -197,6 +197,7 @@ class Api:
self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
...@@ -333,14 +334,17 @@ class Api: ...@@ -333,14 +334,17 @@ class Api:
p.outpath_grids = opts.outdir_txt2img_grids p.outpath_grids = opts.outdir_txt2img_grids
p.outpath_samples = opts.outdir_txt2img_samples p.outpath_samples = opts.outdir_txt2img_samples
shared.state.begin(job="scripts_txt2img") try:
if selectable_scripts is not None: shared.state.begin(job="scripts_txt2img")
p.script_args = script_args if selectable_scripts is not None:
processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here p.script_args = script_args
else: processed = scripts.scripts_txt2img.run(p, *p.script_args) # Need to pass args as list here
p.script_args = tuple(script_args) # Need to pass args as tuple here else:
processed = process_images(p) p.script_args = tuple(script_args) # Need to pass args as tuple here
shared.state.end() processed = process_images(p)
finally:
shared.state.end()
shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
...@@ -390,14 +394,17 @@ class Api: ...@@ -390,14 +394,17 @@ class Api:
p.outpath_grids = opts.outdir_img2img_grids p.outpath_grids = opts.outdir_img2img_grids
p.outpath_samples = opts.outdir_img2img_samples p.outpath_samples = opts.outdir_img2img_samples
shared.state.begin(job="scripts_img2img") try:
if selectable_scripts is not None: shared.state.begin(job="scripts_img2img")
p.script_args = script_args if selectable_scripts is not None:
processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here p.script_args = script_args
else: processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
p.script_args = tuple(script_args) # Need to pass args as tuple here else:
processed = process_images(p) p.script_args = tuple(script_args) # Need to pass args as tuple here
shared.state.end() processed = process_images(p)
finally:
shared.state.end()
shared.total_tqdm.clear()
b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
...@@ -604,6 +611,10 @@ class Api: ...@@ -604,6 +611,10 @@ class Api:
with self.queue_lock: with self.queue_lock:
shared.refresh_checkpoints() shared.refresh_checkpoints()
def refresh_vae(self):
with self.queue_lock:
shared_items.refresh_vae_list()
def create_embedding(self, args: dict): def create_embedding(self, args: dict):
try: try:
shared.state.begin(job="create_embedding") shared.state.begin(job="create_embedding")
...@@ -720,9 +731,9 @@ class Api: ...@@ -720,9 +731,9 @@ class Api:
cuda = {'error': f'{err}'} cuda = {'error': f'{err}'}
return models.MemoryResponse(ram=ram, cuda=cuda) return models.MemoryResponse(ram=ram, cuda=cuda)
def launch(self, server_name, port): def launch(self, server_name, port, root_path):
self.app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive) uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path)
def kill_webui(self): def kill_webui(self):
restart.stop_program() restart.stop_program()
......
import inspect import inspect
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from typing import Any, Optional from typing import Any, Optional
from typing_extensions import Literal from typing_extensions import Literal
...@@ -207,11 +208,10 @@ class PreprocessResponse(BaseModel): ...@@ -207,11 +208,10 @@ class PreprocessResponse(BaseModel):
fields = {} fields = {}
for key, metadata in opts.data_labels.items(): for key, metadata in opts.data_labels.items():
value = opts.data.get(key) value = opts.data.get(key)
optType = opts.typemap.get(type(metadata.default), type(value)) optType = opts.typemap.get(type(metadata.default), type(metadata.default)) if metadata.default else Any
if (metadata is not None): if metadata is not None:
fields.update({key: (Optional[optType], Field( fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))})
default=metadata.default ,description=metadata.label))})
else: else:
fields.update({key: (Optional[optType], Field())}) fields.update({key: (Optional[optType], Field())})
......
...@@ -3,7 +3,7 @@ import html ...@@ -3,7 +3,7 @@ import html
import threading import threading
import time import time
from modules import shared, progress, errors from modules import shared, progress, errors, devices
queue_lock = threading.Lock() queue_lock = threading.Lock()
...@@ -75,6 +75,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): ...@@ -75,6 +75,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
error_message = f'{type(e).__name__}: {e}' error_message = f'{type(e).__name__}: {e}'
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"] res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
devices.torch_gc()
shared.state.skipped = False shared.state.skipped = False
shared.state.interrupted = False shared.state.interrupted = False
shared.state.job_count = 0 shared.state.job_count = 0
......
...@@ -13,8 +13,10 @@ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py ...@@ -13,8 +13,10 @@ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py
parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed") parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup") parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")
parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing") parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")
parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup")
parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation") parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation")
parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")
parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored") parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
...@@ -65,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre ...@@ -65,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre
parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
...@@ -109,3 +112,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi ...@@ -109,3 +112,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
...@@ -3,7 +3,7 @@ import contextlib ...@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache from functools import lru_cache
import torch import torch
from modules import errors from modules import errors, rng_philox
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
...@@ -71,14 +71,17 @@ def enable_tf32(): ...@@ -71,14 +71,17 @@ def enable_tf32():
torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
cpu = torch.device("cpu") cpu: torch.device = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None device: torch.device = None
dtype = torch.float16 device_interrogate: torch.device = None
dtype_vae = torch.float16 device_gfpgan: torch.device = None
dtype_unet = torch.float16 device_esrgan: torch.device = None
device_codeformer: torch.device = None
dtype: torch.dtype = torch.float16
dtype_vae: torch.dtype = torch.float16
dtype_unet: torch.dtype = torch.float16
unet_needs_upcast = False unet_needs_upcast = False
...@@ -90,23 +93,87 @@ def cond_cast_float(input): ...@@ -90,23 +93,87 @@ def cond_cast_float(input):
return input.float() if unet_needs_upcast else input return input.float() if unet_needs_upcast else input
nv_rng = None
def randn(seed, shape): def randn(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
from modules.shared import opts from modules.shared import opts
torch.manual_seed(seed) manual_seed(seed)
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)
if opts.randn_source == "CPU" or device.type == 'mps': if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def randn_local(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
from modules.shared import opts
if opts.randn_source == "NV":
rng = rng_philox.Generator(seed)
return torch.asarray(rng.randn(shape), device=device)
local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
local_generator = torch.Generator(local_device).manual_seed(int(seed))
return torch.randn(shape, device=local_device, generator=local_generator).to(device)
def randn_like(x):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
from modules.shared import opts
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=cpu).to(x.device)
return torch.randn_like(x)
def randn_without_seed(shape): def randn_without_seed(shape):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
from modules.shared import opts from modules.shared import opts
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)
if opts.randn_source == "CPU" or device.type == 'mps': if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device) return torch.randn(shape, device=device)
def manual_seed(seed):
"""Set up a global random number generator using the specified seed."""
from modules.shared import opts
if opts.randn_source == "NV":
global nv_rng
nv_rng = rng_philox.Generator(seed)
return
torch.manual_seed(seed)
def autocast(disable=False): def autocast(disable=False):
from modules import shared from modules import shared
......
...@@ -14,7 +14,8 @@ def record_exception(): ...@@ -14,7 +14,8 @@ def record_exception():
if exception_records and exception_records[-1] == e: if exception_records and exception_records[-1] == e:
return return
exception_records.append((e, tb)) from modules import sysinfo
exception_records.append(sysinfo.format_exception(e, tb))
if len(exception_records) > 5: if len(exception_records) > 5:
exception_records.pop(0) exception_records.pop(0)
...@@ -83,3 +84,53 @@ def run(code, task): ...@@ -83,3 +84,53 @@ def run(code, task):
code() code()
except Exception as e: except Exception as e:
display(task, e) display(task, e)
def check_versions():
from packaging import version
from modules import shared
import torch
import gradio
expected_torch_version = "2.0.0"
expected_xformers_version = "0.0.20"
expected_gradio_version = "3.39.0"
if version.parse(torch.__version__) < version.parse(expected_torch_version):
print_error_explanation(f"""
You are running torch {torch.__version__}.
The program is tested to work with torch {expected_torch_version}.
To reinstall the desired version, run with commandline flag --reinstall-torch.
Beware that this will cause a lot of large files to be downloaded, as well as
there are reports of issues with training tab on the latest version.
Use --skip-version-check commandline argument to disable this check.
""".strip())
if shared.xformers_available:
import xformers
if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
print_error_explanation(f"""
You are running xformers {xformers.__version__}.
The program is tested to work with xformers {expected_xformers_version}.
To reinstall the desired version, run with commandline flag --reinstall-xformers.
Use --skip-version-check commandline argument to disable this check.
""".strip())
if gradio.__version__ != expected_gradio_version:
print_error_explanation(f"""
You are running gradio {gradio.__version__}.
The program is designed to work with gradio {expected_gradio_version}.
Using a different version of gradio is extremely likely to break the program.
Reasons why you have the mismatched gradio version can be:
- you use --skip-install flag.
- you use webui.py to start the program instead of launch.py.
- an extension installs the incompatible gradio version.
Use --skip-version-check commandline argument to disable this check.
""".strip())
...@@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True) ...@@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True)
def active(): def active():
if shared.opts.disable_all_extensions == "all": if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
return [] return []
elif shared.opts.disable_all_extensions == "extra": elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
return [x for x in extensions if x.enabled and x.is_builtin] return [x for x in extensions if x.enabled and x.is_builtin]
else: else:
return [x for x in extensions if x.enabled] return [x for x in extensions if x.enabled]
...@@ -56,10 +56,12 @@ class Extension: ...@@ -56,10 +56,12 @@ class Extension:
self.do_read_info_from_repo() self.do_read_info_from_repo()
return self.to_dict() return self.to_dict()
try:
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
self.from_dict(d) self.from_dict(d)
self.status = 'unknown' except FileNotFoundError:
pass
self.status = 'unknown' if self.status == '' else self.status
def do_read_info_from_repo(self): def do_read_info_from_repo(self):
repo = None repo = None
...@@ -139,8 +141,12 @@ def list_extensions(): ...@@ -139,8 +141,12 @@ def list_extensions():
if not os.path.isdir(extensions_dir): if not os.path.isdir(extensions_dir):
return return
if shared.opts.disable_all_extensions == "all": if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
elif shared.opts.disable_all_extensions == "all":
print("*** \"Disable all extensions\" option was set, will not load any extensions ***") print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
elif shared.cmd_opts.disable_extra_extensions:
print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
elif shared.opts.disable_all_extensions == "extra": elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
......
...@@ -4,16 +4,22 @@ from collections import defaultdict ...@@ -4,16 +4,22 @@ from collections import defaultdict
from modules import errors from modules import errors
extra_network_registry = {} extra_network_registry = {}
extra_network_aliases = {}
def initialize(): def initialize():
extra_network_registry.clear() extra_network_registry.clear()
extra_network_aliases.clear()
def register_extra_network(extra_network): def register_extra_network(extra_network):
extra_network_registry[extra_network.name] = extra_network extra_network_registry[extra_network.name] = extra_network
def register_extra_network_alias(extra_network, alias):
extra_network_aliases[alias] = extra_network
def register_default_extra_networks(): def register_default_extra_networks():
from modules.extra_networks_hypernet import ExtraNetworkHypernet from modules.extra_networks_hypernet import ExtraNetworkHypernet
register_extra_network(ExtraNetworkHypernet()) register_extra_network(ExtraNetworkHypernet())
...@@ -82,20 +88,26 @@ def activate(p, extra_network_data): ...@@ -82,20 +88,26 @@ def activate(p, extra_network_data):
"""call activate for extra networks in extra_network_data in specified order, then call """call activate for extra networks in extra_network_data in specified order, then call
activate for all remaining registered networks with an empty argument list""" activate for all remaining registered networks with an empty argument list"""
activated = []
for extra_network_name, extra_network_args in extra_network_data.items(): for extra_network_name, extra_network_args in extra_network_data.items():
extra_network = extra_network_registry.get(extra_network_name, None) extra_network = extra_network_registry.get(extra_network_name, None)
if extra_network is None:
extra_network = extra_network_aliases.get(extra_network_name, None)
if extra_network is None: if extra_network is None:
print(f"Skipping unknown extra network: {extra_network_name}") print(f"Skipping unknown extra network: {extra_network_name}")
continue continue
try: try:
extra_network.activate(p, extra_network_args) extra_network.activate(p, extra_network_args)
activated.append(extra_network)
except Exception as e: except Exception as e:
errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")
for extra_network_name, extra_network in extra_network_registry.items(): for extra_network_name, extra_network in extra_network_registry.items():
args = extra_network_data.get(extra_network_name, None) if extra_network in activated:
if args is not None:
continue continue
try: try:
......
...@@ -7,7 +7,7 @@ import json ...@@ -7,7 +7,7 @@ import json
import torch import torch
import tqdm import tqdm
from modules import shared, images, sd_models, sd_vae, sd_models_config from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
from modules.ui_common import plaintext_to_html from modules.ui_common import plaintext_to_html
import gradio as gr import gradio as gr
import safetensors.torch import safetensors.torch
...@@ -72,7 +72,20 @@ def to_half(tensor, enable): ...@@ -72,7 +72,20 @@ def to_half(tensor, enable):
return tensor return tensor
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
metadata = {}
for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
if checkpoint_info is None:
continue
metadata.update(checkpoint_info.metadata)
return json.dumps(metadata, indent=4, ensure_ascii=False)
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
shared.state.begin(job="model-merge") shared.state.begin(job="model-merge")
def fail(message): def fail(message):
...@@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ ...@@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared.state.textinfo = "Saving" shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
metadata = None metadata = {}
if save_metadata and copy_metadata_fields:
if primary_model_info:
metadata.update(primary_model_info.metadata)
if secondary_model_info:
metadata.update(secondary_model_info.metadata)
if tertiary_model_info:
metadata.update(tertiary_model_info.metadata)
if save_metadata: if save_metadata:
metadata = {"format": "pt"} try:
metadata.update(json.loads(metadata_json))
except Exception as e:
errors.display(e, "readin metadata from json")
metadata["format"] = "pt"
if save_metadata and add_merge_recipe:
merge_recipe = { merge_recipe = {
"type": "webui", # indicate this model was merged with webui's built-in merger "type": "webui", # indicate this model was merged with webui's built-in merger
"primary_model_hash": primary_model_info.sha256, "primary_model_hash": primary_model_info.sha256,
...@@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ ...@@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
"is_inpainting": result_is_inpainting_model, "is_inpainting": result_is_inpainting_model,
"is_instruct_pix2pix": result_is_instruct_pix2pix_model "is_instruct_pix2pix": result_is_instruct_pix2pix_model
} }
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
sd_merge_models = {} sd_merge_models = {}
...@@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ ...@@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if tertiary_model_info: if tertiary_model_info:
add_model_metadata(tertiary_model_info) add_model_metadata(tertiary_model_info)
metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
metadata["sd_merge_models"] = json.dumps(sd_merge_models) metadata["sd_merge_models"] = json.dumps(sd_merge_models)
_, extension = os.path.splitext(output_modelname) _, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata) safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
else: else:
torch.save(theta_0, output_modelname) torch.save(theta_0, output_modelname)
......
import gradio as gr
from modules import scripts
def add_classes_to_gradio_component(comp):
"""
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
"""
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect')
def IOComponent_init(self, *args, **kwargs):
self.webui_tooltip = kwargs.pop('tooltip', None)
if scripts.scripts_current is not None:
scripts.scripts_current.before_component(self, **kwargs)
scripts.script_callbacks.before_component_callback(self, **kwargs)
res = original_IOComponent_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
scripts.script_callbacks.after_component_callback(self, **kwargs)
if scripts.scripts_current is not None:
scripts.scripts_current.after_component(self, **kwargs)
return res
def Block_get_config(self):
config = original_Block_get_config(self)
webui_tooltip = getattr(self, 'webui_tooltip', None)
if webui_tooltip:
config["webui_tooltip"] = webui_tooltip
return config
def BlockContext_init(self, *args, **kwargs):
res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
return res
original_IOComponent_init = gr.components.IOComponent.__init__
original_Block_get_config = gr.blocks.Block.get_config
original_BlockContext_init = gr.blocks.BlockContext.__init__
gr.components.IOComponent.__init__ = IOComponent_init
gr.blocks.Block.get_config = Block_get_config
gr.blocks.BlockContext.__init__ = BlockContext_init
...@@ -363,7 +363,7 @@ class FilenameGenerator: ...@@ -363,7 +363,7 @@ class FilenameGenerator:
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.model_name, replace_spaces=False), 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False),
'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>] 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp), 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
......
...@@ -10,7 +10,6 @@ from modules import sd_samplers, images as imgutil ...@@ -10,7 +10,6 @@ from modules import sd_samplers, images as imgutil
from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state from modules.shared import opts, state
from modules.images import save_image
import modules.shared as shared import modules.shared as shared
import modules.processing as processing import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
...@@ -18,9 +17,10 @@ import modules.scripts ...@@ -18,9 +17,10 @@ import modules.scripts
def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):
output_dir = output_dir.strip()
processing.fix_seed(p) processing.fix_seed(p)
images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp"))) images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp", ".tif", ".tiff")))
is_inpaint_batch = False is_inpaint_batch = False
if inpaint_mask_dir: if inpaint_mask_dir:
...@@ -32,11 +32,6 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal ...@@ -32,11 +32,6 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
save_normally = output_dir == ''
p.do_not_save_grid = True
p.do_not_save_samples = not save_normally
state.job_count = len(images) * p.n_iter state.job_count = len(images) * p.n_iter
# extract "default" params to use in case getting png info fails # extract "default" params to use in case getting png info fails
...@@ -111,21 +106,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal ...@@ -111,21 +106,14 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal
proc = modules.scripts.scripts_img2img.run(p, *args) proc = modules.scripts.scripts_img2img.run(p, *args)
if proc is None: if proc is None:
proc = process_images(p) if output_dir:
p.outpath_samples = output_dir
for n, processed_image in enumerate(proc.images): p.override_settings['save_to_dirs'] = False
filename = image_path.stem if p.n_iter > 1 or p.batch_size > 1:
infotext = proc.infotext(p, n) p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]'
relpath = os.path.dirname(os.path.relpath(image, input_dir)) else:
p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'
if n > 0: process_images(p)
filename += f"-{n}"
if not save_normally:
os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)
if processed_image.mode == 'RGBA':
processed_image = processed_image.convert("RGB")
save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False)
def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):
......
# this scripts installs necessary requirements and launches main program in webui.py # this scripts installs necessary requirements and launches main program in webui.py
import re
import subprocess import subprocess
import os import os
import sys import sys
...@@ -9,6 +10,7 @@ from functools import lru_cache ...@@ -9,6 +10,7 @@ from functools import lru_cache
from modules import cmd_args, errors from modules import cmd_args, errors
from modules.paths_internal import script_path, extensions_dir from modules.paths_internal import script_path, extensions_dir
from modules.timer import startup_timer
args, _ = cmd_args.parser.parse_known_args() args, _ = cmd_args.parser.parse_known_args()
...@@ -192,7 +194,7 @@ def run_extension_installer(extension_dir): ...@@ -192,7 +194,7 @@ def run_extension_installer(extension_dir):
try: try:
env = os.environ.copy() env = os.environ.copy()
env['PYTHONPATH'] = os.path.abspath(".") env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}"
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
except Exception as e: except Exception as e:
...@@ -222,8 +224,51 @@ def run_extensions_installers(settings_file): ...@@ -222,8 +224,51 @@ def run_extensions_installers(settings_file):
if not os.path.isdir(extensions_dir): if not os.path.isdir(extensions_dir):
return return
for dirname_extension in list_extensions(settings_file): with startup_timer.subcategory("run extensions installers"):
run_extension_installer(os.path.join(extensions_dir, dirname_extension)) for dirname_extension in list_extensions(settings_file):
path = os.path.join(extensions_dir, dirname_extension)
if os.path.isdir(path):
run_extension_installer(path)
startup_timer.record(dirname_extension)
re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*")
def requirements_met(requirements_file):
"""
Does a simple parse of a requirements.txt file to determine if all rerqirements in it
are already installed. Returns True if so, False if not installed or parsing fails.
"""
import importlib.metadata
import packaging.version
with open(requirements_file, "r", encoding="utf8") as file:
for line in file:
if line.strip() == "":
continue
m = re.match(re_requirement, line)
if m is None:
return False
package = m.group(1).strip()
version_required = (m.group(2) or "").strip()
if version_required == "":
continue
try:
version_installed = importlib.metadata.version(package)
except Exception:
return False
if packaging.version.parse(version_required) != packaging.version.parse(version_installed):
return False
return True
def prepare_environment(): def prepare_environment():
...@@ -251,15 +296,18 @@ def prepare_environment(): ...@@ -251,15 +296,18 @@ def prepare_environment():
try: try:
# the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution # the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution
os.remove(os.path.join(script_path, "tmp", "restart")) os.remove(os.path.join(script_path, "tmp", "restart"))
os.environ.setdefault('SD_WEBUI_RESTARTING ', '1') os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
except OSError: except OSError:
pass pass
if not args.skip_python_version_check: if not args.skip_python_version_check:
check_python_version() check_python_version()
startup_timer.record("checks")
commit = commit_hash() commit = commit_hash()
tag = git_tag() tag = git_tag()
startup_timer.record("git version info")
print(f"Python {sys.version}") print(f"Python {sys.version}")
print(f"Version: {tag}") print(f"Version: {tag}")
...@@ -267,21 +315,27 @@ def prepare_environment(): ...@@ -267,21 +315,27 @@ def prepare_environment():
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
startup_timer.record("install torch")
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
raise RuntimeError( raise RuntimeError(
'Torch is not able to use GPU; ' 'Torch is not able to use GPU; '
'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check' 'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'
) )
startup_timer.record("torch GPU test")
if not is_installed("gfpgan"): if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan") run_pip(f"install {gfpgan_package}", "gfpgan")
startup_timer.record("install gfpgan")
if not is_installed("clip"): if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip") run_pip(f"install {clip_package}", "clip")
startup_timer.record("install clip")
if not is_installed("open_clip"): if not is_installed("open_clip"):
run_pip(f"install {openclip_package}", "open_clip") run_pip(f"install {openclip_package}", "open_clip")
startup_timer.record("install open_clip")
if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers: if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
if platform.system() == "Windows": if platform.system() == "Windows":
...@@ -295,8 +349,11 @@ def prepare_environment(): ...@@ -295,8 +349,11 @@ def prepare_environment():
elif platform.system() == "Linux": elif platform.system() == "Linux":
run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
startup_timer.record("install xformers")
if not is_installed("ngrok") and args.ngrok: if not is_installed("ngrok") and args.ngrok:
run_pip("install ngrok", "ngrok") run_pip("install ngrok", "ngrok")
startup_timer.record("install ngrok")
os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
...@@ -306,20 +363,28 @@ def prepare_environment(): ...@@ -306,20 +363,28 @@ def prepare_environment():
git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)
startup_timer.record("clone repositores")
if not is_installed("lpips"): if not is_installed("lpips"):
run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer") run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")
startup_timer.record("install CodeFormer requirements")
if not os.path.isfile(requirements_file): if not os.path.isfile(requirements_file):
requirements_file = os.path.join(script_path, requirements_file) requirements_file = os.path.join(script_path, requirements_file)
run_pip(f"install -r \"{requirements_file}\"", "requirements")
if not requirements_met(requirements_file):
run_pip(f"install -r \"{requirements_file}\"", "requirements")
startup_timer.record("install requirements")
run_extensions_installers(settings_file=args.ui_settings_file) run_extensions_installers(settings_file=args.ui_settings_file)
if args.update_check: if args.update_check:
version_check(commit) version_check(commit)
startup_timer.record("check version")
if args.update_all_extensions: if args.update_all_extensions:
git_pull_recursive(extensions_dir) git_pull_recursive(extensions_dir)
startup_timer.record("update extensions")
if "--exit" in sys.argv: if "--exit" in sys.argv:
print("Exiting because of --exit argument") print("Exiting because of --exit argument")
......
...@@ -90,8 +90,12 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -90,8 +90,12 @@ def setup_for_low_vram(sd_model, use_medvram):
sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu) sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)
elif is_sd2: elif is_sd2:
sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu) sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu)
sd_model.cond_stage_model.model.token_embedding.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.model] = sd_model.cond_stage_model
parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model
else: else:
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.encode = first_stage_model_encode_wrap
...@@ -101,9 +105,6 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -101,9 +105,6 @@ def setup_for_low_vram(sd_model, use_medvram):
if sd_model.embedder: if sd_model.embedder:
sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) sd_model.embedder.register_forward_pre_hook(send_me_to_gpu)
if hasattr(sd_model, 'cond_stage_model'):
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram: if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu) sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else: else:
......
This diff is collapsed.
...@@ -19,7 +19,7 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)* ...@@ -19,7 +19,7 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
!emphasized: "(" prompt ")" !emphasized: "(" prompt ")"
| "(" prompt ":" prompt ")" | "(" prompt ":" prompt ")"
| "[" prompt "]" | "[" prompt "]"
scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER [WHITESPACE] "]"
alternate: "[" prompt ("|" prompt)+ "]" alternate: "[" prompt ("|" prompt)+ "]"
WHITESPACE: /\s+/ WHITESPACE: /\s+/
plain: /([^\\\[\]():|]|\\.)+/ plain: /([^\\\[\]():|]|\\.)+/
...@@ -60,11 +60,11 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ...@@ -60,11 +60,11 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
class CollectSteps(lark.Visitor): class CollectSteps(lark.Visitor):
def scheduled(self, tree): def scheduled(self, tree):
tree.children[-1] = float(tree.children[-1]) tree.children[-2] = float(tree.children[-2])
if tree.children[-1] < 1: if tree.children[-2] < 1:
tree.children[-1] *= steps tree.children[-2] *= steps
tree.children[-1] = min(steps, int(tree.children[-1])) tree.children[-2] = min(steps, int(tree.children[-2]))
res.append(tree.children[-1]) res.append(tree.children[-2])
def alternate(self, tree): def alternate(self, tree):
res.extend(range(1, steps+1)) res.extend(range(1, steps+1))
...@@ -75,7 +75,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ...@@ -75,7 +75,7 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
def at_step(step, tree): def at_step(step, tree):
class AtStep(lark.Transformer): class AtStep(lark.Transformer):
def scheduled(self, args): def scheduled(self, args):
before, after, _, when = args before, after, _, when, _ = args
yield before or () if step <= when else after yield before or () if step <= when else after
def alternate(self, args): def alternate(self, args):
yield next(args[(step - 1)%len(args)]) yield next(args[(step - 1)%len(args)])
...@@ -178,7 +178,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): ...@@ -178,7 +178,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
re_AND = re.compile(r"\bAND\b") re_AND = re.compile(r"\bAND\b")
re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") re_weight = re.compile(r"^((?:\s|.)*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$")
def get_multicond_prompt_list(prompts: SdConditioning | list[str]): def get_multicond_prompt_list(prompts: SdConditioning | list[str]):
...@@ -333,7 +333,7 @@ re_attention = re.compile(r""" ...@@ -333,7 +333,7 @@ re_attention = re.compile(r"""
\\| \\|
\(| \(|
\[| \[|
:([+-]?[.\d]+)\)| :\s*([+-]?[.\d]+)\s*\)|
\)| \)|
]| ]|
[^\\()\[\]:]+| [^\\()\[\]:]+|
......
"""RNG imitiating torch cuda randn on CPU. You are welcome.
Usage:
```
g = Generator(seed=0)
print(g.randn(shape=(3, 4)))
```
Expected output:
```
[[-0.92466259 -0.42534415 -2.6438457 0.14518388]
[-0.12086647 -0.57972564 -0.62285122 -0.32838709]
[-1.07454231 -0.36314407 -1.67105067 2.26550497]]
```
"""
import numpy as np
philox_m = [0xD2511F53, 0xCD9E8D57]
philox_w = [0x9E3779B9, 0xBB67AE85]
two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
def uint32(x):
"""Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
def philox4_round(counter, key):
"""A single round of the Philox 4x32 random number generator."""
v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
counter[0] = v2[1] ^ counter[1] ^ key[0]
counter[1] = v2[0]
counter[2] = v1[1] ^ counter[3] ^ key[1]
counter[3] = v1[0]
def philox4_32(counter, key, rounds=10):
"""Generates 32-bit random numbers using the Philox 4x32 random number generator.
Parameters:
counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
rounds (int): The number of rounds to perform.
Returns:
numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
"""
for _ in range(rounds - 1):
philox4_round(counter, key)
key[0] = key[0] + philox_w[0]
key[1] = key[1] + philox_w[1]
philox4_round(counter, key)
return counter
def box_muller(x, y):
"""Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
u = x * two_pow32_inv + two_pow32_inv / 2
v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
s = np.sqrt(-2.0 * np.log(u))
r1 = s * np.sin(v)
return r1.astype(np.float32)
class Generator:
"""RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
def __init__(self, seed):
self.seed = seed
self.offset = 0
def randn(self, shape):
"""Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
n = 1
for x in shape:
n *= x
counter = np.zeros((4, n), dtype=np.uint32)
counter[0] = self.offset
counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
self.offset += 1
key = np.empty(n, dtype=np.uint64)
key.fill(self.seed)
key = uint32(key)
g = philox4_32(counter, key)
return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
...@@ -12,11 +12,12 @@ def load_module(path): ...@@ -12,11 +12,12 @@ def load_module(path):
return module return module
def preload_extensions(extensions_dir, parser): def preload_extensions(extensions_dir, parser, extension_list=None):
if not os.path.isdir(extensions_dir): if not os.path.isdir(extensions_dir):
return return
for dirname in sorted(os.listdir(extensions_dir)): extensions = extension_list if extension_list is not None else os.listdir(extensions_dir)
for dirname in sorted(extensions):
preload_script = os.path.join(extensions_dir, dirname, "preload.py") preload_script = os.path.join(extensions_dir, dirname, "preload.py")
if not os.path.isfile(preload_script): if not os.path.isfile(preload_script):
continue continue
......
...@@ -16,6 +16,11 @@ class PostprocessImageArgs: ...@@ -16,6 +16,11 @@ class PostprocessImageArgs:
self.image = image self.image = image
class PostprocessBatchListArgs:
def __init__(self, images):
self.images = images
class Script: class Script:
name = None name = None
"""script's internal name derived from title""" """script's internal name derived from title"""
...@@ -119,7 +124,7 @@ class Script: ...@@ -119,7 +124,7 @@ class Script:
def after_extra_networks_activate(self, p, *args, **kwargs): def after_extra_networks_activate(self, p, *args, **kwargs):
""" """
Calledafter extra networks activation, before conds calculation Called after extra networks activation, before conds calculation
allow modification of the network after extra networks activation been applied allow modification of the network after extra networks activation been applied
won't be call if p.disable_extra_networks won't be call if p.disable_extra_networks
...@@ -156,6 +161,25 @@ class Script: ...@@ -156,6 +161,25 @@ class Script:
pass pass
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
"""
Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
This is useful when you want to update the entire batch instead of individual images.
You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
If the number of images is different from the batch size when returning,
then the script has the responsibility to also update the following attributes in the processing object (p):
- p.prompts
- p.negative_prompts
- p.seeds
- p.subseeds
**kwargs will have same items as process_batch, and also:
- batch_number - index of current batch, from 0 to number of batches-1
"""
pass
def postprocess_image(self, p, pp: PostprocessImageArgs, *args): def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
""" """
Called for every image after it has been generated. Called for every image after it has been generated.
...@@ -536,6 +560,14 @@ class ScriptRunner: ...@@ -536,6 +560,14 @@ class ScriptRunner:
except Exception: except Exception:
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
for script in self.alwayson_scripts:
try:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess_batch_list(p, pp, *script_args, **kwargs)
except Exception:
errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
def postprocess_image(self, p, pp: PostprocessImageArgs): def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts: for script in self.alwayson_scripts:
try: try:
...@@ -599,49 +631,3 @@ def reload_script_body_only(): ...@@ -599,49 +631,3 @@ def reload_script_body_only():
reload_scripts = load_scripts # compatibility alias reload_scripts = load_scripts # compatibility alias
def add_classes_to_gradio_component(comp):
"""
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
"""
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect')
def IOComponent_init(self, *args, **kwargs):
if scripts_current is not None:
scripts_current.before_component(self, **kwargs)
script_callbacks.before_component_callback(self, **kwargs)
res = original_IOComponent_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
script_callbacks.after_component_callback(self, **kwargs)
if scripts_current is not None:
scripts_current.after_component(self, **kwargs)
return res
original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init
def BlockContext_init(self, *args, **kwargs):
res = original_BlockContext_init(self, *args, **kwargs)
add_classes_to_gradio_component(self)
return res
original_BlockContext_init = gr.blocks.BlockContext.__init__
gr.blocks.BlockContext.__init__ = BlockContext_init
...@@ -3,8 +3,31 @@ import open_clip ...@@ -3,8 +3,31 @@ import open_clip
import torch import torch
import transformers.utils.hub import transformers.utils.hub
from modules import shared
class DisableInitialization:
class ReplaceHelper:
def __init__(self):
self.replaced = []
def replace(self, obj, field, func):
original = getattr(obj, field, None)
if original is None:
return None
self.replaced.append((obj, field, original))
setattr(obj, field, func)
return original
def restore(self):
for obj, field, original in self.replaced:
setattr(obj, field, original)
self.replaced.clear()
class DisableInitialization(ReplaceHelper):
""" """
When an object of this class enters a `with` block, it starts: When an object of this class enters a `with` block, it starts:
- preventing torch's layer initialization functions from working - preventing torch's layer initialization functions from working
...@@ -21,7 +44,7 @@ class DisableInitialization: ...@@ -21,7 +44,7 @@ class DisableInitialization:
""" """
def __init__(self, disable_clip=True): def __init__(self, disable_clip=True):
self.replaced = [] super().__init__()
self.disable_clip = disable_clip self.disable_clip = disable_clip
def replace(self, obj, field, func): def replace(self, obj, field, func):
...@@ -86,8 +109,81 @@ class DisableInitialization: ...@@ -86,8 +109,81 @@ class DisableInitialization:
self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
for obj, field, original in self.replaced: self.restore()
setattr(obj, field, original)
self.replaced.clear()
class InitializeOnMeta(ReplaceHelper):
"""
Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,
which results in those parameters having no values and taking no memory. model.to() will be broken and
will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.
Usage:
```
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)
```
"""
def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
return
def set_device(x):
x["device"] = "meta"
return x
linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))
conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))
mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))
self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)
def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()
class LoadStateDictOnMeta(ReplaceHelper):
"""
Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.
As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.
Meant to be used together with InitializeOnMeta above.
Usage:
```
with sd_disable_initialization.LoadStateDictOnMeta(state_dict):
model.load_state_dict(state_dict, strict=False)
```
"""
def __init__(self, state_dict, device):
super().__init__()
self.state_dict = state_dict
self.device = device
def __enter__(self):
if shared.cmd_opts.disable_model_loading_ram_optimization:
return
sd = self.state_dict
device = self.device
def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):
params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]
for name, param in params:
if param.is_meta:
self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)
original(self, state_dict, prefix, *args, **kwargs)
for name, _ in params:
key = prefix + name
if key in sd:
del sd[key]
linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))
conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))
mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))
def __exit__(self, exc_type, exc_val, exc_tb):
self.restore()
...@@ -197,7 +197,7 @@ class StableDiffusionModelHijack: ...@@ -197,7 +197,7 @@ class StableDiffusionModelHijack:
conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i]) text_cond_models.append(conditioner.embedders[i])
if typename == 'FrozenOpenCLIPEmbedder2': if typename == 'FrozenOpenCLIPEmbedder2':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
text_cond_models.append(conditioner.embedders[i]) text_cond_models.append(conditioner.embedders[i])
...@@ -243,7 +243,7 @@ class StableDiffusionModelHijack: ...@@ -243,7 +243,7 @@ class StableDiffusionModelHijack:
ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward
def undo_hijack(self, m): def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: if type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
...@@ -292,10 +292,11 @@ class StableDiffusionModelHijack: ...@@ -292,10 +292,11 @@ class StableDiffusionModelHijack:
class EmbeddingsWithFixes(torch.nn.Module): class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings): def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
super().__init__() super().__init__()
self.wrapped = wrapped self.wrapped = wrapped
self.embeddings = embeddings self.embeddings = embeddings
self.textual_inversion_key = textual_inversion_key
def forward(self, input_ids): def forward(self, input_ids):
batch_fixes = self.embeddings.fixes batch_fixes = self.embeddings.fixes
...@@ -309,7 +310,8 @@ class EmbeddingsWithFixes(torch.nn.Module): ...@@ -309,7 +310,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
vecs = [] vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds): for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes: for offset, embedding in fixes:
emb = devices.cond_cast_unet(embedding.vec) vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
emb = devices.cond_cast_unet(vec)
emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])
......
...@@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
position += 1 position += 1
continue continue
emb_len = int(embedding.vec.shape[0]) emb_len = int(embedding.vectors)
if len(chunk.tokens) + emb_len > self.chunk_length: if len(chunk.tokens) + emb_len > self.chunk_length:
next_chunk() next_chunk()
...@@ -245,6 +245,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -245,6 +245,8 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
hashes.append(f"{name}: {shorthash}") hashes.append(f"{name}: {shorthash}")
if hashes: if hashes:
if self.hijack.extra_generation_params.get("TI hashes"):
hashes.append(self.hijack.extra_generation_params.get("TI hashes"))
self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)
if getattr(self.wrapped, 'return_pooled', False): if getattr(self.wrapped, 'return_pooled', False):
...@@ -270,12 +272,17 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): ...@@ -270,12 +272,17 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
z = self.encode_with_transformers(tokens) z = self.encode_with_transformers(tokens)
pooled = getattr(z, 'pooled', None)
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
original_mean = z.mean() original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean() new_mean = z.mean()
z *= (original_mean / new_mean) z = z * (original_mean / new_mean)
if pooled is not None:
z.pooled = pooled
return z return z
......
...@@ -32,7 +32,7 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit ...@@ -32,7 +32,7 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit
def encode_embedding_init_text(self, init_text, nvpt): def encode_embedding_init_text(self, init_text, nvpt):
ids = tokenizer.encode(init_text) ids = tokenizer.encode(init_text)
ids = torch.asarray([ids], device=devices.device, dtype=torch.int) ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)
return embedded return embedded
......
...@@ -256,9 +256,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs): ...@@ -256,9 +256,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None, **kwargs):
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] slice_size = q.shape[1] // steps
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = min(i + slice_size, q.shape[1])
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype) s2 = s1.softmax(dim=-1, dtype=q.dtype)
......
...@@ -39,7 +39,10 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): ...@@ -39,7 +39,10 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
if isinstance(cond, dict): if isinstance(cond, dict):
for y in cond.keys(): for y in cond.keys():
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] if isinstance(cond[y], list):
cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]
else:
cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]
with devices.autocast(): with devices.autocast():
return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()
...@@ -77,3 +80,6 @@ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devi ...@@ -77,3 +80,6 @@ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devi
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)
CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)
CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)
CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)
...@@ -14,7 +14,7 @@ import ldm.modules.midas as midas ...@@ -14,7 +14,7 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer from modules.timer import Timer
import tomesd import tomesd
...@@ -33,6 +33,8 @@ class CheckpointInfo: ...@@ -33,6 +33,8 @@ class CheckpointInfo:
self.filename = filename self.filename = filename
abspath = os.path.abspath(filename) abspath = os.path.abspath(filename)
self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '') name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path): elif abspath.startswith(model_path):
...@@ -43,6 +45,19 @@ class CheckpointInfo: ...@@ -43,6 +45,19 @@ class CheckpointInfo:
if name.startswith("\\") or name.startswith("/"): if name.startswith("\\") or name.startswith("/"):
name = name[1:] name = name[1:]
def read_metadata():
metadata = read_metadata_from_safetensors(filename)
self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
return metadata
self.metadata = {}
if self.is_safetensors:
try:
self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
except Exception as e:
errors.display(e, f"reading metadata for {filename}")
self.name = name self.name = name
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
...@@ -53,16 +68,7 @@ class CheckpointInfo: ...@@ -53,16 +68,7 @@ class CheckpointInfo:
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
self.metadata = {}
_, ext = os.path.splitext(self.filename)
if ext.lower() == ".safetensors":
try:
self.metadata = read_metadata_from_safetensors(filename)
except Exception as e:
errors.display(e, f"reading checkpoint metadata: {filename}")
def register(self): def register(self):
checkpoints_list[self.title] = self checkpoints_list[self.title] = self
...@@ -79,7 +85,7 @@ class CheckpointInfo: ...@@ -79,7 +85,7 @@ class CheckpointInfo:
if self.shorthash not in self.ids: if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
checkpoints_list.pop(self.title) checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]' self.title = f'{self.name} [{self.shorthash}]'
self.register() self.register()
...@@ -290,6 +296,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer ...@@ -290,6 +296,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
state_dict = get_checkpoint_state_dict(checkpoint_info, timer) state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
model.is_sdxl = hasattr(model, 'conditioner') model.is_sdxl = hasattr(model, 'conditioner')
model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
model.is_sd1 = not model.is_sdxl and not model.is_sd2
if model.is_sdxl: if model.is_sdxl:
sd_models_xl.extend_sdxl(model) sd_models_xl.extend_sdxl(model)
...@@ -323,7 +332,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer ...@@ -323,7 +332,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
timer.record("apply half()") timer.record("apply half()")
devices.dtype_unet = model.model.diffusion_model.dtype devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
...@@ -457,7 +466,6 @@ def get_empty_cond(sd_model): ...@@ -457,7 +466,6 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""]) return sd_model.cond_stage_model([""])
def load_model(checkpoint_info=None, already_loaded_state_dict=None): def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint() checkpoint_info = checkpoint_info or select_checkpoint()
...@@ -491,20 +499,25 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): ...@@ -491,20 +499,25 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
sd_model = None sd_model = None
try: try:
with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd): with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):
sd_model = instantiate_from_config(sd_config.model) with sd_disable_initialization.InitializeOnMeta():
except Exception: sd_model = instantiate_from_config(sd_config.model)
pass
except Exception as e:
errors.display(e, "creating model quickly", full_traceback=True)
if sd_model is None: if sd_model is None:
print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
sd_model = instantiate_from_config(sd_config.model)
with sd_disable_initialization.InitializeOnMeta():
sd_model = instantiate_from_config(sd_config.model)
sd_model.used_config = checkpoint_config sd_model.used_config = checkpoint_config
timer.record("create model") timer.record("create model")
load_model_weights(sd_model, checkpoint_info, state_dict, timer) with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
......
...@@ -12,8 +12,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: ...@@ -12,8 +12,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
for embedder in self.conditioner.embedders: for embedder in self.conditioner.embedders:
embedder.ucg_rate = 0.0 embedder.ucg_rate = 0.0
width = getattr(self, 'target_width', 1024) width = getattr(batch, 'width', 1024)
height = getattr(self, 'target_height', 1024) height = getattr(batch, 'height', 1024)
is_negative_prompt = getattr(batch, 'is_negative_prompt', False) is_negative_prompt = getattr(batch, 'is_negative_prompt', False)
aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score
...@@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, ...@@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text,
return torch.cat(res, dim=1) return torch.cat(res, dim=1)
def tokenize(self: sgm.modules.GeneralConditioner, texts):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
return embedder.tokenize(texts)
raise AssertionError('no tokenizer available')
def process_texts(self, texts): def process_texts(self, texts):
for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
return embedder.process_texts(texts) return embedder.process_texts(texts)
...@@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count): ...@@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count):
# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
sgm.modules.GeneralConditioner.tokenize = tokenize
sgm.modules.GeneralConditioner.process_texts = process_texts sgm.modules.GeneralConditioner.process_texts = process_texts
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count
......
...@@ -2,10 +2,8 @@ from collections import namedtuple ...@@ -2,10 +2,8 @@ from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
...@@ -85,11 +83,13 @@ class InterruptedException(BaseException): ...@@ -85,11 +83,13 @@ class InterruptedException(BaseException):
pass pass
if opts.randn_source == "CPU": def replace_torchsde_browinan():
import torchsde._brownian.brownian_interval import torchsde._brownian.brownian_interval
def torchsde_randn(size, dtype, device, seed): def torchsde_randn(size, dtype, device, seed):
generator = torch.Generator(devices.cpu).manual_seed(int(seed)) return devices.randn_local(seed, size).to(device=device, dtype=dtype)
return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
torchsde._brownian.brownian_interval._randn = torchsde_randn torchsde._brownian.brownian_interval._randn = torchsde_randn
replace_torchsde_browinan()
import torch
import tqdm
import k_diffusion.sampling
@torch.no_grad()
def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):
"""Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)
Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}
If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list
"""
extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
step_id = 0
from k_diffusion.sampling import to_d, get_sigmas_karras
def heun_step(x, old_sigma, new_sigma, second_order=True):
nonlocal step_id
denoised = model(x, old_sigma * s_in, **extra_args)
d = to_d(x, old_sigma, denoised)
if callback is not None:
callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})
dt = new_sigma - old_sigma
if new_sigma == 0 or not second_order:
# Euler method
x = x + d * dt
else:
# Heun's method
x_2 = x + d * dt
denoised_2 = model(x_2, new_sigma * s_in, **extra_args)
d_2 = to_d(x_2, new_sigma, denoised_2)
d_prime = (d + d_2) / 2
x = x + d_prime * dt
step_id += 1
return x
steps = sigmas.shape[0] - 1
if restart_list is None:
if steps >= 20:
restart_steps = 9
restart_times = 1
if steps >= 36:
restart_steps = steps // 4
restart_times = 2
sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)
restart_list = {0.1: [restart_steps + 1, restart_times, 2]}
else:
restart_list = {}
restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}
step_list = []
for i in range(len(sigmas) - 1):
step_list.append((sigmas[i], sigmas[i + 1]))
if i + 1 in restart_list:
restart_steps, restart_times, restart_max = restart_list[i + 1]
min_idx = i + 1
max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
if max_idx < min_idx:
sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]
while restart_times > 0:
restart_times -= 1
step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])
last_sigma = None
for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):
if last_sigma is None:
last_sigma = old_sigma
elif last_sigma < old_sigma:
x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5
x = heun_step(x, old_sigma, new_sigma)
last_sigma = new_sigma
return x
...@@ -2,7 +2,7 @@ from collections import deque ...@@ -2,7 +2,7 @@ from collections import deque
import torch import torch
import inspect import inspect
import k_diffusion.sampling import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra
from modules.shared import opts, state from modules.shared import opts, state
import modules.shared as shared import modules.shared as shared
...@@ -30,12 +30,14 @@ samplers_k_diffusion = [ ...@@ -30,12 +30,14 @@ samplers_k_diffusion = [
('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),
('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),
('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),
] ]
samplers_data_k_diffusion = [ samplers_data_k_diffusion = [
sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)
for label, funcname, aliases, options in samplers_k_diffusion for label, funcname, aliases, options in samplers_k_diffusion
if hasattr(k_diffusion.sampling, funcname) if callable(funcname) or hasattr(k_diffusion.sampling, funcname)
] ]
sampler_extra_params = { sampler_extra_params = {
...@@ -258,10 +260,7 @@ class TorchHijack: ...@@ -258,10 +260,7 @@ class TorchHijack:
if noise.shape == x.shape: if noise.shape == x.shape:
return noise return noise
if opts.randn_source == "CPU" or x.device.type == 'mps': return devices.randn_like(x)
return torch.randn_like(x, device=devices.cpu).to(x.device)
else:
return torch.randn_like(x)
class KDiffusionSampler: class KDiffusionSampler:
...@@ -270,7 +269,7 @@ class KDiffusionSampler: ...@@ -270,7 +269,7 @@ class KDiffusionSampler:
self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname) self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, []) self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None self.sampler_noises = None
......
...@@ -11,6 +11,7 @@ import gradio as gr ...@@ -11,6 +11,7 @@ import gradio as gr
import torch import torch
import tqdm import tqdm
import launch
import modules.interrogate import modules.interrogate
import modules.memmon import modules.memmon
import modules.styles import modules.styles
...@@ -26,7 +27,7 @@ demo = None ...@@ -26,7 +27,7 @@ demo = None
parser = cmd_args.parser parser = cmd_args.parser
script_loading.preload_extensions(extensions_dir, parser) script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file))
script_loading.preload_extensions(extensions_builtin_dir, parser) script_loading.preload_extensions(extensions_builtin_dir, parser)
if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:
...@@ -384,7 +385,8 @@ options_templates.update(options_section(('face-restoration', "Face restoration" ...@@ -384,7 +385,8 @@ options_templates.update(options_section(('face-restoration', "Face restoration"
})) }))
options_templates.update(options_section(('system', "System"), { options_templates.update(options_section(('system', "System"), {
"show_warnings": OptionInfo(False, "Show warnings in console."), "show_warnings": OptionInfo(False, "Show warnings in console.").needs_restart(),
"show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_restart(),
"memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),
"samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
"multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
...@@ -426,7 +428,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { ...@@ -426,7 +428,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
})) }))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
......
...@@ -106,10 +106,7 @@ class StyleDatabase: ...@@ -106,10 +106,7 @@ class StyleDatabase:
if os.path.exists(path): if os.path.exists(path):
shutil.copy(path, f"{path}.bak") shutil.copy(path, f"{path}.bak")
fd = os.open(path, os.O_RDWR | os.O_CREAT) with open(path, "w", encoding="utf-8-sig", newline='') as file:
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
# _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
# and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
writer.writeheader() writer.writeheader()
writer.writerows(style._asdict() for k, style in self.styles.items()) writer.writerows(style._asdict() for k, style in self.styles.items())
......
...@@ -109,11 +109,15 @@ def format_traceback(tb): ...@@ -109,11 +109,15 @@ def format_traceback(tb):
return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)] return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]
def format_exception(e, tb):
return {"exception": str(e), "traceback": format_traceback(tb)}
def get_exceptions(): def get_exceptions():
try: try:
from modules import errors from modules import errors
return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)] return list(reversed(errors.exception_records))
except Exception as e: except Exception as e:
return str(e) return str(e)
......
...@@ -181,29 +181,38 @@ class EmbeddingDatabase: ...@@ -181,29 +181,38 @@ class EmbeddingDatabase:
else: else:
return return
# textual inversion embeddings # textual inversion embeddings
if 'string_to_param' in data: if 'string_to_param' in data:
param_dict = data['string_to_param'] param_dict = data['string_to_param']
param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it' assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1] emb = next(iter(param_dict.items()))[1]
# diffuser concepts vec = emb.detach().to(devices.device, dtype=torch.float32)
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: shape = vec.shape[-1]
vectors = vec.shape[0]
elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding
vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
vectors = data['clip_g'].shape[0]
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
assert len(data.keys()) == 1, 'embedding file has multiple terms in it' assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values())) emb = next(iter(data.values()))
if len(emb.shape) == 1: if len(emb.shape) == 1:
emb = emb.unsqueeze(0) emb = emb.unsqueeze(0)
vec = emb.detach().to(devices.device, dtype=torch.float32)
shape = vec.shape[-1]
vectors = vec.shape[0]
else: else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name) embedding = Embedding(vec, name)
embedding.step = data.get('step', None) embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
embedding.vectors = vec.shape[0] embedding.vectors = vectors
embedding.shape = vec.shape[-1] embedding.shape = shape
embedding.filename = path embedding.filename = path
embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')
......
import time import time
import argparse
class TimerSubcategory: class TimerSubcategory:
...@@ -11,20 +12,27 @@ class TimerSubcategory: ...@@ -11,20 +12,27 @@ class TimerSubcategory:
def __enter__(self): def __enter__(self):
self.start = time.time() self.start = time.time()
self.timer.base_category = self.original_base_category + self.category + "/" self.timer.base_category = self.original_base_category + self.category + "/"
self.timer.subcategory_level += 1
if self.timer.print_log:
print(f"{' ' * self.timer.subcategory_level}{self.category}:")
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
elapsed_for_subcategroy = time.time() - self.start elapsed_for_subcategroy = time.time() - self.start
self.timer.base_category = self.original_base_category self.timer.base_category = self.original_base_category
self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy) self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)
self.timer.record(self.category) self.timer.subcategory_level -= 1
self.timer.record(self.category, disable_log=True)
class Timer: class Timer:
def __init__(self): def __init__(self, print_log=False):
self.start = time.time() self.start = time.time()
self.records = {} self.records = {}
self.total = 0 self.total = 0
self.base_category = '' self.base_category = ''
self.print_log = print_log
self.subcategory_level = 0
def elapsed(self): def elapsed(self):
end = time.time() end = time.time()
...@@ -38,13 +46,16 @@ class Timer: ...@@ -38,13 +46,16 @@ class Timer:
self.records[category] += amount self.records[category] += amount
def record(self, category, extra_time=0): def record(self, category, extra_time=0, disable_log=False):
e = self.elapsed() e = self.elapsed()
self.add_time_to_record(self.base_category + category, e + extra_time) self.add_time_to_record(self.base_category + category, e + extra_time)
self.total += e + extra_time self.total += e + extra_time
if self.print_log and not disable_log:
print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s")
def subcategory(self, name): def subcategory(self, name):
self.elapsed() self.elapsed()
...@@ -71,6 +82,10 @@ class Timer: ...@@ -71,6 +82,10 @@ class Timer:
self.__init__() self.__init__()
startup_timer = Timer() parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")
args = parser.parse_known_args()[0]
startup_timer = Timer(print_log=args.log_startup)
startup_record = None startup_record = None
This diff is collapsed.
import gradio as gr
from modules import sd_models, sd_vae, errors, extras, call_queue
from modules.ui_components import FormRow
from modules.ui_common import create_refresh_button
def update_interp_description(value):
interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"
interp_descriptions = {
"No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),
"Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),
"Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")
}
return interp_descriptions[value]
def modelmerger(*args):
try:
results = extras.run_modelmerger(*args)
except Exception as e:
errors.report("Error loading/saving model file", exc_info=True)
sd_models.list_models() # to remove the potentially missing models from the list
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
return results
class UiCheckpointMerger:
def __init__(self):
with gr.Blocks(analytics_enabled=False) as modelmerger_interface:
with gr.Row(equal_height=False):
with gr.Column(variant='compact'):
self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")
with FormRow(elem_id="modelmerger_models"):
self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")
create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")
self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")
create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")
self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")
create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")
self.custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")
self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")
self.interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")
self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description])
with FormRow():
self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
with FormRow():
with gr.Column():
self.config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")
with gr.Column():
with FormRow():
self.bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")
with FormRow():
self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")
with gr.Accordion("Metadata", open=False) as metadata_editor:
with FormRow():
self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")
self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")
self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")
self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")
self.read_metadata = gr.Button("Read metadata from selected checkpoints")
with FormRow():
self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')
with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
with gr.Group(elem_id="modelmerger_results_panel"):
self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)
self.metadata_editor = metadata_editor
self.blocks = modelmerger_interface
def setup_ui(self, dummy_component, sd_model_checkpoint_component):
self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)
self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])
self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])
self.modelmerger_merge.click(
fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
_js='modelmerger',
inputs=[
dummy_component,
self.primary_model_name,
self.secondary_model_name,
self.tertiary_model_name,
self.interp_method,
self.interp_amount,
self.save_as_half,
self.custom_name,
self.checkpoint_format,
self.config_source,
self.bake_in_vae,
self.discard_weights,
self.save_metadata,
self.add_merge_recipe,
self.copy_metadata_fields,
self.metadata_json,
],
outputs=[
self.primary_model_name,
self.secondary_model_name,
self.tertiary_model_name,
sd_model_checkpoint_component,
self.modelmerger_result,
]
)
# Required as a workaround for change() event not triggering when loading values from ui-config.json
self.interp_description.value = update_interp_description(self.interp_method.value)
...@@ -134,7 +134,7 @@ Requested path was: {f} ...@@ -134,7 +134,7 @@ Requested path was: {f}
with gr.Column(variant='panel', elem_id=f"{tabname}_results"): with gr.Column(variant='panel', elem_id=f"{tabname}_results"):
with gr.Group(elem_id=f"{tabname}_gallery_container"): with gr.Group(elem_id=f"{tabname}_gallery_container"):
result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(columns=4) result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery", columns=4)
generation_info = None generation_info = None
with gr.Column(): with gr.Column():
...@@ -223,20 +223,44 @@ Requested path was: {f} ...@@ -223,20 +223,44 @@ Requested path was: {f}
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]
label = None
for comp in refresh_components:
label = getattr(comp, 'label', None)
if label is not None:
break
def refresh(): def refresh():
refresh_method() refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items(): for k, v in args.items():
setattr(refresh_component, k, v) for comp in refresh_components:
setattr(comp, k, v)
return gr.update(**(args or {})) return [gr.update(**(args or {})) for _ in refresh_components]
refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
refresh_button.click( refresh_button.click(
fn=refresh, fn=refresh,
inputs=[], inputs=[],
outputs=[refresh_component] outputs=[*refresh_components]
) )
return refresh_button return refresh_button
def setup_dialog(button_show, dialog, *, button_close=None):
"""Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""
dialog.visible = False
button_show.click(
fn=lambda: gr.update(visible=True),
inputs=[],
outputs=[dialog],
).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }")
if button_close:
button_close.click(fn=None, _js="closePopup")
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment