Commit a1556cf7 authored by novelailab's avatar novelailab

push

parent 1f8d04b2
import torch
import sys
......@@ -13,6 +13,82 @@ def print_top_k(logits, tokenizer, k):
for token in topk_ind[x, y, :].tolist():
print(tokenizer.decode([token]), end=" | ")
def apply_top_k(logits, k):
# filter the logits that are not in the top-k to -inf
# keep top_k_ind and filter the rest
top_k_values = logits.topk(k)[0]
remove_mask = logits < top_k_values[:, -1].unsqueeze(0)
logits[remove_mask == True] = -float("inf")
return logits
def apply_top_p(logits, p):
sorted, indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(sorted, dim=-1)
cumulative_probs = cumulative_probs.scatter(dim=-1, index=indices, src=cumulative_probs)
remove_mask = cumulative_probs > p
logits[remove_mask == True] = -float("inf")
return logits
def apply_tfs(logits, tfs):
sorted, indices = torch.sort(logits, descending=True)
d = sorted
d = d[:, 1:] - d[:, :-1]
d = d[:, 1:] - d[:, :-1]
d = d.abs()
d = d / d.sum(dim=-1).view(1, -1).T
cumulative_probs = torch.cumsum(d, dim=-1)
cumulative_probs = cumulative_probs.scatter(dim=-1, index=indices, src=cumulative_probs)
remove_mask = cumulative_probs > tfs
logits[remove_mask == True] = -float("inf")
return logits
def temperature(logits, temperature):
logits = logits / temperature
return logits
def generate(forward, prompt_tokens, temperature, tokens_to_generate=50, ops_list=[{"temp": 0.9}]):
with torch.no_grad():
in_tokens = prompt_tokens
kv = None
tokens_generated = []
soft_required = ["top_k", "top_p"]
op_map = {
"top_k": apply_top_k,
"top_p": apply_top_p,
"temp": temperature,
"tfs": apply_tfs
}
for _ in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv)
# always work on softmax logits to make sure all models
# behave similarly as logprobs can be quite different
# TODO: can break compatibility with novelai presets.
# logits should be the last token in the sequence
logits = logits[:, -1, :]
logits = torch.log_softmax(logits, dim=-1)
#can save one softmax here by not applying softmax for the first op,
#need to take the softmax out of the necessary functions though
for i, ops in enumerate(ops_list):
batch = []
for op, value in ops.items():
if op in soft_required:
item = torch.log_softmax(logits[i, :, :], dim=-1)
item = op_map[op](item, value)
batch.append(item)
logits = torch.cat(batch, dim=0)
logits = torch.softmax(logits, dim=-1)
logits = torch.multinomial(logits, 1)
in_tokens = logits
tokens_generated.append(logits)
tokens_generated = torch.cat(tokens_generated, dim=-1)
return tokens_generated
def main():
tokenizer = AutoTokenizer.from_pretrained('gpt2')
prompt = """I fucked her with my huge donut, when she seen my donut she went"""
......@@ -28,20 +104,16 @@ def main():
ic(time.perf_counter() - t)
with torch.no_grad():
kv = None
tokens_to_generate = 50
in_tokens = tokens
accum_tokens = []
for x in range(tokens_to_generate):
logits, kv = model(in_tokens, cache=True, kv=kv)
in_tokens = logits[:, -1, :].topk(1)[1]
#in_tokens = torch.cat([in_tokens, logits[:, -1, :].topk(1)[1]], dim=1)
print(tokenizer.decode(in_tokens.squeeze(1).tolist()[-1]), end=" | ")
#accum_tokens = torch.cat(accum_tokens, dim=1)
#accum_tokens = accum_tokens.squeeze(0).tolist()
#print("\n Final token list")
#print(tokenizer.decode(accum_tokens))
ops = {
"top_k": 40,
"top_p": 0.9,
"temp": 0.9,
}
tokens_generated = generate(model.forward, tokens, 40, ops=ops)
tokens_generated = tokenizer.decode(tokens_generated.squeeze().tolist())
ic(prompt)
ic(tokens_generated)
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment