Commit 6a9b33c8 authored by AUTOMATIC's avatar AUTOMATIC

codeformer support

parent 9cb3cc3a
This diff is collapsed.
This diff is collapsed.
import os
import sys
import traceback
import torch
from modules import shared
from modules.paths import script_path
import modules.shared
import modules.face_restoration
from importlib import reload
# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.
pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False
def setup_codeformer():
path = modules.paths.paths.get("CodeFormer", None)
if path is None:
return
# both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own
#stored_sys_path = sys.path
#sys.path = [path] + sys.path
try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
from basicsr.utils.download_util import load_file_from_url
from basicsr.utils import imwrite, img2tensor, tensor2img
from facelib.utils.face_restoration_helper import FaceRestoreHelper
from modules.shared import cmd_opts
net_class = CodeFormer
class FaceRestorerCodeFormer(modules.face_restoration.FaceRestoration):
def name(self):
return "CodeFormer"
def __init__(self):
self.net = None
self.face_helper = None
def create_models(self):
if self.net is not None and self.face_helper is not None:
return self.net, self.face_helper
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(shared.device)
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)
net.eval()
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=shared.device)
if not cmd_opts.unload_gfpgan:
self.net = net
self.face_helper = face_helper
return net, face_helper
def restore(self, np_image):
np_image = np_image[:, :, ::-1]
net, face_helper = self.create_models()
face_helper.clean_all()
face_helper.read_image(np_image)
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
for idx, cropped_face in enumerate(face_helper.cropped_faces):
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(shared.device)
try:
with torch.no_grad():
output = net(cropped_face_t, w=shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
except Exception as error:
print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr)
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
restored_img = restored_img[:, :, ::-1]
return restored_img
global have_codeformer
have_codeformer = True
shared.face_restorers.append(FaceRestorerCodeFormer())
except Exception:
print("Error setting up CodeFormer:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
# sys.path = stored_sys_path
from modules import shared
class FaceRestoration:
def name(self):
return "None"
def restore(self, np_image):
return np_image
def restore_faces(np_image):
face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None]
if len(face_restorers) == 0:
return np_image
face_restorer = face_restorers[0]
return face_restorer.restore(np_image)
......@@ -2,12 +2,15 @@ import os
import sys
import traceback
from modules.paths import script_path
from modules import shared
from modules.shared import cmd_opts
import modules.shared
from modules.paths import script_path
import modules.face_restoration
def gfpgan_model_path():
from modules.shared import cmd_opts
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
found = [x for x in files if os.path.exists(x)]
......@@ -62,6 +65,19 @@ def setup_gfpgan():
global gfpgan_constructor
gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
def name(self):
return "GFPGAN"
def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image
shared.face_restorers.append(FaceRestorerGFPGAN())
except Exception:
print("Error setting up GFPGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
......@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
import modules.images as images
import modules.scripts
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
is_inpaint = mode == 1
is_loopback = mode == 2
is_upscale = mode == 3
......@@ -36,7 +36,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
cfg_scale=cfg_scale,
width=width,
height=height,
use_GFPGAN=use_GFPGAN,
restore_faces=restore_faces,
tiling=tiling,
init_images=[image],
mask=mask,
......
......@@ -5,7 +5,7 @@ import sys
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.insert(0, script_path)
# use current directory as SD dir if it has related files, otherwise parent dir of script as stated in guide
# search for directory of stable diffsuion in following palces
sd_path = None
possible_sd_paths = ['.', os.path.dirname(script_path), os.path.join(script_path, 'repositories/stable-diffusion')]
for possible_sd_path in possible_sd_paths:
......@@ -14,14 +14,19 @@ for possible_sd_path in possible_sd_paths:
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + possible_sd_paths
# add parent directory to path; this is where Stable diffusion repo should be
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion'),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers')
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
]
paths = {}
for d, must_exist, what in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path):
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
else:
sys.path.append(os.path.join(script_path, d))
d = os.path.abspath(d)
sys.path.append(d)
paths[what] = d
......@@ -14,7 +14,7 @@ from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.gfpgan_model as gfpgan
import modules.face_restoration
import modules.images as images
# some of those options should not be changed at all because they would break the model, so I removed them from options.
......@@ -29,7 +29,7 @@ def torch_gc():
class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, use_GFPGAN=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
......@@ -44,7 +44,7 @@ class StableDiffusionProcessing:
self.cfg_scale: float = cfg_scale
self.width: int = width
self.height: int = height
self.use_GFPGAN: bool = use_GFPGAN
self.restore_faces: bool = restore_faces
self.tiling: bool = tiling
self.do_not_save_samples: bool = do_not_save_samples
self.do_not_save_grid: bool = do_not_save_grid
......@@ -136,7 +136,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
"Sampler": samplers[p.sampler_index].name,
"CFG scale": p.cfg_scale,
"Seed": all_seeds[position_in_batch + iteration * p.batch_size],
"GFPGAN": ("GFPGAN" if p.use_GFPGAN else None),
"Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
}
......@@ -193,10 +193,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
if p.use_GFPGAN:
if p.restore_faces:
torch_gc()
x_sample = gfpgan.gfpgan_fix_faces(x_sample)
x_sample = modules.face_restoration.restore_faces(x_sample)
image = Image.fromarray(x_sample)
......
import argparse
import json
import os
import gradio as gr
import torch
import modules.artists
from modules.paths import script_path, sd_path
import modules.codeformer_model
config_filename = "config.json"
......@@ -40,6 +42,7 @@ device = gpu if torch.cuda.is_available() else cpu
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
class State:
interrupted = False
job = ""
......@@ -65,6 +68,7 @@ state = State()
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))
face_restorers = []
def find_any_font():
fonts = ['/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf']
......@@ -116,6 +120,8 @@ class Options:
"upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}),
"show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
}
def __init__(self):
......
......@@ -6,7 +6,7 @@ import modules.processing as processing
from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, use_GFPGAN: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args):
def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, height: int, width: int, *args):
p = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples,
......@@ -21,7 +21,7 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u
cfg_scale=cfg_scale,
width=width,
height=height,
use_GFPGAN=use_GFPGAN,
restore_faces=restore_faces,
tiling=tiling,
)
......
......@@ -206,7 +206,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
sampler_index = gr.Radio(label='Sampling method', elem_id="txt2img_sampling", choices=[x.name for x in samplers], value=samplers[0].name, type="index")
with gr.Row():
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False)
with gr.Row():
......@@ -253,7 +253,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
negative_prompt,
steps,
sampler_index,
use_gfpgan,
restore_faces,
tiling,
batch_count,
batch_size,
......@@ -335,7 +335,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
inpainting_mask_invert = gr.Radio(label='Masking mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", visible=False)
with gr.Row():
use_gfpgan = gr.Checkbox(label='GFPGAN', value=False, visible=gfpgan.have_gfpgan)
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False)
sd_upscale_overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, visible=False)
......@@ -425,7 +425,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
sampler_index,
mask_blur,
inpainting_fill,
use_gfpgan,
restore_faces,
tiling,
switch_mode,
batch_count,
......@@ -521,7 +521,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
extras_upscaler_2_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Upscaler 2 visibility", value=1)
with gr.Group():
gfpgan_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="GFPGAN strength", value=0, interactive=gfpgan.have_gfpgan)
face_restoration_blending = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Faces restoration visibility", value=0, interactive=len(shared.face_restorers) > 1)
submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')
......@@ -534,7 +534,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
fn=run_extras,
inputs=[
image,
gfpgan_strength,
face_restoration_blending,
upscaling_resize,
extras_upscaler_1,
extras_upscaler_2,
......@@ -585,7 +585,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
t = type(info.default)
if info.component is not None:
item = info.component(label=info.label, value=fun, **(info.component_args or {}))
args = info.component_args() if callable(info.component_args) else info.component_args
item = info.component(label=info.label, value=fun, **(args or {}))
elif t == str:
item = gr.Textbox(label=info.label, value=fun, lines=1)
elif t == int:
......
......@@ -92,6 +92,7 @@ echo Installing requirements...
%PYTHON% -m pip install -r %REQS_FILE% --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :update_numpy
goto :show_stdout_stderr
:update_numpy
%PYTHON% -m pip install -U numpy --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
......@@ -105,12 +106,28 @@ if %ERRORLEVEL% == 0 goto :clone_transformers
goto :show_stdout_stderr
:clone_transformers
if exist repositories\taming-transformers goto :check_model
if exist repositories\taming-transformers goto :clone_codeformer
echo Cloning Taming Transforming repository...
%GIT% clone https://github.com/CompVis/taming-transformers.git repositories\taming-transformers >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :clone_codeformer
goto :show_stdout_stderr
:clone_codeformer
if exist repositories\CodeFormer goto :install_codeformer_reqs
echo Cloning CodeFormer repository...
%GIT% clone https://github.com/sczhou/CodeFormer.git repositories\CodeFormer >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :install_codeformer_reqs
goto :show_stdout_stderr
:install_codeformer_reqs
%PYTHON% -c "import lpips" >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :check_model
echo Installing requirements for CodeFormer...
%PYTHON% -m pip install -r repositories\CodeFormer\requirements.txt --prefer-binary >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :check_model
goto :show_stdout_stderr
:check_model
dir model.ckpt >tmp/stdout.txt 2>tmp/stderr.txt
if %ERRORLEVEL% == 0 goto :check_gfpgan
......
......@@ -19,7 +19,9 @@ from modules.ui import plaintext_to_html
import modules.scripts
import modules.processing as processing
import modules.sd_hijack
import modules.gfpgan_model as gfpgan
import modules.codeformer_model
import modules.gfpgan_model
import modules.face_restoration
import modules.realesrgan_model as realesrgan
import modules.esrgan_model as esrgan
import modules.images as images
......@@ -28,10 +30,12 @@ import modules.txt2img
import modules.img2img
modules.codeformer_model.setup_codeformer()
modules.gfpgan_model.setup_gfpgan()
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
esrgan.load_models(cmd_opts.esrgan_models_path)
realesrgan.setup_realesrgan()
gfpgan.setup_gfpgan()
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
......@@ -54,19 +58,19 @@ def load_model_from_config(config, ckpt, verbose=False):
cached_images = {}
def run_extras(image, gfpgan_strength, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
def run_extras(image, face_restoration_blending, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
processing.torch_gc()
image = image.convert("RGB")
outpath = opts.outdir_samples or opts.outdir_extras_samples
if gfpgan.have_gfpgan is not None and gfpgan_strength > 0:
restored_img = gfpgan.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
if face_restoration_blending > 0:
restored_img = modules.face_restoration.restore_faces(np.array(image, dtype=np.uint8))
res = Image.fromarray(restored_img)
if gfpgan_strength < 1.0:
res = Image.blend(image, res, gfpgan_strength)
if face_restoration_blending < 1.0:
res = Image.blend(image, res, face_restoration_blending)
image = res
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment