Commit fa2cfe52 authored by novelailab's avatar novelailab

fix attn more

parent 39568281
......@@ -78,9 +78,9 @@ class SelfAttention(nn.Module):
def forward(self, x, kv=None):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query = self.q_proj(x).view(B, self.n_head, S, self.head_dim)
key = self.k_proj(x).view(B, self.n_head, S, self.head_dim)
value = self.v_proj(x).view(B, self.n_head, S, self.head_dim)
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
k, v = kv
......@@ -89,17 +89,14 @@ class SelfAttention(nn.Module):
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
key = key.permute(0, 2, 1, 3)
query = query.permute(0, 2, 1, 3)
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
)
x = x.contiguous().view(B, S, H)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
return x
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment