Commit 469c992f authored by AUTOMATIC's avatar AUTOMATIC

Merge remote-tracking branch 'deggua/master' into new-ui

parents bb2faa5f 35ac3a66
......@@ -18,6 +18,7 @@ import html
import time
import json
import traceback
from datetime import datetime
import k_diffusion.sampling
from ldm.util import instantiate_from_config
......@@ -28,6 +29,7 @@ try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging
logging.set_verbosity_error()
except Exception:
pass
......@@ -45,8 +47,8 @@ invalid_filename_chars = '<>:"/\\|?*\n'
config_filename = "config.json"
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-inference.yaml", help="path to config which constructs model", )
parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model", )
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware accleration in browser)")
......@@ -72,8 +74,8 @@ css_hide_progressbar = """
SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
samplers = [
*[SamplerData(x[0], lambda funcname=x[1]: KDiffusionSampler(funcname)) for x in [
('Euler ancestral', 'sample_euler_ancestral'),
('Euler', 'sample_euler'),
('Euler ancestral', 'sample_euler_ancestral'),
('LMS', 'sample_lms'),
('Heun', 'sample_heun'),
('DPM 2', 'sample_dpm_2'),
......@@ -118,11 +120,10 @@ except Exception:
sd_upscalers = {
"RealESRGAN": lambda img: upscale_with_realesrgan(img, 2, 0),
"Lanczos": lambda img: img.resize((img.width*2, img.height*2), resample=LANCZOS),
"Lanczos": lambda img: img.resize((img.width * 2, img.height * 2), resample=LANCZOS),
"None": lambda img: img
}
have_gfpgan = False
if os.path.exists(cmd_opts.gfpgan_dir):
try:
......@@ -139,7 +140,7 @@ def gfpgan():
model_name = 'GFPGANv1.3'
model_path = os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models', model_name + '.pth')
if not os.path.isfile(model_path):
raise Exception("GFPGAN model not found at path "+model_path)
raise Exception("GFPGAN model not found at path " + model_path)
return GFPGANer(model_path=model_path, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
......@@ -307,7 +308,6 @@ def torch_gc():
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False):
if short_filename or prompt is None or seed is None:
filename = f"{basename}"
else:
......@@ -336,16 +336,18 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
image.save(os.path.join(path, f"{filename}.jpg"), quality=opts.jpeg_quality, pnginfo=pnginfo)
def sanitize_filename_part(text):
return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
def plaintext_to_html(text):
text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')])
def plaintext_to_html(text, klass=None):
if klass is None:
text = "".join([f"<p>{html.escape(x)}</p>\n" for x in text.split('\n')])
else:
text = "".join([f"<p class=\"{klass}\">{html.escape(x)}</p>\n" for x in text.split('\n')])
return text
def image_grid(imgs, batch_size=1, rows=None):
if rows is None:
if opts.n_rows > 0:
......@@ -392,7 +394,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
for col in range(cols):
x = col * now
if x+tile_w >= w:
if x + tile_w >= w:
x = w - tile_w
tile = image.crop((x, y, x + tile_w, y + tile_h))
......@@ -457,7 +459,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active:
drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
draw_y += line.size[1] + line_spacing
......@@ -538,7 +540,6 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
for x in xs:
res.append(cell(x, y))
grid = image_grid(res, rows=len(ys))
grid = draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
......@@ -628,7 +629,7 @@ class StableDiffusionModelHijack:
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1].reshape(768)
self.word_embeddings[name] = emb
self.word_embeddings_checksums[name] = f'{const_hash(emb)&0xffff:04x}'
self.word_embeddings_checksums[name] = f'{const_hash(emb) & 0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
......@@ -717,7 +718,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else:
found = False
for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids:
if tokens[i:i + len(ids)] == ids:
fixes.append((len(remade_tokens), word))
remade_tokens.append(777)
multipliers.append(mult)
......@@ -741,7 +742,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
......@@ -789,7 +790,7 @@ class EmbeddingsWithFixes(nn.Module):
class StableDiffusionProcessing:
def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None):
def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, strength_GFPGAN=1.0, extra_generation_params=None, overlay_images=None):
self.outpath: str = outpath
self.prompt: str = prompt
self.seed: int = seed
......@@ -804,6 +805,7 @@ class StableDiffusionProcessing:
self.use_GFPGAN: bool = use_GFPGAN
self.do_not_save_samples: bool = do_not_save_samples
self.do_not_save_grid: bool = do_not_save_grid
self.strength_GFPGAN: bool = strength_GFPGAN
self.extra_generation_params: dict = extra_generation_params
self.overlay_images = overlay_images
......@@ -851,7 +853,24 @@ class KDiffusionSampler:
return samples_ddim
Processed = namedtuple('Processed', ['images','seed', 'info'])
Processed = namedtuple('Processed', ['images', 'seed', 'info'])
class OutputInfo:
def __init__(self, prompt: str, params: str, comments: str):
self.prompt = prompt.strip()
self.params = params.strip()
self.comments = comments.strip()
def __str__(self):
return '\n'.join([self.prompt, self.params, self.comments])
def html(self) -> str:
return f'''
{plaintext_to_html(self.prompt, "prompt-info")}<br>
{plaintext_to_html(self.params, "params-info")}
{plaintext_to_html(self.comments, "comments-info")}
'''
def process_images(p: StableDiffusionProcessing) -> Processed:
......@@ -897,7 +916,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
generation_params = {
"Steps": p.steps,
"Sampler": samplers[p.sampler_index].name,
"CFG scale": p.cfg_scale,
"CFG": p.cfg_scale,
"Seed": seed,
"GFPGAN": ("GFPGAN" if p.use_GFPGAN else None)
}
......@@ -908,7 +927,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
def infotext():
return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
return OutputInfo(prompt, generation_params_text, "".join(["\n\n" + x for x in comments]))
if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, model)
......@@ -939,17 +958,22 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.prompt_matrix or opts.samples_save or opts.grid_save:
for i, x_sample in enumerate(x_samples_ddim):
# TODO: convert to BGR colorspace?
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
if p.use_GFPGAN:
if p.use_GFPGAN and have_gfpgan and p.strength_GFPGAN > 0.0:
torch_gc()
gfpgan_model = gfpgan()
cropped_faces, restored_faces, restored_img = gfpgan_model.enhance(x_sample, has_aligned=False, only_center_face=False, paste_back=True)
x_sample = restored_img
x_sample_bgr = x_sample[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan_model.enhance(x_sample_bgr, has_aligned=False, only_center_face=False, paste_back=True)
image = Image.fromarray(gfpgan_output_bgr[:, :, ::-1])
image = Image.fromarray(x_sample)
if p.strength_GFPGAN < 1.0:
image = Image.blend(Image.fromarray(x_sample), image, p.strength_GFPGAN)
else:
image = Image.fromarray(x_sample)
if p.overlay_images is not None and i < len(p.overlay_images):
image = image.convert('RGBA')
......@@ -957,7 +981,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
image = image.convert('RGB')
if not p.do_not_save_samples:
save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=infotext())
save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=str(infotext()))
output_images.append(image)
base_count += 1
......@@ -967,7 +991,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
return_grid = opts.return_grid
if p.prompt_matrix:
grid = image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts)-1)//2))
grid = image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts) - 1) // 2))
try:
grid = draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
......@@ -983,7 +1007,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if return_grid:
output_images.insert(0, grid)
save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
save_image(grid, p.outpath, f"grid-{grid_count:04}", seed, prompt, opts.grid_format, info=str(infotext()), short_filename=not opts.grid_extended_filename)
grid_count += 1
torch_gc()
......@@ -1000,7 +1024,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
return samples_ddim
def txt2img(prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, code: str):
def txt2img(prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, strength_GFPGAN: float, prompt_matrix: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, code: str):
outpath = opts.outdir or "outputs/txt2img-samples"
p = StableDiffusionProcessingTxt2Img(
......@@ -1015,14 +1040,16 @@ def txt2img(prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, promp
width=width,
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN
use_GFPGAN=use_GFPGAN,
strength_GFPGAN=strength_GFPGAN
)
if code != '' and cmd_opts.allow_code:
p.do_not_save_grid = True
p.do_not_save_samples = True
display_result_data = [[], -1, ""]
display_result_data = [[], -1, OutputInfo()]
def display(imgs, s=display_result_data[1], i=display_result_data[2]):
display_result_data[0] = imgs
display_result_data[1] = s
......@@ -1040,7 +1067,7 @@ def txt2img(prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, promp
else:
processed = process_images(p)
return processed.images, processed.seed, plaintext_to_html(processed.info)
return processed.images, processed.info.html()
class Flagging(gr.FlaggingCallback):
......@@ -1054,7 +1081,7 @@ class Flagging(gr.FlaggingCallback):
os.makedirs("log/images", exist_ok=True)
# those must match the "txt2img" function
prompt, steps, sampler_index, use_gfpgan, prompt_matrix, n_iter, batch_size, cfg_scale, seed, height, width, code, images, seed, comment = flag_data
prompt, steps, sampler_index, use_gfpgan, gfpgan_strength, prompt_matrix, n_iter, batch_size, cfg_scale, seed, height, width, code, images, seed, comment = flag_data
filenames = []
......@@ -1069,7 +1096,7 @@ class Flagging(gr.FlaggingCallback):
filename_base = str(int(time.time() * 1000))
for i, filedata in enumerate(images):
filename = "log/images/"+filename_base + ("" if len(images) == 1 else "-"+str(i+1)) + ".png"
filename = "log/images/" + filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + ".png"
if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):]
......@@ -1084,31 +1111,6 @@ class Flagging(gr.FlaggingCallback):
print("Logged:", filenames[0])
txt2img_interface = gr.Interface(
wrap_gradio_call(txt2img),
inputs=[
gr.Textbox(label="Prompt", placeholder="A corgi wearing a top hat as an oil painting.", lines=1),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20),
gr.Radio(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=have_gfpgan),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.5),
gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
gr.Textbox(label="Python script", visible=cmd_opts.allow_code, lines=1)
],
outputs=[
gr.Gallery(label="Images"),
gr.Number(label='Seed'),
gr.HTML(),
],
title="Stable Diffusion Text-to-Image",
flagging_callback=Flagging()
)
def fill(image, mask):
image_mod = Image.new('RGBA', (image.width, image.height))
......@@ -1149,6 +1151,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.overlay_images = []
imgs = []
if not self.init_images or None in self.init_images:
raise Exception('No input image provided for Image-to-Image')
for img in self.init_images:
image = img.convert("RGB")
image = resize_image(self.resize_mode, image, self.width, self.height)
......@@ -1195,8 +1201,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype)
self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)
def sample(self, x, conditioning, unconditional_conditioning):
t_enc = int(min(self.denoising_strength, 0.999) * self.steps)
......@@ -1223,7 +1227,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
return samples_ddim
def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, strength_GFPGAN: float, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples"
if init_img_with_mask is not None:
......@@ -1248,13 +1252,14 @@ def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_
height=height,
prompt_matrix=prompt_matrix,
use_GFPGAN=use_GFPGAN,
strength_GFPGAN=strength_GFPGAN,
init_images=[image],
mask=mask,
mask_blur=mask_blur,
inpainting_fill=inpainting_fill,
resize_mode=resize_mode,
denoising_strength=denoising_strength,
extra_generation_params={"Denoising Strength": denoising_strength}
extra_generation_params={"DNS": denoising_strength}
)
if loopback:
......@@ -1282,7 +1287,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_
grid_count = len(os.listdir(outpath)) - 1
grid = image_grid(history, batch_size, rows=1)
save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename)
save_image(grid, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=str(info), short_filename=not opts.grid_extended_filename)
processed = Processed(history, initial_seed, initial_info)
......@@ -1312,7 +1317,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_
print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} in a total of {batch_count} batches.")
for i in range(batch_count):
p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
p.init_images = work[i * p.batch_size:(i + 1) * p.batch_size]
processed = process_images(p)
......@@ -1332,49 +1337,14 @@ def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_
combined_image = combine_grid(grid)
grid_count = len(os.listdir(outpath)) - 1
save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=initial_info, short_filename=not opts.grid_extended_filename)
save_image(combined_image, outpath, f"grid-{grid_count:04}", initial_seed, prompt, opts.grid_format, info=str(initial_info), short_filename=not opts.grid_extended_filename)
processed = Processed([combined_image], initial_seed, initial_info)
else:
processed = process_images(p)
return processed.images, processed.seed, plaintext_to_html(processed.info)
sample_img2img = "assets/stable-samples/img2img/sketch-mountains-input.jpg"
sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
img2img_interface = gr.Interface(
wrap_gradio_call(img2img),
inputs=[
gr.Textbox(placeholder="A fantasy landscape, trending on artstation.", lines=1),
gr.Image(label="Image for img2img", source="upload", interactive=True, type="pil"),
gr.Image(label="Image for inpainting with mask", source="upload", interactive=True, type="pil", tool="sketch"),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20),
gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index"),
gr.Slider(label='Inpainting: mask blur', minimum=0, maximum=64, step=1, value=4),
gr.Radio(label='Inpainting: masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=have_gfpgan),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),
gr.Checkbox(label='Stable Diffusion upscale', value=False),
gr.Slider(minimum=1, maximum=cmd_opts.max_batch_count, step=1, label='Batch count (how many batches of images to generate)', value=1),
gr.Slider(minimum=1, maximum=8, step=1, label='Batch size (how many images are in a batch; memory-hungry)', value=1),
gr.Slider(minimum=1.0, maximum=15.0, step=0.5, label='Classifier Free Guidance Scale (how strongly the image should follow the prompt)', value=7.0),
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', value=0.75),
gr.Number(label='Seed', value=-1),
gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512),
gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512),
gr.Radio(label="Resize mode", choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
],
outputs=[
gr.Gallery(),
gr.Number(label='Seed'),
gr.HTML(),
],
allow_flagging="never",
)
return processed.images, processed.info.html()
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
......@@ -1397,14 +1367,19 @@ def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index)
def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_index):
torch_gc()
if not image:
raise Exception('No input image provided for Post-Processing')
image = image.convert("RGB")
outpath = opts.outdir or "outputs/extras-samples"
if have_gfpgan is not None and GFPGAN_strength > 0:
if have_gfpgan and GFPGAN_strength > 0:
gfpgan_model = gfpgan()
cropped_faces, restored_faces, restored_img = gfpgan_model.enhance(np.array(image, dtype=np.uint8), has_aligned=False, only_center_face=False, paste_back=True)
res = Image.fromarray(restored_img)
img_data_bgr = np.array(image, dtype=np.uint8)[:, :, ::-1]
cropped_faces, restored_faces, restored_img = gfpgan_model.enhance(img_data_bgr, has_aligned=False, only_center_face=False, paste_back=True)
img_data_rgb = restored_img[:, :, ::-1]
res = Image.fromarray(img_data_rgb)
if GFPGAN_strength < 1.0:
res = Image.blend(image, res, GFPGAN_strength)
......@@ -1417,24 +1392,7 @@ def run_extras(image, GFPGAN_strength, RealESRGAN_upscaling, RealESRGAN_model_in
base_count = len(os.listdir(outpath))
save_image(image, outpath, f"{base_count:05}", None, '', opts.samples_format, short_filename=True)
return image, 0, ''
extras_interface = gr.Interface(
wrap_gradio_call(run_extras),
inputs=[
gr.Image(label="Source", source="upload", interactive=True, type="pil"),
gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=1, interactive=have_gfpgan),
gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Real-ESRGAN upscaling", value=2, interactive=have_realesrgan),
gr.Radio(label='Real-ESRGAN model', choices=[x.name for x in realesrgan_models], value=realesrgan_models[0].name, type="index", interactive=have_realesrgan),
],
outputs=[
gr.Image(label="Result"),
gr.Number(label='Seed', visible=False),
gr.HTML(),
],
allow_flagging="never",
)
return [image], 0, ''
def run_pnginfo(image):
......@@ -1445,7 +1403,7 @@ def run_pnginfo(image):
<p><b>{plaintext_to_html(str(key))}</b></p>
<p>{plaintext_to_html(str(text))}</p>
</div>
""".strip()+"\n"
""".strip() + "\n"
if len(info) == 0:
message = "Nothing found in the image."
......@@ -1465,22 +1423,18 @@ pnginfo_interface = gr.Interface(
allow_flagging="never",
)
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
def run_settings(*args):
up = []
for key, value, comp in zip(opts.data_labels.keys(), args, settings_interface.input_components):
for key, value in zip(opts.data_labels.keys(), args):
opts.data[key] = value
up.append(comp.update(value=value))
opts.save(config_filename)
return 'Settings saved.', ''
return plaintext_to_html(f'Settings saved @ {datetime.now().strftime("%I:%M:%S")}')
def create_setting_component(key):
......@@ -1504,26 +1458,6 @@ def create_setting_component(key):
return item
settings_interface = gr.Interface(
run_settings,
inputs=[create_setting_component(key) for key in opts.data_labels.keys()],
outputs=[
gr.Textbox(label='Result'),
gr.HTML(),
],
title=None,
description=None,
allow_flagging="never",
)
interfaces = [
(txt2img_interface, "txt2img"),
(img2img_interface, "img2img"),
(extras_interface, "Extras"),
(pnginfo_interface, "PNG Info"),
(settings_interface, "Settings"),
]
sd_config = OmegaConf.load(cmd_opts.config)
sd_model = load_model_from_config(sd_config, cmd_opts.ckpt)
sd_model = (sd_model if cmd_opts.no_half else sd_model.half())
......@@ -1537,14 +1471,381 @@ else:
model_hijack = StableDiffusionModelHijack()
model_hijack.hijack(sd_model)
demo = gr.TabbedInterface(
interface_list=[x[0] for x in interfaces],
tab_names=[x[1] for x in interfaces],
css=("" if cmd_opts.no_progressbar_hiding else css_hide_progressbar) + """
.output-html p {margin: 0 0.5em;}
.performance { font-size: 0.85em; color: #444; }
"""
)
def do_generate(
mode: str,
prompt: str,
cfg: float,
denoise: float,
sampler_index: int,
sampler_steps: int,
batch_count: int,
batch_size: int,
input_img,
resize_mode,
image_height: int,
image_width: int,
use_input_seed: bool,
input_seed: int,
facefix: bool,
facefix_strength: float,
prompt_matrix: bool,
loopback: bool,
upscale: bool):
if mode == 'Text-to-Image':
return txt2img(
prompt=prompt,
steps=sampler_steps,
sampler_index=sampler_index,
use_GFPGAN=facefix,
strength_GFPGAN=facefix_strength,
prompt_matrix=prompt_matrix,
n_iter=batch_count,
batch_size=batch_size,
cfg_scale=cfg,
seed=input_seed if use_input_seed else -1,
height=image_height,
width=image_width,
code=''
)
elif mode == 'Image-to-Image':
return img2img(
prompt=prompt,
init_img=input_img,
init_img_with_mask=None,
ddim_steps=sampler_steps,
sampler_index=sampler_index,
mask_blur=0,
inpainting_fill=0,
use_GFPGAN=facefix,
strength_GFPGAN=facefix_strength,
prompt_matrix=prompt_matrix,
loopback=loopback,
sd_upscale=upscale,
n_iter=batch_count,
batch_size=batch_size,
cfg_scale=cfg,
denoising_strength=denoise,
seed=input_seed if use_input_seed else -1,
height=image_height,
width=image_width,
resize_mode=resize_mode,
)
elif mode == 'Post-Processing':
return run_extras(
image=input_img,
GFPGAN_strength=facefix_strength,
RealESRGAN_upscaling=1.0,
RealESRGAN_model_index=0
)
raise Exception('Invalid mode selected')
css_hide_progressbar = \
"""
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""
main_css = \
"""
.output-html p { margin: 0 0.5em; }
.performance, .params-info, .comments-info { font-size: 0.85em; color: #666; }
"""
# [data-testid="image"] {min-height: 512px !important}
# #generate{width: 100%;}
custom_css = \
"""
/* hide scrollbars, better scaling for gallery, small padding for main image */
::-webkit-scrollbar { display: none }
#output_gallery {
min-height: 50vh !important;
scrollbar-width: none;
}
#output_gallery > div > img {
padding-top: 0.5rem;
padding-right: 0.5rem;
padding-left: 0.5rem;
}
/* remove excess padding around prompt textbox, increase font size */
#prompt_row input { font-size: 16px }
#prompt_input {
padding-top: 0.25rem !important;
padding-bottom: 0rem !important;
padding-left: 0rem !important;
padding-right: 0rem !important;
border-style: none !important;
}
/* remove excess padding from mode dropdown, change appear to a button */
#sd_mode {
padding-top: 0 !important;
padding-bottom: 0 !important;
padding-left: 0 !important;
padding-right: 0 !important;
border-style: none !important;
}
#sd_mode > label > select {
font-weight: 600;
min-height: 42px;
max-height: 42px;
text-align: center;
font-size: 1rem;
appearance: none;
-webkit-appearance: none;
background-position: left;
background-size: contain;
padding-right: 0;
border-color: rgb(75 85 99 / var(--tw-border-opacity));
}
/* custom column scaling (odd = right/left, even = center) */
#body>.col:nth-child(odd) {
max-width: 450px;
min-width: 300px;
}
#body>.col:nth-child(even) {
width:250%;
}
/* better overall scaling + limits */
.container {
max-width: min(1600px, 95%);
}
/* hide increment/decrement buttons on number inputs */
input[type="number"]::-webkit-outer-spin-button,
input[type="number"]::-webkit-inner-spin-button {
-webkit-appearance: none;
margin: 0;
}
input[type="number"] {
-moz-appearance: textfield;
}
"""
full_css = main_css + css_hide_progressbar + custom_css
with gr.Blocks(css=full_css, analytics_enabled=False, title='Stable Diffusion WebUI') as demo:
with gr.Tabs(elem_id='tabs'):
with gr.TabItem('Stable Diffusion', id='sd_tab'):
with gr.Row(elem_id='prompt_row'):
sd_prompt = gr.Textbox(elem_id='prompt_input', placeholder='A corgi wearing a top hat as an oil painting.', lines=1, max_lines=1, show_label=False)
with gr.Row(elem_id='body').style(equal_height=False):
# Left Column
with gr.Column():
sd_mode = \
gr.Dropdown(show_label=False, value='Text-to-Image', choices=['Text-to-Image', 'Image-to-Image', 'Post-Processing'], elem_id='sd_mode')
with gr.Row():
sd_image_height = \
gr.Number(label="Image height", value=512, precision=0, elem_id='img_height')
sd_image_width = \
gr.Number(label="Image width", value=512, precision=0, elem_id='img_width')
with gr.Row():
sd_batch_count = \
gr.Number(label='Batch count', precision=0, value=1)
sd_batch_size = \
gr.Number(label='Images per batch', precision=0, value=1)
with gr.Group():
sd_input_image = \
gr.Image(label='Input Image', source="upload", interactive=True, type="pil", show_label=True, visible=False)
sd_resize_mode = \
gr.Dropdown(label="Resize mode", choices=["Stretch", "Scale and crop", "Scale and fill"], type="index", value="Stretch", visible=False)
# Center Column
with gr.Column():
sd_output_image = \
gr.Gallery(show_label=False, elem_id='output_gallery').style(grid=3)
sd_output_html = \
gr.HTML()
# Right Column
with gr.Column():
sd_generate = \
gr.Button('Generate', variant='primary').style(full_width=True)
with gr.Row():
sd_sampling_method = \
gr.Dropdown(label='Sampling method', choices=[x.name for x in samplers], value=samplers[0].name, type="index")
sd_sampling_steps = \
gr.Slider(label="Sampling steps", value=30, minimum=5, maximum=100, step=5)
with gr.Group():
sd_cfg = \
gr.Slider(label='Prompt similarity (CFG)', value=8.0, minimum=1.0, maximum=15.0, step=0.5)
sd_denoise = \
gr.Slider(label='Denoising strength (DNS)', value=0.75, minimum=0.0, maximum=1.0, step=0.01, visible=False)
sd_facefix = \
gr.Checkbox(label='GFPGAN', value=False, visible=have_gfpgan)
sd_facefix_strength = \
gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Strength", value=1, interactive=have_gfpgan, visible=False)
sd_use_input_seed = \
gr.Checkbox(label='Custom seed')
sd_input_seed = \
gr.Number(value=-1, visible=False, show_label=False)
# TODO: Change to 'Enable syntactic prompts'
sd_matrix = \
gr.Checkbox(label='Create prompt matrix', value=False)
sd_loopback = \
gr.Checkbox(label='Output loopback', value=False, visible=False)
sd_upscale = \
gr.Checkbox(label='Super resolution upscale', value=False, visible=False)
with gr.TabItem('Settings', id='settings_tab'):
# TODO: Add HTML output to indicate settings saved
sd_settings = [create_setting_component(key) for key in opts.data_labels.keys()]
sd_save_settings = \
gr.Button('Save')
sd_confirm_settings = \
gr.HTML()
def mode_change(mode: str, facefix: bool, custom_seed: bool):
is_img2img = (mode == 'Image-to-Image')
is_txt2img = (mode == 'Text-to-Image')
is_pp = (mode == 'Post-Processing')
return {
sd_cfg: gr.update(visible=is_img2img or is_txt2img),
sd_denoise: gr.update(visible=is_img2img),
sd_sampling_method: gr.update(visible=is_img2img or is_txt2img),
sd_sampling_steps: gr.update(visible=is_img2img or is_txt2img),
sd_batch_count: gr.update(visible=is_img2img or is_txt2img),
sd_batch_size: gr.update(visible=is_img2img or is_txt2img),
sd_input_image: gr.update(visible=is_img2img or is_pp),
sd_resize_mode: gr.update(visible=is_img2img),
sd_image_height: gr.update(visible=is_img2img or is_txt2img),
sd_image_width: gr.update(visible=is_img2img or is_txt2img),
sd_use_input_seed: gr.update(visible=is_img2img or is_txt2img),
# TODO: can we handle this by updating use_input_seed and having its callback handle it?
sd_input_seed: gr.update(visible=(is_img2img or is_txt2img) and custom_seed),
sd_facefix: gr.update(visible=True),
# TODO: see above, but for facefix
sd_facefix_strength: gr.update(visible=facefix),
sd_matrix: gr.update(visible=is_img2img or is_txt2img),
sd_loopback: gr.update(visible=is_img2img),
sd_upscale: gr.update(visible=is_img2img)
}
sd_mode.change(
fn=mode_change,
inputs=[
sd_mode,
sd_facefix,
sd_use_input_seed
],
outputs=[
sd_cfg,
sd_denoise,
sd_sampling_method,
sd_sampling_steps,
sd_batch_count,
sd_batch_size,
sd_input_image,
sd_resize_mode,
sd_image_height,
sd_image_width,
sd_use_input_seed,
sd_input_seed,
sd_facefix,
sd_facefix_strength,
sd_matrix,
sd_loopback,
sd_upscale
]
)
do_generate_args = dict(
fn=wrap_gradio_call(do_generate),
inputs=[
sd_mode,
sd_prompt,
sd_cfg,
sd_denoise,
sd_sampling_method,
sd_sampling_steps,
sd_batch_count,
sd_batch_size,
sd_input_image,
sd_resize_mode,
sd_image_height,
sd_image_width,
sd_use_input_seed,
sd_input_seed,
sd_facefix,
sd_facefix_strength,
sd_matrix,
sd_loopback,
sd_upscale
],
outputs=[
sd_output_image,
sd_output_html
]
)
sd_prompt.submit(**do_generate_args)
sd_generate.click(**do_generate_args)
sd_use_input_seed.change(
lambda checked: gr.update(visible=checked),
inputs=sd_use_input_seed,
outputs=sd_input_seed
)
sd_image_height.submit(
lambda value: 64 * ((value + 63) // 64) if value > 0 else 512,
inputs=sd_image_height,
outputs=sd_image_height
)
sd_image_width.submit(
lambda value: 64 * ((value + 63) // 64) if value > 0 else 512,
inputs=sd_image_width,
outputs=sd_image_width
)
sd_batch_count.submit(
lambda value: value if value > 0 else 1,
inputs=sd_batch_count,
outputs=sd_batch_count
)
sd_batch_size.submit(
lambda value: value if value > 0 else 1,
inputs=sd_batch_size,
outputs=sd_batch_size
)
sd_facefix.change(
lambda checked: gr.update(visible=checked),
inputs=sd_facefix,
outputs=sd_facefix_strength
)
sd_save_settings.click(
fn=run_settings,
inputs=sd_settings,
outputs=sd_confirm_settings
)
demo.queue(concurrency_count=1)
demo.launch()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment