Commit aaea8b44 authored by C43H66N12O12S2's avatar C43H66N12O12S2 Committed by GitHub

Update cross attention to the newest version

parent a5a760a7
...@@ -67,8 +67,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None): ...@@ -67,8 +67,9 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
mem_free_total = mem_free_cuda + mem_free_torch mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3 gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
mem_required = tensor_size * 2.5 modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1 steps = 1
if mem_required > mem_free_total: if mem_required > mem_free_total:
...@@ -86,7 +87,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): ...@@ -86,7 +87,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
end = i + slice_size end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1) s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1 del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment