Commit 688c4a91 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into 1404-script-reload-without-restart

parents a634c322 852fd90c
......@@ -25,3 +25,4 @@ __pycache__
/.idea
notification.mp3
/SwinIR
/textual_inversion
......@@ -11,12 +11,12 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- One click install and run script (but you still must install python and git)
- Outpainting
- Inpainting
- Prompt
- Stable Diffusion upscale
- Prompt Matrix
- Stable Diffusion Upscale
- Attention, specify parts of text that the model should pay more attention to
- a man in a ((txuedo)) - will pay more attentinoto tuxedo
- a man in a (txuedo:1.21) - alternative syntax
- Loopback, run img2img procvessing multiple times
- a man in a ((tuxedo)) - will pay more attention to tuxedo
- a man in a (tuxedo:1.21) - alternative syntax
- Loopback, run img2img processing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
- Textual Inversion
- have as many embeddings as you want and use any names you like for them
......@@ -35,15 +35,15 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- 4GB video card support (also reports of 2GB working)
- Correct seeds for batches
- Prompt length validation
- get length of prompt in tokensas you type
- get a warning after geenration if some text was truncated
- get length of prompt in tokens as you type
- get a warning after generation if some text was truncated
- Generation parameters
- parameters you used to generate images are saved with that image
- in PNG chunks for PNG, in EXIF for JPEG
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
- can be disabled in settings
- Settings page
- Running arbitrary python code from UI (must run with commandline flag to enable)
- Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button
......
......@@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
})
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
......
function start_training_textual_inversion(){
requestProgress('ti')
gradioApp().querySelector('#ti_error').innerHTML=''
return args_to_array(arguments)
}
......@@ -199,12 +199,18 @@ let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
let token_timeout;
function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) {
event.preventDefault();
gradioApp().getElementById(generate_button_id).click();
return;
}
function update_txt2img_tokens(...args) {
update_token_counter("txt2img_token_button")
if (args.length == 2)
return args[0]
return args;
}
function update_img2img_tokens(...args) {
update_token_counter("img2img_token_button")
if (args.length == 2)
return args[0]
return args;
}
function update_token_counter(button_id) {
......@@ -213,6 +219,14 @@ function update_token_counter(button_id) {
token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
}
function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) {
event.preventDefault();
gradioApp().getElementById(generate_button_id).click();
return;
}
}
function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},2000)
......
......@@ -15,6 +15,7 @@ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
......@@ -111,6 +112,9 @@ if not skip_torch_cuda_test:
if not is_installed("gfpgan"):
run_pip(f"install {gfpgan_package}", "gfpgan")
if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip")
os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
......
......@@ -32,10 +32,9 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
device = get_optimal_device()
device_codeformer = cpu if has_mps else device
dtype = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
......
......@@ -73,8 +73,8 @@ def fix_model_layers(crt_model, pretrained_net):
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
self.name = "ESRGAN"
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
self.model_name = "ESRGAN 4x"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth"
self.model_name = "ESRGAN_4x"
self.scalers = []
self.user_path = dirname
self.model_path = os.path.join(models_path, self.name)
......
......@@ -311,7 +311,12 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
#currently disabled if using the save button, will work otherwise
# if enabled it will cause a bug because styles is not included in the save_files data dictionary
if hasattr(p, "styles"):
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
......
......@@ -103,7 +103,9 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpaint_full_res_padding=inpaint_full_res_padding,
inpainting_mask_invert=inpainting_mask_invert,
)
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
p.extra_generation_params["Mask blur"] = mask_blur
......
......@@ -5,7 +5,6 @@ import importlib
from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
from modules import shared
from modules.upscaler import Upscaler
from modules.paths import script_path, models_path
......@@ -43,7 +42,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
for place in places:
if os.path.exists(place):
for file in glob.iglob(place + '**/**', recursive=True):
full_path = os.path.join(place, file)
full_path = file
if os.path.isdir(full_path):
continue
if len(ext_filter) != 0:
......@@ -121,16 +120,30 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
def load_upscalers():
sd = shared.script_path
# We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__
modules_dir = os.path.join(sd, "modules")
for file in os.listdir(modules_dir):
if "_model.py" in file:
model_name = file.replace("_model.py", "")
full_model = f"modules.{model_name}_model"
try:
importlib.import_module(full_model)
except:
pass
datas = []
c_o = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__():
name = cls.__name__
module_name = cls.__module__
module = importlib.import_module(module_name)
class_ = getattr(module, name)
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
opt_string = None
try:
opt_string = shared.opts.__getattr__(cmd_name)
if cmd_name in c_o:
opt_string = c_o[cmd_name]
except:
pass
scaler = class_(opt_string)
......
......@@ -56,7 +56,7 @@ class StableDiffusionProcessing:
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
self.styles: str = styles
self.styles: list = styles or []
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
......@@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
}
generation_params.update(p.extra_generation_params)
......@@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
fix_seed(p)
os.makedirs(p.outpath_samples, exist_ok=True)
os.makedirs(p.outpath_grids, exist_ok=True)
if p.outpath_samples is not None:
os.makedirs(p.outpath_samples, exist_ok=True)
if p.outpath_grids is not None:
os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
......@@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = []
output_images = []
......
import os.path
import sys
import traceback
import PIL.Image
import numpy as np
import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import shared, modelloader
from modules.paths import models_path
from modules.scunet_model_arch import SCUNet as net
class UpscalerScuNET(modules.upscaler.Upscaler):
def __init__(self, dirname):
self.name = "ScuNET"
self.model_path = os.path.join(models_path, self.name)
self.model_name = "ScuNET GAN"
self.model_name2 = "ScuNET PSNR"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
self.model_url2 = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_psnr.pth"
self.user_path = dirname
super().__init__()
model_paths = self.find_models(ext_filter=[".pth"])
scalers = []
add_model2 = True
for file in model_paths:
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
if name == self.model_name2 or file == self.model_url2:
add_model2 = False
try:
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
print(f"Error loading ScuNET model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if add_model2:
scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self)
scalers.append(scaler_data2)
self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
return img
device = shared.device
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device)
img = img.to(device)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
device = shared.device
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
else:
filename = path
if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
return None
model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.to(device)
return model
# -*- coding: utf-8 -*-
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_, DropPath
class WMSA(nn.Module):
""" Self-attention module in Swin Transformer
"""
def __init__(self, input_dim, output_dim, head_dim, window_size, type):
super(WMSA, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.head_dim = head_dim
self.scale = self.head_dim ** -0.5
self.n_heads = input_dim // head_dim
self.window_size = window_size
self.type = type
self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
self.relative_position_params = nn.Parameter(
torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads))
self.linear = nn.Linear(self.input_dim, self.output_dim)
trunc_normal_(self.relative_position_params, std=.02)
self.relative_position_params = torch.nn.Parameter(
self.relative_position_params.view(2 * window_size - 1, 2 * window_size - 1, self.n_heads).transpose(1,
2).transpose(
0, 1))
def generate_mask(self, h, w, p, shift):
""" generating the mask of SW-MSA
Args:
shift: shift parameters in CyclicShift.
Returns:
attn_mask: should be (1 1 w p p),
"""
# supporting sqaure.
attn_mask = torch.zeros(h, w, p, p, p, p, dtype=torch.bool, device=self.relative_position_params.device)
if self.type == 'W':
return attn_mask
s = p - shift
attn_mask[-1, :, :s, :, s:, :] = True
attn_mask[-1, :, s:, :, :s, :] = True
attn_mask[:, -1, :, :s, :, s:] = True
attn_mask[:, -1, :, s:, :, :s] = True
attn_mask = rearrange(attn_mask, 'w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)')
return attn_mask
def forward(self, x):
""" Forward pass of Window Multi-head Self-attention module.
Args:
x: input tensor with shape of [b h w c];
attn_mask: attention mask, fill -inf where the value is True;
Returns:
output: tensor shape [b h w c]
"""
if self.type != 'W': x = torch.roll(x, shifts=(-(self.window_size // 2), -(self.window_size // 2)), dims=(1, 2))
x = rearrange(x, 'b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c', p1=self.window_size, p2=self.window_size)
h_windows = x.size(1)
w_windows = x.size(2)
# sqaure validation
# assert h_windows == w_windows
x = rearrange(x, 'b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c', p1=self.window_size, p2=self.window_size)
qkv = self.embedding_layer(x)
q, k, v = rearrange(qkv, 'b nw np (threeh c) -> threeh b nw np c', c=self.head_dim).chunk(3, dim=0)
sim = torch.einsum('hbwpc,hbwqc->hbwpq', q, k) * self.scale
# Adding learnable relative embedding
sim = sim + rearrange(self.relative_embedding(), 'h p q -> h 1 1 p q')
# Using Attn Mask to distinguish different subwindows.
if self.type != 'W':
attn_mask = self.generate_mask(h_windows, w_windows, self.window_size, shift=self.window_size // 2)
sim = sim.masked_fill_(attn_mask, float("-inf"))
probs = nn.functional.softmax(sim, dim=-1)
output = torch.einsum('hbwij,hbwjc->hbwic', probs, v)
output = rearrange(output, 'h b w p c -> b w p (h c)')
output = self.linear(output)
output = rearrange(output, 'b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c', w1=h_windows, p1=self.window_size)
if self.type != 'W': output = torch.roll(output, shifts=(self.window_size // 2, self.window_size // 2),
dims=(1, 2))
return output
def relative_embedding(self):
cord = torch.tensor(np.array([[i, j] for i in range(self.window_size) for j in range(self.window_size)]))
relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
# negative is allowed
return self.relative_position_params[:, relation[:, :, 0].long(), relation[:, :, 1].long()]
class Block(nn.Module):
def __init__(self, input_dim, output_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
""" SwinTransformer Block
"""
super(Block, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
assert type in ['W', 'SW']
self.type = type
if input_resolution <= window_size:
self.type = 'W'
self.ln1 = nn.LayerNorm(input_dim)
self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.ln2 = nn.LayerNorm(input_dim)
self.mlp = nn.Sequential(
nn.Linear(input_dim, 4 * input_dim),
nn.GELU(),
nn.Linear(4 * input_dim, output_dim),
)
def forward(self, x):
x = x + self.drop_path(self.msa(self.ln1(x)))
x = x + self.drop_path(self.mlp(self.ln2(x)))
return x
class ConvTransBlock(nn.Module):
def __init__(self, conv_dim, trans_dim, head_dim, window_size, drop_path, type='W', input_resolution=None):
""" SwinTransformer and Conv Block
"""
super(ConvTransBlock, self).__init__()
self.conv_dim = conv_dim
self.trans_dim = trans_dim
self.head_dim = head_dim
self.window_size = window_size
self.drop_path = drop_path
self.type = type
self.input_resolution = input_resolution
assert self.type in ['W', 'SW']
if self.input_resolution <= self.window_size:
self.type = 'W'
self.trans_block = Block(self.trans_dim, self.trans_dim, self.head_dim, self.window_size, self.drop_path,
self.type, self.input_resolution)
self.conv1_1 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
self.conv1_2 = nn.Conv2d(self.conv_dim + self.trans_dim, self.conv_dim + self.trans_dim, 1, 1, 0, bias=True)
self.conv_block = nn.Sequential(
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
nn.ReLU(True),
nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False)
)
def forward(self, x):
conv_x, trans_x = torch.split(self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1)
conv_x = self.conv_block(conv_x) + conv_x
trans_x = Rearrange('b c h w -> b h w c')(trans_x)
trans_x = self.trans_block(trans_x)
trans_x = Rearrange('b h w c -> b c h w')(trans_x)
res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
x = x + res
return x
class SCUNet(nn.Module):
# def __init__(self, in_nc=3, config=[2, 2, 2, 2, 2, 2, 2], dim=64, drop_path_rate=0.0, input_resolution=256):
def __init__(self, in_nc=3, config=None, dim=64, drop_path_rate=0.0, input_resolution=256):
super(SCUNet, self).__init__()
if config is None:
config = [2, 2, 2, 2, 2, 2, 2]
self.config = config
self.dim = dim
self.head_dim = 32
self.window_size = 8
# drop path rate for each layer
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
begin = 0
self.m_down1 = [ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution)
for i in range(config[0])] + \
[nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
begin += config[0]
self.m_down2 = [ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 2)
for i in range(config[1])] + \
[nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
begin += config[1]
self.m_down3 = [ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 4)
for i in range(config[2])] + \
[nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
begin += config[2]
self.m_body = [ConvTransBlock(4 * dim, 4 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 8)
for i in range(config[3])]
begin += config[3]
self.m_up3 = [nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(2 * dim, 2 * dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 4)
for i in range(config[4])]
begin += config[4]
self.m_up2 = [nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(dim, dim, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution // 2)
for i in range(config[5])]
begin += config[5]
self.m_up1 = [nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False), ] + \
[ConvTransBlock(dim // 2, dim // 2, self.head_dim, self.window_size, dpr[i + begin],
'W' if not i % 2 else 'SW', input_resolution)
for i in range(config[6])]
self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
self.m_head = nn.Sequential(*self.m_head)
self.m_down1 = nn.Sequential(*self.m_down1)
self.m_down2 = nn.Sequential(*self.m_down2)
self.m_down3 = nn.Sequential(*self.m_down3)
self.m_body = nn.Sequential(*self.m_body)
self.m_up3 = nn.Sequential(*self.m_up3)
self.m_up2 = nn.Sequential(*self.m_up2)
self.m_up1 = nn.Sequential(*self.m_up1)
self.m_tail = nn.Sequential(*self.m_tail)
# self.apply(self._init_weights)
def forward(self, x0):
h, w = x0.size()[-2:]
paddingBottom = int(np.ceil(h / 64) * 64 - h)
paddingRight = int(np.ceil(w / 64) * 64 - w)
x0 = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x0)
x1 = self.m_head(x0)
x2 = self.m_down1(x1)
x3 = self.m_down2(x2)
x4 = self.m_down3(x3)
x = self.m_body(x4)
x = self.m_up3(x + x4)
x = self.m_up2(x + x3)
x = self.m_up1(x + x2)
x = self.m_tail(x + x1)
x = x[..., :h, :w]
return x
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
\ No newline at end of file
......@@ -6,244 +6,41 @@ import torch
import numpy as np
from torch import einsum
from modules import prompt_parser
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
from ldm.util import default
from einops import rearrange
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
def apply_optimizations():
if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def nonlinearity_hijack(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
class StableDiffusionModelHijack:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
fixes = None
comments = []
dir_mtime = None
layers = None
circular_enabled = False
clip = None
def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
tokenizer = model.cond_stage_model.tokenizer
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
self.word_embeddings[name] = emb.detach().to(device)
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, name))
for fn in os.listdir(dirname):
try:
fullfn = os.path.join(dirname, fn)
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
......@@ -253,12 +50,7 @@ class StableDiffusionModelHijack:
self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
apply_optimizations()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
......@@ -296,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.hijack = hijack
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
......@@ -317,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
......@@ -339,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None)
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if possible_matches is None:
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
else:
found = False
for ids, word in possible_matches:
if tokens[i:i + len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
......@@ -431,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None)
embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
elif possible_matches is None:
i += 1
elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
else:
found = False
for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
......@@ -464,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
......@@ -484,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
......@@ -517,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module):
inputs_embeds = self.wrapped(input_ids)
if batch_fixes is not None:
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, word in fixes:
emb = self.embeddings.word_embeddings[word]
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
return inputs_embeds
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
vecs.append(tensor)
return inputs_embeds
return torch.stack(vecs)
def add_circular_option_to_conv_2d():
......
import math
import torch
from torch import einsum
from ldm.util import default
from einops import rearrange
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def nonlinearity_hijack(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
......@@ -8,14 +8,11 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from modules import shared, modelloader
from modules import shared, modelloader, devices
from modules.paths import models_path
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
model_name = "sd-v1-4.ckpt"
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
user_dir = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
......@@ -30,12 +27,10 @@ except Exception:
pass
def setup_model(dirname):
global user_dir
user_dir = dirname
def setup_model():
if not os.path.exists(model_path):
os.makedirs(model_path)
checkpoints_list.clear()
list_models()
......@@ -45,13 +40,13 @@ def checkpoint_tiles():
def list_models():
checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
if user_dir is not None and abspath.startswith(user_dir):
name = abspath.replace(user_dir, '')
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
......@@ -69,7 +64,7 @@ def list_models():
h = model_hash(cmd_ckpt)
title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.sd_model_checkpoint = title
shared.opts.data['sd_model_checkpoint'] = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
......@@ -106,8 +101,11 @@ def select_checkpoint():
if len(checkpoints_list) == 0:
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
if shared.cmd_opts.ckpt is not None:
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)
......@@ -134,6 +132,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
if not shared.cmd_opts.no_half:
model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
model.sd_model_hash = sd_model_hash
model.sd_model_checkpint = checkpoint_file
......
......@@ -77,7 +77,9 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence)
state.sampling_step = 0
for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
for x in seq:
if state.interrupted:
break
......@@ -207,7 +209,9 @@ def extended_trange(sampler, count, *args, **kwargs):
state.sampling_steps = count
state.sampling_step = 0
for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
for x in seq:
if state.interrupted:
break
......
......@@ -40,6 +40,7 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(model_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
......@@ -57,6 +58,9 @@ parser.add_argument("--opt-channelslast", action='store_true', help="change memo
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
cmd_opts = parser.parse_args()
device = get_optimal_device()
......@@ -78,6 +82,7 @@ class State:
current_latent = None
current_image = None
current_image_sampling_step = 0
textinfo = None
def interrupt(self):
self.interrupted = True
......@@ -88,7 +93,7 @@ class State:
self.current_image_sampling_step = 0
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S")
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
state = State()
......@@ -318,14 +323,14 @@ class TotalTQDM:
)
def update(self):
if not opts.multiple_tqdm:
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
self._tqdm.update()
def updateTotal(self, new_total):
if not opts.multiple_tqdm:
if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars:
return
if self._tqdm is None:
self.reset()
......
......@@ -5,6 +5,7 @@ import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader
from modules.paths import models_path
......@@ -122,18 +123,20 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=device)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
import tqdm
class PersonalizedBase(Dataset):
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
self.size = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
with open(template_file, "r") as file:
lines = [x.strip() for x in file.readlines()]
self.lines = lines
assert data_root, 'dataset directory not specified'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
image = Image.open(path)
image = image.convert('RGB')
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
filename_tokens = [token for token in filename_tokens if token.isalpha()]
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
self.dataset.append((init_latent, filename_tokens))
self.length = len(self.dataset) * repeats
self.initial_indexes = np.arange(self.length) % len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
def __len__(self):
return self.length
def __getitem__(self, i):
if i % len(self.dataset) == 0:
self.shuffle()
index = self.indexes[i % len(self.indexes)]
x, filename_tokens = self.dataset[index]
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens))
return x, text
import os
import sys
import traceback
import torch
import tqdm
import html
import datetime
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
self.name = name
self.step = step
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
def save(self, filename):
embedding_data = {
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
"sd_checkpoint": self.sd_checkpoint,
"sd_checkpoint_name": self.sd_checkpoint_name,
}
torch.save(embedding_data, filename)
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
return embedding
def load_textual_inversion_embeddings(self):
mt = os.path.getmtime(self.embeddings_dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('hash', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
for fn in os.listdir(self.embeddings_dir):
try:
fullfn = os.path.join(self.embeddings_dir, fn)
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
return None, None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
return embedding, len(ids)
return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
embedding.save(fn)
return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
os.makedirs(embedding_dir, exist_ok=True)
else:
embedding_dir = None
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
losses = torch.zeros((32,))
last_saved_file = "<none>"
last_saved_image = "<none>"
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
if embedding.step > steps:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
c = cond_model([text])
loss = shared.sd_model(x.unsqueeze(0), c)[0]
losses[embedding.step % losses.shape[0]] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=text,
steps=20,
do_not_save_grid=True,
do_not_save_samples=True,
)
processed = processing.process_images(p)
image = processed.images[0]
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {text}"
shared.state.job_no = embedding.step
shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
Last prompt: {html.escape(text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
checkpoint = sd_models.select_checkpoint()
embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None
embedding.save(filename)
return embedding, filename
import html
import gradio as gr
import modules.textual_inversion.textual_inversion as ti
from modules import sd_hijack, shared
def create_embedding(name, initialization_text, nvpt):
filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
def train_embedding(*args):
try:
sd_hijack.undo_optimizations()
embedding, filename = ti.train_embedding(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
Embedding saved to {html.escape(filename)}
"""
return res, ""
except Exception:
raise
finally:
sd_hijack.apply_optimizations()
......@@ -34,7 +34,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
denoising_strength=denoising_strength if enable_hr else None,
)
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None:
......
......@@ -11,6 +11,7 @@ import time
import traceback
import platform
import subprocess as sp
from functools import reduce
import numpy as np
import torch
......@@ -21,6 +22,7 @@ import gradio as gr
import gradio.utils
import gradio.routes
from modules import sd_hijack
from modules.paths import script_path
from modules.shared import opts, cmd_opts
import modules.shared as shared
......@@ -32,6 +34,9 @@ import modules.gfpgan_model
import modules.codeformer_model
import modules.styles
import modules.generation_parameters_copypaste
from modules.prompt_parser import get_learned_conditioning_prompt_schedules
from modules.images import apply_filename_pattern, get_next_sequence_number
import modules.textual_inversion.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
......@@ -94,14 +99,31 @@ def send_gradio_gallery_to_image(x):
def save_files(js_data, images, index):
import csv
os.makedirs(opts.outdir_save, exist_ok=True)
import csv
filenames = []
#quick dictionary to class object conversion. Its neccesary due apply_filename_pattern requiring it
class MyObject:
def __init__(self, d=None):
if d is not None:
for key, value in d.items():
setattr(self, key, value)
data = json.loads(js_data)
p = MyObject(data)
path = opts.outdir_save
save_to_dirs = opts.save_to_dirs
if save_to_dirs:
dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, p.seed, p.prompt)
path = os.path.join(opts.outdir_save, dirname)
os.makedirs(path, exist_ok=True)
if index > -1 and opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only
images = [images[index]]
infotexts = [data["infotexts"][index]]
else:
......@@ -113,11 +135,20 @@ def save_files(js_data, images, index):
if at_start:
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
filename_base = str(int(time.time() * 1000))
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
if file_decoration != "":
file_decoration = "-" + file_decoration.lower()
file_decoration = apply_filename_pattern(file_decoration, p, p.seed, p.prompt)
truncated = (file_decoration[:240] + '..') if len(file_decoration) > 240 else file_decoration
filename_base = truncated
extension = opts.samples_format.lower()
basecount = get_next_sequence_number(path, "")
for i, filedata in enumerate(images):
filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + f".{extension}"
filepath = os.path.join(opts.outdir_save, filename)
file_number = f"{basecount+i:05}"
filename = file_number + filename_base + f".{extension}"
filepath = os.path.join(path, filename)
if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):]
......@@ -142,8 +173,8 @@ def save_files(js_data, images, index):
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
def wrap_gradio_call(func):
def f(*args, **kwargs):
def wrap_gradio_call(func, extra_outputs=None):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
if run_memmon:
shared.mem_mon.monitor()
......@@ -159,7 +190,10 @@ def wrap_gradio_call(func):
shared.state.job = ""
shared.state.job_count = 0
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t
......@@ -179,6 +213,7 @@ def wrap_gradio_call(func):
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
shared.state.interrupted = False
shared.state.job_count = 0
return tuple(res)
......@@ -187,7 +222,7 @@ def wrap_gradio_call(func):
def check_progress_call(id_part):
if shared.state.job_count == 0:
return "", gr_show(False), gr_show(False)
return "", gr_show(False), gr_show(False), gr_show(False)
progress = 0
......@@ -219,13 +254,19 @@ def check_progress_call(id_part):
else:
preview_visibility = gr_show(True)
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
if shared.state.textinfo is not None:
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
else:
textinfo_result = gr_show(False)
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
def check_progress_call_initial(id_part):
shared.state.job_count = -1
shared.state.current_latent = None
shared.state.current_image = None
shared.state.textinfo = None
return check_progress_call(id_part)
......@@ -345,8 +386,11 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component]
)
def update_token_counter(text):
tokens, token_count, max_length = model_hijack.tokenize(text)
def update_token_counter(text, steps):
prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step,prompt_text in flat_prompts]
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>"
......@@ -364,8 +408,7 @@ def create_toprow(is_img2img):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
with gr.Column(scale=10, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
......@@ -396,16 +439,19 @@ def create_toprow(is_img2img):
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
save_style = gr.Button('Create style', elem_id="style_create")
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
if textinfo is None:
textinfo = gr.HTML(visible=False)
def setup_progressbar(progressbar, preview, id_part):
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
check_progress.click(
fn=lambda: check_progress_call(id_part),
show_progress=False,
inputs=[],
outputs=[progressbar, preview, preview],
outputs=[progressbar, preview, preview, textinfo],
)
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
......@@ -413,13 +459,16 @@ def setup_progressbar(progressbar, preview, id_part):
fn=lambda: check_progress_call_initial(id_part),
show_progress=False,
inputs=[],
outputs=[progressbar, preview, preview],
outputs=[progressbar, preview, preview, textinfo],
)
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
with gr.Row(elem_id='txt2img_progress_row'):
......@@ -483,7 +532,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict(
fn=txt2img,
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
_js="submit",
inputs=[
txt2img_prompt,
......@@ -539,6 +588,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
roll.click(
fn=roll_artist,
_js="update_txt2img_tokens",
inputs=[
txt2img_prompt,
],
......@@ -567,9 +617,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
]
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True)
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'):
with gr.Column(scale=1):
......@@ -675,7 +726,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
)
img2img_args = dict(
fn=img2img,
fn=wrap_gradio_gpu_call(modules.img2img.img2img),
_js="submit_img2img",
inputs=[
dummy_component,
......@@ -743,6 +794,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
roll.click(
fn=roll_artist,
_js="update_img2img_tokens",
inputs=[
img2img_prompt,
],
......@@ -753,6 +805,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)]
style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
button.click(
......@@ -764,9 +817,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2],
)
for button, (prompt, negative_prompt), (style1, style2) in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns):
for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
button.click(
fn=apply_styles,
_js=js_func,
inputs=[prompt, negative_prompt, style1, style2],
outputs=[prompt, negative_prompt, style1, style2],
)
......@@ -789,6 +843,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(denoising_strength, "Denoising strength"),
]
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False):
......@@ -828,7 +883,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click(
fn=run_extras,
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index",
inputs=[
dummy_component,
......@@ -878,7 +933,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
pnginfo_send_to_img2img = gr.Button('Send to img2img')
image.change(
fn=wrap_gradio_call(run_pnginfo),
fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image],
outputs=[html, generation_info, html2],
)
......@@ -887,7 +942,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
......@@ -896,10 +951,98 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
with gr.Blocks() as textual_inversion_interface:
with gr.Row().style(equal_height=False):
with gr.Column():
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
create_embedding = gr.Button(value="Create", variant='primary')
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
with gr.Row():
with gr.Column(scale=2):
gr.HTML(value="")
with gr.Column():
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_embedding = gr.Button(value="Train", variant='primary')
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
create_embedding.click(
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
new_embedding_name,
initialization_text,
nvpt,
],
outputs=[
train_embedding_name,
ti_output,
ti_outcome,
]
)
train_embedding.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
learn_rate,
dataset_directory,
log_directory,
steps,
create_image_every,
save_embedding_every,
template_file,
],
outputs=[
ti_output,
ti_outcome,
]
)
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
......@@ -1036,6 +1179,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(textual_inversion_interface, "Textual inversion", "ti"),
(settings_interface, "Settings", "settings"),
]
......@@ -1071,11 +1215,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
def modelmerger(*args):
try:
results = run_modelmerger(*args)
results = modules.extras.run_modelmerger(*args)
except Exception as e:
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() #To remove the potentially missing models from the list
modules.sd_models.list_models() # to remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
return results
......
......@@ -13,14 +13,12 @@ Pillow
pytorch_lightning
realesrgan
scikit-image>=0.19
git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379
timm==0.4.12
transformers==4.19.2
torch
einops
jsonmerge
clean-fid
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right
torchdiffeq
kornia
......@@ -18,7 +18,6 @@ piexif==1.1.3
einops==0.4.1
jsonmerge==1.8.0
clean-fid==0.1.29
git+https://github.com/openai/CLIP@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
resize-right==0.0.2
torchdiffeq==0.2.3
kornia==0.6.7
......@@ -157,7 +157,7 @@ button{
max-width: 10em;
}
#txt2img_preview, #img2img_preview{
#txt2img_preview, #img2img_preview, #ti_preview{
position: absolute;
width: 320px;
left: 0;
......@@ -172,18 +172,18 @@ button{
}
@media screen and (min-width: 768px) {
#txt2img_preview, #img2img_preview {
#txt2img_preview, #img2img_preview, #ti_preview {
position: absolute;
}
}
@media screen and (max-width: 767px) {
#txt2img_preview, #img2img_preview {
#txt2img_preview, #img2img_preview, #ti_preview {
position: relative;
}
}
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
display: none;
}
......@@ -247,7 +247,7 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{
}
#txt2img_progressbar, #img2img_progressbar{
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
position: absolute;
z-index: 1000;
right: 0;
......
a painting, art by [name]
a rendering, art by [name]
a cropped painting, art by [name]
the painting, art by [name]
a clean painting, art by [name]
a dirty painting, art by [name]
a dark painting, art by [name]
a picture, art by [name]
a cool painting, art by [name]
a close-up painting, art by [name]
a bright painting, art by [name]
a cropped painting, art by [name]
a good painting, art by [name]
a close-up painting, art by [name]
a rendition, art by [name]
a nice painting, art by [name]
a small painting, art by [name]
a weird painting, art by [name]
a large painting, art by [name]
a painting of [filewords], art by [name]
a rendering of [filewords], art by [name]
a cropped painting of [filewords], art by [name]
the painting of [filewords], art by [name]
a clean painting of [filewords], art by [name]
a dirty painting of [filewords], art by [name]
a dark painting of [filewords], art by [name]
a picture of [filewords], art by [name]
a cool painting of [filewords], art by [name]
a close-up painting of [filewords], art by [name]
a bright painting of [filewords], art by [name]
a cropped painting of [filewords], art by [name]
a good painting of [filewords], art by [name]
a close-up painting of [filewords], art by [name]
a rendition of [filewords], art by [name]
a nice painting of [filewords], art by [name]
a small painting of [filewords], art by [name]
a weird painting of [filewords], art by [name]
a large painting of [filewords], art by [name]
a photo of a [name]
a rendering of a [name]
a cropped photo of the [name]
the photo of a [name]
a photo of a clean [name]
a photo of a dirty [name]
a dark photo of the [name]
a photo of my [name]
a photo of the cool [name]
a close-up photo of a [name]
a bright photo of the [name]
a cropped photo of a [name]
a photo of the [name]
a good photo of the [name]
a photo of one [name]
a close-up photo of the [name]
a rendition of the [name]
a photo of the clean [name]
a rendition of a [name]
a photo of a nice [name]
a good photo of a [name]
a photo of the nice [name]
a photo of the small [name]
a photo of the weird [name]
a photo of the large [name]
a photo of a cool [name]
a photo of a small [name]
a photo of a [name], [filewords]
a rendering of a [name], [filewords]
a cropped photo of the [name], [filewords]
the photo of a [name], [filewords]
a photo of a clean [name], [filewords]
a photo of a dirty [name], [filewords]
a dark photo of the [name], [filewords]
a photo of my [name], [filewords]
a photo of the cool [name], [filewords]
a close-up photo of a [name], [filewords]
a bright photo of the [name], [filewords]
a cropped photo of a [name], [filewords]
a photo of the [name], [filewords]
a good photo of the [name], [filewords]
a photo of one [name], [filewords]
a close-up photo of the [name], [filewords]
a rendition of the [name], [filewords]
a photo of the clean [name], [filewords]
a rendition of a [name], [filewords]
a photo of a nice [name], [filewords]
a good photo of a [name], [filewords]
a photo of the nice [name], [filewords]
a photo of the small [name], [filewords]
a photo of the weird [name], [filewords]
a photo of the large [name], [filewords]
a photo of a cool [name], [filewords]
a photo of a small [name], [filewords]
......@@ -6,30 +6,29 @@ from modules import devices
from modules.paths import script_path
import signal
import threading
import modules.paths
import modules.codeformer_model as codeformer
import modules.esrgan_model as esrgan
import modules.bsrgan_model as bsrgan
import modules.extras
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.img2img
import modules.ldsr_model as ldsr
import modules.lowvram
import modules.realesrgan_model as realesrgan
import modules.paths
import modules.scripts
import modules.sd_hijack
import modules.sd_models
import modules.shared as shared
import modules.swinir_model as swinir
import modules.txt2img
import modules.ui
from modules import devices
from modules import modelloader
from modules.paths import script_path
from modules.shared import cmd_opts
modelloader.cleanup_models()
modules.sd_models.setup_model(cmd_opts.ckpt_dir)
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
......@@ -47,7 +46,7 @@ def wrap_queued_call(func):
return f
def wrap_gradio_gpu_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
devices.torch_gc()
......@@ -59,6 +58,7 @@ def wrap_gradio_gpu_call(func):
shared.state.current_image = None
shared.state.current_image_sampling_step = 0
shared.state.interrupted = False
shared.state.textinfo = None
with queue_lock:
res = func(*args, **kwargs)
......@@ -70,7 +70,7 @@ def wrap_gradio_gpu_call(func):
return res
return modules.ui.wrap_gradio_call(f)
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
......@@ -89,15 +89,8 @@ def webui():
while 1:
demo = modules.ui.create_ui(
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
run_pnginfo=modules.extras.run_pnginfo,
run_modelmerger=modules.extras.run_modelmerger
)
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
demo.launch(
share=cmd_opts.share,
server_name="0.0.0.0" if cmd_opts.listen else None,
......@@ -123,5 +116,6 @@ def webui():
print('Restarting Gradio')
if __name__ == "__main__":
webui()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment