Commit 843b2b64 authored by AUTOMATIC's avatar AUTOMATIC

Instance of CUDA out of memory on a low-res batch, even with...

 Instance of CUDA out of memory on a low-res batch, even with --opt-split-attention-v1 (found cause) #255
parent 535b25ad
......@@ -5,7 +5,7 @@ import traceback
import cv2
import torch
from modules import shared
from modules import shared, devices
from modules.paths import script_path
import modules.shared
import modules.face_restoration
......@@ -51,6 +51,7 @@ def setup_codeformer():
def create_models(self):
if self.net is not None and self.face_helper is not None:
self.net.to(shared.device)
return self.net, self.face_helper
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(shared.device)
......@@ -61,9 +62,9 @@ def setup_codeformer():
face_helper = FaceRestoreHelper(1, face_size=512, crop_ratio=(1, 1), det_model='retinaface_resnet50', save_ext='png', use_parse=True, device=shared.device)
if not cmd_opts.unload_gfpgan:
self.net = net
self.face_helper = face_helper
self.net = net
self.face_helper = face_helper
self.net.to(shared.device)
return net, face_helper
......@@ -72,20 +73,20 @@ def setup_codeformer():
original_resolution = np_image.shape[0:2]
net, face_helper = self.create_models()
face_helper.clean_all()
face_helper.read_image(np_image)
face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
face_helper.align_warp_face()
self.create_models()
self.face_helper.clean_all()
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
self.face_helper.align_warp_face()
for idx, cropped_face in enumerate(face_helper.cropped_faces):
for idx, cropped_face in enumerate(self.face_helper.cropped_faces):
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
cropped_face_t = cropped_face_t.unsqueeze(0).to(shared.device)
try:
with torch.no_grad():
output = net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
output = self.net(cropped_face_t, w=w if w is not None else shared.opts.code_former_weight, adain=True)[0]
restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
del output
torch.cuda.empty_cache()
......@@ -94,16 +95,19 @@ def setup_codeformer():
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
restored_face = restored_face.astype('uint8')
face_helper.add_restored_face(restored_face)
self.face_helper.add_restored_face(restored_face)
face_helper.get_inverse_affine(None)
self.face_helper.get_inverse_affine(None)
restored_img = face_helper.paste_faces_to_input_image()
restored_img = self.face_helper.paste_faces_to_input_image()
restored_img = restored_img[:, :, ::-1]
if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
if shared.opts.face_restoration_unload:
self.net.to(devices.cpu)
return restored_img
global have_codeformer
......
......@@ -2,7 +2,7 @@ import os
import sys
import traceback
from modules import shared
from modules import shared, devices
from modules.shared import cmd_opts
from modules.paths import script_path
import modules.face_restoration
......@@ -28,24 +28,29 @@ def gfpgan():
global loaded_gfpgan_model
if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device)
return loaded_gfpgan_model
if gfpgan_constructor is None:
return None
model = gfpgan_constructor(model_path=gfpgan_model_path(), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
if not cmd_opts.unload_gfpgan:
loaded_gfpgan_model = model
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
return model
def gfpgan_fix_faces(np_image):
model = gfpgan()
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
if shared.opts.face_restoration_unload:
model.gfpgan.to(devices.cpu)
return np_image
......
......@@ -30,7 +30,7 @@ parser.add_argument("--allow-code", action='store_true', help="allow custom scri
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="a workaround test; may help with speed if you use --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="unload GFPGAN every time after processing images. Warning: seems to cause memory leaks")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
......@@ -133,6 +133,7 @@ class Options:
"face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
"code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
"face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
......
......@@ -384,8 +384,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
switch_mode = gr.Radio(label='Mode', elem_id="img2img_mode", choices=['Redraw whole image', 'Inpaint a part of image', 'Loopback', 'SD upscale'], value='Redraw whole image', type="index", show_label=False)
init_img = gr.Image(label="Image for img2img", source="upload", interactive=True, type="pil")
init_img_with_mask = gr.Image(label="Image for inpainting with mask", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", visible=False, image_mode="RGBA")
init_img_with_mask_comment = gr.HTML(elem_id="mask_bug_info", value="<small>if the editor shows ERROR, switch to another tab and back, then to another img2img mode above and back</small>", visible=False)
init_mask = gr.Image(label="Mask", source="upload", interactive=True, type="pil", visible=False)
init_img_with_mask_comment = gr.HTML(elem_id="mask_bug_info", value="<small>if the editor shows ERROR, switch to another tab and back, then to another img2img mode above and back</small>", visible=False)
with gr.Row():
resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", show_label=False, choices=["Just resize", "Crop and resize", "Resize and fill"], type="index", value="Just resize")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment