Commit 0d5dc9a6 authored by AUTOMATIC1111's avatar AUTOMATIC1111

rework RNG to use generators instead of generating noises beforehand

parent d81d3fa8
...@@ -3,7 +3,7 @@ import contextlib ...@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache from functools import lru_cache
import torch import torch
from modules import errors, rng_philox from modules import errors
if sys.platform == "darwin": if sys.platform == "darwin":
from modules import mac_specific from modules import mac_specific
...@@ -96,84 +96,6 @@ def cond_cast_float(input): ...@@ -96,84 +96,6 @@ def cond_cast_float(input):
nv_rng = None nv_rng = None
def randn(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
from modules.shared import opts
manual_seed(seed)
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def randn_local(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
from modules.shared import opts
if opts.randn_source == "NV":
rng = rng_philox.Generator(seed)
return torch.asarray(rng.randn(shape), device=device)
local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
local_generator = torch.Generator(local_device).manual_seed(int(seed))
return torch.randn(shape, device=local_device, generator=local_generator).to(device)
def randn_like(x):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
from modules.shared import opts
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
if opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=cpu).to(x.device)
return torch.randn_like(x)
def randn_without_seed(shape):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
from modules.shared import opts
if opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(shape), device=device)
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
return torch.randn(shape, device=device)
def manual_seed(seed):
"""Set up a global random number generator using the specified seed."""
from modules.shared import opts
if opts.randn_source == "NV":
global nv_rng
nv_rng = rng_philox.Generator(seed)
return
torch.manual_seed(seed)
def autocast(disable=False): def autocast(disable=False):
from modules import shared from modules import shared
...@@ -236,3 +158,4 @@ def first_time_calculation(): ...@@ -236,3 +158,4 @@ def first_time_calculation():
x = torch.zeros((1, 1, 3, 3)).to(device, dtype) x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x) conv2d(x)
...@@ -14,7 +14,7 @@ from skimage import exposure ...@@ -14,7 +14,7 @@ from skimage import exposure
from typing import Any, Dict, List from typing import Any, Dict, List
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
from modules.sd_hijack import model_hijack from modules.sd_hijack import model_hijack
from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
...@@ -186,6 +186,7 @@ class StableDiffusionProcessing: ...@@ -186,6 +186,7 @@ class StableDiffusionProcessing:
self.cached_c = StableDiffusionProcessing.cached_c self.cached_c = StableDiffusionProcessing.cached_c
self.uc = None self.uc = None
self.c = None self.c = None
self.rng: rng.ImageRNG = None
self.user = None self.user = None
...@@ -475,82 +476,9 @@ class Processed: ...@@ -475,82 +476,9 @@ class Processed:
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
dot = (low_norm*high_norm).sum(1)
if dot.mean() > 0.9995:
return low * val + high * (1 - val)
omega = torch.acos(dot)
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
eta_noise_seed_delta = opts.eta_noise_seed_delta or 0 g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
xs = [] return g.next()
# if we have multiple seeds, this means we are working with batch size>1; this then
# enables the generation of additional tensors with noise that the sampler will use during its processing.
# Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
# produce the same images as with two batches [100], [101].
if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
else:
sampler_noises = None
for i, seed in enumerate(seeds):
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
subnoise = None
if subseeds is not None and subseed_strength != 0:
subseed = 0 if i >= len(subseeds) else subseeds[i]
subnoise = devices.randn(subseed, noise_shape)
# randn results depend on device; gpu and cpu get different results for same seed;
# the way I see it, it's better to do this on CPU, so that everyone gets same result;
# but the original script had it like this, so I do not dare change it for now because
# it will break everyone's seeds.
noise = devices.randn(seed, noise_shape)
if subnoise is not None:
noise = slerp(subseed_strength, noise, subnoise)
if noise_shape != shape:
x = devices.randn(seed, shape)
dx = (shape[2] - noise_shape[2]) // 2
dy = (shape[1] - noise_shape[1]) // 2
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
tx = 0 if dx < 0 else dx
ty = 0 if dy < 0 else dy
dx = max(-dx, 0)
dy = max(-dy, 0)
x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
noise = x
if sampler_noises is not None:
cnt = p.sampler.number_of_needed_noises(p)
if eta_noise_seed_delta > 0:
devices.manual_seed(seed + eta_noise_seed_delta)
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
xs.append(noise)
if sampler_noises is not None:
p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
x = torch.stack(xs).to(shared.device)
return x
class DecodedSamples(list): class DecodedSamples(list):
...@@ -769,6 +697,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: ...@@ -769,6 +697,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
if p.scripts is not None: if p.scripts is not None:
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
...@@ -1072,7 +1002,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1072,7 +1002,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = self.rng.next()
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
del x del x
...@@ -1160,7 +1090,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1160,7 +1090,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
noise = self.rng.next()
# GC now before running the next img2img to prevent running out of memory # GC now before running the next img2img to prevent running out of memory
devices.torch_gc() devices.torch_gc()
...@@ -1418,7 +1349,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -1418,7 +1349,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask) self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = self.rng.next()
if self.initial_noise_multiplier != 1.0: if self.initial_noise_multiplier != 1.0:
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
......
import torch
from modules import devices, rng_philox, shared
def randn(seed, shape, generator=None):
"""Generate a tensor with random numbers from a normal distribution using seed.
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
manual_seed(seed)
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
return torch.randn(shape, device=devices.device, generator=generator)
def randn_local(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
if shared.opts.randn_source == "NV":
rng = rng_philox.Generator(seed)
return torch.asarray(rng.randn(shape), device=devices.device)
local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
local_generator = torch.Generator(local_device).manual_seed(int(seed))
return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
def randn_like(x):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
if shared.opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
return torch.randn_like(x)
def randn_without_seed(shape, generator=None):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
Use either randn() or manual_seed() to initialize the generator."""
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
return torch.randn(shape, device=devices.device, generator=generator)
def manual_seed(seed):
"""Set up a global random number generator using the specified seed."""
from modules.shared import opts
if opts.randn_source == "NV":
global nv_rng
nv_rng = rng_philox.Generator(seed)
return
torch.manual_seed(seed)
def create_generator(seed):
if shared.opts.randn_source == "NV":
return rng_philox.Generator(seed)
device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
generator = torch.Generator(device).manual_seed(int(seed))
return generator
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
dot = (low_norm*high_norm).sum(1)
if dot.mean() > 0.9995:
return low * val + high * (1 - val)
omega = torch.acos(dot)
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
class ImageRNG:
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
self.shape = shape
self.seeds = seeds
self.subseeds = subseeds
self.subseed_strength = subseed_strength
self.seed_resize_from_h = seed_resize_from_h
self.seed_resize_from_w = seed_resize_from_w
self.generators = [create_generator(seed) for seed in seeds]
self.is_first = True
def first(self):
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8)
xs = []
for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
subnoise = None
if self.subseeds is not None and self.subseed_strength != 0:
subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
subnoise = randn(subseed, noise_shape)
if noise_shape != self.shape:
noise = randn(seed, noise_shape)
else:
noise = randn(seed, self.shape, generator=generator)
if subnoise is not None:
noise = slerp(self.subseed_strength, noise, subnoise)
if noise_shape != self.shape:
x = randn(seed, self.shape, generator=generator)
dx = (self.shape[2] - noise_shape[2]) // 2
dy = (self.shape[1] - noise_shape[1]) // 2
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
tx = 0 if dx < 0 else dx
ty = 0 if dy < 0 else dy
dx = max(-dx, 0)
dy = max(-dy, 0)
x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
noise = x
xs.append(noise)
eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
if eta_noise_seed_delta:
self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
return torch.stack(xs).to(shared.device)
def next(self):
if self.is_first:
self.is_first = False
return self.first()
xs = []
for generator in self.generators:
x = randn_without_seed(self.shape, generator=generator)
xs.append(x)
return torch.stack(xs).to(shared.device)
devices.randn = randn
devices.randn_local = randn_local
devices.randn_like = randn_like
devices.randn_without_seed = randn_without_seed
devices.manual_seed = manual_seed
import inspect import inspect
from collections import namedtuple, deque from collections import namedtuple
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
...@@ -132,10 +132,15 @@ replace_torchsde_browinan() ...@@ -132,10 +132,15 @@ replace_torchsde_browinan()
class TorchHijack: class TorchHijack:
def __init__(self, sampler_noises): """This is here to replace torch.randn_like of k-diffusion.
# Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
# implementation. k-diffusion has random_sampler argument for most samplers, but not for all, so
self.sampler_noises = deque(sampler_noises) this is needed to properly replace every use of torch.randn_like.
We need to replace to make images generated in batches to be same as images generated individually."""
def __init__(self, p):
self.rng = p.rng
def __getattr__(self, item): def __getattr__(self, item):
if item == 'randn_like': if item == 'randn_like':
...@@ -147,12 +152,7 @@ class TorchHijack: ...@@ -147,12 +152,7 @@ class TorchHijack:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x): def randn_like(self, x):
if self.sampler_noises: return self.rng.next()
noise = self.sampler_noises.popleft()
if noise.shape == x.shape:
return noise
return devices.randn_like(x)
class Sampler: class Sampler:
...@@ -215,7 +215,7 @@ class Sampler: ...@@ -215,7 +215,7 @@ class Sampler:
self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) k_diffusion.sampling.torch = TorchHijack(p)
extra_params_kwargs = {} extra_params_kwargs = {}
for param_name in self.extra_params: for param_name in self.extra_params:
......
...@@ -16,7 +16,7 @@ import modules.interrogate ...@@ -16,7 +16,7 @@ import modules.interrogate
import modules.memmon import modules.memmon
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args, rng # noqa: F401
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401
from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.models.diffusion.ddpm import LatentDiffusion
from typing import Optional from typing import Optional
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment