Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
f01682ee
Commit
f01682ee
authored
Aug 15, 2023
by
AUTOMATIC1111
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
store patches for Lora in a specialized module
parent
7327be97
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
118 additions
and
61 deletions
+118
-61
extensions-builtin/Lora/lora_patches.py
extensions-builtin/Lora/lora_patches.py
+31
-0
extensions-builtin/Lora/networks.py
extensions-builtin/Lora/networks.py
+18
-14
extensions-builtin/Lora/scripts/lora_script.py
extensions-builtin/Lora/scripts/lora_script.py
+5
-47
modules/patches.py
modules/patches.py
+64
-0
No files found.
extensions-builtin/Lora/lora_patches.py
0 → 100644
View file @
f01682ee
import
torch
import
networks
from
modules
import
patches
class
LoraPatches
:
def
__init__
(
self
):
self
.
Linear_forward
=
patches
.
patch
(
__name__
,
torch
.
nn
.
Linear
,
'forward'
,
networks
.
network_Linear_forward
)
self
.
Linear_load_state_dict
=
patches
.
patch
(
__name__
,
torch
.
nn
.
Linear
,
'_load_from_state_dict'
,
networks
.
network_Linear_load_state_dict
)
self
.
Conv2d_forward
=
patches
.
patch
(
__name__
,
torch
.
nn
.
Conv2d
,
'forward'
,
networks
.
network_Conv2d_forward
)
self
.
Conv2d_load_state_dict
=
patches
.
patch
(
__name__
,
torch
.
nn
.
Conv2d
,
'_load_from_state_dict'
,
networks
.
network_Conv2d_load_state_dict
)
self
.
GroupNorm_forward
=
patches
.
patch
(
__name__
,
torch
.
nn
.
GroupNorm
,
'forward'
,
networks
.
network_GroupNorm_forward
)
self
.
GroupNorm_load_state_dict
=
patches
.
patch
(
__name__
,
torch
.
nn
.
GroupNorm
,
'_load_from_state_dict'
,
networks
.
network_GroupNorm_load_state_dict
)
self
.
LayerNorm_forward
=
patches
.
patch
(
__name__
,
torch
.
nn
.
LayerNorm
,
'forward'
,
networks
.
network_LayerNorm_forward
)
self
.
LayerNorm_load_state_dict
=
patches
.
patch
(
__name__
,
torch
.
nn
.
LayerNorm
,
'_load_from_state_dict'
,
networks
.
network_LayerNorm_load_state_dict
)
self
.
MultiheadAttention_forward
=
patches
.
patch
(
__name__
,
torch
.
nn
.
MultiheadAttention
,
'forward'
,
networks
.
network_MultiheadAttention_forward
)
self
.
MultiheadAttention_load_state_dict
=
patches
.
patch
(
__name__
,
torch
.
nn
.
MultiheadAttention
,
'_load_from_state_dict'
,
networks
.
network_MultiheadAttention_load_state_dict
)
def
undo
(
self
):
self
.
Linear_forward
=
patches
.
undo
(
__name__
,
torch
.
nn
.
Linear
,
'forward'
)
self
.
Linear_load_state_dict
=
patches
.
undo
(
__name__
,
torch
.
nn
.
Linear
,
'_load_from_state_dict'
)
self
.
Conv2d_forward
=
patches
.
undo
(
__name__
,
torch
.
nn
.
Conv2d
,
'forward'
)
self
.
Conv2d_load_state_dict
=
patches
.
undo
(
__name__
,
torch
.
nn
.
Conv2d
,
'_load_from_state_dict'
)
self
.
GroupNorm_forward
=
patches
.
undo
(
__name__
,
torch
.
nn
.
GroupNorm
,
'forward'
)
self
.
GroupNorm_load_state_dict
=
patches
.
undo
(
__name__
,
torch
.
nn
.
GroupNorm
,
'_load_from_state_dict'
)
self
.
LayerNorm_forward
=
patches
.
undo
(
__name__
,
torch
.
nn
.
LayerNorm
,
'forward'
)
self
.
LayerNorm_load_state_dict
=
patches
.
undo
(
__name__
,
torch
.
nn
.
LayerNorm
,
'_load_from_state_dict'
)
self
.
MultiheadAttention_forward
=
patches
.
undo
(
__name__
,
torch
.
nn
.
MultiheadAttention
,
'forward'
)
self
.
MultiheadAttention_load_state_dict
=
patches
.
undo
(
__name__
,
torch
.
nn
.
MultiheadAttention
,
'_load_from_state_dict'
)
extensions-builtin/Lora/networks.py
View file @
f01682ee
...
@@ -2,6 +2,7 @@ import logging
...
@@ -2,6 +2,7 @@ import logging
import
os
import
os
import
re
import
re
import
lora_patches
import
network
import
network
import
network_lora
import
network_lora
import
network_hada
import
network_hada
...
@@ -418,74 +419,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
...
@@ -418,74 +419,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
def
network_Linear_forward
(
self
,
input
):
def
network_Linear_forward
(
self
,
input
):
if
shared
.
opts
.
lora_functional
:
if
shared
.
opts
.
lora_functional
:
return
network_forward
(
self
,
input
,
torch
.
nn
.
Linear_forward_before_network
)
return
network_forward
(
self
,
input
,
originals
.
Linear_forward
)
network_apply_weights
(
self
)
network_apply_weights
(
self
)
return
torch
.
nn
.
Linear_forward_before_network
(
self
,
input
)
return
originals
.
Linear_forward
(
self
,
input
)
def
network_Linear_load_state_dict
(
self
,
*
args
,
**
kwargs
):
def
network_Linear_load_state_dict
(
self
,
*
args
,
**
kwargs
):
network_reset_cached_weight
(
self
)
network_reset_cached_weight
(
self
)
return
torch
.
nn
.
Linear_load_state_dict_before_network
(
self
,
*
args
,
**
kwargs
)
return
originals
.
Linear_load_state_dict
(
self
,
*
args
,
**
kwargs
)
def
network_Conv2d_forward
(
self
,
input
):
def
network_Conv2d_forward
(
self
,
input
):
if
shared
.
opts
.
lora_functional
:
if
shared
.
opts
.
lora_functional
:
return
network_forward
(
self
,
input
,
torch
.
nn
.
Conv2d_forward_before_network
)
return
network_forward
(
self
,
input
,
originals
.
Conv2d_forward
)
network_apply_weights
(
self
)
network_apply_weights
(
self
)
return
torch
.
nn
.
Conv2d_forward_before_network
(
self
,
input
)
return
originals
.
Conv2d_forward
(
self
,
input
)
def
network_Conv2d_load_state_dict
(
self
,
*
args
,
**
kwargs
):
def
network_Conv2d_load_state_dict
(
self
,
*
args
,
**
kwargs
):
network_reset_cached_weight
(
self
)
network_reset_cached_weight
(
self
)
return
torch
.
nn
.
Conv2d_load_state_dict_before_network
(
self
,
*
args
,
**
kwargs
)
return
originals
.
Conv2d_load_state_dict
(
self
,
*
args
,
**
kwargs
)
def
network_GroupNorm_forward
(
self
,
input
):
def
network_GroupNorm_forward
(
self
,
input
):
if
shared
.
opts
.
lora_functional
:
if
shared
.
opts
.
lora_functional
:
return
network_forward
(
self
,
input
,
torch
.
nn
.
GroupNorm_forward_before_network
)
return
network_forward
(
self
,
input
,
originals
.
GroupNorm_forward
)
network_apply_weights
(
self
)
network_apply_weights
(
self
)
return
torch
.
nn
.
GroupNorm_forward_before_network
(
self
,
input
)
return
originals
.
GroupNorm_forward
(
self
,
input
)
def
network_GroupNorm_load_state_dict
(
self
,
*
args
,
**
kwargs
):
def
network_GroupNorm_load_state_dict
(
self
,
*
args
,
**
kwargs
):
network_reset_cached_weight
(
self
)
network_reset_cached_weight
(
self
)
return
torch
.
nn
.
GroupNorm_load_state_dict_before_network
(
self
,
*
args
,
**
kwargs
)
return
originals
.
GroupNorm_load_state_dict
(
self
,
*
args
,
**
kwargs
)
def
network_LayerNorm_forward
(
self
,
input
):
def
network_LayerNorm_forward
(
self
,
input
):
if
shared
.
opts
.
lora_functional
:
if
shared
.
opts
.
lora_functional
:
return
network_forward
(
self
,
input
,
torch
.
nn
.
LayerNorm_forward_before_network
)
return
network_forward
(
self
,
input
,
originals
.
LayerNorm_forward
)
network_apply_weights
(
self
)
network_apply_weights
(
self
)
return
torch
.
nn
.
LayerNorm_forward_before_network
(
self
,
input
)
return
originals
.
LayerNorm_forward
(
self
,
input
)
def
network_LayerNorm_load_state_dict
(
self
,
*
args
,
**
kwargs
):
def
network_LayerNorm_load_state_dict
(
self
,
*
args
,
**
kwargs
):
network_reset_cached_weight
(
self
)
network_reset_cached_weight
(
self
)
return
torch
.
nn
.
LayerNorm_load_state_dict_before_network
(
self
,
*
args
,
**
kwargs
)
return
originals
.
LayerNorm_load_state_dict
(
self
,
*
args
,
**
kwargs
)
def
network_MultiheadAttention_forward
(
self
,
*
args
,
**
kwargs
):
def
network_MultiheadAttention_forward
(
self
,
*
args
,
**
kwargs
):
network_apply_weights
(
self
)
network_apply_weights
(
self
)
return
torch
.
nn
.
MultiheadAttention_forward_before_network
(
self
,
*
args
,
**
kwargs
)
return
originals
.
MultiheadAttention_forward
(
self
,
*
args
,
**
kwargs
)
def
network_MultiheadAttention_load_state_dict
(
self
,
*
args
,
**
kwargs
):
def
network_MultiheadAttention_load_state_dict
(
self
,
*
args
,
**
kwargs
):
network_reset_cached_weight
(
self
)
network_reset_cached_weight
(
self
)
return
torch
.
nn
.
MultiheadAttention_load_state_dict_before_network
(
self
,
*
args
,
**
kwargs
)
return
originals
.
MultiheadAttention_load_state_dict
(
self
,
*
args
,
**
kwargs
)
def
list_available_networks
():
def
list_available_networks
():
...
@@ -552,6 +553,9 @@ def infotext_pasted(infotext, params):
...
@@ -552,6 +553,9 @@ def infotext_pasted(infotext, params):
if
added
:
if
added
:
params
[
"Prompt"
]
+=
"
\n
"
+
""
.
join
(
added
)
params
[
"Prompt"
]
+=
"
\n
"
+
""
.
join
(
added
)
originals
:
lora_patches
.
LoraPatches
=
None
extra_network_lora
=
None
extra_network_lora
=
None
available_networks
=
{}
available_networks
=
{}
...
...
extensions-builtin/Lora/scripts/lora_script.py
View file @
f01682ee
...
@@ -7,17 +7,14 @@ from fastapi import FastAPI
...
@@ -7,17 +7,14 @@ from fastapi import FastAPI
import
network
import
network
import
networks
import
networks
import
lora
# noqa:F401
import
lora
# noqa:F401
import
lora_patches
import
extra_networks_lora
import
extra_networks_lora
import
ui_extra_networks_lora
import
ui_extra_networks_lora
from
modules
import
script_callbacks
,
ui_extra_networks
,
extra_networks
,
shared
from
modules
import
script_callbacks
,
ui_extra_networks
,
extra_networks
,
shared
,
patches
def
unload
():
def
unload
():
torch
.
nn
.
Linear
.
forward
=
torch
.
nn
.
Linear_forward_before_network
networks
.
originals
.
undo
()
torch
.
nn
.
Linear
.
_load_from_state_dict
=
torch
.
nn
.
Linear_load_state_dict_before_network
torch
.
nn
.
Conv2d
.
forward
=
torch
.
nn
.
Conv2d_forward_before_network
torch
.
nn
.
Conv2d
.
_load_from_state_dict
=
torch
.
nn
.
Conv2d_load_state_dict_before_network
torch
.
nn
.
MultiheadAttention
.
forward
=
torch
.
nn
.
MultiheadAttention_forward_before_network
torch
.
nn
.
MultiheadAttention
.
_load_from_state_dict
=
torch
.
nn
.
MultiheadAttention_load_state_dict_before_network
def
before_ui
():
def
before_ui
():
...
@@ -28,46 +25,7 @@ def before_ui():
...
@@ -28,46 +25,7 @@ def before_ui():
extra_networks
.
register_extra_network_alias
(
networks
.
extra_network_lora
,
"lyco"
)
extra_networks
.
register_extra_network_alias
(
networks
.
extra_network_lora
,
"lyco"
)
if
not
hasattr
(
torch
.
nn
,
'Linear_forward_before_network'
):
networks
.
originals
=
lora_patches
.
LoraPatches
()
torch
.
nn
.
Linear_forward_before_network
=
torch
.
nn
.
Linear
.
forward
if
not
hasattr
(
torch
.
nn
,
'Linear_load_state_dict_before_network'
):
torch
.
nn
.
Linear_load_state_dict_before_network
=
torch
.
nn
.
Linear
.
_load_from_state_dict
if
not
hasattr
(
torch
.
nn
,
'Conv2d_forward_before_network'
):
torch
.
nn
.
Conv2d_forward_before_network
=
torch
.
nn
.
Conv2d
.
forward
if
not
hasattr
(
torch
.
nn
,
'Conv2d_load_state_dict_before_network'
):
torch
.
nn
.
Conv2d_load_state_dict_before_network
=
torch
.
nn
.
Conv2d
.
_load_from_state_dict
if
not
hasattr
(
torch
.
nn
,
'GroupNorm_forward_before_network'
):
torch
.
nn
.
GroupNorm_forward_before_network
=
torch
.
nn
.
GroupNorm
.
forward
if
not
hasattr
(
torch
.
nn
,
'GroupNorm_load_state_dict_before_network'
):
torch
.
nn
.
GroupNorm_load_state_dict_before_network
=
torch
.
nn
.
GroupNorm
.
_load_from_state_dict
if
not
hasattr
(
torch
.
nn
,
'LayerNorm_forward_before_network'
):
torch
.
nn
.
LayerNorm_forward_before_network
=
torch
.
nn
.
LayerNorm
.
forward
if
not
hasattr
(
torch
.
nn
,
'LayerNorm_load_state_dict_before_network'
):
torch
.
nn
.
LayerNorm_load_state_dict_before_network
=
torch
.
nn
.
LayerNorm
.
_load_from_state_dict
if
not
hasattr
(
torch
.
nn
,
'MultiheadAttention_forward_before_network'
):
torch
.
nn
.
MultiheadAttention_forward_before_network
=
torch
.
nn
.
MultiheadAttention
.
forward
if
not
hasattr
(
torch
.
nn
,
'MultiheadAttention_load_state_dict_before_network'
):
torch
.
nn
.
MultiheadAttention_load_state_dict_before_network
=
torch
.
nn
.
MultiheadAttention
.
_load_from_state_dict
torch
.
nn
.
Linear
.
forward
=
networks
.
network_Linear_forward
torch
.
nn
.
Linear
.
_load_from_state_dict
=
networks
.
network_Linear_load_state_dict
torch
.
nn
.
Conv2d
.
forward
=
networks
.
network_Conv2d_forward
torch
.
nn
.
Conv2d
.
_load_from_state_dict
=
networks
.
network_Conv2d_load_state_dict
torch
.
nn
.
GroupNorm
.
forward
=
networks
.
network_GroupNorm_forward
torch
.
nn
.
GroupNorm
.
_load_from_state_dict
=
networks
.
network_GroupNorm_load_state_dict
torch
.
nn
.
LayerNorm
.
forward
=
networks
.
network_LayerNorm_forward
torch
.
nn
.
LayerNorm
.
_load_from_state_dict
=
networks
.
network_LayerNorm_load_state_dict
torch
.
nn
.
MultiheadAttention
.
forward
=
networks
.
network_MultiheadAttention_forward
torch
.
nn
.
MultiheadAttention
.
_load_from_state_dict
=
networks
.
network_MultiheadAttention_load_state_dict
script_callbacks
.
on_model_loaded
(
networks
.
assign_network_names_to_compvis_modules
)
script_callbacks
.
on_model_loaded
(
networks
.
assign_network_names_to_compvis_modules
)
script_callbacks
.
on_script_unloaded
(
unload
)
script_callbacks
.
on_script_unloaded
(
unload
)
...
...
modules/patches.py
0 → 100644
View file @
f01682ee
from
collections
import
defaultdict
def
patch
(
key
,
obj
,
field
,
replacement
):
"""Replaces a function in a module or a class.
Also stores the original function in this module, possible to be retrieved via original(key, obj, field).
If the function is already replaced by this caller (key), an exception is raised -- use undo() before that.
Arguments:
key: identifying information for who is doing the replacement. You can use __name__.
obj: the module or the class
field: name of the function as a string
replacement: the new function
Returns:
the original function
"""
patch_key
=
(
obj
,
field
)
if
patch_key
in
originals
[
key
]:
raise
RuntimeError
(
f
"patch for {field} is already applied"
)
original_func
=
getattr
(
obj
,
field
)
originals
[
key
][
patch_key
]
=
original_func
setattr
(
obj
,
field
,
replacement
)
return
original_func
def
undo
(
key
,
obj
,
field
):
"""Undoes the peplacement by the patch().
If the function is not replaced, raises an exception.
Arguments:
key: identifying information for who is doing the replacement. You can use __name__.
obj: the module or the class
field: name of the function as a string
Returns:
Always None
"""
patch_key
=
(
obj
,
field
)
if
patch_key
not
in
originals
[
key
]:
raise
RuntimeError
(
f
"there is no patch for {field} to undo"
)
original_func
=
originals
[
key
]
.
pop
(
patch_key
)
setattr
(
obj
,
field
,
original_func
)
return
None
def
original
(
key
,
obj
,
field
):
"""Returns the original function for the patch created by the patch() function"""
patch_key
=
(
obj
,
field
)
return
originals
[
key
]
.
get
(
patch_key
,
None
)
originals
=
defaultdict
(
dict
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment