Commit d6bd6a42 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into master

parents d23a46ce c23f666d
......@@ -70,6 +70,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML.
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
......
......@@ -39,9 +39,12 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
if input_dir == '':
return outputs, "Please select an input directory.", ''
image_list = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
image_list = [file for file in [os.path.join(input_dir, x) for x in sorted(os.listdir(input_dir))] if os.path.isfile(file)]
for img in image_list:
image = Image.open(img)
try:
image = Image.open(img)
except Exception:
continue
imageArr.append(image)
imageNameArr.append(img)
else:
......@@ -118,10 +121,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
images.save_image(image, path=outpath, basename="", seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
forced_filename=image_name if opts.use_original_name_batch else None)
if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0]
else:
basename = ''
images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)
if opts.enable_pnginfo:
image.info = existing_pnginfo
......
......@@ -22,16 +22,26 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False, activation_func=None):
super().__init__()
assert layer_structure is not None, "layer_structure mut not be None"
assert layer_structure is not None, "layer_structure must not be None"
assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
linears = []
for i in range(len(layer_structure) - 1):
linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
if activation_func == "relu":
linears.append(torch.nn.ReLU())
elif activation_func == "leakyrelu":
linears.append(torch.nn.LeakyReLU())
elif activation_func == 'linear' or activation_func is None:
pass
else:
raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
if add_layer_norm:
linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
......@@ -42,8 +52,9 @@ class HypernetworkModule(torch.nn.Module):
self.load_state_dict(state_dict)
else:
for layer in self.linear:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
if type(layer) == torch.nn.Linear:
layer.weight.data.normal_(mean=0.0, std=0.01)
layer.bias.data.zero_()
self.to(devices.device)
......@@ -69,7 +80,8 @@ class HypernetworkModule(torch.nn.Module):
def trainables(self):
layer_structure = []
for layer in self.linear:
layer_structure += [layer.weight, layer.bias]
if type(layer) == torch.nn.Linear:
layer_structure += [layer.weight, layer.bias]
return layer_structure
......@@ -81,7 +93,7 @@ class Hypernetwork:
filename = None
name = None
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False, activation_func=None):
self.filename = None
self.name = name
self.layers = {}
......@@ -90,11 +102,12 @@ class Hypernetwork:
self.sd_checkpoint_name = None
self.layer_structure = layer_structure
self.add_layer_norm = add_layer_norm
self.activation_func = activation_func
for size in enable_sizes or []:
self.layers[size] = (
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm, self.activation_func),
)
def weights(self):
......@@ -117,6 +130,7 @@ class Hypernetwork:
state_dict['name'] = self.name
state_dict['layer_structure'] = self.layer_structure
state_dict['is_layer_norm'] = self.add_layer_norm
state_dict['activation_func'] = self.activation_func
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
......@@ -131,12 +145,13 @@ class Hypernetwork:
self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
self.add_layer_norm = state_dict.get('is_layer_norm', False)
self.activation_func = state_dict.get('activation_func', None)
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm),
HypernetworkModule(size, sd[0], self.layer_structure, self.add_layer_norm, self.activation_func),
HypernetworkModule(size, sd[1], self.layer_structure, self.add_layer_norm, self.activation_func),
)
self.name = state_dict.get('name', self.name)
......
......@@ -10,7 +10,7 @@ from modules import sd_hijack, shared, devices
from modules.hypernetworks import hypernetwork
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False):
def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm=False, activation_func=None):
fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
......@@ -22,6 +22,7 @@ def create_hypernetwork(name, enable_sizes, layer_structure=None, add_layer_norm
enable_sizes=[int(x) for x in enable_sizes],
layer_structure=layer_structure,
add_layer_norm=add_layer_norm,
activation_func=activation_func,
)
hypernet.save(fn)
......
......@@ -540,17 +540,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def create_dummy_mask(self, x, width=None, height=None):
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
height = height or self.height
width = width or self.width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
else:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
return image_conditioning
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
......@@ -587,7 +607,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
return samples
......@@ -613,6 +633,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.inpainting_mask_invert = inpainting_mask_invert
self.mask = None
self.nmask = None
self.image_conditioning = None
def init(self, all_prompts, all_seeds, all_subseeds):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers_for_img2img, self.sampler_index, self.sd_model)
......@@ -714,10 +735,39 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
if self.image_mask is not None:
conditioning_mask = np.array(self.image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
conditioning_mask = conditioning_mask.to(image.device)
conditioning_image = image * (1.0 - conditioning_mask)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
else:
self.image_conditioning = torch.zeros(
self.init_latent.shape[0], 5, 1, 1,
dtype=self.init_latent.dtype,
device=self.init_latent.device
)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
samples = samples * self.nmask + self.init_latent * self.mask
......
import torch
from einops import repeat
from omegaconf import ListConfig
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from ldm.models.diffusion.ddpm import LatentDiffusion
from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ddim import DDIMSampler, noise_like
# =================================================================================================
# Monkey patch DDIMSampler methods from RunwayML repo directly.
# Adapted from:
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py
# =================================================================================================
@torch.no_grad()
def sample_ddim(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
# =================================================================================================
# Monkey patch PLMSSampler methods.
# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes.
# Adapted from:
# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py
# =================================================================================================
@torch.no_grad()
def sample_plms(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples, intermediates
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat([unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
else:
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t
# =================================================================================================
# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config.
# Adapted from:
# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py
# =================================================================================================
@torch.no_grad()
def get_unconditional_conditioning(self, batch_size, null_label=None):
if null_label is not None:
xc = null_label
if isinstance(xc, ListConfig):
xc = list(xc)
if isinstance(xc, dict) or isinstance(xc, list):
c = self.get_learned_conditioning(xc)
else:
if hasattr(xc, "to"):
xc = xc.to(self.device)
c = self.get_learned_conditioning(xc)
else:
# todo: get null label from cond_stage_model
raise NotImplementedError()
c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device)
return c
class LatentInpaintDiffusion(LatentDiffusion):
def __init__(
self,
concat_keys=("mask", "masked_image"),
masked_image_key="masked_image",
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
def should_hijack_inpainting(checkpoint_info):
return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")
def do_inpainting_hijack():
ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning
ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
\ No newline at end of file
......@@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
......@@ -203,14 +204,26 @@ def load_model_weights(model, checkpoint_info):
model.sd_checkpoint_info = checkpoint_info
def load_model():
def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
checkpoint_info = select_checkpoint()
checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config)
if should_hijack_inpainting(checkpoint_info):
# Hardcoded config for now...
sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion"
sd_config.model.params.use_ema = False
sd_config.model.params.conditioning_key = "hybrid"
sd_config.model.params.unet_config.params.in_channels = 9
# Create a "fake" config with a different name so that we know to unload it when switching models.
checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
do_inpainting_hijack()
sd_model = instantiate_from_config(sd_config.model)
load_model_weights(sd_model, checkpoint_info)
......@@ -234,9 +247,9 @@ def reload_model_weights(sd_model, info=None):
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
shared.sd_model = load_model()
shared.sd_model = load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
......
......@@ -117,6 +117,8 @@ class VanillaStableDiffusionSampler:
self.config = None
self.last_latent = None
self.conditioning_key = sd_model.model.conditioning_key
def number_of_needed_noises(self, p):
return 0
......@@ -136,6 +138,12 @@ class VanillaStableDiffusionSampler:
if self.stop_at is not None and self.step > self.stop_at:
raise InterruptedException
# Have to unwrap the inpainting conditioning here to perform pre-processing
image_conditioning = None
if isinstance(cond, dict):
image_conditioning = cond["c_concat"][0]
cond = cond["c_crossattn"][0]
unconditional_conditioning = unconditional_conditioning["c_crossattn"][0]
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
......@@ -157,6 +165,12 @@ class VanillaStableDiffusionSampler:
img_orig = self.sampler.model.q_sample(self.init_latent, ts)
x_dec = img_orig * self.mask + self.nmask * x_dec
# Wrap the image conditioning back up since the DDIM code can accept the dict directly.
# Note that they need to be lists because it just concatenates them later.
if image_conditioning is not None:
cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
if self.mask is not None:
......@@ -182,7 +196,7 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
self.initialize(p)
......@@ -199,11 +213,17 @@ class VanillaStableDiffusionSampler:
self.last_latent = x
self.step = 0
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
samples = self.launch_sampling(steps, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning))
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
self.initialize(p)
self.init_latent = None
......@@ -212,6 +232,11 @@ class VanillaStableDiffusionSampler:
steps = steps or p.steps
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
# existing code fails with certain step counts, like 9
try:
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
......@@ -230,7 +255,7 @@ class CFGDenoiser(torch.nn.Module):
self.init_latent = None
self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale):
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise InterruptedException
......@@ -241,28 +266,29 @@ class CFGDenoiser(torch.nn.Module):
repeats = [len(conds_list[i]) for i in range(batch_size)]
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
if tensor.shape[1] == uncond.shape[1]:
cond_in = torch.cat([tensor, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond=cond_in)
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
else:
x_out = torch.zeros_like(x_in)
for batch_offset in range(0, x_out.shape[0], batch_size):
a = batch_offset
b = a + batch_size
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]})
else:
x_out = torch.zeros_like(x_in)
batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b])
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond)
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
denoised_uncond = x_out[-uncond.shape[0]:]
denoised = torch.clone(denoised_uncond)
......@@ -308,6 +334,8 @@ class KDiffusionSampler:
self.config = None
self.last_latent = None
self.conditioning_key = sd_model.model.conditioning_key
def callback_state(self, d):
step = d['i']
latent = d["denoised"]
......@@ -363,7 +391,7 @@ class KDiffusionSampler:
return extra_params_kwargs
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
if p.sampler_noise_scheduler_override:
......@@ -392,11 +420,16 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
steps = steps or p.steps
if p.sampler_noise_scheduler_override:
......@@ -418,7 +451,12 @@ class KDiffusionSampler:
extra_params_kwargs['sigmas'] = sigmas
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs))
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
......@@ -1224,6 +1224,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "320", "640", "1280"])
new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'")
new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization")
new_hypernetwork_activation_func = gr.Dropdown(value="relu", label="Select activation function of hypernetwork", choices=["linear", "relu", "leakyrelu"])
with gr.Row():
with gr.Column(scale=3):
......@@ -1308,6 +1309,7 @@ def create_ui(wrap_gradio_gpu_call):
new_hypernetwork_sizes,
new_hypernetwork_layer_structure,
new_hypernetwork_add_layer_norm,
new_hypernetwork_activation_func,
],
outputs=[
train_hypernetwork_name,
......
......@@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs):
if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info)
p.sd_model = shared.sd_model
def confirm_checkpoints(p, xs):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment