Commit f261a4a5 authored by AUTOMATIC's avatar AUTOMATIC

use selected device instead of always cuda for UniPC sampler

parent a11ce2b9
......@@ -3,7 +3,8 @@
import torch
from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC
from modules import shared
from modules import shared, devices
class UniPCSampler(object):
def __init__(self, model, **kwargs):
......@@ -16,8 +17,8 @@ class UniPCSampler(object):
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != devices.device:
attr = attr.to(devices.device)
setattr(self, name, attr)
def set_hooks(self, before_sample, after_sample, after_update):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment