Commit 2e8ba0fa authored by Greendayle's avatar Greendayle

fix conflicts

parents 5f12e7ef 4f33289d
......@@ -16,7 +16,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- Attention, specify parts of text that the model should pay more attention to
- a man in a ((tuxedo)) - will pay more attention to tuxedo
- a man in a (tuxedo:1.21) - alternative syntax
- select text and press ctrl+up or ctrl+down to aduotmatically adjust attention to selected text
- select text and press ctrl+up or ctrl+down to automatically adjust attention to selected text (code contributed by anonymous user)
- Loopback, run img2img processing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
- Textual Inversion
......@@ -65,6 +65,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- [Composable-Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/), a way to use multiple prompts at once
- separate prompts using uppercase `AND`
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
......
contextMenuInit = function(){
let eventListenerApplied=false;
let menuSpecs = new Map();
const uid = function(){
return Date.now().toString(36) + Math.random().toString(36).substr(2);
}
function showContextMenu(event,element,menuEntries){
let posx = event.clientX + document.body.scrollLeft + document.documentElement.scrollLeft;
let posy = event.clientY + document.body.scrollTop + document.documentElement.scrollTop;
let oldMenu = gradioApp().querySelector('#context-menu')
if(oldMenu){
oldMenu.remove()
}
let tabButton = gradioApp().querySelector('button')
let baseStyle = window.getComputedStyle(tabButton)
const contextMenu = document.createElement('nav')
contextMenu.id = "context-menu"
contextMenu.style.background = baseStyle.background
contextMenu.style.color = baseStyle.color
contextMenu.style.fontFamily = baseStyle.fontFamily
contextMenu.style.top = posy+'px'
contextMenu.style.left = posx+'px'
const contextMenuList = document.createElement('ul')
contextMenuList.className = 'context-menu-items';
contextMenu.append(contextMenuList);
menuEntries.forEach(function(entry){
let contextMenuEntry = document.createElement('a')
contextMenuEntry.innerHTML = entry['name']
contextMenuEntry.addEventListener("click", function(e) {
entry['func']();
})
contextMenuList.append(contextMenuEntry);
})
gradioApp().getRootNode().appendChild(contextMenu)
let menuWidth = contextMenu.offsetWidth + 4;
let menuHeight = contextMenu.offsetHeight + 4;
let windowWidth = window.innerWidth;
let windowHeight = window.innerHeight;
if ( (windowWidth - posx) < menuWidth ) {
contextMenu.style.left = windowWidth - menuWidth + "px";
}
if ( (windowHeight - posy) < menuHeight ) {
contextMenu.style.top = windowHeight - menuHeight + "px";
}
}
function appendContextMenuOption(targetEmementSelector,entryName,entryFunction){
currentItems = menuSpecs.get(targetEmementSelector)
if(!currentItems){
currentItems = []
menuSpecs.set(targetEmementSelector,currentItems);
}
let newItem = {'id':targetEmementSelector+'_'+uid(),
'name':entryName,
'func':entryFunction,
'isNew':true}
currentItems.push(newItem)
return newItem['id']
}
function removeContextMenuOption(uid){
menuSpecs.forEach(function(v,k) {
let index = -1
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
if(index>=0){
v.splice(index, 1);
}
})
}
function addContextMenuEventListener(){
if(eventListenerApplied){
return;
}
gradioApp().addEventListener("click", function(e) {
let source = e.composedPath()[0]
if(source.id && source.indexOf('check_progress')>-1){
return
}
let oldMenu = gradioApp().querySelector('#context-menu')
if(oldMenu){
oldMenu.remove()
}
});
gradioApp().addEventListener("contextmenu", function(e) {
let oldMenu = gradioApp().querySelector('#context-menu')
if(oldMenu){
oldMenu.remove()
}
menuSpecs.forEach(function(v,k) {
if(e.composedPath()[0].matches(k)){
showContextMenu(e,e.composedPath()[0],v)
e.preventDefault()
return
}
})
});
eventListenerApplied=true
}
return [appendContextMenuOption, removeContextMenuOption, addContextMenuEventListener]
}
initResponse = contextMenuInit()
appendContextMenuOption = initResponse[0]
removeContextMenuOption = initResponse[1]
addContextMenuEventListener = initResponse[2]
//Start example Context Menu Items
generateOnRepeatId = appendContextMenuOption('#txt2img_generate','Generate forever',function(){
let genbutton = gradioApp().querySelector('#txt2img_generate');
let interruptbutton = gradioApp().querySelector('#txt2img_interrupt');
if(!interruptbutton.offsetParent){
genbutton.click();
}
clearInterval(window.generateOnRepeatInterval)
window.generateOnRepeatInterval = setInterval(function(){
if(!interruptbutton.offsetParent){
genbutton.click();
}
},
500)}
)
cancelGenerateForever = function(){
clearInterval(window.generateOnRepeatInterval)
let interruptbutton = gradioApp().querySelector('#txt2img_interrupt');
if(interruptbutton.offsetParent){
interruptbutton.click();
}
}
appendContextMenuOption('#txt2img_interrupt','Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#txt2img_generate', 'Cancel generate forever',cancelGenerateForever)
appendContextMenuOption('#roll','Roll three',
function(){
let rollbutton = gradioApp().querySelector('#roll');
setTimeout(function(){rollbutton.click()},100)
setTimeout(function(){rollbutton.click()},200)
setTimeout(function(){rollbutton.click()},300)
}
)
//End example Context Menu Items
onUiUpdate(function(){
addContextMenuEventListener()
});
addEventListener('keydown', (event) => {
let target = event.originalTarget;
let target = event.originalTarget || event.composedPath()[0];
if (!target.hasAttribute("placeholder")) return;
if (!target.placeholder.toLowerCase().includes("prompt")) return;
......
......@@ -35,6 +35,7 @@ titles = {
"Denoising strength": "Determines how little respect the algorithm should have for image's content. At 0, nothing will change, and at 1 you'll get an unrelated image. With values below 1.0, processing will take less steps than the Sampling Steps slider specifies.",
"Denoising strength change factor": "In loopback mode, on each loop the denoising strength is multiplied by this value. <1 means decreasing variety so your sequence will converge on a fixed picture. >1 means increasing variety so your sequence will become more and more chaotic.",
"Skip": "Stop processing current image and continue processing.",
"Interrupt": "Stop processing images and return any results accumulated so far.",
"Save": "Write image to a directory (default - log/images) and generation parameters into csv file.",
......
......@@ -86,6 +86,9 @@ function showGalleryImage(){
if(fullImg_preview != null){
fullImg_preview.forEach(function function_name(e) {
if (e.dataset.modded)
return;
e.dataset.modded = true;
if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer'
......
// code related to showing and updating progressbar shown as the image is being made
global_progressbars = {}
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_interrupt, id_preview, id_gallery){
function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){
var progressbar = gradioApp().getElementById(id_progressbar)
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
var interrupt = gradioApp().getElementById(id_interrupt)
if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){
......@@ -32,30 +33,37 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
if(!progressDiv){
if (skip) {
skip.style.display = "none"
}
interrupt.style.display = "none"
}
}
window.setTimeout(function(){ requestMoreProgress(id_part, id_progressbar_span, id_interrupt) }, 500)
window.setTimeout(function() { requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) }, 500)
});
mutationObserver.observe( progressbar, { childList:true, subtree:true })
}
}
onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery')
})
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){
btn = gradioApp().getElementById(id_part+"_check_progress");
if(btn==null) return;
btn.click();
var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0;
var skip = id_skip ? gradioApp().getElementById(id_skip) : null
var interrupt = gradioApp().getElementById(id_interrupt)
if(progressDiv && interrupt){
if (skip) {
skip.style.display = "block"
}
interrupt.style.display = "block"
}
}
......
......@@ -4,6 +4,7 @@ import os
import sys
import importlib.util
import shlex
import platform
dir_repos = "repositories"
dir_tmp = "tmp"
......@@ -31,6 +32,7 @@ def extract_arg(args, name):
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
xformers = '--xformers' in args
def repo_dir(name):
......@@ -124,6 +126,12 @@ if not is_installed("gfpgan"):
if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip")
if not is_installed("xformers") and xformers:
if platform.system() == "Windows":
run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
elif platform.system() == "Linux":
run_pip("install xformers", "xformers")
os.makedirs(dir_repos, exist_ok=True)
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
......
......@@ -111,7 +111,7 @@ class UpscalerESRGAN(Upscaler):
print("Unable to load %s from %s" % (self.model_path, filename))
return None
pretrained_net = torch.load(filename, map_location='cpu' if shared.device.type == 'mps' else None)
pretrained_net = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net)
......
......@@ -43,7 +43,7 @@ class Hypernetwork:
def load_hypernetworks(path):
res = {}
for filename in glob.iglob(path + '**/*.pt', recursive=True):
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
try:
hn = Hypernetwork(filename)
res[hn.name] = hn
......
......@@ -32,6 +32,8 @@ def process_batch(p, input_dir, output_dir, args):
for i, image in enumerate(images):
state.job = f"{i+1} out of {len(images)}"
if state.skipped:
state.skipped = False
if state.interrupted:
break
......
......@@ -141,6 +141,7 @@ class Processed:
self.all_subseeds = all_subseeds or [self.subseed]
self.infotexts = infotexts or [info]
def js(self):
obj = {
"prompt": self.prompt,
......@@ -312,6 +313,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments()
comments = {}
......@@ -349,6 +351,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
state.job_count = p.n_iter
for n in range(p.n_iter):
if state.skipped:
state.skipped = False
if state.interrupted:
break
......@@ -375,7 +380,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
if state.interrupted:
if state.interrupted or state.skipped:
# if we are interruped, sample returns just noise
# use the image collected previously in sampler loop
......
......@@ -239,6 +239,15 @@ def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):
conds_list.append(conds_for_batch)
# if prompts have wildly different lengths above the limit we'll get tensors fo different shapes
# and won't be able to torch.stack them. So this fixes that.
token_count = max([x.shape[0] for x in tensors])
for i in range(len(tensors)):
if tensors[i].shape[0] != token_count:
last_vector = tensors[i][-1:]
last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])
tensors[i] = torch.vstack([tensors[i], last_vector_repeated])
return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)
......
......@@ -18,16 +18,17 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.opt_split_attention_v1:
if cmd_opts.xformers and shared.xformers_available and not torch.version.hip:
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
elif cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.attention_CrossAttention_forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
......@@ -37,6 +38,13 @@ def undo_optimizations():
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
def get_target_prompt_token_count(token_count):
if token_count < 75:
return 75
return math.ceil(token_count / 10) * 10
class StableDiffusionModelHijack:
fixes = None
comments = []
......@@ -82,10 +90,12 @@ class StableDiffusionModelHijack:
for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros'
def clear_comments(self):
self.comments = []
def tokenize(self, text):
max_length = self.clip.max_length - 2
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, max_length
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
......@@ -94,7 +104,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.wrapped = wrapped
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
......@@ -116,7 +125,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
......@@ -148,19 +156,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
used_custom_terms.append((embedding.name, embedding.checksum()))
i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
prompt_target_length = get_target_prompt_token_count(token_count)
tokens_to_add = prompt_target_length - len(remade_tokens) + 1
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_tokens = [id_start] + remade_tokens + [id_end] * tokens_to_add
multipliers = [1.0] + multipliers + [1.0] * tokens_to_add
return remade_tokens, fixes, multipliers, token_count
......@@ -177,7 +178,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
remade_tokens, fixes, multipliers, current_token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
token_count = max(current_token_count, token_count)
cache[line] = (remade_tokens, fixes, multipliers)
......@@ -191,7 +193,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
maxlen = self.wrapped.max_length # you get to stay at 77
used_custom_terms = []
remade_batch_tokens = []
overflowing_words = []
......@@ -263,17 +265,24 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens)
target_token_count = get_target_prompt_token_count(token_count) + 2
position_ids_array = [min(x, 75) for x in range(target_token_count-1)] + [76]
position_ids = torch.asarray(position_ids_array, device=devices.device).expand((1, -1))
remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids)
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(batch_multipliers).to(device)
batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]
batch_multipliers = torch.asarray(batch_multipliers_of_same_length).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
......
import math
import sys
import traceback
import torch
from torch import einsum
......@@ -7,18 +10,37 @@ from einops import rearrange
from modules import shared
if shared.cmd_opts.xformers:
try:
import xformers.ops
import functorch
xformers._is_functorch_available = True
shared.xformers_available = True
except Exception:
print("Cannot import xformers", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
q_in = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
hypernetwork = shared.selected_hypernetwork()
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None:
k_in = self.to_k(hypernetwork_layers[0](context))
v_in = self.to_v(hypernetwork_layers[1](context))
else:
k_in = self.to_k(context)
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
......@@ -31,6 +53,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
......@@ -105,6 +128,25 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
return self.to_out(r2)
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
hypernetwork = shared.selected_hypernetwork()
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None:
k_in = self.to_k(hypernetwork_layers[0](context))
v_in = self.to_v(hypernetwork_layers[1](context))
else:
k_in = self.to_k(context)
v_in = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
......@@ -167,3 +209,13 @@ def cross_attention_attnblock_forward(self, x):
h3 += x
return h3
def xformers_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_).contiguous()
k1 = self.k(h_).contiguous()
v = self.v(h_).contiguous()
out = xformers.ops.memory_efficient_attention(q1, k1, v)
out = self.proj_out(out)
return x+out
......@@ -122,7 +122,11 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
pl_sd = torch.load(checkpoint_file, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
sd = pl_sd
model.load_state_dict(sd, strict=False)
......
......@@ -106,7 +106,7 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs):
seq = sequence if cmd_opts.disable_console_progressbars else tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
for x in seq:
if state.interrupted:
if state.interrupted or state.skipped:
break
yield x
......@@ -142,6 +142,16 @@ class VanillaStableDiffusionSampler:
assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers'
cond = tensor
# for DDIM, shapes must match, we can't just process cond and uncond independently;
# filling unconditional_conditioning with repeats of the last vector to match length is
# not 100% correct but should work well enough
if unconditional_conditioning.shape[1] < cond.shape[1]:
last_vector = unconditional_conditioning[:, -1:]
last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1])
unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated])
elif unconditional_conditioning.shape[1] > cond.shape[1]:
unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]]
if self.mask is not None:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
......@@ -221,18 +231,29 @@ class CFGDenoiser(torch.nn.Module):
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
b = min(a + batch_size, tensor.shape[0])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
denoised_uncond = x_out[-batch_size:]
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
for i, conds in enumerate(conds_list):
......@@ -254,7 +275,7 @@ def extended_trange(sampler, count, *args, **kwargs):
seq = range(count) if cmd_opts.disable_console_progressbars else tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs)
for x in seq:
if state.interrupted:
if state.interrupted or state.skipped:
break
if sampler.stop_at is not None and x > sampler.stop_at:
......
......@@ -43,6 +43,7 @@ parser.add_argument("--realesrgan-models-path", type=str, help="Path to director
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
......@@ -73,7 +74,7 @@ device = devices.device
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False
config_filename = cmd_opts.ui_settings_file
hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks'))
......@@ -84,6 +85,7 @@ def selected_hypernetwork():
class State:
skipped = False
interrupted = False
job = ""
job_no = 0
......@@ -96,6 +98,9 @@ class State:
current_image_sampling_step = 0
textinfo = None
def skip(self):
self.skipped = True
def interrupt(self):
self.interrupted = True
......@@ -118,8 +123,6 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
# This was moved to webui.py with the other model "setup" calls.
# modules.sd_models.list_models()
def realesrgan_models_names():
......
......@@ -192,6 +192,7 @@ def wrap_gradio_call(func, extra_outputs=None):
# last item is always HTML
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed_text}</p>{vram_html}</div>"
shared.state.skipped = False
shared.state.interrupted = False
shared.state.job_count = 0
......@@ -417,9 +418,16 @@ def create_toprow(is_img2img):
with gr.Column(scale=1):
with gr.Row():
skip = gr.Button('Skip', elem_id=f"{id_part}_skip")
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
skip.click(
fn=lambda: shared.state.skip(),
inputs=[],
outputs=[],
)
interrupt.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
......@@ -952,7 +960,7 @@ def create_ui(wrap_gradio_gpu_call):
custom_name = gr.Textbox(label="Custom Name (Optional)")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
save_as_half = gr.Checkbox(value=False, label="Save as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
......
......@@ -25,3 +25,4 @@ lark==1.1.2
git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow]
tensorflow==2.10.0
tensorflow-io==0.27.0
functorch==0.2.1
......@@ -398,10 +398,20 @@ input[type="range"]{
#txt2img_interrupt, #img2img_interrupt{
position: absolute;
width: 100%;
width: 50%;
height: 72px;
background: #b4c0cc;
border-radius: 8px;
border-radius: 0px;
display: none;
}
#txt2img_skip, #img2img_skip{
position: absolute;
width: 50%;
right: 0px;
height: 72px;
background: #b4c0cc;
border-radius: 0px;
display: none;
}
......@@ -415,4 +425,31 @@ input[type="range"]{
#img2img_image div.h-60{
height: 480px;
}
\ No newline at end of file
}
#context-menu{
z-index:9999;
position:absolute;
display:block;
padding:0px 0;
border:2px solid #a55000;
border-radius:8px;
box-shadow:1px 1px 2px #CE6400;
width: 200px;
}
.context-menu-items{
list-style: none;
margin: 0;
padding: 0;
}
.context-menu-items a{
display:block;
padding:5px;
cursor:pointer;
}
.context-menu-items a:hover{
background: #a55000;
}
......@@ -58,6 +58,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
shared.state.current_latent = None
shared.state.current_image = None
shared.state.current_image_sampling_step = 0
shared.state.skipped = False
shared.state.interrupted = False
shared.state.textinfo = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment