Commit ffea9b15 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #3414 from discus0434/master

[Hypernetworks] Add a feature to use dropout / more activation functions
parents e3862501 6a4fa73a
import csv
import datetime import datetime
import glob import glob
import html import html
import os import os
import sys import sys
import traceback import traceback
import tqdm
import csv
import modules.textual_inversion.dataset
import torch import torch
import tqdm
from ldm.util import default
from modules import devices, shared, processing, sd_models
import torch
from torch import einsum
from einops import rearrange, repeat from einops import rearrange, repeat
import modules.textual_inversion.dataset from ldm.util import default
from modules import devices, processing, sd_models, shared
from modules.textual_inversion import textual_inversion from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
activation_dict = {
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None): "relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
"swish": torch.nn.Hardswish,
}
def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
super().__init__() super().__init__()
assert layer_structure is not None, "layer_structure must not be None" assert layer_structure is not None, "layer_structure must not be None"
...@@ -31,20 +35,26 @@ class HypernetworkModule(torch.nn.Module): ...@@ -31,20 +35,26 @@ class HypernetworkModule(torch.nn.Module):
linears = [] linears = []
for i in range(len(layer_structure) - 1): for i in range(len(layer_structure) - 1):
# Add a fully-connected layer
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1]))) linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
if activation_func == "relu": # Add an activation func
linears.append(torch.nn.ReLU()) if activation_func == "linear" or activation_func is None:
elif activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU())
elif activation_func == 'linear' or activation_func is None:
pass pass
elif activation_func in self.activation_dict:
linears.append(self.activation_dict[activation_func]())
else: else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}') raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
# Add layer normalization
if add_layer_norm: if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1]))) linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
# Add dropout expect last layer
if use_dropout and i < len(layer_structure) - 3:
linears.append(torch.nn.Dropout(p=0.3))
self.linear = torch.nn.Sequential(*linears) self.linear = torch.nn.Sequential(*linears)
if state_dict is not None: if state_dict is not None:
...@@ -93,7 +103,7 @@ class Hypernetwork: ...@@ -93,7 +103,7 @@ class Hypernetwork:
filename = None filename = None
name = None name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None): def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
self.filename = None self.filename = None
self.name = name self.name = name
self.layers = {} self.layers = {}
...@@ -101,13 +111,14 @@ class Hypernetwork: ...@@ -101,13 +111,14 @@ class Hypernetwork:
self.sd_checkpoint = None self.sd_checkpoint = None
self.sd_checkpoint_name = None self.sd_checkpoint_name = None
self.layer_structure = layer_structure self.layer_structure = layer_structure
self.add_layer_norm = add_layer_norm
self.activation_func = activation_func self.activation_func = activation_func
self.add_layer_norm = add_layer_norm
self.use_dropout = use_dropout
for size in enable_sizes or []: for size in enable_sizes or []:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func), HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
) )
def weights(self): def weights(self):
...@@ -129,8 +140,9 @@ class Hypernetwork: ...@@ -129,8 +140,9 @@ class Hypernetwork:
state_dict['step'] = self.step state_dict['step'] = self.step
state_dict['name'] = self.name state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure state_dict['layer_structure'] = self.layer_structure
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['activation_func'] = self.activation_func state_dict['activation_func'] = self.activation_func
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['use_dropout'] = self.use_dropout
state_dict['sd_checkpoint'] = self.sd_checkpoint state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
...@@ -144,14 +156,15 @@ class Hypernetwork: ...@@ -144,14 +156,15 @@ class Hypernetwork:
state_dict = torch.load(filename, map_location='cpu') state_dict = torch.load(filename, map_location='cpu')
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1]) self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.activation_func = state_dict.get('activation_func', None) self.activation_func = state_dict.get('activation_func', None)
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.use_dropout = state_dict.get('use_dropout', False)
for size, sd in state_dict.items(): for size, sd in state_dict.items():
if type(size) == int: if type(size) == int:
self.layers[size] = ( self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func), HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func), HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.add_layer_norm, self.use_dropout),
) )
self.name = state_dict.get('name', self.name) self.name = state_dict.get('name', self.name)
...@@ -308,6 +321,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -308,6 +321,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
return hypernetwork, filename return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step) scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate) optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
steps_without_grad = 0 steps_without_grad = 0
......
...@@ -3,14 +3,13 @@ import os ...@@ -3,14 +3,13 @@ import os
import re import re
import gradio as gr import gradio as gr
import modules.textual_inversion.textual_inversion
import modules.textual_inversion.preprocess import modules.textual_inversion.preprocess
from modules import sd_hijack, shared, devices import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, add_layer_norm=False, activation_func=None): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name. # Remove illegal characters from name.
name = "".join( x for x in name if (x.isalnum() or x in "._- ")) name = "".join( x for x in name if (x.isalnum() or x in "._- "))
...@@ -25,8 +24,9 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, ...@@ -25,8 +24,9 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
name=name, name=name,
enable_sizes=[int(x) for x in enable_sizes], enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure, layer_structure=layer_structure,
add_layer_norm=add_layer_norm,
activation_func=activation_func, activation_func=activation_func,
add_layer_norm=add_layer_norm,
use_dropout=use_dropout,
) )
hypernet.save(fn) hypernet.save(fn)
......
...@@ -5,19 +5,22 @@ import json ...@@ -5,19 +5,22 @@ import json
import math import math
import mimetypes import mimetypes
import os import os
import platform
import random import random
import subprocess as sp
import sys import sys
import tempfile import tempfile
import time import time
import traceback import traceback
import platform
import subprocess as sp
from functools import partial, reduce from functools import partial, reduce
import gradio as gr
import gradio.routes
import gradio.utils
import numpy as np import numpy as np
import piexif
import torch import torch
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
import piexif
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
...@@ -30,17 +33,21 @@ from modules.shared import opts, cmd_opts, restricted_opts ...@@ -30,17 +33,21 @@ from modules.shared import opts, cmd_opts, restricted_opts
if cmd_opts.deepdanbooru: if cmd_opts.deepdanbooru:
from modules.deepbooru import get_deepbooru_tags from modules.deepbooru import get_deepbooru_tags
import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img import modules.codeformer_model
from modules.sd_hijack import model_hijack import modules.generation_parameters_copypaste
import modules.gfpgan_model
import modules.hypernetworks.ui
import modules.images_history as img_his
import modules.ldsr_model import modules.ldsr_model
import modules.scripts import modules.scripts
import modules.gfpgan_model import modules.shared as shared
import modules.codeformer_model
import modules.styles import modules.styles
import modules.generation_parameters_copypaste import modules.textual_inversion.ui
from modules import prompt_parser from modules import prompt_parser
from modules.images import save_image from modules.images import save_image
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
import modules.textual_inversion.ui import modules.textual_inversion.ui
import modules.hypernetworks.ui import modules.hypernetworks.ui
...@@ -1233,9 +1240,10 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1233,9 +1240,10 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_name = gr.Textbox(label="Name") new_hypernetwork_name = gr.Textbox(label="Name")
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"]) new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu", "elu", "swish"])
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout")
overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork") overwrite_old_hypernetwork = gr.Checkbox(value=False, label="Overwrite Old Hypernetwork")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
...@@ -1285,7 +1293,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1285,7 +1293,7 @@ def create_ui(wrap_gradio_gpu_call):
with gr.Row(): with gr.Row():
embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005") embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005")
hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001") hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001")
batch_size = gr.Number(label='Batch size', value=1, precision=0) batch_size = gr.Number(label='Batch size', value=1, precision=0)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images") dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion") log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
...@@ -1335,8 +1343,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1335,8 +1343,9 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes, new_hypernetwork_sizes,
overwrite_old_hypernetwork, overwrite_old_hypernetwork,
new_hypernetwork_layer_structure, new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func, new_hypernetwork_activation_func,
new_hypernetwork_add_layer_norm,
new_hypernetwork_use_dropout
], ],
outputs=[ outputs=[
train_hypernetwork_name, train_hypernetwork_name,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment