Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
af758e97
Commit
af758e97
authored
Nov 01, 2022
by
Jairo Correa
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Unload sd_model before loading the other
parent
5c9b3625
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
34 additions
and
10 deletions
+34
-10
modules/lowvram.py
modules/lowvram.py
+13
-8
modules/processing.py
modules/processing.py
+3
-0
modules/sd_hijack.py
modules/sd_hijack.py
+4
-0
modules/sd_models.py
modules/sd_models.py
+13
-1
webui.py
webui.py
+1
-1
No files found.
modules/lowvram.py
View file @
af758e97
...
...
@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
def
first_stage_model_encode_wrap
(
self
,
encoder
,
x
):
send_me_to_gpu
(
self
,
None
)
return
encoder
(
x
)
def
first_stage_model_decode_wrap
(
self
,
decoder
,
z
):
send_me_to_gpu
(
self
,
None
)
return
decoder
(
z
)
first_stage_model
=
sd_model
.
first_stage_model
first_stage_model_encode
=
sd_model
.
first_stage_model
.
encode
first_stage_model_decode
=
sd_model
.
first_stage_model
.
decode
def
first_stage_model_encode_wrap
(
x
):
send_me_to_gpu
(
first_stage_model
,
None
)
return
first_stage_model_encode
(
x
)
def
first_stage_model_decode_wrap
(
z
):
send_me_to_gpu
(
first_stage_model
,
None
)
return
first_stage_model_decode
(
z
)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
...
...
@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
# register hooks for those the first two models
sd_model
.
cond_stage_model
.
transformer
.
register_forward_pre_hook
(
send_me_to_gpu
)
sd_model
.
first_stage_model
.
register_forward_pre_hook
(
send_me_to_gpu
)
sd_model
.
first_stage_model
.
encode
=
lambda
x
,
en
=
sd_model
.
first_stage_model
.
encode
:
first_stage_model_encode_wrap
(
sd_model
.
first_stage_model
,
en
,
x
)
sd_model
.
first_stage_model
.
decode
=
lambda
z
,
de
=
sd_model
.
first_stage_model
.
decode
:
first_stage_model_decode_wrap
(
sd_model
.
first_stage_model
,
de
,
z
)
sd_model
.
first_stage_model
.
encode
=
first_stage_model_encode_wrap
sd_model
.
first_stage_model
.
decode
=
first_stage_model_decode_wrap
parents
[
sd_model
.
cond_stage_model
.
transformer
]
=
sd_model
.
cond_stage_model
if
use_medvram
:
...
...
modules/processing.py
View file @
af758e97
...
...
@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if
p
.
scripts
is
not
None
:
p
.
scripts
.
postprocess
(
p
,
res
)
p
.
sd_model
=
None
p
.
sampler
=
None
return
res
...
...
modules/sd_hijack.py
View file @
af758e97
...
...
@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
if
type
(
model_embeddings
.
token_embedding
)
==
EmbeddingsWithFixes
:
model_embeddings
.
token_embedding
=
model_embeddings
.
token_embedding
.
wrapped
self
.
layers
=
None
self
.
circular_enabled
=
False
self
.
clip
=
None
def
apply_circular
(
self
,
enable
):
if
self
.
circular_enabled
==
enable
:
return
...
...
modules/sd_models.py
View file @
af758e97
import
collections
import
os.path
import
sys
import
gc
from
collections
import
namedtuple
import
torch
import
re
...
...
@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None):
if
checkpoint_info
.
config
!=
shared
.
cmd_opts
.
config
:
print
(
f
"Loading config from: {checkpoint_info.config}"
)
if
shared
.
sd_model
:
sd_hijack
.
model_hijack
.
undo_hijack
(
shared
.
sd_model
)
shared
.
sd_model
=
None
gc
.
collect
()
devices
.
torch_gc
()
sd_config
=
OmegaConf
.
load
(
checkpoint_info
.
config
)
if
should_hijack_inpainting
(
checkpoint_info
):
...
...
@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None):
checkpoint_info
=
checkpoint_info
.
_replace
(
config
=
checkpoint_info
.
config
.
replace
(
".yaml"
,
"-inpainting.yaml"
))
do_inpainting_hijack
()
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
load_model_weights
(
sd_model
,
checkpoint_info
)
...
...
@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None):
return
sd_model
def
reload_model_weights
(
sd_model
,
info
=
None
):
def
reload_model_weights
(
sd_model
=
None
,
info
=
None
):
from
modules
import
lowvram
,
devices
,
sd_hijack
checkpoint_info
=
info
or
select_checkpoint
()
if
not
sd_model
:
sd_model
=
shared
.
sd_model
if
sd_model
.
sd_model_checkpoint
==
checkpoint_info
.
filename
:
return
if
sd_model
.
sd_checkpoint_info
.
config
!=
checkpoint_info
.
config
or
should_hijack_inpainting
(
checkpoint_info
)
!=
should_hijack_inpainting
(
sd_model
.
sd_checkpoint_info
):
del
sd_model
checkpoints_loaded
.
clear
()
load_model
(
checkpoint_info
)
return
shared
.
sd_model
...
...
webui.py
View file @
af758e97
...
...
@@ -77,7 +77,7 @@ def initialize():
modules
.
scripts
.
load_scripts
()
modules
.
sd_models
.
load_model
()
shared
.
opts
.
onchange
(
"sd_model_checkpoint"
,
wrap_queued_call
(
lambda
:
modules
.
sd_models
.
reload_model_weights
(
shared
.
sd_model
)))
shared
.
opts
.
onchange
(
"sd_model_checkpoint"
,
wrap_queued_call
(
lambda
:
modules
.
sd_models
.
reload_model_weights
()))
shared
.
opts
.
onchange
(
"sd_hypernetwork"
,
wrap_queued_call
(
lambda
:
modules
.
hypernetworks
.
hypernetwork
.
load_hypernetwork
(
shared
.
opts
.
sd_hypernetwork
)))
shared
.
opts
.
onchange
(
"sd_hypernetwork_strength"
,
modules
.
hypernetworks
.
hypernetwork
.
apply_strength
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment