Commit 8111b556 authored by brkirch's avatar brkirch

Add support for PyTorch nightly and local builds

parent 3bd73776
...@@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs): ...@@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs):
return orig_tensor_numpy(self, *args, **kwargs) return orig_tensor_numpy(self, *args, **kwargs)
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working # MPS workaround for https://github.com/pytorch/pytorch/issues/89784
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): orig_cumsum = torch.cumsum
torch.Tensor.to = tensor_to_fix orig_Tensor_cumsum = torch.Tensor.cumsum
torch.nn.functional.layer_norm = layer_norm_fix def cumsum_fix(input, cumsum_func, *args, **kwargs):
torch.Tensor.numpy = numpy_fix if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]):
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
return cumsum_func(input, *args, **kwargs)
if has_mps():
if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
torch.Tensor.to = tensor_to_fix
torch.nn.functional.layer_norm = layer_norm_fix
torch.Tensor.numpy = numpy_fix
elif version.parse(torch.__version__) > version.parse("1.13.1"):
if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)):
torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) )
torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) )
orig_narrow = torch.narrow
torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() )
...@@ -4,7 +4,7 @@ import threading ...@@ -4,7 +4,7 @@ import threading
import time import time
import importlib import importlib
import signal import signal
import threading import re
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
...@@ -13,6 +13,11 @@ from modules import import_hook, errors ...@@ -13,6 +13,11 @@ from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path from modules.paths import script_path
import torch
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__:
torch.__version__ = re.search(r'[\d.]+', torch.__version__).group(0)
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
import modules.codeformer_model as codeformer import modules.codeformer_model as codeformer
import modules.extras import modules.extras
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment