Commit 3a52a64a authored by novelailab's avatar novelailab

noemblm attention cleanup

parent eaba913c
...@@ -105,29 +105,33 @@ class SelfAttention(nn.Module): ...@@ -105,29 +105,33 @@ class SelfAttention(nn.Module):
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
def forward(self, x): def forward(self, x, kv=None, cache=False):
query = self.q_proj(x) B, S, H = x.shape # batch, sequence, hidden_dim
key = self.k_proj(x) # split heads into: [batch, head, sequence, head_dim]
value = self.v_proj(x) query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
query = _split_heads(query, self.n_head, self.head_dim, True) value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
key = _split_heads(key, self.n_head, self.head_dim, True)
value = _split_heads(value, self.n_head, self.head_dim, False) if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
key = key.permute(0, 2, 1, 3) query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
query = query.permute(0, 2, 1, 3) causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
x = _attn( x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
) )
x = _merge_heads(x, self.n_head, self.head_dim) x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x) x = self.out_proj(x)
if cache:
return x return x, (key, value)
else:
return x
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype): def __init__(self, dim, hidden_dim, activation, device, dtype):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment