Commit 1ae860cf authored by Jonathan Beltran's avatar Jonathan Beltran

this fixes a bug with use_scale_latent_for_hires_fix

parent ccf95b0e
...@@ -644,6 +644,27 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -644,6 +644,27 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def create_dummy_mask(self, x, width=None, height=None):
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
height = height or self.height
width = width or self.width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
else:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
return image_conditioning
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
...@@ -690,6 +711,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -690,6 +711,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None x = None
devices.torch_gc() devices.torch_gc()
if opts.use_scale_latent_for_hires_fix:
image_conditioning = self.create_dummy_mask(samples)
else:
image_conditioning = self.img2img_image_conditioning( image_conditioning = self.img2img_image_conditioning(
decoded_samples, decoded_samples,
samples, samples,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment