Commit 0f703af6 authored by Connum's avatar Connum

Merge branch 'master' into img2img-alt-force-method

parents 4f5f7865 ada901ed
...@@ -4,6 +4,7 @@ import numpy as np ...@@ -4,6 +4,7 @@ import numpy as np
from PIL import Image from PIL import Image
import torch import torch
import tqdm
from modules import processing, shared, images, devices from modules import processing, shared, images, devices
from modules.shared import opts from modules.shared import opts
...@@ -149,28 +150,45 @@ def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount): ...@@ -149,28 +150,45 @@ def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
alpha = alpha * alpha * (3 - (2 * alpha)) alpha = alpha * alpha * (3 - (2 * alpha))
return theta0 + ((theta1 - theta0) * alpha) return theta0 + ((theta1 - theta0) * alpha)
model_0 = torch.load('models/' + modelname_0 + '.ckpt') if os.path.exists(modelname_0):
model_1 = torch.load('models/' + modelname_1 + '.ckpt') model0_filename = modelname_0
modelname_0 = os.path.splitext(os.path.basename(modelname_0))[0]
else:
model0_filename = 'models/' + modelname_0 + '.ckpt'
if os.path.exists(modelname_1):
model1_filename = modelname_1
modelname_1 = os.path.splitext(os.path.basename(modelname_1))[0]
else:
model1_filename = 'models/' + modelname_1 + '.ckpt'
print(f"Loading {model0_filename}...")
model_0 = torch.load(model0_filename, map_location='cpu')
print(f"Loading {model1_filename}...")
model_1 = torch.load(model1_filename, map_location='cpu')
theta_0 = model_0['state_dict'] theta_0 = model_0['state_dict']
theta_1 = model_1['state_dict'] theta_1 = model_1['state_dict']
theta_func = weighted_sum
theta_funcs = {
if interp_method == "Weighted Sum": "Weighted Sum": weighted_sum,
theta_func = weighted_sum "Sigmoid": sigmoid,
if interp_method == "Sigmoid": }
theta_func = sigmoid theta_func = theta_funcs[interp_method]
for key in theta_0.keys(): print(f"Merging...")
for key in tqdm.tqdm(theta_0.keys()):
if 'model' in key and key in theta_1: if 'model' in key and key in theta_1:
theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount) theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount)
for key in theta_1.keys(): for key in theta_1.keys():
if 'model' in key and key not in theta_0: if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key] theta_0[key] = theta_1[key]
output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt'; output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt'
print(f"Saving to {output_modelname}...")
torch.save(model_0, output_modelname) torch.save(model_0, output_modelname)
return "<p>Model saved to " + output_modelname + "</p>" print(f"Checkpoint saved.")
return "Checkpoint saved to " + output_modelname
...@@ -49,6 +49,7 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None ...@@ -49,6 +49,7 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None
css_hide_progressbar = """ css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; } .wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; } .progress-bar { display:none!important; }
.meta-text { display:none!important; } .meta-text { display:none!important; }
""" """
...@@ -865,7 +866,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -865,7 +866,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
submit_result = gr.HTML(elem_id="modelmerger_result") submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
submit.click( submit.click(
fn=run_modelmerger, fn=run_modelmerger,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment