Commit 6c6ae28b authored by AUTOMATIC's avatar AUTOMATIC

send all three of GFPGAN's and codeformer's models to CPU memory instead of just one for #1283

parent 556c36b9
......@@ -69,10 +69,14 @@ def setup_model(dirname):
self.net = net
self.face_helper = face_helper
self.net.to(devices.device_codeformer)
return net, face_helper
def send_model_to(self, device):
self.net.to(device)
self.face_helper.face_det.to(device)
self.face_helper.face_parse.to(device)
def restore(self, np_image, w=None):
np_image = np_image[:, :, ::-1]
......@@ -82,6 +86,8 @@ def setup_model(dirname):
if self.net is None or self.face_helper is None:
return np_image
self.send_model_to(devices.device_codeformer)
self.face_helper.clean_all()
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
......@@ -113,8 +119,10 @@ def setup_model(dirname):
if original_resolution != restored_img.shape[0:2]:
restored_img = cv2.resize(restored_img, (0, 0), fx=original_resolution[1]/restored_img.shape[1], fy=original_resolution[0]/restored_img.shape[0], interpolation=cv2.INTER_LINEAR)
self.face_helper.clean_all()
if shared.opts.face_restoration_unload:
self.net.to(devices.cpu)
self.send_model_to(devices.cpu)
return restored_img
......
import contextlib
import torch
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
......@@ -57,3 +59,11 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device)
def autocast():
from modules import shared
if dtype == torch.float32 or shared.cmd_opts.precision == "full":
return contextlib.nullcontext()
return torch.autocast("cuda")
......@@ -37,22 +37,32 @@ def gfpgann():
print("Unable to load gfpgan model!")
return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
return model
def send_model_to(model, device):
model.gfpgan.to(device)
model.face_helper.face_det.to(device)
model.face_helper.face_parse.to(device)
def gfpgan_fix_faces(np_image):
model = gfpgann()
if model is None:
return np_image
send_model_to(model, devices.device)
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
model.face_helper.clean_all()
if shared.opts.face_restoration_unload:
model.gfpgan.to(devices.cpu)
send_model_to(model, devices.cpu)
return np_image
......
import contextlib
import json
import math
import os
......@@ -330,9 +329,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
with torch.no_grad(), precision_scope("cuda"), ema_scope():
with torch.no_grad():
p.init(all_prompts, all_seeds, all_subseeds)
if state.job_count == -1:
......@@ -351,8 +349,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
#uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
#c = p.sd_model.get_learned_conditioning(prompts)
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_learned_conditioning(prompts, p.steps)
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
......@@ -361,7 +360,9 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength)
with devices.autocast():
samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength).to(devices.dtype)
if state.interrupted:
# if we are interruped, sample returns just noise
......@@ -386,6 +387,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
x_sample = modules.face_restoration.restore_faces(x_sample)
devices.torch_gc()
image = Image.fromarray(x_sample)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment