Commit 53154ba1 authored by victorca25's avatar victorca25 Committed by GitHub

Merge branch 'master' into esrgan_mod

parents ad4de819 9d1138e2
...@@ -49,15 +49,18 @@ def list_hypernetworks(path): ...@@ -49,15 +49,18 @@ def list_hypernetworks(path):
def load_hypernetwork(filename): def load_hypernetwork(filename):
print(f"Loading hypernetwork {filename}")
path = shared.hypernetworks.get(filename, None) path = shared.hypernetworks.get(filename, None)
if (path is not None): if path is not None:
print(f"Loading hypernetwork {filename}")
try: try:
shared.loaded_hypernetwork = Hypernetwork(path) shared.loaded_hypernetwork = Hypernetwork(path)
except Exception: except Exception:
print(f"Error loading hypernetwork {path}", file=sys.stderr) print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
else: else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
shared.loaded_hypernetwork = None shared.loaded_hypernetwork = None
......
...@@ -284,6 +284,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -284,6 +284,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
......
...@@ -5,7 +5,6 @@ from collections import namedtuple ...@@ -5,7 +5,6 @@ from collections import namedtuple
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices from modules import shared, modelloader, devices
......
...@@ -242,6 +242,7 @@ options_templates.update(options_section(('ui', "User interface"), { ...@@ -242,6 +242,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
......
...@@ -37,7 +37,7 @@ class Upscaler: ...@@ -37,7 +37,7 @@ class Upscaler:
self.pre_pad = 0 self.pre_pad = 0
self.mod_scale = None self.mod_scale = None
if self.model_path is not None and self.name: if self.model_path is None and self.name:
self.model_path = os.path.join(models_path, self.name) self.model_path = os.path.join(models_path, self.name)
if self.model_path and create_dirs: if self.model_path and create_dirs:
os.makedirs(self.model_path, exist_ok=True) os.makedirs(self.model_path, exist_ok=True)
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images from modules import images, hypernetwork
from modules.processing import process_images, Processed, get_correct_sampler from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
...@@ -80,8 +80,7 @@ def apply_checkpoint(p, x, xs): ...@@ -80,8 +80,7 @@ def apply_checkpoint(p, x, xs):
def apply_hypernetwork(p, x, xs): def apply_hypernetwork(p, x, xs):
hn = shared.hypernetworks.get(x, None) hypernetwork.load_hypernetwork(x)
opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None'
def format_value_add_label(p, opt, x): def format_value_add_label(p, opt, x):
...@@ -203,8 +202,6 @@ class Script(scripts.Script): ...@@ -203,8 +202,6 @@ class Script(scripts.Script):
p.batch_size = 1 p.batch_size = 1
initial_hn = opts.sd_hypernetwork
def process_axis(opt, vals): def process_axis(opt, vals):
if opt.label == 'Nothing': if opt.label == 'Nothing':
return [0] return [0]
...@@ -262,6 +259,7 @@ class Script(scripts.Script): ...@@ -262,6 +259,7 @@ class Script(scripts.Script):
# Confirm options are valid before starting # Confirm options are valid before starting
if opt.label == "Sampler": if opt.label == "Sampler":
samplers_dict = build_samplers_dict(p)
for sampler_val in valslist: for sampler_val in valslist:
if sampler_val.lower() not in samplers_dict.keys(): if sampler_val.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {sampler_val}") raise RuntimeError(f"Unknown sampler: {sampler_val}")
...@@ -321,6 +319,6 @@ class Script(scripts.Script): ...@@ -321,6 +319,6 @@ class Script(scripts.Script):
# restore checkpoint in case it was changed by axes # restore checkpoint in case it was changed by axes
modules.sd_models.reload_model_weights(shared.sd_model) modules.sd_models.reload_model_weights(shared.sd_model)
opts.data["sd_hypernetwork"] = initial_hn hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
return processed return processed
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment