Commit bb2e2c82 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #4233 from thesved/patch-1

Make DDIM and PLMS work on Mac OS
parents b8a2e387 86b7fc6e
import torch
import modules.devices as devices
from einops import repeat
from omegaconf import ListConfig
......@@ -314,6 +315,20 @@ class LatentInpaintDiffusion(LatentDiffusion):
self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys
# =================================================================================================
# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
# =================================================================================================
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
optimal_type = devices.get_optimal_device()
if attr.device != optimal_type:
if getattr(torch, 'has_mps', False):
attr = attr.to(device="mps", dtype=torch.float32)
else:
attr = attr.to(optimal_type)
setattr(self, name, attr)
def should_hijack_inpainting(checkpoint_info):
......@@ -326,6 +341,8 @@ def do_inpainting_hijack():
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
\ No newline at end of file
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment