Commit a1b9e387 authored by novelailab's avatar novelailab

eval stuff

parent 05bbefb8
from typing import KeysView from typing import KeysView
from regex import D
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -200,13 +201,25 @@ class GPTJModel(nn.Module): ...@@ -200,13 +201,25 @@ class GPTJModel(nn.Module):
for _ in range(n_layer): for _ in range(n_layer):
self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype)) self.layers.append(Layer(attn=SelfAttention, ff=FeedForward, hidden_dim=hidden_dim, n_head=n_head, eps=eps, activation=activation, device=device, dtype=dtype))
def forward(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False): def forward(self, x, target=None, hypernetwork=None, act_ck=False, kv=None, cache=False):
x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache) x, kv = self.get_embeds(x, hypernetwork=hypernetwork, act_ck=act_ck, kv=kv, cache=cache)
x = self.lm_head(x) x = self.lm_head(x)
if kv: if target:
return x.float(), kv logits = x.view(-1, logits.shape[-1])
labels = target.view(-1)
loss = F.cross_entropy(logits, labels)
#clean this mess later
if cache:
if target:
return loss, x.float(), kv
else:
return x.float(), kv
else: else:
return x.float() if target:
return loss, x.float()
else:
return x.float()
def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False): def get_embeds(self, x, hypernetwork=None, act_ck=False, kv=None, cache=False):
if kv is None: if kv is None:
...@@ -239,5 +252,6 @@ def load_gpt_j(path="models/6b", state_dict=None): ...@@ -239,5 +252,6 @@ def load_gpt_j(path="models/6b", state_dict=None):
"vocab_dim": 50400, "vocab_dim": 50400,
"eps": 1e-5 "eps": 1e-5
} }
config = DotMap(config)
model = GPTJBaseLM.load(config, path, state_dict) model = GPTJBaseLM.load(config, path, state_dict)
return model return model
...@@ -159,12 +159,13 @@ def func_multinomial(x): ...@@ -159,12 +159,13 @@ def func_multinomial(x):
@torch.no_grad() @torch.no_grad()
def generate_greedy(forward, prompt_tokens, tokens_to_generate=50): def generate_greedy(forward, prompt_tokens, tokens_to_generate=50):
in_tokens = prompt_tokens in_tokens = prompt_tokens
padding_token = 50256
generated = torch.zeros(prompt_tokens.shape[0], 0, dtype=torch.long).to(in_tokens.device) generated = torch.zeros(prompt_tokens.shape[0], 0, dtype=torch.long).to(in_tokens.device)
kv = None kv = None
for _ in range(tokens_to_generate): for i in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv) logits, kv = forward(in_tokens, cache=True, kv=kv)
logits = logits[:, -1, :] #get the last token in the seq logits = logits[:, -1, :] #get the last token in the seq
# get the token before the padding_token in the seq
logits = logits.argmax(dim=-1).unsqueeze(-1) logits = logits.argmax(dim=-1).unsqueeze(-1)
generated = torch.cat([generated, logits], dim=-1) generated = torch.cat([generated, logits], dim=-1)
...@@ -308,8 +309,8 @@ def main(): ...@@ -308,8 +309,8 @@ def main():
} }
ops_list = [ops] * bsz ops_list = [ops] * bsz
#tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list) tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
tokens_generated = generate_greedy(model.forward, tokens, gen_len) #tokens_generated = generate_greedy(model.forward, tokens, gen_len)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops) #tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print(tokens_generated.shape) print(tokens_generated.shape)
ic(prompt) ic(prompt)
......
...@@ -17,7 +17,7 @@ import torch.nn.functional as F ...@@ -17,7 +17,7 @@ import torch.nn.functional as F
from lm_eval.models.gpt2 import GPT2LM from lm_eval.models.gpt2 import GPT2LM
from lm_eval import tasks, evaluator, utils, base from lm_eval import tasks, evaluator, utils, base
from basedformer import optimizer, utils, gptj, noemblm, gpt2
class EvalHarnessAdapter(GPT2LM): class EvalHarnessAdapter(GPT2LM):
""" """
...@@ -36,7 +36,7 @@ class EvalHarnessAdapter(GPT2LM): ...@@ -36,7 +36,7 @@ class EvalHarnessAdapter(GPT2LM):
self.neox_args = neox_args self.neox_args = neox_args
self.tokenizer = neox_args.tokenizer self.tokenizer = neox_args.tokenizer
self._device = torch.device(f"cuda:{neox_args.local_rank}") self._device = torch.device(f"cuda:{neox_args.local_rank}")
self._eot_token_id = neox_args.tokenizer.eod_id self._eot_token_id = 50256
self._max_length = neox_args.max_position_embeddings // 2 self._max_length = neox_args.max_position_embeddings // 2
self._max_gen_toks = 128 self._max_gen_toks = 128
self._vocab_size = neox_args.padded_vocab_size self._vocab_size = neox_args.padded_vocab_size
...@@ -416,7 +416,6 @@ def run_eval_harness( ...@@ -416,7 +416,6 @@ def run_eval_harness(
num_fewshot=0, num_fewshot=0,
bootstrap_iters=2, bootstrap_iters=2,
): ):
print_rank_0("Running evaluation harness...")
adapter = EvalHarnessAdapter(model, forward_step_fn, neox_args, batch_size) adapter = EvalHarnessAdapter(model, forward_step_fn, neox_args, batch_size)
return adapter.run_eval( return adapter.run_eval(
eval_tasks=eval_tasks, num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters eval_tasks=eval_tasks, num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -89,6 +89,7 @@ for input_ids, labels in t: ...@@ -89,6 +89,7 @@ for input_ids, labels in t:
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16): with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model.lm(input_ids[x*bs:(x+1)*bs, :512].cuda(), act_ck=False) logits = model.lm(input_ids[x*bs:(x+1)*bs, :512].cuda(), act_ck=False)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0])) #print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
#roll down the sequence
logits = logits.view(-1, logits.shape[-1]) logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :512].contiguous() gas_labels = labels[x*bs:(x+1)*bs, :512].contiguous()
gas_labels = gas_labels.view(-1) gas_labels = gas_labels.view(-1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment