Commit faed465a authored by brkirch's avatar brkirch Committed by AUTOMATIC1111

MPS Upscalers Fix

Get ESRGAN, SCUNet, and SwinIR working correctly on MPS by ensuring memory is contiguous for tensor views before sending to MPS device.
parent 4c24347e
...@@ -81,3 +81,7 @@ def autocast(disable=False): ...@@ -81,3 +81,7 @@ def autocast(disable=False):
return contextlib.nullcontext() return contextlib.nullcontext()
return torch.autocast("cuda") return torch.autocast("cuda")
# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
...@@ -190,7 +190,7 @@ def upscale_without_tiling(model, img): ...@@ -190,7 +190,7 @@ def upscale_without_tiling(model, img):
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255 img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_esrgan) img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
......
...@@ -54,9 +54,8 @@ class UpscalerScuNET(modules.upscaler.Upscaler): ...@@ -54,9 +54,8 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device) img = devices.mps_contiguous_to(img.unsqueeze(0), device)
img = img.to(device)
with torch.no_grad(): with torch.no_grad():
output = model(img) output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
......
...@@ -111,7 +111,7 @@ def upscale( ...@@ -111,7 +111,7 @@ def upscale(
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(devices.device_swinir) img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir)
with torch.no_grad(), precision_scope("cuda"): with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size() _, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old h_pad = (h_old // window_size + 1) * window_size - h_old
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment