Commit 91643f65 authored by William Moorehouse's avatar William Moorehouse

Add support for checkpoint merging

parent ca3e5519
...@@ -3,6 +3,8 @@ import os ...@@ -3,6 +3,8 @@ import os
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch
from modules import processing, shared, images, devices from modules import processing, shared, images, devices
from modules.shared import opts from modules.shared import opts
import modules.gfpgan_model import modules.gfpgan_model
...@@ -135,3 +137,25 @@ def run_pnginfo(image): ...@@ -135,3 +137,25 @@ def run_pnginfo(image):
info = f"<div><p>{message}<p></div>" info = f"<div><p>{message}<p></div>"
return '', geninfo, info return '', geninfo, info
def run_modelmerger(modelname_0, modelname_1, alpha):
model_0 = torch.load('models/' + modelname_0 + '.ckpt')
model_1 = torch.load('models/' + modelname_1 + '.ckpt')
theta_0 = model_0['state_dict']
theta_1 = model_1['state_dict']
for key in theta_0.keys():
if 'model' in key and key in theta_1:
theta_0[key] = (1 - alpha) * theta_0[key] + alpha * theta_1[key]
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt';
torch.save(model_0, output_modelname)
return "<p>Model saved to " + output_modelname + "</p>"
...@@ -393,7 +393,7 @@ def setup_progressbar(progressbar, preview, id_part): ...@@ -393,7 +393,7 @@ def setup_progressbar(progressbar, preview, id_part):
) )
def create_ui(txt2img, img2img, run_extras, run_pnginfo): def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
...@@ -853,6 +853,31 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -853,6 +853,31 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
outputs=[html, generation_info, html2], outputs=[html, generation_info, html2],
) )
with gr.Blocks() as modelmerger_interface:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>/models</b> directory.</p>")
modelname_0 = gr.Textbox(elem_id="modelmerger_modelname_0", label="Model Name (to)")
modelname_1 = gr.Textbox(elem_id="modelmerger_modelname_1", label="Model Name (from)")
alpha = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Alpha', value=0.3)
submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.HTML(elem_id="modelmerger_result")
submit.click(
fn=run_modelmerger,
inputs=[
modelname_0,
modelname_1,
alpha
],
outputs=[
submit_result,
]
)
def create_setting_component(key): def create_setting_component(key):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default
...@@ -950,6 +975,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -950,6 +975,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
(img2img_interface, "img2img", "img2img"), (img2img_interface, "img2img", "img2img"),
(extras_interface, "Extras", "extras"), (extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"), (pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(settings_interface, "Settings", "settings"), (settings_interface, "Settings", "settings"),
] ]
......
...@@ -85,7 +85,8 @@ def webui(): ...@@ -85,7 +85,8 @@ def webui():
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img), txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
img2img=wrap_gradio_gpu_call(modules.img2img.img2img), img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras), run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
run_pnginfo=modules.extras.run_pnginfo run_pnginfo=modules.extras.run_pnginfo,
run_modelmerger=modules.extras.run_modelmerger
) )
demo.launch( demo.launch(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment