Commit 84b6fcd0 authored by AUTOMATIC1111's avatar AUTOMATIC1111

add NV option for Random number generator source setting, which allows to...

add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia  videocards.
parent ccb92339
......@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
from modules import errors
from modules import errors, rng_philox
if sys.platform == "darwin":
from modules import mac_specific
......@@ -90,23 +90,58 @@ def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
nv_rng = None
def randn(seed, shape):
from modules.shared import opts
torch.manual_seed(seed)
manual_seed(seed)
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_like(x):
from modules.shared import opts
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=cpu).to(x.device)
return torch.randn_like(x)
def randn_without_seed(shape):
from modules.shared import opts
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def manual_seed(seed):
from modules.shared import opts
if opts.randn_source == "NV":
global nv_rng
nv_rng = rng_philox.Generator(seed)
return
torch.manual_seed(seed)
def autocast(disable=False):
from modules import shared
......
......@@ -492,7 +492,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
subnoise = None
if subseeds is not None:
if subseeds is not None and subseed_strength != 0:
subseed = 0 if i >= len(subseeds) else subseeds[i]
subnoise = devices.randn(subseed, noise_shape)
......@@ -524,7 +524,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
cnt = p.sampler.number_of_needed_noises(p)
if eta_noise_seed_delta > 0:
torch.manual_seed(seed + eta_noise_seed_delta)
devices.manual_seed(seed + eta_noise_seed_delta)
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
......@@ -636,7 +636,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
**p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
......
"""RNG imitiating torch cuda randn on CPU. You are welcome.
Usage:
```
g = Generator(seed=0)
print(g.randn(shape=(3, 4)))
```
Expected output:
```
[[-0.92466259 -0.42534415 -2.6438457 0.14518388]
[-0.12086647 -0.57972564 -0.62285122 -0.32838709]
[-1.07454231 -0.36314407 -1.67105067 2.26550497]]
```
"""
import numpy as np
philox_m = [0xD2511F53, 0xCD9E8D57]
philox_w = [0x9E3779B9, 0xBB67AE85]
two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
def uint32(x):
"""Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
return np.moveaxis(x.view(np.uint32).reshape(-1, 2), 0, 1)
def philox4_round(counter, key):
"""A single round of the Philox 4x32 random number generator."""
v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
counter[0] = v2[1] ^ counter[1] ^ key[0]
counter[1] = v2[0]
counter[2] = v1[1] ^ counter[3] ^ key[1]
counter[3] = v1[0]
def philox4_32(counter, key, rounds=10):
"""Generates 32-bit random numbers using the Philox 4x32 random number generator.
Parameters:
counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
rounds (int): The number of rounds to perform.
Returns:
numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
"""
for _ in range(rounds - 1):
philox4_round(counter, key)
key[0] = key[0] + philox_w[0]
key[1] = key[1] + philox_w[1]
philox4_round(counter, key)
return counter
def box_muller(x, y):
"""Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
u = x.astype(np.float32) * two_pow32_inv + two_pow32_inv / 2
v = y.astype(np.float32) * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
s = np.sqrt(-2.0 * np.log(u))
r1 = s * np.sin(v)
return r1.astype(np.float32)
class Generator:
"""RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
def __init__(self, seed):
self.seed = seed
self.offset = 0
def randn(self, shape):
"""Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
n = 1
for x in shape:
n *= x
counter = np.zeros((4, n), dtype=np.uint32)
counter[0] = self.offset
counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
self.offset += 1
key = uint32(np.array([[self.seed] * n], dtype=np.uint64))
g = philox4_32(counter, key)
return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
......@@ -260,10 +260,7 @@ class TorchHijack:
if noise.shape == x.shape:
return noise
if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
else:
return torch.randn_like(x)
return devices.randn_like(x)
class KDiffusionSampler:
......
......@@ -428,7 +428,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),
"upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"),
"randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),
}))
options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment