Commit 34502809 authored by AUTOMATIC's avatar AUTOMATIC

split codebase into multiple files; to anyone this affects negatively: sorry

parent d7b67d9b
import os
import sys
import traceback
from modules.paths import script_path
from modules.shared import cmd_opts
def gfpgan_model_path():
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
found = [x for x in files if os.path.exists(x)]
if len(found) == 0:
raise Exception("GFPGAN model not found in paths: " + ", ".join(files))
return found[0]
loaded_gfpgan_model = None
def gfpgan():
global loaded_gfpgan_model
if loaded_gfpgan_model is None and gfpgan_constructor is not None:
loaded_gfpgan_model = gfpgan_constructor(model_path=gfpgan_model_path(), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
return loaded_gfpgan_model
def gfpgan_fix_faces(np_image):
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image
have_gfpgan = False
gfpgan_constructor = None
def setup_gfpgan():
try:
gfpgan_model_path()
if os.path.exists(cmd_opts.gfpgan_dir):
sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
from gfpgan import GFPGANer
global have_gfpgan
have_gfpgan = True
global gfpgan_constructor
gfpgan_constructor = GFPGANer
except Exception:
print("Error setting up GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
This diff is collapsed.
import math
from PIL import Image
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
import modules.images as images
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, prompt_matrix, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_name: str, upscale_overlap: int, inpaint_full_res: bool):
is_inpaint = mode == 1
is_loopback = mode == 2
is_upscale = mode == 3
if is_inpaint:
image = init_img_with_mask['image']
mask = init_img_with_mask['mask']
else:
image = init_img
mask = None
assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
p = StableDiffusionProcessingImg2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
prompt=prompt,
seed=seed,
sampler_index=sampler_index,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
init_images=[image],
mask=mask,
mask_blur=mask_blur,
inpainting_fill=inpainting_fill,
resize_mode=resize_mode,
denoising_strength=denoising_strength,
inpaint_full_res=inpaint_full_res,
extra_generation_params={"Denoising Strength": denoising_strength}
)
if is_loopback:
output_images, info = None, None
history = []
initial_seed = None
initial_info = None
for i in range(n_iter):
p.n_iter = 1
p.batch_size = 1
p.do_not_save_grid = True
state.job = f"Batch {i + 1} out of {n_iter}"
processed = process_images(p)
if initial_seed is None:
initial_seed = processed.seed
initial_info = processed.info
p.init_images = [processed.images[0]]
p.seed = processed.seed + 1
p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
history.append(processed.images[0])
grid = images.image_grid(history, batch_size, rows=1)
images.save_image(grid, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename)
processed = Processed(p, history, initial_seed, initial_info)
elif is_upscale:
initial_seed = None
initial_info = None
upscaler = shared.sd_upscalers.get(upscaler_name, next(iter(shared.sd_upscalers.values())))
img = upscaler(init_img)
processing.torch_gc()
grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap)
p.n_iter = 1
p.do_not_save_grid = True
p.do_not_save_samples = True
work = []
work_results = []
for y, h, row in grid.tiles:
for tiledata in row:
work.append(tiledata[2])
batch_count = math.ceil(len(work) / p.batch_size)
print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} in a total of {batch_count} batches.")
for i in range(batch_count):
p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
state.job = f"Batch {i + 1} out of {batch_count}"
processed = process_images(p)
if initial_seed is None:
initial_seed = processed.seed
initial_info = processed.info
p.seed = processed.seed + 1
work_results += processed.images
image_index = 0
for y, h, row in grid.tiles:
for tiledata in row:
tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
image_index += 1
combined_image = images.combine_grid(grid)
if opts.samples_save:
images.save_image(combined_image, p.outpath_samples, "", initial_seed, prompt, opts.grid_format, info=initial_info)
processed = Processed(p, [combined_image], initial_seed, initial_info)
else:
processed = process_images(p)
return processed.images, processed.js(), plaintext_to_html(processed.info)
import torch
module_in_gpu = None
cpu = torch.device("cpu")
gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu
def setup_for_low_vram(sd_model, use_medvram):
parents = {}
def send_me_to_gpu(module, _):
"""send this module to GPU; send whatever tracked module was previous in GPU to CPU;
we add this as forward_pre_hook to a lot of modules and this way all but one of them will
be in CPU
"""
global module_in_gpu
module = parents.get(module, module)
if module_in_gpu == module:
return
if module_in_gpu is not None:
module_in_gpu.to(cpu)
module.to(gpu)
module_in_gpu = module
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
def first_stage_model_encode_wrap(self, encoder, x):
send_me_to_gpu(self, None)
return encoder(x)
def first_stage_model_decode_wrap(self, decoder, z):
send_me_to_gpu(self, None)
return decoder(z)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
sd_model.to(device)
sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
# register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram:
sd_model.model.register_forward_pre_hook(send_me_to_gpu)
else:
diff_model = sd_model.model.diffusion_model
# the third remaining model is still too big for 4 GB, so we also do the same for its submodules
# so that only one of them is in GPU at a time
stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
sd_model.model.to(device)
diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
# install hooks for bits of third model
diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.input_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
for block in diff_model.output_blocks:
block.register_forward_pre_hook(send_me_to_gpu)
import argparse
import os
import sys
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.insert(0, script_path)
# use current directory as SD dir if it has related files, otherwise parent dir of script as stated in guide
sd_path = os.path.abspath('.') if os.path.exists('./ldm/models/diffusion/ddpm.py') else os.path.dirname(script_path)
# add parent directory to path; this is where Stable diffusion repo should be
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion'),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers')
]
for d, must_exist, what in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path):
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
else:
sys.path.append(os.path.join(script_path, d))
This diff is collapsed.
import sys
import traceback
from collections import namedtuple
import numpy as np
from PIL import Image
from modules.shared import cmd_opts
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
realesrgan_models = []
have_realesrgan = False
RealESRGANer_constructor = None
def setup_realesrgan():
global realesrgan_models
global have_realesrgan
global RealESRGANer_constructor
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
realesrgan_models = [
RealesrganModelInfo(
name="Real-ESRGAN 4x plus",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
),
RealesrganModelInfo(
name="Real-ESRGAN 4x plus anime 6B",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
netscale=4, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
),
RealesrganModelInfo(
name="Real-ESRGAN 2x plus",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
netscale=2, model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
),
]
have_realesrgan = True
RealESRGANer_constructor = RealESRGANer
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
have_realesrgan = False
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
if not have_realesrgan or RealESRGANer_constructor is None:
return image
info = realesrgan_models[RealESRGAN_model_index]
model = info.model()
upsampler = RealESRGANer_constructor(
scale=info.netscale,
model_path=info.location,
model=model,
half=not cmd_opts.no_half
)
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
image = Image.fromarray(upsampled)
return image
import os
import sys
import traceback
import gradio as gr
class Script:
filename = None
def title(self):
raise NotImplementedError()
scripts = []
def load_scripts(basedir, globs):
for filename in os.listdir(basedir):
path = os.path.join(basedir, filename)
if not os.path.isfile(path):
continue
with open(path, "r", encoding="utf8") as file:
text = file.read()
from types import ModuleType
compiled = compile(text, path, 'exec')
module = ModuleType(filename)
module.__dict__.update(globs)
exec(compiled, module.__dict__)
for key, item in module.__dict__.items():
if type(item) == type and issubclass(item, Script):
item.filename = path
scripts.append(item)
def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
try:
res = func()
return res
except Exception:
print(f"Error calling: {filename/funcname}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return default
def setup_ui():
titles = [wrap_call(script.title, script.filename, "title") for script in scripts]
gr.Dropdown(options=[""] + titles, value="", type="index")
import os
import sys
import traceback
import torch
import numpy as np
from modules.shared import opts, device
class StableDiffusionModelHijack:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
fixes = None
comments = []
dir_mtime = None
def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
tokenizer = model.cond_stage_model.tokenizer
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path)
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
self.word_embeddings[name] = emb.detach()
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1))&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, name))
for fn in os.listdir(dirname):
try:
process_file(os.path.join(dirname, fn), fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.")
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.hijack = hijack
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
for text, ident in tokens_with_parens:
mult = 1.0
for c in text:
if c == '[':
mult /= 1.1
if c == ']':
mult *= 1.1
if c == '(':
mult *= 1.1
if c == ')':
mult /= 1.1
if mult != 1.0:
self.token_mults[ident] = mult
def forward(self, text):
self.hijack.fixes = []
self.hijack.comments = []
remade_batch_tokens = []
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length - 2
used_custom_terms = []
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
batch_multipliers = []
for tokens in batch_tokens:
tuple_tokens = tuple(tokens)
if tuple_tokens in cache:
remade_tokens, fixes, multipliers = cache[tuple_tokens]
else:
fixes = []
remade_tokens = []
multipliers = []
mult = 1.0
i = 0
while i < len(tokens):
token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
elif possible_matches is None:
remade_tokens.append(token)
multipliers.append(mult)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
self.hijack.fixes.append(fixes)
batch_multipliers.append(multipliers)
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens)
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers = torch.asarray(np.array(batch_multipliers)).to(device)
original_mean = z.mean()
z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
new_mean = z.mean()
z *= original_mean / new_mean
return z
class EmbeddingsWithFixes(torch.nn.Module):
def __init__(self, wrapped, embeddings):
super().__init__()
self.wrapped = wrapped
self.embeddings = embeddings
def forward(self, input_ids):
batch_fixes = self.embeddings.fixes
self.embeddings.fixes = None
inputs_embeds = self.wrapped(input_ids)
if batch_fixes is not None:
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, word in fixes:
emb = self.embeddings.word_embeddings[word]
emb_len = min(tensor.shape[0]-offset, emb.shape[0])
tensor[offset:offset+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
return inputs_embeds
model_hijack = StableDiffusionModelHijack()
from collections import namedtuple
import torch
import tqdm
import k_diffusion.sampling
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [
*[SamplerData(x[0], lambda model, funcname=x[1]: KDiffusionSampler(funcname, model)) for x in [
('Euler a', 'sample_euler_ancestral'),
('Euler', 'sample_euler'),
('LMS', 'sample_lms'),
('Heun', 'sample_heun'),
('DPM2', 'sample_dpm_2'),
('DPM2 a', 'sample_dpm_2_ancestral'),
] if hasattr(k_diffusion.sampling, x[1])],
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(DDIMSample, model)),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(PLMSSampler, model)),
]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
if sampler_wrapper.mask is not None:
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else None
self.mask = None
self.nmask = None
self.init_latent = None
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
# existing code fails with cetin step counts, like 9
try:
self.sampler.make_schedule(ddim_num_steps=p.steps, verbose=False)
except Exception:
self.sampler.make_schedule(ddim_num_steps=p.steps+1, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
self.mask = p.mask
self.nmask = p.nmask
self.init_latent = p.init_latent
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
return samples
def sample(self, p, x, conditioning, unconditional_conditioning):
samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
return samples_ddim
class CFGDenoiser(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
self.mask = None
self.nmask = None
self.init_latent = None
def forward(self, x, sigma, uncond, cond, cond_scale):
if shared.batch_cond_uncond:
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
cond_in = torch.cat([uncond, cond])
uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
denoised = uncond + (cond - uncond) * cond_scale
else:
uncond = self.inner_model(x, sigma, cond=uncond)
cond = self.inner_model(x, sigma, cond=cond)
denoised = uncond + (cond - uncond) * cond_scale
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
return denoised
def extended_trange(*args, **kwargs):
for x in tqdm.trange(*args, desc=state.job, **kwargs):
if state.interrupted:
break
yield x
class KDiffusionSampler:
def __init__(self, funcname, sd_model):
self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
sigmas = self.model_wrap.get_sigmas(p.steps)
noise = noise * sigmas[p.steps - t_enc - 1]
xi = x + noise
sigma_sched = sigmas[p.steps - t_enc - 1:]
self.model_wrap_cfg.mask = p.mask
self.model_wrap_cfg.nmask = p.nmask
self.model_wrap_cfg.init_latent = p.init_latent
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
def sample(self, p, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps)
x = x * sigmas[0]
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False)
return samples_ddim
import argparse
import json
import os
import gradio as gr
import torch
from modules.paths import script_path, sd_path
config_filename = "config.json"
sd_model_file = os.path.join(script_path, 'model.ckpt')
if not os.path.exists(sd_model_file):
sd_model_file = "models/ldm/stable-diffusion-v1/model.ckpt"
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=os.path.join(sd_path, sd_model_file), help="path to checkpoint of model",)
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default='GFPGANv1.3.pth')
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
parser.add_argument("--embeddings-dir", type=str, default='embeddings', help="embeddings dirtectory for textual inversion (default: embeddings)")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrficing a lot of speed for very low VRM usage")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="a workaround test; may help with speed in you use --lowvram")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
cmd_opts = parser.parse_args()
cpu = torch.device("cpu")
gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
class State:
interrupted = False
job = ""
def interrupt(self):
self.interrupted = True
state = State()
class Options:
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None):
self.default = default
self.label = label
self.component = component
self.component_args = component_args
data = None
data_labels = {
"outdir_samples": OptionInfo("", "Output dictectory for images; if empty, defaults to two directories below"),
"outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output dictectory for txt2img images'),
"outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output dictectory for img2img images'),
"outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output dictectory for images from extras tab'),
"outdir_grids": OptionInfo("", "Output dictectory for grids; if empty, defaults to two directories below"),
"outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output dictectory for txt2img grids'),
"outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output dictectory for img2img grids'),
"save_to_dirs": OptionInfo(False, "When writing images/grids, create a directory with name derived from the prompt"),
"save_to_dirs_prompt_len": OptionInfo(10, "When using above, how many words from prompt to put into directory name", gr.Slider, {"minimum": 1, "maximum": 32, "step": 1}),
"outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button"),
"samples_save": OptionInfo(True, "Save indiviual samples"),
"samples_format": OptionInfo('png', 'File format for indiviual samples'),
"grid_save": OptionInfo(True, "Save image grids"),
"return_grid": OptionInfo(True, "Show grid in results for web"),
"grid_format": OptionInfo('png', 'File format for grids'),
"grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
"grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
"n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),
"jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"font": OptionInfo("arial.ttf", "Font for image grids that have text"),
"prompt_matrix_add_to_start": OptionInfo(True, "In prompt matrix, add the variable combination of text to the start of the prompt, rather than the end"),
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text text and [text] to make it pay less attention"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
}
def __init__(self):
self.data = {k: v.default for k, v in self.data_labels.items()}
def __setattr__(self, key, value):
if self.data is not None:
if key in self.data:
self.data[key] = value
return super(Options, self).__setattr__(key, value)
def __getattr__(self, item):
if self.data is not None:
if item in self.data:
return self.data[item]
if item in self.data_labels:
return self.data_labels[item].default
return super(Options, self).__getattribute__(item)
def save(self, filename):
with open(filename, "w", encoding="utf8") as file:
json.dump(self.data, file)
def load(self, filename):
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
sd_upscalers = {}
sd_model = None
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, code: str):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids,
prompt=prompt,
negative_prompt=negative_prompt,
seed=seed,
sampler_index=sampler_index,
batch_size=batch_size,
n_iter=n_iter,
steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
)
if code != '' and cmd_opts.allow_code:
p.do_not_save_grid = True
p.do_not_save_samples = True
display_result_data = [[], -1, ""]
def display(imgs, s=display_result_data[1], i=display_result_data[2]):
display_result_data[0] = imgs
display_result_data[1] = s
display_result_data[2] = i
from types import ModuleType
compiled = compile(code, '', 'exec')
module = ModuleType("testmodule")
module.__dict__.update(globals())
module.p = p
module.display = display
exec(compiled, module.__dict__)
processed = Processed(p, *display_result_data)
else:
processed = process_images(p)
return processed.images, processed.js(), plaintext_to_html(processed.info)
This diff is collapsed.
......@@ -9,11 +9,6 @@ button{
align-self: stretch !important;
}
#img2img_mode{
padding: 0 0 1em 0;
border: none !important;
}
#img2img_prompt, #txt2img_prompt{
padding: 0;
border: none !important;
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment