Commit 861db783 authored by brkirch's avatar brkirch Committed by AUTOMATIC1111

Use apply_hypernetwork function

parent 574c8e55
......@@ -202,16 +202,10 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
hypernetwork = shared.loaded_hypernetwork
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None:
k = self.to_k(hypernetwork_layers[0](context)) * self.scale
v = self.to_v(hypernetwork_layers[1](context))
else:
k = self.to_k(context) * self.scale
v = self.to_v(context)
del context, x
context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
k = self.to_k(context_k) * self.scale
v = self.to_v(context_v)
del context, context_k, context_v, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = einsum_op(q, k, v)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment