Commit f7c787eb authored by AUTOMATIC's avatar AUTOMATIC

make it possible to use hypernetworks without opt split attention

parent 97bc0b95
...@@ -4,7 +4,12 @@ import sys ...@@ -4,7 +4,12 @@ import sys
import traceback import traceback
import torch import torch
from modules import devices
from ldm.util import default
from modules import devices, shared
import torch
from torch import einsum
from einops import rearrange, repeat
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
...@@ -48,15 +53,36 @@ def load_hypernetworks(path): ...@@ -48,15 +53,36 @@ def load_hypernetworks(path):
return res return res
def apply(self, x, context=None, mask=None, original=None):
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork: hypernetwork = shared.selected_hypernetwork()
if context.shape[1] == 77 and CrossAttention.noise_cond: hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
context = context + (torch.randn_like(context) * 0.1)
h_k, h_v = CrossAttention.hypernetwork[context.shape[2]] if hypernetwork_layers is not None:
k = self.to_k(h_k(context)) k = self.to_k(hypernetwork_layers[0](context))
v = self.to_v(h_v(context)) v = self.to_v(hypernetwork_layers[1](context))
else: else:
k = self.to_k(context) k = self.to_k(context)
v = self.to_v(context) v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if mask is not None:
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
...@@ -8,7 +8,7 @@ from torch import einsum ...@@ -8,7 +8,7 @@ from torch import einsum
from torch.nn.functional import silu from torch.nn.functional import silu
import modules.textual_inversion.textual_inversion import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
from modules.shared import opts, device, cmd_opts from modules.shared import opts, device, cmd_opts
import ldm.modules.attention import ldm.modules.attention
...@@ -20,6 +20,8 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At ...@@ -20,6 +20,8 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At
def apply_optimizations(): def apply_optimizations():
undo_optimizations()
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.opt_split_attention_v1: if cmd_opts.opt_split_attention_v1:
...@@ -30,7 +32,7 @@ def apply_optimizations(): ...@@ -30,7 +32,7 @@ def apply_optimizations():
def undo_optimizations(): def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment