Commit 94a35ca9 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #5191 from aliencaocao/enable_checkpoint_switching_in_override_settings

Support changing checkpoint and vae through override_settings
parents 713c48dd 9a8678f6
...@@ -20,6 +20,8 @@ import modules.shared as shared ...@@ -20,6 +20,8 @@ import modules.shared as shared
import modules.face_restoration import modules.face_restoration
import modules.images as images import modules.images as images
import modules.styles import modules.styles
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
import logging import logging
from ldm.data.util import AddMiDaS from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
...@@ -454,8 +456,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -454,8 +456,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
try: try:
for k, v in p.override_settings.items(): for k, v in p.override_settings.items():
setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model impossible setattr(opts, k, v)
if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not if k == 'sd_hypernetwork': shared.reload_hypernetworks() # make onchange call for changing hypernet
if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model
if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE
res = process_images_inner(p) res = process_images_inner(p)
...@@ -463,6 +467,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -463,6 +467,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
for k, v in stored_opts.items(): for k, v in stored_opts.items():
setattr(opts, k, v) setattr(opts, k, v)
if k == 'sd_hypernetwork': shared.reload_hypernetworks() if k == 'sd_hypernetwork': shared.reload_hypernetworks()
if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
if k == 'sd_vae': sd_vae.reload_vae_weights()
return res return res
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment