Commit 26d08193 authored by Chris OBryan's avatar Chris OBryan

extras: Add option to run upscaling before face fixing

Face restoration can look much better if ran after upscaling, as it
allows the restoration to fix upscaling artifacts. This patch adds
an option to choose which order to run upscaling/face fixing in.
parent 737eb28f
...@@ -7,6 +7,10 @@ from PIL import Image ...@@ -7,6 +7,10 @@ from PIL import Image
import torch import torch
import tqdm import tqdm
from typing import Callable, List, Tuple
from functools import partial
from dataclasses import dataclass
from modules import processing, shared, images, devices, sd_models from modules import processing, shared, images, devices, sd_models
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
...@@ -20,7 +24,7 @@ import gradio as gr ...@@ -20,7 +24,7 @@ import gradio as gr
cached_images = {} cached_images = {}
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ):
devices.torch_gc() devices.torch_gc()
imageArr = [] imageArr = []
...@@ -56,68 +60,109 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -56,68 +60,109 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
else: else:
outpath = opts.outdir_samples or opts.outdir_extras_samples outpath = opts.outdir_samples or opts.outdir_extras_samples
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
existing_pnginfo = image.info or {}
image = image.convert("RGB") # Extra operation definitions
info = "" def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
if gfpgan_visibility < 1.0:
res = Image.blend(image, res, gfpgan_visibility)
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
return (res, info)
if gfpgan_visibility > 0: def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8)) restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img) res = Image.fromarray(restored_img)
if gfpgan_visibility < 1.0: if codeformer_visibility < 1.0:
res = Image.blend(image, res, gfpgan_visibility) res = Image.blend(image, res, codeformer_visibility)
info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n" info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res return (res, info)
if codeformer_visibility > 0:
restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
res = Image.fromarray(restored_img)
if codeformer_visibility < 1.0: def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
res = Image.blend(image, res, codeformer_visibility) small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
pixels = tuple(np.array(small).flatten().tolist())
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight,
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n" c = cached_images.get(key)
image = res if c is None:
upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
c = cropped
cached_images[key] = c
return c
def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
nonlocal upscaling_resize
if resize_mode == 1: if resize_mode == 1:
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height) upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
crop_info = " (crop)" if upscaling_crop else "" crop_info = " (crop)" if upscaling_crop else ""
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n" info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
return (image, info)
@dataclass
class UpscaleParams:
upscaler_idx: int
blend_alpha: float
def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None
for upscaler in params:
res = upscale(image, upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
if blended_result is None:
blended_result = res
else:
blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
return (blended_result, info)
# Build a list of operations to run
facefix_ops: List[Callable] = []
if gfpgan_visibility > 0:
facefix_ops.append(run_gfpgan)
if codeformer_visibility > 0:
facefix_ops.append(run_codeformer)
upscale_ops: List[Callable] = []
if resize_mode == 1:
upscale_ops.append(run_prepare_crop)
if upscaling_resize != 0:
step_params: List[UpscaleParams] = []
step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_1, blend_alpha=1.0 ))
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility ) )
upscale_ops.append( partial(run_upscalers_blend, step_params) )
extras_ops: List[Callable] = []
if upscale_first:
extras_ops = upscale_ops + facefix_ops
else:
extras_ops = facefix_ops + upscale_ops
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
existing_pnginfo = image.info or {}
if upscaling_resize != 1.0: image = image.convert("RGB")
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): info = ""
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) # Run each operation on each image
pixels = tuple(np.array(small).flatten().tolist()) for op in extras_ops:
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight, image, info = op(image, info)
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels
c = cached_images.get(key)
if c is None:
upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
c = cropped
cached_images[key] = c
return c
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
res = Image.blend(res, res2, extras_upscaler_2_visibility)
image = res
while len(cached_images) > 2: while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))] del cached_images[next(iter(cached_images.keys()))]
......
...@@ -1119,6 +1119,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1119,6 +1119,9 @@ def create_ui(wrap_gradio_gpu_call):
codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer) codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer) codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)
with gr.Group():
upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary') submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
...@@ -1152,6 +1155,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1152,6 +1155,7 @@ def create_ui(wrap_gradio_gpu_call):
extras_upscaler_1, extras_upscaler_1,
extras_upscaler_2, extras_upscaler_2,
extras_upscaler_2_visibility, extras_upscaler_2_visibility,
upscale_before_face_fix,
], ],
outputs=[ outputs=[
result_images, result_images,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment