Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
05bbefb8
Commit
05bbefb8
authored
Apr 29, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix greedy sampling
parent
a5470cba
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
4 deletions
+11
-4
basedformer/sampling.py
basedformer/sampling.py
+11
-4
No files found.
basedformer/sampling.py
View file @
05bbefb8
...
...
@@ -6,6 +6,13 @@ import functorch
import
time
import
sys
# TODO: Write a streamer for the sampler so we can decouple tokens_to_generate over the batch as well
# a lot more work as you need to schedule the forwards. Then we need a batcher, to look over a queue
# and take in the next batch items, without waiting for too long and selecting requests with sequence lengths and if possible
# generation lengths close.
# TODO: make the padding work to generate (need to take the logit before the padding starts instead of the last logit.)
def
print_top_k
(
logits
,
tokenizer
,
k
):
topk_ind
=
logits
.
topk
(
k
)[
1
]
for
x
in
range
(
topk_ind
.
shape
[
0
]):
...
...
@@ -152,14 +159,13 @@ def func_multinomial(x):
@
torch
.
no_grad
()
def
generate_greedy
(
forward
,
prompt_tokens
,
tokens_to_generate
=
50
):
in_tokens
=
prompt_tokens
context
=
prompt_tokens
generated
=
torch
.
tensor
([[]],
dtype
=
torch
.
long
)
.
to
(
in_tokens
.
device
)
generated
=
torch
.
zeros
(
prompt_tokens
.
shape
[
0
],
0
,
dtype
=
torch
.
long
)
.
to
(
in_tokens
.
device
)
kv
=
None
for
_
in
range
(
tokens_to_generate
):
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
)
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
logits
=
logits
.
argmax
(
dim
=-
1
)
logits
=
logits
.
argmax
(
dim
=-
1
)
.
unsqueeze
(
-
1
)
generated
=
torch
.
cat
([
generated
,
logits
],
dim
=-
1
)
in_tokens
=
logits
...
...
@@ -302,7 +308,8 @@ def main():
}
ops_list
=
[
ops
]
*
bsz
tokens_generated
=
generate
(
model
.
forward
,
tokens
,
gen_len
,
ops_list
=
ops_list
)
#tokens_generated = generate(model.forward, tokens, gen_len, ops_list=ops_list)
tokens_generated
=
generate_greedy
(
model
.
forward
,
tokens
,
gen_len
)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print
(
tokens_generated
.
shape
)
ic
(
prompt
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment