Commit d81c5039 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #8367 from pamparamm/scaled-dot-product-attention

Add scaled dot product attention
parents 1ace16e7 8d7fa2f6
This diff is collapsed.
......@@ -37,11 +37,23 @@ def apply_optimizations():
optimization_method = None
can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention")) # not everyone has torch 2.x to use sdp
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
optimization_method = 'xformers'
elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
optimization_method = 'sdp-no-mem'
elif cmd_opts.opt_sdp_attention and can_use_sdp:
print("Applying scaled dot product cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
optimization_method = 'sdp'
elif cmd_opts.opt_sub_quad_attention:
print("Applying sub-quadratic cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
......
......@@ -346,6 +346,52 @@ def xformers_attention_forward(self, x, context=None, mask=None):
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
batch_size, sequence_length, inner_dim = x.shape
if mask is not None:
mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
k_in = self.to_k(context_k)
v_in = self.to_v(context_v)
head_dim = inner_dim // h
q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
del q_in, k_in, v_in
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
hidden_states = hidden_states.to(dtype)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return scaled_dot_product_attention_forward(self, x, context, mask)
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
......@@ -427,6 +473,30 @@ def xformers_attnblock_forward(self, x):
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
def sdp_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
out = out.to(dtype)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
def sdp_no_mem_attnblock_forward(self, x):
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):
return sdp_attnblock_forward(self, x)
def sub_quad_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
......
......@@ -69,6 +69,8 @@ parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size fo
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment