Commit b6e5edd7 authored by AUTOMATIC's avatar AUTOMATIC

add built-in extension system

add support for adding upscalers in extensions
move LDSR, ScuNET and SwinIR to built-in extensions
parent 46b0d230
import os
from modules import paths
def preload(parser):
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR'))
...@@ -5,8 +5,8 @@ import traceback ...@@ -5,8 +5,8 @@ import traceback
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.ldsr_model_arch import LDSR from ldsr_model_arch import LDSR
from modules import shared from modules import shared, script_callbacks
class UpscalerLDSR(Upscaler): class UpscalerLDSR(Upscaler):
...@@ -52,3 +52,12 @@ class UpscalerLDSR(Upscaler): ...@@ -52,3 +52,12 @@ class UpscalerLDSR(Upscaler):
return img return img
ddim_steps = shared.opts.ldsr_steps ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale) return ldsr.super_resolution(img, ddim_steps, self.scale)
def on_ui_settings():
import gradio as gr
shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
script_callbacks.on_ui_settings(on_ui_settings)
import os
from modules import paths
def preload(parser):
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET'))
...@@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -9,7 +9,7 @@ from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader from modules import devices, modelloader
from modules.scunet_model_arch import SCUNet as net from scunet_model_arch import SCUNet as net
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
...@@ -49,7 +49,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -49,7 +49,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
if model is None: if model is None:
return img return img
device = devices.device_scunet device = devices.get_device_for('scunet')
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
...@@ -66,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -66,7 +66,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
return PIL.Image.fromarray(output, 'RGB') return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str): def load_model(self, path: str):
device = devices.device_scunet device = devices.get_device_for('scunet')
if "http" in path: if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name, filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True) progress=True)
......
import os
from modules import paths
def preload(parser):
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR'))
...@@ -7,13 +7,16 @@ from PIL import Image ...@@ -7,13 +7,16 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader, devices from modules import modelloader, devices, script_callbacks, shared
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
from modules.swinir_model_arch import SwinIR as net from swinir_model_arch import SwinIR as net
from modules.swinir_model_arch_v2 import Swin2SR as net2 from swinir_model_arch_v2 import Swin2SR as net2
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
device_swinir = devices.get_device_for('swinir')
class UpscalerSwinIR(Upscaler): class UpscalerSwinIR(Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self.name = "SwinIR" self.name = "SwinIR"
...@@ -38,7 +41,7 @@ class UpscalerSwinIR(Upscaler): ...@@ -38,7 +41,7 @@ class UpscalerSwinIR(Upscaler):
model = self.load_model(model_file) model = self.load_model(model_file)
if model is None: if model is None:
return img return img
model = model.to(devices.device_swinir) model = model.to(device_swinir, dtype=devices.dtype)
img = upscale(img, model) img = upscale(img, model)
try: try:
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -90,8 +93,6 @@ class UpscalerSwinIR(Upscaler): ...@@ -90,8 +93,6 @@ class UpscalerSwinIR(Upscaler):
model.load_state_dict(pretrained_model[params], strict=True) model.load_state_dict(pretrained_model[params], strict=True)
else: else:
model.load_state_dict(pretrained_model, strict=True) model.load_state_dict(pretrained_model, strict=True)
if not cmd_opts.no_half:
model = model.half()
return model return model
...@@ -107,7 +108,7 @@ def upscale( ...@@ -107,7 +108,7 @@ def upscale(
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_swinir) img = img.unsqueeze(0).to(device_swinir, dtype=devices.dtype)
with torch.no_grad(), devices.autocast(): with torch.no_grad(), devices.autocast():
_, _, h_old, w_old = img.size() _, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old h_pad = (h_old // window_size + 1) * window_size - h_old
...@@ -135,8 +136,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -135,8 +136,8 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
stride = tile - tile_overlap stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile] h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=devices.device_swinir).type_as(img) E = torch.zeros(b, c, h * sf, w * sf, dtype=devices.dtype, device=device_swinir).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=devices.device_swinir) W = torch.zeros_like(E, dtype=devices.dtype, device=device_swinir)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar: with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list: for h_idx in h_idx_list:
...@@ -155,3 +156,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -155,3 +156,13 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
output = E.div_(W) output = E.div_(W)
return output return output
def on_ui_settings():
import gradio as gr
shared.opts.add_option("SWIN_tile", shared.OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}, section=('upscaling', "Upscaling")))
shared.opts.add_option("SWIN_tile_overlap", shared.OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}, section=('upscaling', "Upscaling")))
script_callbacks.on_ui_settings(on_ui_settings)
...@@ -44,6 +44,15 @@ def get_optimal_device(): ...@@ -44,6 +44,15 @@ def get_optimal_device():
return cpu return cpu
def get_device_for(task):
from modules import shared
if task in shared.cmd_opts.use_cpu:
return cpu
return get_optimal_device()
def torch_gc(): def torch_gc():
if torch.cuda.is_available(): if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()): with torch.cuda.device(get_cuda_device_string()):
...@@ -67,7 +76,7 @@ def enable_tf32(): ...@@ -67,7 +76,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32") errors.run(enable_tf32, "Enabling TF32")
cpu = torch.device("cpu") cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16 dtype = torch.float16
dtype_vae = torch.float16 dtype_vae = torch.float16
......
...@@ -8,6 +8,7 @@ from modules import paths, shared ...@@ -8,6 +8,7 @@ from modules import paths, shared
extensions = [] extensions = []
extensions_dir = os.path.join(paths.script_path, "extensions") extensions_dir = os.path.join(paths.script_path, "extensions")
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")
def active(): def active():
...@@ -15,12 +16,13 @@ def active(): ...@@ -15,12 +16,13 @@ def active():
class Extension: class Extension:
def __init__(self, name, path, enabled=True): def __init__(self, name, path, enabled=True, is_builtin=False):
self.name = name self.name = name
self.path = path self.path = path
self.enabled = enabled self.enabled = enabled
self.status = '' self.status = ''
self.can_update = False self.can_update = False
self.is_builtin = is_builtin
repo = None repo = None
try: try:
...@@ -79,11 +81,19 @@ def list_extensions(): ...@@ -79,11 +81,19 @@ def list_extensions():
if not os.path.isdir(extensions_dir): if not os.path.isdir(extensions_dir):
return return
for dirname in sorted(os.listdir(extensions_dir)): paths = []
path = os.path.join(extensions_dir, dirname) for dirname in [extensions_dir, extensions_builtin_dir]:
if not os.path.isdir(dirname):
return
for extension_dirname in sorted(os.listdir(dirname)):
path = os.path.join(dirname, extension_dirname)
if not os.path.isdir(path): if not os.path.isdir(path):
continue continue
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
for dirname, path, is_builtin in paths:
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
extensions.append(extension) extensions.append(extension)
...@@ -124,10 +124,9 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None): ...@@ -124,10 +124,9 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
def load_upscalers(): def load_upscalers():
sd = shared.script_path
# We can only do this 'magic' method to dynamically load upscalers if they are referenced, # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
# so we'll try to import any _model.py files before looking in __subclasses__ # so we'll try to import any _model.py files before looking in __subclasses__
modules_dir = os.path.join(sd, "modules") modules_dir = os.path.join(shared.script_path, "modules")
for file in os.listdir(modules_dir): for file in os.listdir(modules_dir):
if "_model.py" in file: if "_model.py" in file:
model_name = file.replace("_model.py", "") model_name = file.replace("_model.py", "")
...@@ -136,22 +135,13 @@ def load_upscalers(): ...@@ -136,22 +135,13 @@ def load_upscalers():
importlib.import_module(full_model) importlib.import_module(full_model)
except: except:
pass pass
datas = [] datas = []
c_o = vars(shared.cmd_opts) commandline_options = vars(shared.cmd_opts)
for cls in Upscaler.__subclasses__(): for cls in Upscaler.__subclasses__():
name = cls.__name__ name = cls.__name__
module_name = cls.__module__
module = importlib.import_module(module_name)
class_ = getattr(module, name)
cmd_name = f"{name.lower().replace('upscaler', '')}_models_path" cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
opt_string = None scaler = cls(commandline_options.get(cmd_name, None))
try: datas += scaler.scalers
if cmd_name in c_o:
opt_string = c_o[cmd_name]
except:
pass
scaler = class_(opt_string)
for child in scaler.scalers:
datas.append(child)
shared.sd_upscalers = datas shared.sd_upscalers = datas
...@@ -50,9 +50,6 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi ...@@ -50,9 +50,6 @@ parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory wi
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN')) parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(models_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN')) parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(models_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN')) parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(models_path, 'RealESRGAN'))
parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(models_path, 'ScuNET'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(models_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None) parser.add_argument("--clip-models-path", type=str, help="Path to directory with CLIP model file(s).", default=None)
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
...@@ -61,7 +58,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en ...@@ -61,7 +58,7 @@ parser.add_argument("--opt-split-attention", action='store_true', help="force-en
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--use-cpu", nargs='+',choices=['all', 'sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer'], help="use CPU as torch device for specified modules", default=[], type=str.lower) parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
...@@ -95,6 +92,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req ...@@ -95,6 +92,7 @@ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, req
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)
script_loading.preload_extensions(extensions.extensions_dir, parser) script_loading.preload_extensions(extensions.extensions_dir, parser)
script_loading.preload_extensions(extensions.extensions_builtin_dir, parser)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
...@@ -112,8 +110,8 @@ restricted_opts = { ...@@ -112,8 +110,8 @@ restricted_opts = {
cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access
devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \
(devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer'])
device = devices.device device = devices.device
weight_load_location = None if cmd_opts.lowram else "cpu" weight_load_location = None if cmd_opts.lowram else "cpu"
...@@ -326,9 +324,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { ...@@ -326,9 +324,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}),
"use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"),
})) }))
......
...@@ -28,7 +28,6 @@ import modules.codeformer_model ...@@ -28,7 +28,6 @@ import modules.codeformer_model
import modules.generation_parameters_copypaste as parameters_copypaste import modules.generation_parameters_copypaste as parameters_copypaste
import modules.gfpgan_model import modules.gfpgan_model
import modules.hypernetworks.ui import modules.hypernetworks.ui
import modules.ldsr_model
import modules.scripts import modules.scripts
import modules.shared as shared import modules.shared as shared
import modules.styles import modules.styles
......
...@@ -78,6 +78,12 @@ def extension_table(): ...@@ -78,6 +78,12 @@ def extension_table():
""" """
for ext in extensions.extensions: for ext in extensions.extensions:
remote = ""
if ext.is_builtin:
remote = "built-in"
elif ext.remote:
remote = f"""<a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape("built-in" if ext.is_builtin else ext.remote or '')}</a>"""
if ext.can_update: if ext.can_update:
ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>""" ext_status = f"""<label><input class="gr-check-radio gr-checkbox" name="update_{html.escape(ext.name)}" checked="checked" type="checkbox">{html.escape(ext.status)}</label>"""
else: else:
...@@ -86,7 +92,7 @@ def extension_table(): ...@@ -86,7 +92,7 @@ def extension_table():
code += f""" code += f"""
<tr> <tr>
<td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td> <td><label><input class="gr-check-radio gr-checkbox" name="enable_{html.escape(ext.name)}" type="checkbox" {'checked="checked"' if ext.enabled else ''}>{html.escape(ext.name)}</label></td>
<td><a href="{html.escape(ext.remote or '')}" target="_blank">{html.escape(ext.remote or '')}</a></td> <td>{remote}</td>
<td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td> <td{' class="extension_status"' if ext.remote is not None else ''}>{ext_status}</td>
</tr> </tr>
""" """
......
...@@ -53,10 +53,11 @@ def initialize(): ...@@ -53,10 +53,11 @@ def initialize():
codeformer.setup_model(cmd_opts.codeformer_models_path) codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers()
modules.scripts.load_scripts() modules.scripts.load_scripts()
modelloader.load_upscalers()
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
modules.sd_models.load_model() modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
...@@ -177,6 +178,8 @@ def webui(): ...@@ -177,6 +178,8 @@ def webui():
print('Reloading custom scripts') print('Reloading custom scripts')
modules.scripts.reload_scripts() modules.scripts.reload_scripts()
modelloader.load_upscalers()
print('Reloading modules: modules.ui') print('Reloading modules: modules.ui')
importlib.reload(modules.ui) importlib.reload(modules.ui)
print('Refreshing Model List') print('Refreshing Model List')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment