Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
f126986b
Commit
f126986b
authored
Nov 01, 2022
by
AUTOMATIC1111
Committed by
GitHub
Nov 01, 2022
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #4098 from jn-jairo/load-model
Unload sd_model before loading the other to solve the issue #3449
parents
08744040
af758e97
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
34 additions
and
10 deletions
+34
-10
modules/lowvram.py
modules/lowvram.py
+13
-8
modules/processing.py
modules/processing.py
+3
-0
modules/sd_hijack.py
modules/sd_hijack.py
+4
-0
modules/sd_models.py
modules/sd_models.py
+13
-1
webui.py
webui.py
+1
-1
No files found.
modules/lowvram.py
View file @
f126986b
...
...
@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
# see below for register_forward_pre_hook;
# first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
# useless here, and we just replace those methods
def
first_stage_model_encode_wrap
(
self
,
encoder
,
x
):
send_me_to_gpu
(
self
,
None
)
return
encoder
(
x
)
def
first_stage_model_decode_wrap
(
self
,
decoder
,
z
):
send_me_to_gpu
(
self
,
None
)
return
decoder
(
z
)
first_stage_model
=
sd_model
.
first_stage_model
first_stage_model_encode
=
sd_model
.
first_stage_model
.
encode
first_stage_model_decode
=
sd_model
.
first_stage_model
.
decode
def
first_stage_model_encode_wrap
(
x
):
send_me_to_gpu
(
first_stage_model
,
None
)
return
first_stage_model_encode
(
x
)
def
first_stage_model_decode_wrap
(
z
):
send_me_to_gpu
(
first_stage_model
,
None
)
return
first_stage_model_decode
(
z
)
# remove three big modules, cond, first_stage, and unet from the model and then
# send the model to GPU. Then put modules back. the modules will be in CPU.
...
...
@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
# register hooks for those the first two models
sd_model
.
cond_stage_model
.
transformer
.
register_forward_pre_hook
(
send_me_to_gpu
)
sd_model
.
first_stage_model
.
register_forward_pre_hook
(
send_me_to_gpu
)
sd_model
.
first_stage_model
.
encode
=
lambda
x
,
en
=
sd_model
.
first_stage_model
.
encode
:
first_stage_model_encode_wrap
(
sd_model
.
first_stage_model
,
en
,
x
)
sd_model
.
first_stage_model
.
decode
=
lambda
z
,
de
=
sd_model
.
first_stage_model
.
decode
:
first_stage_model_decode_wrap
(
sd_model
.
first_stage_model
,
de
,
z
)
sd_model
.
first_stage_model
.
encode
=
first_stage_model_encode_wrap
sd_model
.
first_stage_model
.
decode
=
first_stage_model_decode_wrap
parents
[
sd_model
.
cond_stage_model
.
transformer
]
=
sd_model
.
cond_stage_model
if
use_medvram
:
...
...
modules/processing.py
View file @
f126986b
...
...
@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if
p
.
scripts
is
not
None
:
p
.
scripts
.
postprocess
(
p
,
res
)
p
.
sd_model
=
None
p
.
sampler
=
None
return
res
...
...
modules/sd_hijack.py
View file @
f126986b
...
...
@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
if
type
(
model_embeddings
.
token_embedding
)
==
EmbeddingsWithFixes
:
model_embeddings
.
token_embedding
=
model_embeddings
.
token_embedding
.
wrapped
self
.
layers
=
None
self
.
circular_enabled
=
False
self
.
clip
=
None
def
apply_circular
(
self
,
enable
):
if
self
.
circular_enabled
==
enable
:
return
...
...
modules/sd_models.py
View file @
f126986b
import
collections
import
os.path
import
sys
import
gc
from
collections
import
namedtuple
import
torch
import
re
...
...
@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None):
if
checkpoint_info
.
config
!=
shared
.
cmd_opts
.
config
:
print
(
f
"Loading config from: {checkpoint_info.config}"
)
if
shared
.
sd_model
:
sd_hijack
.
model_hijack
.
undo_hijack
(
shared
.
sd_model
)
shared
.
sd_model
=
None
gc
.
collect
()
devices
.
torch_gc
()
sd_config
=
OmegaConf
.
load
(
checkpoint_info
.
config
)
if
should_hijack_inpainting
(
checkpoint_info
):
...
...
@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None):
checkpoint_info
=
checkpoint_info
.
_replace
(
config
=
checkpoint_info
.
config
.
replace
(
".yaml"
,
"-inpainting.yaml"
))
do_inpainting_hijack
()
sd_model
=
instantiate_from_config
(
sd_config
.
model
)
load_model_weights
(
sd_model
,
checkpoint_info
)
...
...
@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None):
return
sd_model
def
reload_model_weights
(
sd_model
,
info
=
None
):
def
reload_model_weights
(
sd_model
=
None
,
info
=
None
):
from
modules
import
lowvram
,
devices
,
sd_hijack
checkpoint_info
=
info
or
select_checkpoint
()
if
not
sd_model
:
sd_model
=
shared
.
sd_model
if
sd_model
.
sd_model_checkpoint
==
checkpoint_info
.
filename
:
return
if
sd_model
.
sd_checkpoint_info
.
config
!=
checkpoint_info
.
config
or
should_hijack_inpainting
(
checkpoint_info
)
!=
should_hijack_inpainting
(
sd_model
.
sd_checkpoint_info
):
del
sd_model
checkpoints_loaded
.
clear
()
load_model
(
checkpoint_info
)
return
shared
.
sd_model
...
...
webui.py
View file @
f126986b
...
...
@@ -78,7 +78,7 @@ def initialize():
modules
.
scripts
.
load_scripts
()
modules
.
sd_models
.
load_model
()
shared
.
opts
.
onchange
(
"sd_model_checkpoint"
,
wrap_queued_call
(
lambda
:
modules
.
sd_models
.
reload_model_weights
(
shared
.
sd_model
)))
shared
.
opts
.
onchange
(
"sd_model_checkpoint"
,
wrap_queued_call
(
lambda
:
modules
.
sd_models
.
reload_model_weights
()))
shared
.
opts
.
onchange
(
"sd_hypernetwork"
,
wrap_queued_call
(
lambda
:
modules
.
hypernetworks
.
hypernetwork
.
load_hypernetwork
(
shared
.
opts
.
sd_hypernetwork
)))
shared
.
opts
.
onchange
(
"sd_hypernetwork_strength"
,
modules
.
hypernetworks
.
hypernetwork
.
apply_strength
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment