Commit fd8674a4 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #13948 from aria1th/hypertile-in-sample

support HyperTile optimization
parents 8aa51f68 97431f29
"""
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
Warn : The patch works well only if the input image has a width and height that are multiples of 128
Author : @tfernd Github : https://github.com/tfernd/HyperTile
"""
from __future__ import annotations
from typing import Callable
from typing_extensions import Literal
import logging
from functools import wraps, cache
from contextlib import contextmanager
import math
import torch.nn as nn
import random
from einops import rearrange
# TODO add SD-XL layers
DEPTH_LAYERS = {
0: [
# SD 1.5 U-Net (diffusers)
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.1.1.transformer_blocks.0.attn1",
"input_blocks.2.1.transformer_blocks.0.attn1",
"output_blocks.9.1.transformer_blocks.0.attn1",
"output_blocks.10.1.transformer_blocks.0.attn1",
"output_blocks.11.1.transformer_blocks.0.attn1",
# SD 1.5 VAE
"decoder.mid_block.attentions.0",
"decoder.mid.attn_1",
],
1: [
# SD 1.5 U-Net (diffusers)
"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.0.attn1",
"input_blocks.5.1.transformer_blocks.0.attn1",
"output_blocks.6.1.transformer_blocks.0.attn1",
"output_blocks.7.1.transformer_blocks.0.attn1",
"output_blocks.8.1.transformer_blocks.0.attn1",
],
2: [
# SD 1.5 U-Net (diffusers)
"down_blocks.2.attentions.0.transformer_blocks.0.attn1",
"down_blocks.2.attentions.1.transformer_blocks.0.attn1",
"up_blocks.1.attentions.0.transformer_blocks.0.attn1",
"up_blocks.1.attentions.1.transformer_blocks.0.attn1",
"up_blocks.1.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.7.1.transformer_blocks.0.attn1",
"input_blocks.8.1.transformer_blocks.0.attn1",
"output_blocks.3.1.transformer_blocks.0.attn1",
"output_blocks.4.1.transformer_blocks.0.attn1",
"output_blocks.5.1.transformer_blocks.0.attn1",
],
3: [
# SD 1.5 U-Net (diffusers)
"mid_block.attentions.0.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"middle_block.1.transformer_blocks.0.attn1",
],
}
# XL layers, thanks for GitHub@gel-crabs for the help
DEPTH_LAYERS_XL = {
0: [
# SD 1.5 U-Net (diffusers)
"down_blocks.0.attentions.0.transformer_blocks.0.attn1",
"down_blocks.0.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.0.transformer_blocks.0.attn1",
"up_blocks.3.attentions.1.transformer_blocks.0.attn1",
"up_blocks.3.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.0.attn1",
"input_blocks.5.1.transformer_blocks.0.attn1",
"output_blocks.3.1.transformer_blocks.0.attn1",
"output_blocks.4.1.transformer_blocks.0.attn1",
"output_blocks.5.1.transformer_blocks.0.attn1",
# SD 1.5 VAE
"decoder.mid_block.attentions.0",
"decoder.mid.attn_1",
],
1: [
# SD 1.5 U-Net (diffusers)
#"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
#"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
#"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"input_blocks.4.1.transformer_blocks.1.attn1",
"input_blocks.5.1.transformer_blocks.1.attn1",
"output_blocks.3.1.transformer_blocks.1.attn1",
"output_blocks.4.1.transformer_blocks.1.attn1",
"output_blocks.5.1.transformer_blocks.1.attn1",
"input_blocks.7.1.transformer_blocks.0.attn1",
"input_blocks.8.1.transformer_blocks.0.attn1",
"output_blocks.0.1.transformer_blocks.0.attn1",
"output_blocks.1.1.transformer_blocks.0.attn1",
"output_blocks.2.1.transformer_blocks.0.attn1",
"input_blocks.7.1.transformer_blocks.1.attn1",
"input_blocks.8.1.transformer_blocks.1.attn1",
"output_blocks.0.1.transformer_blocks.1.attn1",
"output_blocks.1.1.transformer_blocks.1.attn1",
"output_blocks.2.1.transformer_blocks.1.attn1",
"input_blocks.7.1.transformer_blocks.2.attn1",
"input_blocks.8.1.transformer_blocks.2.attn1",
"output_blocks.0.1.transformer_blocks.2.attn1",
"output_blocks.1.1.transformer_blocks.2.attn1",
"output_blocks.2.1.transformer_blocks.2.attn1",
"input_blocks.7.1.transformer_blocks.3.attn1",
"input_blocks.8.1.transformer_blocks.3.attn1",
"output_blocks.0.1.transformer_blocks.3.attn1",
"output_blocks.1.1.transformer_blocks.3.attn1",
"output_blocks.2.1.transformer_blocks.3.attn1",
"input_blocks.7.1.transformer_blocks.4.attn1",
"input_blocks.8.1.transformer_blocks.4.attn1",
"output_blocks.0.1.transformer_blocks.4.attn1",
"output_blocks.1.1.transformer_blocks.4.attn1",
"output_blocks.2.1.transformer_blocks.4.attn1",
"input_blocks.7.1.transformer_blocks.5.attn1",
"input_blocks.8.1.transformer_blocks.5.attn1",
"output_blocks.0.1.transformer_blocks.5.attn1",
"output_blocks.1.1.transformer_blocks.5.attn1",
"output_blocks.2.1.transformer_blocks.5.attn1",
"input_blocks.7.1.transformer_blocks.6.attn1",
"input_blocks.8.1.transformer_blocks.6.attn1",
"output_blocks.0.1.transformer_blocks.6.attn1",
"output_blocks.1.1.transformer_blocks.6.attn1",
"output_blocks.2.1.transformer_blocks.6.attn1",
"input_blocks.7.1.transformer_blocks.7.attn1",
"input_blocks.8.1.transformer_blocks.7.attn1",
"output_blocks.0.1.transformer_blocks.7.attn1",
"output_blocks.1.1.transformer_blocks.7.attn1",
"output_blocks.2.1.transformer_blocks.7.attn1",
"input_blocks.7.1.transformer_blocks.8.attn1",
"input_blocks.8.1.transformer_blocks.8.attn1",
"output_blocks.0.1.transformer_blocks.8.attn1",
"output_blocks.1.1.transformer_blocks.8.attn1",
"output_blocks.2.1.transformer_blocks.8.attn1",
"input_blocks.7.1.transformer_blocks.9.attn1",
"input_blocks.8.1.transformer_blocks.9.attn1",
"output_blocks.0.1.transformer_blocks.9.attn1",
"output_blocks.1.1.transformer_blocks.9.attn1",
"output_blocks.2.1.transformer_blocks.9.attn1",
],
2: [
# SD 1.5 U-Net (diffusers)
"mid_block.attentions.0.transformer_blocks.0.attn1",
# SD 1.5 U-Net (ldm)
"middle_block.1.transformer_blocks.0.attn1",
"middle_block.1.transformer_blocks.1.attn1",
"middle_block.1.transformer_blocks.2.attn1",
"middle_block.1.transformer_blocks.3.attn1",
"middle_block.1.transformer_blocks.4.attn1",
"middle_block.1.transformer_blocks.5.attn1",
"middle_block.1.transformer_blocks.6.attn1",
"middle_block.1.transformer_blocks.7.attn1",
"middle_block.1.transformer_blocks.8.attn1",
"middle_block.1.transformer_blocks.9.attn1",
],
3 : [] # TODO - separate layers for SD-XL
}
RNG_INSTANCE = random.Random()
def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
"""
Returns a random divisor of value that
x * min_value <= value
if max_options is 1, the behavior is deterministic
"""
min_value = min(min_value, value)
# All big divisors of value (inclusive)
divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order
ns = [value // i for i in divisors[:max_options]] # has at least 1 element # big -> small order
idx = RNG_INSTANCE.randint(0, len(ns) - 1)
return ns[idx]
def set_hypertile_seed(seed: int) -> None:
RNG_INSTANCE.seed(seed)
def largest_tile_size_available(width:int, height:int) -> int:
"""
Calculates the largest tile size available for a given width and height
Tile size is always a power of 2
"""
gcd = math.gcd(width, height)
largest_tile_size_available = 1
while gcd % (largest_tile_size_available * 2) == 0:
largest_tile_size_available *= 2
return largest_tile_size_available
def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
"""
Finds h and w such that h*w = hw and h/w = aspect_ratio
We check all possible divisors of hw and return the closest to the aspect ratio
"""
divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
return closest_pair
@cache
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
"""
Finds h and w such that h*w = hw and h/w = aspect_ratio
"""
h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
# find h and w such that h*w = hw and h/w = aspect_ratio
if h * w != hw:
w_candidate = hw / h
# check if w is an integer
if not w_candidate.is_integer():
h_candidate = hw / w
# check if h is an integer
if not h_candidate.is_integer():
return iterative_closest_divisors(hw, aspect_ratio)
else:
h = int(h_candidate)
else:
w = int(w_candidate)
return h, w
@contextmanager
def split_attention(
layer: nn.Module,
/,
aspect_ratio: float, # width/height
tile_size: int = 128, # 128 for VAE
swap_size: int = 1, # 1 for VAE
*,
disable: bool = False,
max_depth: Literal[0, 1, 2, 3] = 0, # ! Try 0 or 1
scale_depth: bool = True, # scale the tile-size depending on the depth
is_sdxl: bool = False, # is the model SD-XL
):
# Hijacks AttnBlock from ldm and Attention from diffusers
if disable:
logging.info(f"Attention for {layer.__class__.__qualname__} not splitted")
yield
return
latent_tile_size = max(128, tile_size) // 8
def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable:
@wraps(forward)
def wrapper(*args, **kwargs):
x = args[0]
# VAE
if x.ndim == 4:
b, c, h, w = x.shape
nh = random_divisor(h, latent_tile_size, swap_size)
nw = random_divisor(w, latent_tile_size, swap_size)
if nh * nw > 1:
x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) # split into nh * nw tiles
out = forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
# U-Net
else:
hw: int = x.size(1)
h, w = find_hw_candidates(hw, aspect_ratio)
assert h * w == hw, f"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"
factor = 2**depth if scale_depth else 1
nh = random_divisor(h, latent_tile_size * factor, swap_size)
nw = random_divisor(w, latent_tile_size * factor, swap_size)
module._split_sizes_hypertile.append((nh, nw)) # type: ignore
if nh * nw > 1:
x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)
out = forward(x, *args[1:], **kwargs)
if nh * nw > 1:
out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)
return out
return wrapper
# Handle hijacking the forward method and recovering afterwards
try:
if is_sdxl:
layers = DEPTH_LAYERS_XL
else:
layers = DEPTH_LAYERS
for depth in range(max_depth + 1):
for layer_name, module in layer.named_modules():
if any(layer_name.endswith(try_name) for try_name in layers[depth]):
# print input shape for debugging
logging.debug(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}")
# hijack
module._original_forward_hypertile = module.forward
module.forward = self_attn_forward(module.forward, depth, layer_name, module)
module._split_sizes_hypertile = []
yield
finally:
for layer_name, module in layer.named_modules():
# remove hijack
if hasattr(module, "_original_forward_hypertile"):
if module._split_sizes_hypertile:
logging.debug(f"layer {layer_name} splitted with ({module._split_sizes_hypertile})")
# recover
module.forward = module._original_forward_hypertile
del module._original_forward_hypertile
del module._split_sizes_hypertile
def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, opts):
"""
Returns context manager for VAE
"""
enabled = opts.hypertile_split_vae_attn
swap_size = opts.hypertile_swap_size_vae
max_depth = opts.hypertile_max_depth_vae
tile_size_max = opts.hypertile_max_tile_vae
return split_attention(
model,
aspect_ratio=aspect_ratio,
tile_size=min(tile_size, tile_size_max),
swap_size=swap_size,
disable=not enabled,
max_depth=max_depth,
is_sdxl=False,
)
def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, opts, is_sdxl:bool):
"""
Returns context manager for U-Net
"""
enabled = opts.hypertile_split_unet_attn
swap_size = opts.hypertile_swap_size_unet
max_depth = opts.hypertile_max_depth_unet
tile_size_max = opts.hypertile_max_tile_unet
return split_attention(
model,
aspect_ratio=aspect_ratio,
tile_size=min(tile_size, tile_size_max),
swap_size=swap_size,
disable=not enabled,
max_depth=max_depth,
is_sdxl=is_sdxl,
)
......@@ -24,6 +24,7 @@ from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.paths as paths
import modules.face_restoration
from modules.hypertile import set_hypertile_seed, largest_tile_size_available, hypertile_context_unet, hypertile_context_vae
import modules.images as images
import modules.styles
import modules.sd_models as sd_models
......@@ -799,7 +800,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
......@@ -861,7 +861,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
p.comment(comment)
p.extra_generation_params.update(model_hijack.extra_generation_params)
set_hypertile_seed(p.seed)
# add batch size + hypertile status to information to reproduce the run
if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
......@@ -873,7 +874,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
with hypertile_context_vae(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), opts=shared.opts):
x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
......@@ -1140,24 +1141,25 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
aspect_ratio = self.width / self.height
x = self.rng.next()
tile_size = largest_tile_size_available(self.width, self.height)
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
del x
if not self.enable_hr:
return samples
devices.torch_gc()
if self.latent_scale_mode is None:
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
else:
decoded_samples = None
with sd_models.SkipWritingToConfig():
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
devices.torch_gc()
return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
......@@ -1165,7 +1167,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
return samples
self.is_hr_pass = True
target_width = self.hr_upscale_to_x
target_height = self.hr_upscale_to_y
......@@ -1243,18 +1244,20 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.scripts is not None:
self.scripts.before_hr(self)
tile_size = largest_tile_size_available(target_width, target_height)
aspect_ratio = self.width / self.height
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
self.sampler = None
devices.torch_gc()
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
self.is_hr_pass = False
return decoded_samples
def close(self):
......@@ -1529,7 +1532,10 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.initial_noise_multiplier != 1.0:
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
x *= self.initial_noise_multiplier
aspect_ratio = self.width / self.height
tile_size = largest_tile_size_available(self.width, self.height)
with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
if self.mask is not None:
......
......@@ -201,6 +201,14 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
"pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
"persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
"batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
"hypertile_split_unet_attn" : OptionInfo(False, "Split attention in Unet with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"),
"hypertile_split_vae_attn": OptionInfo(False, "Split attention in VAE with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"),
"hypertile_max_depth_vae" : OptionInfo(3, "Max depth for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
"hypertile_max_depth_unet" : OptionInfo(3, "Max depth for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
"hypertile_max_tile_vae" : OptionInfo(128, "Max tile size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"),
"hypertile_max_tile_unet" : OptionInfo(256, "Max tile size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"),
"hypertile_swap_size_unet": OptionInfo(3, "Swap size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
"hypertile_swap_size_vae": OptionInfo(3, "Swap size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
}))
options_templates.update(options_section(('compatibility', "Compatibility"), {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment