Commit d6bd6a42 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into master

parents d23a46ce c23f666d
......@@ -70,6 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
......
......@@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '':
return outputs, "Please select an input directory.", ''
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
for img in image_list:
image = Image.open(img)
try:
image = Image.open(img)
except Exception:
continue
imageArr.append(image)
imageNameArr.append(img)
else:
......@@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
forced_filename=image_name if opts.use_original_name_batch else None)
if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
basename = ''
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo:
image.info = existing_pnginfo
......
......@@ -22,16 +22,26 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
super().__init__()
assert layer_structure is not None, "layer_structure mut not be None"
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
if activation_func == "relu":
linears.append(torch.nn.ReLU())
elif activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU())
elif activation_func == 'linear' or activation_func is None:
pass
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
......@@ -42,8 +52,9 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
if type(layer) == torch.nn.Linear:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
self.to(devices.device)
......@@ -69,7 +80,8 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
layer_structure += [layer.weight, layer.bias]
if type(layer) == torch.nn.Linear:
layer_structure += [layer.weight, layer.bias]
return layer_structure
......@@ -81,7 +93,7 @@ class Hypernetwork:
filename = None
name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
self.filename = None
self.name = name
self.layers = {}
......@@ -90,11 +102,12 @@ class Hypernetwork:
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
self.add_layer_norm = add_layer_norm
self.activation_func = activation_func
for size in enable_sizes or []:
self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
)
def weights(self):
......@@ -117,6 +130,7 @@ class Hypernetwork:
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['activation_func'] = self.activation_func
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
......@@ -131,12 +145,13 @@ class Hypernetwork:
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.activation_func = state_dict.get('activation_func', None)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
)
self.name = state_dict.get('name', self.name)
......
......@@ -10,7 +10,7 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
......@@ -22,6 +22,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
add_layer_norm=add_layer_norm,
activation_func=activation_func,
)
hypernet.save(fn)
......
......@@ -540,17 +540,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def create_dummy_mask(self, x, width=None, height=None):
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
height = height or self.height
width = width or self.width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
else:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
return image_conditioning
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
......@@ -587,7 +607,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
return samples
......@@ -613,6 +633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpainting_mask_invert = inpainting_mask_invert
self.mask = None
self.nmask = None
self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
......@@ -714,10 +735,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
if self.image_mask is not None:
conditioning_mask = np.array(self.image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
conditioning_mask = conditioning_mask.to(image.device)
conditioning_image = image * (1.0 - conditioning_mask)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
else:
self.image_conditioning = torch.zeros(
self.init_latent.shape[0], 5, 1, 1,
dtype=self.init_latent.dtype,
device=self.init_latent.device
)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
......
This diff is collapsed.
......@@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
......@@ -203,14 +204,26 @@ def load_model_weights(model, checkpoint_info):
model.sd_checkpoint_info = checkpoint_info
def load_model():
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = select_checkpoint()
checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.use_ema = False
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.unet_config.params.in_channels = 9
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
......@@ -234,9 +247,9 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
shared.sd_model = load_model()
shared.sd_model = load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
......
......@@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler:
self.config = None
self.last_latent = None
self.conditioning_key = sd_model.model.conditioning_key
def number_of_needed_noises(self, p):
return 0
......@@ -136,6 +138,12 @@ class VanillaStableDiffusionSampler:
if self.stop_at is not None and self.step > self.stop_at:
raise InterruptedException
# Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None
if isinstance(cond, dict):
image_conditioning = cond["c_concat"][0]
cond = cond["c_crossattn"][0]
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
......@@ -157,6 +165,12 @@ class VanillaStableDiffusionSampler:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None:
......@@ -182,7 +196,7 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
self.initialize(p)
......@@ -199,11 +213,17 @@ class VanillaStableDiffusionSampler:
self.last_latent = x
self.step = 0
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p)
self.init_latent = None
......@@ -212,6 +232,11 @@ class VanillaStableDiffusionSampler:
steps = steps or p.steps
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
# existing code fails with certain step counts, like 9
try:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
......@@ -230,7 +255,7 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale):
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
......@@ -241,28 +266,29 @@ class CFGDenoiser(torch.nn.Module):
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
......@@ -308,6 +334,8 @@ class KDiffusionSampler:
self.config = None
self.last_latent = None
self.conditioning_key = sd_model.model.conditioning_key
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
......@@ -363,7 +391,7 @@ class KDiffusionSampler:
return extra_params_kwargs
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
if p.sampler_noise_scheduler_override:
......@@ -392,11 +420,16 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
if p.sampler_noise_scheduler_override:
......@@ -418,7 +451,12 @@ class KDiffusionSampler:
extra_params_kwargs['sigmas'] = sigmas
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
......@@ -1224,6 +1224,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row():
with gr.Column(scale=3):
......@@ -1308,6 +1309,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes,
new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func,
],
outputs=[
train_hypernetwork_name,
......
......@@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs):
if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info)
p.sd_model = shared.sd_model
def confirm_checkpoints(p, xs):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment