Commit 5a0db84b authored by AUTOMATIC1111's avatar AUTOMATIC1111

add infotext

add proper support for recalculating conds in k-diffusion samplers
remove support for compvis samplers
parent 956e69bf
...@@ -344,6 +344,8 @@ infotext_to_setting_name_mapping = [ ...@@ -344,6 +344,8 @@ infotext_to_setting_name_mapping = [
('Pad conds', 'pad_cond_uncond'), ('Pad conds', 'pad_cond_uncond'),
('VAE Encoder', 'sd_vae_encode_method'), ('VAE Encoder', 'sd_vae_encode_method'),
('VAE Decoder', 'sd_vae_decode_method'), ('VAE Decoder', 'sd_vae_decode_method'),
('Refiner', 'sd_refiner_checkpoint'),
('Refiner switch at', 'sd_refiner_switch_at'),
] ]
......
...@@ -370,6 +370,9 @@ class StableDiffusionProcessing: ...@@ -370,6 +370,9 @@ class StableDiffusionProcessing:
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
def get_conds(self):
return self.c, self.uc
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
...@@ -1251,6 +1254,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1251,6 +1254,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast(): with devices.autocast():
extra_networks.activate(self, self.extra_network_data) extra_networks.activate(self, self.extra_network_data)
def get_conds(self):
if self.is_hr_pass:
return self.hr_c, self.hr_uc
return super().get_conds()
def parse_extra_network_prompts(self): def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts() res = super().parse_extra_network_prompts()
......
...@@ -131,16 +131,27 @@ replace_torchsde_browinan() ...@@ -131,16 +131,27 @@ replace_torchsde_browinan()
def apply_refiner(sampler): def apply_refiner(sampler):
completed_ratio = sampler.step / sampler.steps completed_ratio = sampler.step / sampler.steps
if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint:
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
if refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
with sd_models.SkipWritingToConfig(): if completed_ratio <= shared.opts.sd_refiner_switch_at:
sd_models.reload_model_weights(info=refiner_checkpoint_info) return False
if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
return False
refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
if refiner_checkpoint_info is None:
raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')
sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=refiner_checkpoint_info)
devices.torch_gc()
sampler.p.setup_conds()
sampler.update_inner_model()
devices.torch_gc() return True
sampler.update_inner_model()
sampler.p.setup_conds()
...@@ -71,8 +71,6 @@ class VanillaStableDiffusionSampler: ...@@ -71,8 +71,6 @@ class VanillaStableDiffusionSampler:
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException raise sd_samplers_common.InterruptedException
sd_samplers_common.apply_refiner(self)
if self.stop_at is not None and self.step > self.stop_at: if self.stop_at is not None and self.step > self.stop_at:
raise sd_samplers_common.InterruptedException raise sd_samplers_common.InterruptedException
......
...@@ -87,8 +87,9 @@ class CFGDenoiser(torch.nn.Module): ...@@ -87,8 +87,9 @@ class CFGDenoiser(torch.nn.Module):
negative prompt. negative prompt.
""" """
def __init__(self): def __init__(self, sampler):
super().__init__() super().__init__()
self.sampler = sampler
self.model_wrap = None self.model_wrap = None
self.mask = None self.mask = None
self.nmask = None self.nmask = None
...@@ -126,11 +127,17 @@ class CFGDenoiser(torch.nn.Module): ...@@ -126,11 +127,17 @@ class CFGDenoiser(torch.nn.Module):
def update_inner_model(self): def update_inner_model(self):
self.model_wrap = None self.model_wrap = None
c, uc = self.p.get_conds()
self.sampler.sampler_extra_args['cond'] = c
self.sampler.sampler_extra_args['uncond'] = uc
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped: if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException raise sd_samplers_common.InterruptedException
sd_samplers_common.apply_refiner(self) if sd_samplers_common.apply_refiner(self):
cond = self.sampler.sampler_extra_args['cond']
uncond = self.sampler.sampler_extra_args['uncond']
# at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
# so is_edit_model is set to False to support AND composition. # so is_edit_model is set to False to support AND composition.
...@@ -282,12 +289,12 @@ class TorchHijack: ...@@ -282,12 +289,12 @@ class TorchHijack:
class KDiffusionSampler: class KDiffusionSampler:
def __init__(self, funcname, sd_model): def __init__(self, funcname, sd_model):
self.p = None self.p = None
self.funcname = funcname self.funcname = funcname
self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, []) self.extra_params = sampler_extra_params.get(funcname, [])
self.model_wrap_cfg = CFGDenoiser() self.sampler_extra_args = {}
self.model_wrap_cfg = CFGDenoiser(self)
self.model_wrap = self.model_wrap_cfg.inner_model self.model_wrap = self.model_wrap_cfg.inner_model
self.sampler_noises = None self.sampler_noises = None
self.stop_at = None self.stop_at = None
...@@ -476,7 +483,7 @@ class KDiffusionSampler: ...@@ -476,7 +483,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.init_latent = x self.model_wrap_cfg.init_latent = x
self.last_latent = x self.last_latent = x
extra_args = { self.sampler_extra_args = {
'cond': conditioning, 'cond': conditioning,
'image_cond': image_conditioning, 'image_cond': image_conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
...@@ -484,7 +491,7 @@ class KDiffusionSampler: ...@@ -484,7 +491,7 @@ class KDiffusionSampler:
's_min_uncond': self.s_min_uncond 's_min_uncond': self.s_min_uncond
} }
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond: if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True p.extra_generation_params["Pad conds"] = True
...@@ -514,13 +521,14 @@ class KDiffusionSampler: ...@@ -514,13 +521,14 @@ class KDiffusionSampler:
extra_params_kwargs['noise_sampler'] = noise_sampler extra_params_kwargs['noise_sampler'] = noise_sampler
self.last_latent = x self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ self.sampler_extra_args = {
'cond': conditioning, 'cond': conditioning,
'image_cond': image_conditioning, 'image_cond': image_conditioning,
'uncond': unconditional_conditioning, 'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale, 'cond_scale': p.cfg_scale,
's_min_uncond': self.s_min_uncond 's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs)) }
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
if self.model_wrap_cfg.padded_cond_uncond: if self.model_wrap_cfg.padded_cond_uncond:
p.extra_generation_params["Pad conds"] = True p.extra_generation_params["Pad conds"] = True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment