Commit af758e97 authored by Jairo Correa's avatar Jairo Correa

Unload sd_model before loading the other

parent 5c9b3625
...@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
# see below for register_forward_pre_hook; # see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods # useless here, and we just replace those methods
def first_stage_model_encode_wrap(self, encoder, x):
send_me_to_gpu(self, None)
return encoder(x)
def first_stage_model_decode_wrap(self, decoder, z): first_stage_model = sd_model.first_stage_model
send_me_to_gpu(self, None) first_stage_model_encode = sd_model.first_stage_model.encode
return decoder(z) first_stage_model_decode = sd_model.first_stage_model.decode
def first_stage_model_encode_wrap(x):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_encode(x)
def first_stage_model_decode_wrap(z):
send_me_to_gpu(first_stage_model, None)
return first_stage_model_decode(z)
# remove three big modules, cond, first_stage, and unet from the model and then # remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU. # send the model to GPU. Then put modules back. the modules will be in CPU.
...@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram): ...@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
# register hooks for those the first two models # register hooks for those the first two models
sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x) sd_model.first_stage_model.encode = first_stage_model_encode_wrap
sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z) sd_model.first_stage_model.decode = first_stage_model_decode_wrap
parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
if use_medvram: if use_medvram:
......
...@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None: if p.scripts is not None:
p.scripts.postprocess(p, res) p.scripts.postprocess(p, res)
p.sd_model = None
p.sampler = None
return res return res
......
...@@ -94,6 +94,10 @@ class StableDiffusionModelHijack: ...@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
self.layers = None
self.circular_enabled = False
self.clip = None
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
return return
......
import collections import collections
import os.path import os.path
import sys import sys
import gc
from collections import namedtuple from collections import namedtuple
import torch import torch
import re import re
...@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None): ...@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None):
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}") print(f"Loading config from: {checkpoint_info.config}")
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
gc.collect()
devices.torch_gc()
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info): if should_hijack_inpainting(checkpoint_info):
...@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None): ...@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None):
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack() do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info) load_model_weights(sd_model, checkpoint_info)
...@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None): ...@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None):
return sd_model return sd_model
def reload_model_weights(sd_model, info=None): def reload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if not sd_model:
sd_model = shared.sd_model
if sd_model.sd_model_checkpoint == checkpoint_info.filename: if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
del sd_model
checkpoints_loaded.clear() checkpoints_loaded.clear()
load_model(checkpoint_info) load_model(checkpoint_info)
return shared.sd_model return shared.sd_model
......
...@@ -77,7 +77,7 @@ def initialize(): ...@@ -77,7 +77,7 @@ def initialize():
modules.scripts.load_scripts() modules.scripts.load_scripts()
modules.sd_models.load_model() modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment