Commit d682444e authored by AUTOMATIC's avatar AUTOMATIC

add option to select hypernetwork modules when creating

parent 5ba23cb4
...@@ -42,7 +42,7 @@ class Hypernetwork: ...@@ -42,7 +42,7 @@ class Hypernetwork:
filename = None filename = None
name = None name = None
def __init__(self, name=None): def __init__(self, name=None, enable_sizes=None):
self.filename = None self.filename = None
self.name = name self.name = name
self.layers = {} self.layers = {}
...@@ -50,7 +50,7 @@ class Hypernetwork: ...@@ -50,7 +50,7 @@ class Hypernetwork:
self.sd_checkpoint = None self.sd_checkpoint = None
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
for size in [320, 640, 768, 1280]: for size in enable_sizes or [320, 640, 768, 1280]:
self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size)) self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))
def weights(self): def weights(self):
......
...@@ -9,11 +9,11 @@ from modules import sd_hijack, shared ...@@ -9,11 +9,11 @@ from modules import sd_hijack, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
def create_hypernetwork(name): def create_hypernetwork(name, enable_sizes):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt") fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists" assert not os.path.exists(fn), f"file {fn} already exists"
hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name) hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(name=name, enable_sizes=[int(x) for x in enable_sizes])
hypernet.save(fn) hypernet.save(fn)
shared.reload_hypernetworks() shared.reload_hypernetworks()
......
...@@ -1037,6 +1037,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1037,6 +1037,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new hypernetwork</p>")
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
...@@ -1114,6 +1115,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1114,6 +1115,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.hypernetworks.ui.create_hypernetwork, fn=modules.hypernetworks.ui.create_hypernetwork,
inputs=[ inputs=[
new_hypernetwork_name, new_hypernetwork_name,
new_hypernetwork_sizes,
], ],
outputs=[ outputs=[
train_hypernetwork_name, train_hypernetwork_name,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment