Commit 3c7bd057 authored by novelailab's avatar novelailab

clean gptj attention

parent 24e93cbf
......@@ -70,13 +70,20 @@ class SelfAttention(nn.Module):
self.register_buffer("sin", sin)
self.register_buffer("cos", cos)
def forward(self, x):
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# other than v because some rotary bs?
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim)
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
offset = 0
if self.rotary_dim < self.head_dim:
......@@ -95,7 +102,13 @@ class SelfAttention(nn.Module):
else:
key = apply_rotary_pos_emb(key, (self.sin, self.cos), offset=offset).to(key.dtype)
query = apply_rotary_pos_emb(query, (self.sin, self.cos), offset=offset).to(query.dtype)
if cache:
# doing this to avoid transposing key again after loading it as transposed.
cache = (key, )
query = query.transpose(1, 2)
key = key.transpose(1, 2)
query_length, key_length = query.size(-2), key.size(-2)
#causal mask with generation in mind
......@@ -108,7 +121,10 @@ class SelfAttention(nn.Module):
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
return x
if cache:
return x, (cache[0], value)
else:
return x
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment