1. 24 Dec, 2022 6 commits
  2. 21 Dec, 2022 1 commit
    • brkirch's avatar
      Use other MPS optimization for large q.shape[0] * q.shape[1] · 35b1775b
      brkirch authored
      Check if q.shape[0] * q.shape[1] is 2**18 or larger and use the lower memory usage MPS optimization if it is. This should prevent most crashes that were occurring at certain resolutions (e.g. 1024x1024, 2048x512, 512x2048).
      
      Also included is a change to check slice_size and prevent it from being divisible by 4096 which also results in a crash. Otherwise a crash can occur at 1024x512 or 512x1024 resolution.
      35b1775b
  3. 17 Dec, 2022 2 commits
    • brkirch's avatar
      Add attributes used by MPS · cca16373
      brkirch authored
      cca16373
    • brkirch's avatar
      Add numpy fix for MPS on PyTorch 1.12.1 · 16b4509f
      brkirch authored
      When saving training results with torch.save(), an exception is thrown:
      "RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead."
      
      So for MPS, check if Tensor.requires_grad and detach() if necessary.
      16b4509f
  4. 16 Dec, 2022 1 commit
  5. 11 Dec, 2022 1 commit
  6. 10 Dec, 2022 29 commits