Commit 558c30dc authored by Arda Cihaner's avatar Arda Cihaner

Removing attention mask from encoder

parent e2a4d6b1
......@@ -5,11 +5,8 @@ from basedformer.utils import *
from basedformer.models import base_image
import einops
def _attn(query, key, value, causal_mask, masked_bias,
attention_mask=None, scale_attn=None):
def _attn(query, key, value, attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = torch.where(causal_mask, attn_weights, masked_bias.to(attn_weights.dtype))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
......@@ -26,9 +23,6 @@ class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config):
nn.Module.__init__(self)
max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool()
self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim
......@@ -37,8 +31,6 @@ class SelfAttention(nn.Module):
dtype = config.dtype
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
......@@ -59,11 +51,8 @@ class SelfAttention(nn.Module):
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
query_length, key_length = query.size(-2), key.size(-2) # seq_len, seq_len
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length]
x = _attn(
query, key, value, causal_mask, self.masked_bias, None, self.scale_attn
query, key, value, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
......@@ -135,11 +124,11 @@ class VisionTransformer(base_image.BaseVisionModel):
'patch_size': 16,
'hidden_dim': 768,
'n_classes' : 1000,
'activation': gelu_new,
'activation': F.gelu,
'image_size': (224, 224),
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
'device': torch.device('cpu'),
'dtype': torch.float32,
}
super().__init__(self.default_config)
self.embed = ViTEmbeds(self.config)
......@@ -151,9 +140,7 @@ class VisionTransformer(base_image.BaseVisionModel):
def forward(self, x):
p_size = self.config.patch_size
patches = einops.rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=p_size, s2=p_size)
print(patches.shape)
patches = self.embed(patches)
print(patches.shape)
for encoder in self.encoder_layers:
patches = encoder(patches)
return self.mlp_head(patches)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment