Commit 8382affa authored by novelailab's avatar novelailab

add lm head

parent d57cfcec
...@@ -60,11 +60,11 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True ...@@ -60,11 +60,11 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with torch.no_grad(): with torch.no_grad():
model = load_gpt_j().cuda().half() model = load_gpt_j().cuda().half()
x = torch.zeros(1, 2048).cuda().long() x = torch.zeros(1, 1024).cuda().long()
print(model(x).shape) print(model(x).shape)
print("PyTorch Eager") print("PyTorch Eager")
timeit(r=1, n=100, func=lambda: model(x), do_tqdm=False, first=False) timeit(r=1, n=100, func=lambda: model(x), do_tqdm=False, first=False)
module = torch.jit.trace(model, torch.zeros((1, 2048)).long().cuda()) module = torch.jit.trace(model, torch.zeros((1, 1024)).long().cuda())
torch.jit.optimize_for_inference(module) torch.jit.optimize_for_inference(module)
print("PyTorch JIT") print("PyTorch JIT")
timeit(r=1, n=100, func=lambda: module(x), do_tqdm=False, first=False) timeit(r=1, n=100, func=lambda: module(x), do_tqdm=False, first=False)
\ No newline at end of file
...@@ -75,6 +75,9 @@ class SplitCheckpoint(MutableMapping): ...@@ -75,6 +75,9 @@ class SplitCheckpoint(MutableMapping):
#TODO: Might change with non einsum functions? #TODO: Might change with non einsum functions?
def get_logits(x, embedding):
return embedding(x)
def gelu_new(x): def gelu_new(x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
...@@ -152,7 +155,7 @@ class SelfAttention(nn.Module): ...@@ -152,7 +155,7 @@ class SelfAttention(nn.Module):
self.n_head = n_head self.n_head = n_head
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float())) self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
self.register_buffer("bias", bias) self.register_buffer("bias", bias)
self.register_buffer("masked_bias", torch.tensor(-1e9, requires_grad=False)) self.register_buffer("masked_bias", torch.tensor(-1e10, requires_grad=False)) #-1e10 is what mtj uses.
attn_bias = False attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype) self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
...@@ -221,7 +224,7 @@ class GPTLayer(nn.Module): ...@@ -221,7 +224,7 @@ class GPTLayer(nn.Module):
attn_out = self.attn(x) attn_out = self.attn(x)
ff_out = self.ff(x) ff_out = self.ff(x)
x = residual + ff_out + attn_out + (hyper_out if hypernetwork is not None else 0) x = residual + ff_out + attn_out# + (hyper_out if hypernetwork is not None else 0)
return x return x
# Can access and change every module from here, as both Layer class and ff and attn classes are passed from GPTModel. # Can access and change every module from here, as both Layer class and ff and attn classes are passed from GPTModel.
...@@ -232,6 +235,7 @@ class GPTModel(nn.Module): ...@@ -232,6 +235,7 @@ class GPTModel(nn.Module):
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype) self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype) self.ln_final = nn.LayerNorm(self.hidden_dim, eps=eps, device=device, dtype=dtype)
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
self.lm_head = nn.Linear(hidden_dim, vocab_dim, bias=True)
for _ in range(n_layer): for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype)) self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
#TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet. #TODO: Decouple more, maybe even init everything here, not sure. Not modular enough yet.
...@@ -244,6 +248,11 @@ class GPTModel(nn.Module): ...@@ -244,6 +248,11 @@ class GPTModel(nn.Module):
x = self.ln_final(x) x = self.ln_final(x)
return x return x
def get_logits(self, x):
x = self.forward(x)
x = self.lm_head(x)
return x.float()
@classmethod @classmethod
def load(cls, config, path=None, state_dict=None): def load(cls, config, path=None, state_dict=None):
...@@ -271,12 +280,6 @@ class GPTModel(nn.Module): ...@@ -271,12 +280,6 @@ class GPTModel(nn.Module):
# TODO: Do we want to have the LM head as a seperate Class? Or just a function? I think we might be better off with a function here and maybe # TODO: Do we want to have the LM head as a seperate Class? Or just a function? I think we might be better off with a function here and maybe
# also for the self attention, we can just write a function that gets fed in the q, k, v. # also for the self attention, we can just write a function that gets fed in the q, k, v.
class GPTLM(nn.Module):
def __init__(self):
return
def forward(self, x):
return
def load_gpt_j(path="models/6b", state_dict=None): def load_gpt_j(path="models/6b", state_dict=None):
config = { config = {
"n_layer": 28, "n_layer": 28,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment