Commit bb2e2c82 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #4233 from thesved/patch-1

Make DDIM and PLMS work on Mac OS
parents b8a2e387 86b7fc6e
import torch import torch
import modules.devices as devices
from einops import repeat from einops import repeat
from omegaconf import ListConfig from omegaconf import ListConfig
...@@ -314,6 +315,20 @@ class LatentInpaintDiffusion(LatentDiffusion): ...@@ -314,6 +315,20 @@ class LatentInpaintDiffusion(LatentDiffusion):
self.masked_image_key = masked_image_key self.masked_image_key = masked_image_key
assert self.masked_image_key in concat_keys assert self.masked_image_key in concat_keys
self.concat_keys = concat_keys self.concat_keys = concat_keys
# =================================================================================================
# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
# =================================================================================================
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
optimal_type = devices.get_optimal_device()
if attr.device != optimal_type:
if getattr(torch, 'has_mps', False):
attr = attr.to(device="mps", dtype=torch.float32)
else:
attr = attr.to(optimal_type)
setattr(self, name, attr)
def should_hijack_inpainting(checkpoint_info): def should_hijack_inpainting(checkpoint_info):
...@@ -326,6 +341,8 @@ def do_inpainting_hijack(): ...@@ -326,6 +341,8 @@ def do_inpainting_hijack():
ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
\ No newline at end of file ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment