Commit a1b9e387 authored by novelailab's avatar novelailab

eval stuff

parent 05bbefb8
from typing import KeysView
from regex import D
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -200,13 +201,25 @@ class GPTJModel(nn.Module):
for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
def forward(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
def forward(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
x = self.lm_head(x)
if kv:
return x.float(), kv
if target:
logits = x.view(-1, logits.shape[-1])
labels = target.view(-1)
loss = F.cross_entropy(logits, labels)
#clean this mess later
if cache:
if target:
return loss, x.float(), kv
else:
return x.float(), kv
else:
return x.float()
if target:
return loss, x.float()
else:
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
if kv is None:
......@@ -239,5 +252,6 @@ def load_gpt_j(path="models/6b", state_dict=None):
"vocab_dim": 50400,
"eps": 1e-5
}
config = DotMap(config)
model = GPTJBaseLM.load(config, path, state_dict)
return model
......@@ -159,12 +159,13 @@ def func_multinomial(x):
@torch.no_grad()
def generate_greedy(forward, prompt_tokens, tokens_to_generate=50):
in_tokens = prompt_tokens
padding_token = 50256
generated = torch.zeros(prompt_tokens.shape[0], 0, dtype=torch.long).to(in_tokens.device)
kv = None
for _ in range(tokens_to_generate):
for i in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv)
logits = logits[:, -1, :] #get the last token in the seq
# get the token before the padding_token in the seq
logits = logits.argmax(dim=-1).unsqueeze(-1)
generated = torch.cat([generated, logits], dim=-1)
......@@ -308,8 +309,8 @@ def main():
}
ops_list = [ops] * bsz
#tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
tokens_generated = generate_greedy(model.forward, tokens, gen_len)
tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
#tokens_generated = generate_greedy(model.forward, tokens, gen_len)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print(tokens_generated.shape)
ic(prompt)
......
......@@ -17,7 +17,7 @@ import torch.nn.functional as F
from lm_eval.models.gpt2 import GPT2LM
from lm_eval import tasks, evaluator, utils, base
from basedformer import optimizer, utils, gptj, noemblm, gpt2
class EvalHarnessAdapter(GPT2LM):
"""
......@@ -36,7 +36,7 @@ class EvalHarnessAdapter(GPT2LM):
self.neox_args = neox_args
self.tokenizer = neox_args.tokenizer
self._device = torch.device(f"cuda:{neox_args.local_rank}")
self._eot_token_id = neox_args.tokenizer.eod_id
self._eot_token_id = 50256
self._max_length = neox_args.max_position_embeddings // 2
self._max_gen_toks = 128
self._vocab_size = neox_args.padded_vocab_size
......@@ -416,7 +416,6 @@ def run_eval_harness(
num_fewshot=0,
bootstrap_iters=2,
):
print_rank_0("Running evaluation harness...")
adapter = EvalHarnessAdapter(model, forward_step_fn, neox_args, batch_size)
return adapter.run_eval(
eval_tasks=eval_tasks, num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -89,6 +89,7 @@ for input_ids, labels in t:
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model.lm(input_ids[x*bs:(x+1)*bs, :512].cuda(), act_ck=False)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
#roll down the sequence
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :512].contiguous()
gas_labels = gas_labels.view(-1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment