Commit 05bbefb8 authored by novelailab's avatar novelailab

fix greedy sampling

parent a5470cba
...@@ -6,6 +6,13 @@ import functorch ...@@ -6,6 +6,13 @@ import functorch
import time import time
import sys import sys
# TODO: Write a streamer for the sampler so we can decouple tokens_to_generate over the batch as well
# a lot more work as you need to schedule the forwards. Then we need a batcher, to look over a queue
# and take in the next batch items, without waiting for too long and selecting requests with sequence lengths and if possible
# generation lengths close.
# TODO: make the padding work to generate (need to take the logit before the padding starts instead of the last logit.)
def print_top_k(logits, tokenizer, k): def print_top_k(logits, tokenizer, k):
topk_ind = logits.topk(k)[1] topk_ind = logits.topk(k)[1]
for x in range(topk_ind.shape[0]): for x in range(topk_ind.shape[0]):
...@@ -152,14 +159,13 @@ def func_multinomial(x): ...@@ -152,14 +159,13 @@ def func_multinomial(x):
@torch.no_grad() @torch.no_grad()
def generate_greedy(forward, prompt_tokens, tokens_to_generate=50): def generate_greedy(forward, prompt_tokens, tokens_to_generate=50):
in_tokens = prompt_tokens in_tokens = prompt_tokens
context = prompt_tokens generated = torch.zeros(prompt_tokens.shape[0], 0, dtype=torch.long).to(in_tokens.device)
generated = torch.tensor([[]], dtype=torch.long).to(in_tokens.device)
kv = None kv = None
for _ in range(tokens_to_generate): for _ in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv) logits, kv = forward(in_tokens, cache=True, kv=kv)
logits = logits[:, -1, :] #get the last token in the seq logits = logits[:, -1, :] #get the last token in the seq
logits = logits.argmax(dim=-1) logits = logits.argmax(dim=-1).unsqueeze(-1)
generated = torch.cat([generated, logits], dim=-1) generated = torch.cat([generated, logits], dim=-1)
in_tokens = logits in_tokens = logits
...@@ -302,7 +308,8 @@ def main(): ...@@ -302,7 +308,8 @@ def main():
} }
ops_list = [ops] * bsz ops_list = [ops] * bsz
tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list) #tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
tokens_generated = generate_greedy(model.forward, tokens, gen_len)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops) #tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print(tokens_generated.shape) print(tokens_generated.shape)
ic(prompt) ic(prompt)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment