Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
S
Stable Diffusion Webui
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Stable Diffusion Webui
Commits
07be13ca
Commit
07be13ca
authored
Aug 01, 2023
by
AUTOMATIC1111
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add metadata to checkpoint merger
parent
6d3a0c95
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
9 deletions
+52
-9
modules/extras.py
modules/extras.py
+33
-6
modules/sd_models.py
modules/sd_models.py
+1
-1
modules/ui_checkpoint_merger.py
modules/ui_checkpoint_merger.py
+18
-2
No files found.
modules/extras.py
View file @
07be13ca
...
...
@@ -7,7 +7,7 @@ import json
import
torch
import
tqdm
from
modules
import
shared
,
images
,
sd_models
,
sd_vae
,
sd_models_config
from
modules
import
shared
,
images
,
sd_models
,
sd_vae
,
sd_models_config
,
errors
from
modules.ui_common
import
plaintext_to_html
import
gradio
as
gr
import
safetensors.torch
...
...
@@ -72,7 +72,20 @@ def to_half(tensor, enable):
return
tensor
def
run_modelmerger
(
id_task
,
primary_model_name
,
secondary_model_name
,
tertiary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
custom_name
,
checkpoint_format
,
config_source
,
bake_in_vae
,
discard_weights
,
save_metadata
):
def
read_metadata
(
primary_model_name
,
secondary_model_name
,
tertiary_model_name
):
metadata
=
{}
for
checkpoint_name
in
[
primary_model_name
,
secondary_model_name
,
tertiary_model_name
]:
checkpoint_info
=
sd_models
.
checkpoints_list
.
get
(
checkpoint_name
,
None
)
if
checkpoint_info
is
None
:
continue
metadata
.
update
(
checkpoint_info
.
metadata
)
return
json
.
dumps
(
metadata
,
indent
=
4
,
ensure_ascii
=
False
)
def
run_modelmerger
(
id_task
,
primary_model_name
,
secondary_model_name
,
tertiary_model_name
,
interp_method
,
multiplier
,
save_as_half
,
custom_name
,
checkpoint_format
,
config_source
,
bake_in_vae
,
discard_weights
,
save_metadata
,
add_merge_recipe
,
copy_metadata_fields
,
metadata_json
):
shared
.
state
.
begin
(
job
=
"model-merge"
)
def
fail
(
message
):
...
...
@@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
shared
.
state
.
textinfo
=
"Saving"
print
(
f
"Saving to {output_modelname}..."
)
metadata
=
None
metadata
=
{}
if
save_metadata
and
copy_metadata_fields
:
if
primary_model_info
:
metadata
.
update
(
primary_model_info
.
metadata
)
if
secondary_model_info
:
metadata
.
update
(
secondary_model_info
.
metadata
)
if
tertiary_model_info
:
metadata
.
update
(
tertiary_model_info
.
metadata
)
if
save_metadata
:
metadata
=
{
"format"
:
"pt"
}
try
:
metadata
.
update
(
json
.
loads
(
metadata_json
))
except
Exception
as
e
:
errors
.
display
(
e
,
"readin metadata from json"
)
metadata
[
"format"
]
=
"pt"
if
save_metadata
and
add_merge_recipe
:
merge_recipe
=
{
"type"
:
"webui"
,
# indicate this model was merged with webui's built-in merger
"primary_model_hash"
:
primary_model_info
.
sha256
,
...
...
@@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
"is_inpainting"
:
result_is_inpainting_model
,
"is_instruct_pix2pix"
:
result_is_instruct_pix2pix_model
}
metadata
[
"sd_merge_recipe"
]
=
json
.
dumps
(
merge_recipe
)
sd_merge_models
=
{}
...
...
@@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
if
tertiary_model_info
:
add_model_metadata
(
tertiary_model_info
)
metadata
[
"sd_merge_recipe"
]
=
json
.
dumps
(
merge_recipe
)
metadata
[
"sd_merge_models"
]
=
json
.
dumps
(
sd_merge_models
)
_
,
extension
=
os
.
path
.
splitext
(
output_modelname
)
if
extension
.
lower
()
==
".safetensors"
:
safetensors
.
torch
.
save_file
(
theta_0
,
output_modelname
,
metadata
=
metadata
)
safetensors
.
torch
.
save_file
(
theta_0
,
output_modelname
,
metadata
=
metadata
if
len
(
metadata
)
>
0
else
None
)
else
:
torch
.
save
(
theta_0
,
output_modelname
)
...
...
modules/sd_models.py
View file @
07be13ca
...
...
@@ -85,7 +85,7 @@ class CheckpointInfo:
if
self
.
shorthash
not
in
self
.
ids
:
self
.
ids
+=
[
self
.
shorthash
,
self
.
sha256
,
f
'{self.name} [{self.shorthash}]'
]
checkpoints_list
.
pop
(
self
.
title
)
checkpoints_list
.
pop
(
self
.
title
,
None
)
self
.
title
=
f
'{self.name} [{self.shorthash}]'
self
.
register
()
...
...
modules/ui_checkpoint_merger.py
View file @
07be13ca
...
...
@@ -51,7 +51,6 @@ class UiCheckpointMerger:
with
FormRow
():
self
.
checkpoint_format
=
gr
.
Radio
(
choices
=
[
"ckpt"
,
"safetensors"
],
value
=
"safetensors"
,
label
=
"Checkpoint format"
,
elem_id
=
"modelmerger_checkpoint_format"
)
self
.
save_as_half
=
gr
.
Checkbox
(
value
=
False
,
label
=
"Save as float16"
,
elem_id
=
"modelmerger_save_as_half"
)
self
.
save_metadata
=
gr
.
Checkbox
(
value
=
True
,
label
=
"Save metadata (.safetensors only)"
,
elem_id
=
"modelmerger_save_metadata"
)
with
FormRow
():
with
gr
.
Column
():
...
...
@@ -65,16 +64,30 @@ class UiCheckpointMerger:
with
FormRow
():
self
.
discard_weights
=
gr
.
Textbox
(
value
=
""
,
label
=
"Discard weights with matching name"
,
elem_id
=
"modelmerger_discard_weights"
)
with
gr
.
Row
():
with
gr
.
Accordion
(
"Metadata"
,
open
=
False
)
as
metadata_editor
:
with
FormRow
():
self
.
save_metadata
=
gr
.
Checkbox
(
value
=
True
,
label
=
"Save metadata"
,
elem_id
=
"modelmerger_save_metadata"
)
self
.
add_merge_recipe
=
gr
.
Checkbox
(
value
=
True
,
label
=
"Add merge recipe metadata"
,
elem_id
=
"modelmerger_add_recipe"
)
self
.
copy_metadata_fields
=
gr
.
Checkbox
(
value
=
True
,
label
=
"Copy metadata from merged models"
,
elem_id
=
"modelmerger_copy_metadata"
)
self
.
metadata_json
=
gr
.
TextArea
(
'{}'
,
label
=
"Metadata in JSON format"
)
self
.
read_metadata
=
gr
.
Button
(
"Read metadata from selected checkpoints"
)
with
FormRow
():
self
.
modelmerger_merge
=
gr
.
Button
(
elem_id
=
"modelmerger_merge"
,
value
=
"Merge"
,
variant
=
'primary'
)
with
gr
.
Column
(
variant
=
'compact'
,
elem_id
=
"modelmerger_results_container"
):
with
gr
.
Group
(
elem_id
=
"modelmerger_results_panel"
):
self
.
modelmerger_result
=
gr
.
HTML
(
elem_id
=
"modelmerger_result"
,
show_label
=
False
)
self
.
metadata_editor
=
metadata_editor
self
.
blocks
=
modelmerger_interface
def
setup_ui
(
self
,
dummy_component
,
sd_model_checkpoint_component
):
self
.
checkpoint_format
.
change
(
lambda
fmt
:
gr
.
update
(
visible
=
fmt
==
'safetensors'
),
inputs
=
[
self
.
checkpoint_format
],
outputs
=
[
self
.
metadata_editor
],
show_progress
=
False
)
self
.
read_metadata
.
click
(
extras
.
read_metadata
,
inputs
=
[
self
.
primary_model_name
,
self
.
secondary_model_name
,
self
.
tertiary_model_name
],
outputs
=
[
self
.
metadata_json
])
self
.
modelmerger_merge
.
click
(
fn
=
lambda
:
''
,
inputs
=
[],
outputs
=
[
self
.
modelmerger_result
])
self
.
modelmerger_merge
.
click
(
fn
=
call_queue
.
wrap_gradio_gpu_call
(
modelmerger
,
extra_outputs
=
lambda
:
[
gr
.
update
()
for
_
in
range
(
4
)]),
...
...
@@ -93,6 +106,9 @@ class UiCheckpointMerger:
self
.
bake_in_vae
,
self
.
discard_weights
,
self
.
save_metadata
,
self
.
add_merge_recipe
,
self
.
copy_metadata_fields
,
self
.
metadata_json
,
],
outputs
=
[
self
.
primary_model_name
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment