Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
3cd4fd51
Commit
3cd4fd51
authored
Jun 27, 2023
by
AUTOMATIC1111
Committed by
GitHub
Jun 27, 2023
Browse files
Options
Browse Files
Download
Plain Diff
Merge pull request #10823 from akx/model-loady
Upscaler model loading cleanup
parents
d4f9250c
2667f47f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
101 additions
and
91 deletions
+101
-91
extensions-builtin/LDSR/scripts/ldsr_model.py
extensions-builtin/LDSR/scripts/ldsr_model.py
+8
-12
extensions-builtin/ScuNET/scripts/scunet_model.py
extensions-builtin/ScuNET/scripts/scunet_model.py
+11
-15
extensions-builtin/SwinIR/scripts/swinir_model.py
extensions-builtin/SwinIR/scripts/swinir_model.py
+29
-28
modules/esrgan_model.py
modules/esrgan_model.py
+10
-13
modules/gfpgan_model.py
modules/gfpgan_model.py
+1
-1
modules/modelloader.py
modules/modelloader.py
+27
-4
modules/realesrgan_model.py
modules/realesrgan_model.py
+15
-18
No files found.
extensions-builtin/LDSR/scripts/ldsr_model.py
View file @
3cd4fd51
import
os
from
basicsr.utils.download_util
import
load_file_from_url
from
modules.modelloader
import
load_file_from_url
from
modules.upscaler
import
Upscaler
,
UpscalerData
from
ldsr_model_arch
import
LDSR
from
modules
import
shared
,
script_callbacks
,
errors
...
...
@@ -43,20 +42,17 @@ class UpscalerLDSR(Upscaler):
if
local_safetensors_path
is
not
None
and
os
.
path
.
exists
(
local_safetensors_path
):
model
=
local_safetensors_path
else
:
model
=
local_ckpt_path
if
local_ckpt_path
is
not
None
else
load_file_from_url
(
url
=
self
.
model_url
,
model_dir
=
self
.
model_download_path
,
file_name
=
"model.ckpt"
,
progress
=
True
)
model
=
local_ckpt_path
or
load_file_from_url
(
self
.
model_url
,
model_dir
=
self
.
model_download_path
,
file_name
=
"model.ckpt"
)
yaml
=
local_yaml_path
if
local_yaml_path
is
not
None
else
load_file_from_url
(
url
=
self
.
yaml_url
,
model_dir
=
self
.
model_download_path
,
file_name
=
"project.yaml"
,
progress
=
True
)
yaml
=
local_yaml_path
or
load_file_from_url
(
self
.
yaml_url
,
model_dir
=
self
.
model_download_path
,
file_name
=
"project.yaml"
)
try
:
return
LDSR
(
model
,
yaml
)
except
Exception
:
errors
.
report
(
"Error importing LDSR"
,
exc_info
=
True
)
return
None
return
LDSR
(
model
,
yaml
)
def
do_upscale
(
self
,
img
,
path
):
ldsr
=
self
.
load_model
(
path
)
if
ldsr
is
None
:
print
(
"NO LDSR!"
)
try
:
ldsr
=
self
.
load_model
(
path
)
except
Exception
:
errors
.
report
(
f
"Failed loading LDSR model {path}"
,
exc_info
=
True
)
return
img
ddim_steps
=
shared
.
opts
.
ldsr_steps
return
ldsr
.
super_resolution
(
img
,
ddim_steps
,
self
.
scale
)
...
...
extensions-builtin/ScuNET/scripts/scunet_model.py
View file @
3cd4fd51
import
os.path
import
sys
import
PIL.Image
...
...
@@ -6,12 +5,11 @@ import numpy as np
import
torch
from
tqdm
import
tqdm
from
basicsr.utils.download_util
import
load_file_from_url
import
modules.upscaler
from
modules
import
devices
,
modelloader
,
script_callbacks
,
errors
from
scunet_model_arch
import
SCUNet
as
net
from
scunet_model_arch
import
SCUNet
from
modules.modelloader
import
load_file_from_url
from
modules.shared
import
opts
...
...
@@ -28,7 +26,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
scalers
=
[]
add_model2
=
True
for
file
in
model_paths
:
if
"http"
in
file
:
if
file
.
startswith
(
"http"
)
:
name
=
self
.
model_name
else
:
name
=
modelloader
.
friendly_name
(
file
)
...
...
@@ -89,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
torch
.
cuda
.
empty_cache
()
model
=
self
.
load_model
(
selected_file
)
if
model
is
None
:
print
(
f
"ScuNET: Unable to load model from {selected_file}"
,
file
=
sys
.
stderr
)
try
:
model
=
self
.
load_model
(
selected_file
)
except
Exception
as
e
:
print
(
f
"ScuNET: Unable to load model from {selected_file}: {e}"
,
file
=
sys
.
stderr
)
return
img
device
=
devices
.
get_device_for
(
'scunet'
)
...
...
@@ -119,15 +118,12 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
def
load_model
(
self
,
path
:
str
):
device
=
devices
.
get_device_for
(
'scunet'
)
if
"http"
in
path
:
filename
=
load_file_from_url
(
url
=
self
.
model_url
,
model_dir
=
self
.
model_download_path
,
file_name
=
"
%
s.pth"
%
self
.
name
,
progress
=
True
)
if
path
.
startswith
(
"http"
):
# TODO: this doesn't use `path` at all?
filename
=
load_file_from_url
(
self
.
model_url
,
model_dir
=
self
.
model_download_path
,
file_name
=
f
"{self.name}.pth"
)
else
:
filename
=
path
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
self
.
model_path
,
filename
))
or
filename
is
None
:
print
(
f
"ScuNET: Unable to load model from {filename}"
,
file
=
sys
.
stderr
)
return
None
model
=
net
(
in_nc
=
3
,
config
=
[
4
,
4
,
4
,
4
,
4
,
4
,
4
],
dim
=
64
)
model
=
SCUNet
(
in_nc
=
3
,
config
=
[
4
,
4
,
4
,
4
,
4
,
4
,
4
],
dim
=
64
)
model
.
load_state_dict
(
torch
.
load
(
filename
),
strict
=
True
)
model
.
eval
()
for
_
,
v
in
model
.
named_parameters
():
...
...
extensions-builtin/SwinIR/scripts/swinir_model.py
View file @
3cd4fd51
import
o
s
import
sy
s
import
numpy
as
np
import
torch
from
PIL
import
Image
from
basicsr.utils.download_util
import
load_file_from_url
from
tqdm
import
tqdm
from
modules
import
modelloader
,
devices
,
script_callbacks
,
shared
from
modules.shared
import
opts
,
state
from
swinir_model_arch
import
SwinIR
as
net
from
swinir_model_arch_v2
import
Swin2SR
as
net2
from
swinir_model_arch
import
SwinIR
from
swinir_model_arch_v2
import
Swin2SR
from
modules.upscaler
import
Upscaler
,
UpscalerData
SWINIR_MODEL_URL
=
"https://github.com/JingyunLiang/SwinIR/releases/download/v0.0/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR-L_x4_GAN.pth"
device_swinir
=
devices
.
get_device_for
(
'swinir'
)
...
...
@@ -19,16 +19,14 @@ device_swinir = devices.get_device_for('swinir')
class
UpscalerSwinIR
(
Upscaler
):
def
__init__
(
self
,
dirname
):
self
.
name
=
"SwinIR"
self
.
model_url
=
"https://github.com/JingyunLiang/SwinIR/releases/download/v0.0"
\
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR"
\
"-L_x4_GAN.pth "
self
.
model_url
=
SWINIR_MODEL_URL
self
.
model_name
=
"SwinIR 4x"
self
.
user_path
=
dirname
super
()
.
__init__
()
scalers
=
[]
model_files
=
self
.
find_models
(
ext_filter
=
[
".pt"
,
".pth"
])
for
model
in
model_files
:
if
"http"
in
model
:
if
model
.
startswith
(
"http"
)
:
name
=
self
.
model_name
else
:
name
=
modelloader
.
friendly_name
(
model
)
...
...
@@ -37,8 +35,10 @@ class UpscalerSwinIR(Upscaler):
self
.
scalers
=
scalers
def
do_upscale
(
self
,
img
,
model_file
):
model
=
self
.
load_model
(
model_file
)
if
model
is
None
:
try
:
model
=
self
.
load_model
(
model_file
)
except
Exception
as
e
:
print
(
f
"Failed loading SwinIR model {model_file}: {e}"
,
file
=
sys
.
stderr
)
return
img
model
=
model
.
to
(
device_swinir
,
dtype
=
devices
.
dtype
)
img
=
upscale
(
img
,
model
)
...
...
@@ -49,30 +49,31 @@ class UpscalerSwinIR(Upscaler):
return
img
def
load_model
(
self
,
path
,
scale
=
4
):
if
"http"
in
path
:
dl_name
=
"
%
s
%
s"
%
(
self
.
model_name
.
replace
(
" "
,
"_"
),
".pth"
)
filename
=
load_file_from_url
(
url
=
path
,
model_dir
=
self
.
model_download_path
,
file_name
=
dl_name
,
progress
=
True
)
if
path
.
startswith
(
"http"
):
filename
=
modelloader
.
load_file_from_url
(
url
=
path
,
model_dir
=
self
.
model_download_path
,
file_name
=
f
"{self.model_name.replace(' ', '_')}.pth"
,
)
else
:
filename
=
path
if
filename
is
None
or
not
os
.
path
.
exists
(
filename
):
return
None
if
filename
.
endswith
(
".v2.pth"
):
model
=
net2
(
upscale
=
scale
,
in_chans
=
3
,
img_size
=
64
,
window_size
=
8
,
img_range
=
1.0
,
depths
=
[
6
,
6
,
6
,
6
,
6
,
6
],
embed_dim
=
180
,
num_heads
=
[
6
,
6
,
6
,
6
,
6
,
6
],
mlp_ratio
=
2
,
upsampler
=
"nearest+conv"
,
resi_connection
=
"1conv"
,
model
=
Swin2SR
(
upscale
=
scale
,
in_chans
=
3
,
img_size
=
64
,
window_size
=
8
,
img_range
=
1.0
,
depths
=
[
6
,
6
,
6
,
6
,
6
,
6
],
embed_dim
=
180
,
num_heads
=
[
6
,
6
,
6
,
6
,
6
,
6
],
mlp_ratio
=
2
,
upsampler
=
"nearest+conv"
,
resi_connection
=
"1conv"
,
)
params
=
None
else
:
model
=
net
(
model
=
SwinIR
(
upscale
=
scale
,
in_chans
=
3
,
img_size
=
64
,
...
...
modules/esrgan_model.py
View file @
3cd4fd51
import
o
s
import
sy
s
import
numpy
as
np
import
torch
from
PIL
import
Image
from
basicsr.utils.download_util
import
load_file_from_url
import
modules.esrgan_model_arch
as
arch
from
modules
import
modelloader
,
images
,
devices
from
modules.upscaler
import
Upscaler
,
UpscalerData
from
modules.shared
import
opts
from
modules.upscaler
import
Upscaler
,
UpscalerData
def
mod2normal
(
state_dict
):
...
...
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler):
scaler_data
=
UpscalerData
(
self
.
model_name
,
self
.
model_url
,
self
,
4
)
scalers
.
append
(
scaler_data
)
for
file
in
model_paths
:
if
"http"
in
file
:
if
file
.
startswith
(
"http"
)
:
name
=
self
.
model_name
else
:
name
=
modelloader
.
friendly_name
(
file
)
...
...
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler):
self
.
scalers
.
append
(
scaler_data
)
def
do_upscale
(
self
,
img
,
selected_model
):
model
=
self
.
load_model
(
selected_model
)
if
model
is
None
:
try
:
model
=
self
.
load_model
(
selected_model
)
except
Exception
as
e
:
print
(
f
"Unable to load ESRGAN model {selected_model}: {e}"
,
file
=
sys
.
stderr
)
return
img
model
.
to
(
devices
.
device_esrgan
)
img
=
esrgan_upscale
(
model
,
img
)
return
img
def
load_model
(
self
,
path
:
str
):
if
"http"
in
path
:
filename
=
load_file_from_url
(
if
path
.
startswith
(
"http"
):
# TODO: this doesn't use `path` at all?
filename
=
modelloader
.
load_file_from_url
(
url
=
self
.
model_url
,
model_dir
=
self
.
model_download_path
,
file_name
=
f
"{self.model_name}.pth"
,
progress
=
True
,
)
else
:
filename
=
path
if
not
os
.
path
.
exists
(
filename
)
or
filename
is
None
:
print
(
f
"Unable to load {self.model_path} from {filename}"
)
return
None
state_dict
=
torch
.
load
(
filename
,
map_location
=
'cpu'
if
devices
.
device_esrgan
.
type
==
'mps'
else
None
)
...
...
modules/gfpgan_model.py
View file @
3cd4fd51
...
...
@@ -25,7 +25,7 @@ def gfpgann():
return
None
models
=
modelloader
.
load_models
(
model_path
,
model_url
,
user_path
,
ext_filter
=
"GFPGAN"
)
if
len
(
models
)
==
1
and
"http"
in
models
[
0
]
:
if
len
(
models
)
==
1
and
models
[
0
]
.
startswith
(
"http"
)
:
model_file
=
models
[
0
]
elif
len
(
models
)
!=
0
:
latest_file
=
max
(
models
,
key
=
os
.
path
.
getctime
)
...
...
modules/modelloader.py
View file @
3cd4fd51
from
__future__
import
annotations
import
os
import
shutil
import
importlib
...
...
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from
modules.paths
import
script_path
,
models_path
def
load_file_from_url
(
url
:
str
,
*
,
model_dir
:
str
,
progress
:
bool
=
True
,
file_name
:
str
|
None
=
None
,
)
->
str
:
"""Download a file from `url` into `model_dir`, using the file present if possible.
Returns the path to the downloaded file.
"""
os
.
makedirs
(
model_dir
,
exist_ok
=
True
)
if
not
file_name
:
parts
=
urlparse
(
url
)
file_name
=
os
.
path
.
basename
(
parts
.
path
)
cached_file
=
os
.
path
.
abspath
(
os
.
path
.
join
(
model_dir
,
file_name
))
if
not
os
.
path
.
exists
(
cached_file
):
print
(
f
'Downloading: "{url}" to {cached_file}
\n
'
)
from
torch.hub
import
download_url_to_file
download_url_to_file
(
url
,
cached_file
,
progress
=
progress
)
return
cached_file
def
load_models
(
model_path
:
str
,
model_url
:
str
=
None
,
command_path
:
str
=
None
,
ext_filter
=
None
,
download_name
=
None
,
ext_blacklist
=
None
)
->
list
:
"""
A one-and done loader to try finding the desired models in specified directories.
...
...
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if
model_url
is
not
None
and
len
(
output
)
==
0
:
if
download_name
is
not
None
:
from
basicsr.utils.download_util
import
load_file_from_url
dl
=
load_file_from_url
(
model_url
,
places
[
0
],
True
,
download_name
)
output
.
append
(
dl
)
output
.
append
(
load_file_from_url
(
model_url
,
model_dir
=
places
[
0
],
file_name
=
download_name
))
else
:
output
.
append
(
model_url
)
...
...
@@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
def
friendly_name
(
file
:
str
):
if
"http"
in
file
:
if
file
.
startswith
(
"http"
)
:
file
=
urlparse
(
file
)
.
path
file
=
os
.
path
.
basename
(
file
)
...
...
modules/realesrgan_model.py
View file @
3cd4fd51
...
...
@@ -2,7 +2,6 @@ import os
import
numpy
as
np
from
PIL
import
Image
from
basicsr.utils.download_util
import
load_file_from_url
from
realesrgan
import
RealESRGANer
from
modules.upscaler
import
Upscaler
,
UpscalerData
...
...
@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
if
not
self
.
enable
:
return
img
info
=
self
.
load_model
(
path
)
if
not
os
.
path
.
exists
(
info
.
local_data_path
):
print
(
f
"Unable to load RealESRGAN model: {info.name}"
)
try
:
info
=
self
.
load_model
(
path
)
except
Exception
:
errors
.
report
(
f
"Unable to load RealESRGAN model {path}"
,
exc_info
=
True
)
return
img
upsampler
=
RealESRGANer
(
...
...
@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
return
image
def
load_model
(
self
,
path
):
try
:
info
=
next
(
iter
([
scaler
for
scaler
in
self
.
scalers
if
scaler
.
data_path
==
path
]),
None
)
if
info
is
None
:
print
(
f
"Unable to find model info: {path}"
)
return
None
if
info
.
local_data_path
.
startswith
(
"http"
):
info
.
local_data_path
=
load_file_from_url
(
url
=
info
.
data_path
,
model_dir
=
self
.
model_download_path
,
progress
=
True
)
return
info
except
Exception
:
errors
.
report
(
"Error making Real-ESRGAN models list"
,
exc_info
=
True
)
return
None
for
scaler
in
self
.
scalers
:
if
scaler
.
data_path
==
path
:
if
scaler
.
local_data_path
.
startswith
(
"http"
):
scaler
.
local_data_path
=
modelloader
.
load_file_from_url
(
scaler
.
data_path
,
model_dir
=
self
.
model_download_path
,
)
if
not
os
.
path
.
exists
(
scaler
.
local_data_path
):
raise
FileNotFoundError
(
f
"RealESRGAN data missing: {scaler.local_data_path}"
)
return
scaler
raise
ValueError
(
f
"Unable to find model info: {path}"
)
def
load_models
(
self
,
_
):
return
get_realesrgan_models
(
self
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment