Commit f2a5cbe6 authored by AUTOMATIC's avatar AUTOMATIC

fix #3986 breaking --no-half-vae

parent 675b51eb
......@@ -183,11 +183,20 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.to(memory_format=torch.channels_last)
if not shared.cmd_opts.no_half:
vae = model.first_stage_model
# with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
if shared.cmd_opts.no_half_vae:
model.first_stage_model = None
model.half()
model.first_stage_model = vae
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
model.first_stage_model.to(devices.dtype_vae)
if shared.opts.sd_checkpoint_cache > 0:
# if PR #4035 were to get merged, restore base VAE first before caching
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment