Commit fe4e3c26 authored by AUTOMATIC's avatar AUTOMATIC

fix for PLMS live previews in txt2img

parent ca3861e0
...@@ -87,7 +87,7 @@ ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdq ...@@ -87,7 +87,7 @@ ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdq
class VanillaStableDiffusionSampler: class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model): def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model) self.sampler = constructor(sd_model)
self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else None self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
self.mask = None self.mask = None
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None
...@@ -113,7 +113,9 @@ class VanillaStableDiffusionSampler: ...@@ -113,7 +113,9 @@ class VanillaStableDiffusionSampler:
return samples return samples
def sample(self, p, x, conditioning, unconditional_conditioning): def sample(self, p, x, conditioning, unconditional_conditioning):
self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs) for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
self.mask = None self.mask = None
self.nmask = None self.nmask = None
self.init_latent = None self.init_latent = None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment