Commit a5470cba authored by novelailab's avatar novelailab

sampler on its own file and greedy sampling

parent a1ab899d
import torch
from basedformer import gptj
from basedformer.utils import *
from transformers import AutoTokenizer
from icecream import ic
import functorch
import time
import sys
def print_top_k(logits, tokenizer, k):
topk_ind = logits.topk(k)[1]
for x in range(topk_ind.shape[0]):
for y in range(topk_ind.shape[1]):
print("\nToken " + str(y))
for token in topk_ind[x, y, :].tolist():
print(tokenizer.decode([token]), end=" | ")
def apply_top_k(logits, k):
# filter the logits that are not in the top-k to -inf
# keep top_k_ind and filter the rest
top_k_values = logits.topk(k)[0]
remove_mask = logits < top_k_values[:, -1].unsqueeze(-1)
logits[remove_mask == True] = -float("inf")
return logits
def apply_top_p(logits, p):
logits = torch.softmax(logits, dim=-1)
sorted, indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(sorted, dim=-1)
mask_tensor = cumulative_probs > p
# Shift the indices to the right to keep also the first token above the threshold
mask_tensor[..., 1:] = mask_tensor[..., :-1].clone()
mask_tensor[..., 0] = 0
mask_tensor = mask_tensor.scatter(dim=-1, index=indices, src=mask_tensor)
logits[mask_tensor == True] = -float("inf")
return logits
def apply_tfs(logits, tfs):
logits = torch.softmax(logits, dim=-1)
sorted, indices = torch.sort(logits, descending=True)
d = sorted
d = d[:, 1:] - d[:, :-1]
d = d[:, 1:] - d[:, :-1]
d = d.abs()
d = d / d.sum(dim=-1).view(1, -1).T
cumulative_probs = torch.cumsum(d, dim=-1)
mask_tensor = torch.empty(indices.shape).cuda()
mask_tensor[:, 1:-1] = (cumulative_probs > tfs)[:, :]
# Always remove last token
mask_tensor[:, -1:] = True
# Always keep the first token
mask_tensor[:, 0] = False
mask_tensor = mask_tensor.scatter(dim=-1, index=indices, src=mask_tensor)
logits[mask_tensor == True] = -float("inf")
return logits
def apply_typical(logits, mass=0.9):
scores = logits
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, -float("inf"))
return scores
def apply_temp(logits, temperature):
logits = logits / temperature
return logits
def rep_pen(input_ids, scores, penalty, m=3.33, penalize_last=250,
alpha_frequency=None, alpha_presence=None, whitelist=None,
):
scores = torch.log_softmax(scores, dim=-1)
penalty = 1.0 if penalty < 1.0 else penalty
raw_penalty = penalty
penalize_last = None
if not m is None and not penalize_last is None and penalize_last >= 1:
penalty = (torch.arange(penalize_last)/(penalize_last - 1)) * 2. - 1
penalty = (m * penalty) / (1 + torch.abs(penalty) * (m - 1))
penalty = 1 + ((penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
penalize_last = penalize_last
alpha_enable = alpha_frequency is not None or alpha_presence is not None
whitelist = None
whitelist_list = None
if whitelist is not None:
whitelist_list = whitelist
##########
if whitelist is None and whitelist_list is not None:
whitelist_list = list(filter(lambda x: x >= 0 and x < scores.shape[1], whitelist_list))
if len(whitelist_list) > 0:
whitelist = torch.tensor(whitelist_list).long().sort()[0]
whitelist = whitelist.to(input_ids.device)
if whitelist is not None:
unpenalized = scores.gather(1, whitelist.view(1, -1))
if raw_penalty > 1.0:
if not penalize_last is None:
penality_len = min(input_ids.shape[1], penalize_last)
input_ids = input_ids[:, -penality_len:]
score = torch.gather(scores, 1, input_ids)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
if not penalize_last is None:
penalty = penalty.type(score.dtype).to(score.device)
score = torch.where(score < 0, score * penalty[:, -penality_len:], score / penalty[:, -penality_len:])
else:
score = torch.where(score < 0, score * penalty, score / penalty)
scores.scatter_(1, input_ids, score)
if alpha_enable:
c = torch.zeros(scores.shape).long().to(input_ids.device)
# unique only returns counts for first item in batch, so manually iterate
for i in range(input_ids.shape[0]):
if penalize_last is not None:
token_input_ids, counts = torch.unique(input_ids[i,-penalize_last:], sorted=True, return_counts=True, dim=-1)
else:
token_input_ids, counts = torch.unique(input_ids[i], sorted=True, return_counts=True, dim=-1)
c[i].scatter_(0, token_input_ids, counts)
if alpha_frequency:
scores -= c * alpha_frequency
if alpha_presence:
scores[c > 0] -= alpha_presence
if whitelist is not None:
scores.scatter_(1, whitelist.view(1, -1), unpenalized)
return scores
def func_multinomial(x):
torch.manual_seed(69)
return torch.multinomial(x, 1)
@torch.no_grad()
def generate_greedy(forward, prompt_tokens, tokens_to_generate=50):
in_tokens = prompt_tokens
context = prompt_tokens
generated = torch.tensor([[]], dtype=torch.long).to(in_tokens.device)
kv = None
for _ in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv)
logits = logits[:, -1, :] #get the last token in the seq
logits = logits.argmax(dim=-1)
generated = torch.cat([generated, logits], dim=-1)
in_tokens = logits
return generated
@torch.no_grad()
def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0.9}]):
in_tokens = prompt_tokens
context = prompt_tokens
generated = torch.zeros(prompt_tokens.shape[0], 0, dtype=torch.long).to(in_tokens.device)
kv = None
fully_deterministic = False
#soft_required = ["top_k", "top_p"]
op_map = {
"top_k": apply_top_k,
"top_p": apply_top_p,
"typical": apply_typical,
"temp": apply_temp,
"tfs": apply_tfs,
"rep_pen": rep_pen,
}
funcnomial = functorch.vmap(func_multinomial, randomness="different")
for _ in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv)
logits = logits[:, -1, :] #get the last token in the seq
logits = torch.log_softmax(logits, dim=-1)
#can save one softmax here by not applying softmax for the first op,
#need to take the softmax out of the necessary functions though
batch = []
for i, ops in enumerate(ops_list):
item = logits[i, ...].unsqueeze(0)
ctx = context[i, ...].unsqueeze(0)
for op, value in ops.items():
if op == "rep_pen":
item = op_map[op](ctx, item, **value)
else:
item = op_map[op](item, value)
batch.append(item)
logits = torch.cat(batch, dim=0)
logits = torch.softmax(logits, dim=-1)
#fully_deterministic makes it deterministic across the batch
if fully_deterministic:
logits = logits.split(1, dim=0)
logit_list = []
for logit in logits:
torch.manual_seed(69)
logit_list.append(torch.multinomial(logit, 1))
logits = torch.cat(logit_list, dim=0)
else:
torch.manual_seed(69)
logits = torch.multinomial(logits, 1)
generated = torch.cat([generated, logits], dim=-1)
context = torch.cat([context, logits], dim=-1)
in_tokens = logits
return generated
def generate_real_batched(forward, prompt_tokens, tokens_to_generate=50, ops={"temp": 0.9}):
with torch.no_grad():
in_tokens = prompt_tokens
kv = None
fully_deterministic = False
tokens_generated = []
op_map = {
"top_k": apply_top_k,
"top_p": apply_top_p,
"typical": apply_typical,
"temp": apply_temp,
"tfs": apply_tfs
}
for _ in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv)
logits = logits[:, -1, :] #get the last token in the seq
logits = torch.log_softmax(logits, dim=-1)
for op, value in ops.items():
logits = op_map[op](logits, value).float()
logits = torch.softmax(logits, dim=-1).float()
if fully_deterministic:
logits = logits.split(1, dim=0)
logit_list = []
for logit in logits:
torch.manual_seed(69)
logit_list.append(torch.multinomial(logit, 1))
logits = torch.cat(logit_list, dim=0)
else:
torch.manual_seed(69)
logits = torch.multinomial(logits, 1)
in_tokens = logits
tokens_generated.append(logits)
tokens_generated = torch.cat(tokens_generated, dim=-1)
return tokens_generated
def main():
bsz = 4
gen_len = 250
torch.manual_seed(69)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
prompt = """I fucked her with my huge donut, when she seen my donut she went"""
prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
tokens = tokenizer.encode(prompt)
print("Prompt:")
for x in range(len(tokens)):
print(tokenizer.decode([tokens[x]]), end=" | ")
print("\n Generation:")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
tokens = [tokens] * bsz
#tokens = torch.cat([tokens, tokens], dim=0)
tokens = torch.cat(tokens, dim=0)
t = time.perf_counter()
model = gptj.load_gpt_j().cuda().half().eval()
model = model.lm
ic(time.perf_counter() - t)
rep_pen = {
"penalty": 3,
}
ops = {
"rep_pen": rep_pen,
"top_k": 50,
"temp": 0.8,
}
ops_list = [ops] * bsz
tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print(tokens_generated.shape)
ic(prompt)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
for gen in tokens_generated:
print(str(gen))
print("===========================================================")
#ic(tokenizer.batch_decode(tokens_generated_batched.cpu().numpy()))
#timeit(lambda: generate(model.forward, tokens, 30, ops_list=ops_list), n=30)
#timeit(lambda: generate_real_batched(model.forward, tokens, 30, ops=ops), n=30)
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment