Commit d4205e66 authored by AUTOMATIC's avatar AUTOMATIC

gfpgan: just download the damn model

parent d6fd71f3
import os import os
import sys import sys
import traceback import traceback
from glob import glob
from modules import shared, devices from modules import shared, devices
from modules.shared import cmd_opts from modules.shared import cmd_opts
...@@ -11,14 +12,20 @@ import modules.face_restoration ...@@ -11,14 +12,20 @@ import modules.face_restoration
def gfpgan_model_path(): def gfpgan_model_path():
from modules.shared import cmd_opts from modules.shared import cmd_opts
filemask = 'GFPGAN*.pth'
if cmd_opts.gfpgan_model is not None:
return cmd_opts.gfpgan_model
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')] places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
found = [x for x in files if os.path.exists(x)]
if len(found) == 0: filename = None
raise Exception("GFPGAN model not found in paths: " + ", ".join(files)) for place in places:
filename = next(iter(glob(os.path.join(place, filemask))), None)
if filename is not None:
break
return found[0] return filename
loaded_gfpgan_model = None loaded_gfpgan_model = None
...@@ -34,7 +41,7 @@ def gfpgan(): ...@@ -34,7 +41,7 @@ def gfpgan():
if gfpgan_constructor is None: if gfpgan_constructor is None:
return None return None
model = gfpgan_constructor(model_path=gfpgan_model_path(), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device) model.gfpgan.to(shared.device)
loaded_gfpgan_model = model loaded_gfpgan_model = model
......
...@@ -2,7 +2,6 @@ import sys ...@@ -2,7 +2,6 @@ import sys
import argparse import argparse
import json import json
import os import os
from glob import glob
import gradio as gr import gradio as gr
import tqdm import tqdm
...@@ -22,7 +21,7 @@ parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs ...@@ -22,7 +21,7 @@ parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",) parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=next(iter(glob('GFPGAN*.pth')), '')) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)") parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI") parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment