Commit 5bfef6e0 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #4844 from R-N/vae-misc

Remove no longer necessary code from VAE selector, fix #4651
parents cdc8020d f1bdf2b1
...@@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): ...@@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
cache_enabled = shared.opts.sd_checkpoint_cache > 0 cache_enabled = shared.opts.sd_checkpoint_cache > 0
if cache_enabled:
sd_vae.restore_base_vae(model)
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
if cache_enabled and checkpoint_info in checkpoints_loaded: if cache_enabled and checkpoint_info in checkpoints_loaded:
# use checkpoint cache # use checkpoint cache
vae_name = sd_vae.get_filename(vae_file) if vae_file else None print(f"Loading weights [{sd_model_hash}] from cache")
vae_message = f" with {vae_name} VAE" if vae_name else ""
print(f"Loading weights [{sd_model_hash}]{vae_message} from cache")
model.load_state_dict(checkpoints_loaded[checkpoint_info]) model.load_state_dict(checkpoints_loaded[checkpoint_info])
else: else:
# load from file # load from file
...@@ -220,6 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): ...@@ -220,6 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.sd_model_checkpoint = checkpoint_file model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info model.sd_checkpoint_info = checkpoint_info
vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file)
sd_vae.load_vae(model, vae_file) sd_vae.load_vae(model, vae_file)
......
...@@ -91,7 +91,7 @@ def get_vae_from_settings(vae_file="auto"): ...@@ -91,7 +91,7 @@ def get_vae_from_settings(vae_file="auto"):
# if VAE selected but not found, fallback to auto # if VAE selected but not found, fallback to auto
if vae_file not in default_vae_values and not os.path.isfile(vae_file): if vae_file not in default_vae_values and not os.path.isfile(vae_file):
vae_file = "auto" vae_file = "auto"
print("Selected VAE doesn't exist") print(f"Selected VAE doesn't exist: {vae_file}")
return vae_file return vae_file
...@@ -101,15 +101,15 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"): ...@@ -101,15 +101,15 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
# if vae_file argument is provided, it takes priority, but not saved # if vae_file argument is provided, it takes priority, but not saved
if vae_file and vae_file not in default_vae_list: if vae_file and vae_file not in default_vae_list:
if not os.path.isfile(vae_file): if not os.path.isfile(vae_file):
print(f"VAE provided as function argument doesn't exist: {vae_file}")
vae_file = "auto" vae_file = "auto"
print("VAE provided as function argument doesn't exist")
# for the first load, if vae-path is provided, it takes priority, saved, and failure is reported # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
if first_load and shared.cmd_opts.vae_path is not None: if first_load and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path): if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path vae_file = shared.cmd_opts.vae_path
shared.opts.data['sd_vae'] = get_filename(vae_file) shared.opts.data['sd_vae'] = get_filename(vae_file)
else: else:
print("VAE provided as command line argument doesn't exist") print(f"VAE provided as command line argument doesn't exist: {vae_file}")
# fallback to selector in settings, if vae selector not set to act as default fallback # fallback to selector in settings, if vae selector not set to act as default fallback
if not shared.opts.sd_vae_as_default: if not shared.opts.sd_vae_as_default:
vae_file = get_vae_from_settings(vae_file) vae_file = get_vae_from_settings(vae_file)
...@@ -117,20 +117,20 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"): ...@@ -117,20 +117,20 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"):
if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
if os.path.isfile(shared.cmd_opts.vae_path): if os.path.isfile(shared.cmd_opts.vae_path):
vae_file = shared.cmd_opts.vae_path vae_file = shared.cmd_opts.vae_path
print("Using VAE provided as command line argument") print(f"Using VAE provided as command line argument: {vae_file}")
# if still not found, try look for ".vae.pt" beside model # if still not found, try look for ".vae.pt" beside model
model_path = os.path.splitext(checkpoint_file)[0] model_path = os.path.splitext(checkpoint_file)[0]
if vae_file == "auto": if vae_file == "auto":
vae_file_try = model_path + ".vae.pt" vae_file_try = model_path + ".vae.pt"
if os.path.isfile(vae_file_try): if os.path.isfile(vae_file_try):
vae_file = vae_file_try vae_file = vae_file_try
print("Using VAE found beside selected model") print(f"Using VAE found similar to selected model: {vae_file}")
# if still not found, try look for ".vae.ckpt" beside model # if still not found, try look for ".vae.ckpt" beside model
if vae_file == "auto": if vae_file == "auto":
vae_file_try = model_path + ".vae.ckpt" vae_file_try = model_path + ".vae.ckpt"
if os.path.isfile(vae_file_try): if os.path.isfile(vae_file_try):
vae_file = vae_file_try vae_file = vae_file_try
print("Using VAE found beside selected model") print(f"Using VAE found similar to selected model: {vae_file}")
# No more fallbacks for auto # No more fallbacks for auto
if vae_file == "auto": if vae_file == "auto":
vae_file = None vae_file = None
...@@ -146,6 +146,7 @@ def load_vae(model, vae_file=None): ...@@ -146,6 +146,7 @@ def load_vae(model, vae_file=None):
# save_settings = False # save_settings = False
if vae_file: if vae_file:
assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
print(f"Loading VAE weights from: {vae_file}") print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
......
...@@ -334,7 +334,7 @@ options_templates.update(options_section(('training', "Training"), { ...@@ -334,7 +334,7 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment