Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
a5470cba
Commit
a5470cba
authored
Apr 29, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
sampler on its own file and greedy sampling
parent
a1ab899d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
316 additions
and
1 deletion
+316
-1
basedformer/sampling.py
basedformer/sampling.py
+316
-1
No files found.
basedformer/sampling.py
View file @
a5470cba
import
torch
from
basedformer
import
gptj
from
basedformer.utils
import
*
from
transformers
import
AutoTokenizer
from
icecream
import
ic
import
functorch
import
time
import
sys
def
print_top_k
(
logits
,
tokenizer
,
k
):
topk_ind
=
logits
.
topk
(
k
)[
1
]
for
x
in
range
(
topk_ind
.
shape
[
0
]):
for
y
in
range
(
topk_ind
.
shape
[
1
]):
print
(
"
\n
Token "
+
str
(
y
))
for
token
in
topk_ind
[
x
,
y
,
:]
.
tolist
():
print
(
tokenizer
.
decode
([
token
]),
end
=
" | "
)
def
apply_top_k
(
logits
,
k
):
# filter the logits that are not in the top-k to -inf
# keep top_k_ind and filter the rest
top_k_values
=
logits
.
topk
(
k
)[
0
]
remove_mask
=
logits
<
top_k_values
[:,
-
1
]
.
unsqueeze
(
-
1
)
logits
[
remove_mask
==
True
]
=
-
float
(
"inf"
)
return
logits
def
apply_top_p
(
logits
,
p
):
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
sorted
,
indices
=
torch
.
sort
(
logits
,
descending
=
True
)
cumulative_probs
=
torch
.
cumsum
(
sorted
,
dim
=-
1
)
mask_tensor
=
cumulative_probs
>
p
# Shift the indices to the right to keep also the first token above the threshold
mask_tensor
[
...
,
1
:]
=
mask_tensor
[
...
,
:
-
1
]
.
clone
()
mask_tensor
[
...
,
0
]
=
0
mask_tensor
=
mask_tensor
.
scatter
(
dim
=-
1
,
index
=
indices
,
src
=
mask_tensor
)
logits
[
mask_tensor
==
True
]
=
-
float
(
"inf"
)
return
logits
def
apply_tfs
(
logits
,
tfs
):
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
sorted
,
indices
=
torch
.
sort
(
logits
,
descending
=
True
)
d
=
sorted
d
=
d
[:,
1
:]
-
d
[:,
:
-
1
]
d
=
d
[:,
1
:]
-
d
[:,
:
-
1
]
d
=
d
.
abs
()
d
=
d
/
d
.
sum
(
dim
=-
1
)
.
view
(
1
,
-
1
)
.
T
cumulative_probs
=
torch
.
cumsum
(
d
,
dim
=-
1
)
mask_tensor
=
torch
.
empty
(
indices
.
shape
)
.
cuda
()
mask_tensor
[:,
1
:
-
1
]
=
(
cumulative_probs
>
tfs
)[:,
:]
# Always remove last token
mask_tensor
[:,
-
1
:]
=
True
# Always keep the first token
mask_tensor
[:,
0
]
=
False
mask_tensor
=
mask_tensor
.
scatter
(
dim
=-
1
,
index
=
indices
,
src
=
mask_tensor
)
logits
[
mask_tensor
==
True
]
=
-
float
(
"inf"
)
return
logits
def
apply_typical
(
logits
,
mass
=
0.9
):
scores
=
logits
normalized
=
torch
.
nn
.
functional
.
log_softmax
(
scores
,
dim
=-
1
)
p
=
torch
.
exp
(
normalized
)
ent
=
-
(
normalized
*
p
)
.
nansum
(
-
1
,
keepdim
=
True
)
# shift and sort
shifted_scores
=
torch
.
abs
((
-
normalized
)
-
ent
)
sorted_scores
,
sorted_indices
=
torch
.
sort
(
shifted_scores
,
descending
=
False
)
sorted_logits
=
scores
.
gather
(
-
1
,
sorted_indices
)
cumulative_probs
=
sorted_logits
.
softmax
(
dim
=-
1
)
.
cumsum
(
dim
=-
1
)
# Remove tokens with cumulative mass above the threshold
last_ind
=
(
cumulative_probs
<
mass
)
.
sum
(
dim
=
1
)
last_ind
[
last_ind
<
0
]
=
0
sorted_indices_to_remove
=
sorted_scores
>
sorted_scores
.
gather
(
1
,
last_ind
.
view
(
-
1
,
1
))
indices_to_remove
=
sorted_indices_to_remove
.
scatter
(
1
,
sorted_indices
,
sorted_indices_to_remove
)
scores
=
scores
.
masked_fill
(
indices_to_remove
,
-
float
(
"inf"
))
return
scores
def
apply_temp
(
logits
,
temperature
):
logits
=
logits
/
temperature
return
logits
def
rep_pen
(
input_ids
,
scores
,
penalty
,
m
=
3.33
,
penalize_last
=
250
,
alpha_frequency
=
None
,
alpha_presence
=
None
,
whitelist
=
None
,
):
scores
=
torch
.
log_softmax
(
scores
,
dim
=-
1
)
penalty
=
1.0
if
penalty
<
1.0
else
penalty
raw_penalty
=
penalty
penalize_last
=
None
if
not
m
is
None
and
not
penalize_last
is
None
and
penalize_last
>=
1
:
penalty
=
(
torch
.
arange
(
penalize_last
)
/
(
penalize_last
-
1
))
*
2.
-
1
penalty
=
(
m
*
penalty
)
/
(
1
+
torch
.
abs
(
penalty
)
*
(
m
-
1
))
penalty
=
1
+
((
penalty
+
1
)
/
2
)
.
unsqueeze
(
0
)
*
(
penalty
-
1
)
penalize_last
=
penalize_last
alpha_enable
=
alpha_frequency
is
not
None
or
alpha_presence
is
not
None
whitelist
=
None
whitelist_list
=
None
if
whitelist
is
not
None
:
whitelist_list
=
whitelist
##########
if
whitelist
is
None
and
whitelist_list
is
not
None
:
whitelist_list
=
list
(
filter
(
lambda
x
:
x
>=
0
and
x
<
scores
.
shape
[
1
],
whitelist_list
))
if
len
(
whitelist_list
)
>
0
:
whitelist
=
torch
.
tensor
(
whitelist_list
)
.
long
()
.
sort
()[
0
]
whitelist
=
whitelist
.
to
(
input_ids
.
device
)
if
whitelist
is
not
None
:
unpenalized
=
scores
.
gather
(
1
,
whitelist
.
view
(
1
,
-
1
))
if
raw_penalty
>
1.0
:
if
not
penalize_last
is
None
:
penality_len
=
min
(
input_ids
.
shape
[
1
],
penalize_last
)
input_ids
=
input_ids
[:,
-
penality_len
:]
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
# if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
if
not
penalize_last
is
None
:
penalty
=
penalty
.
type
(
score
.
dtype
)
.
to
(
score
.
device
)
score
=
torch
.
where
(
score
<
0
,
score
*
penalty
[:,
-
penality_len
:],
score
/
penalty
[:,
-
penality_len
:])
else
:
score
=
torch
.
where
(
score
<
0
,
score
*
penalty
,
score
/
penalty
)
scores
.
scatter_
(
1
,
input_ids
,
score
)
if
alpha_enable
:
c
=
torch
.
zeros
(
scores
.
shape
)
.
long
()
.
to
(
input_ids
.
device
)
# unique only returns counts for first item in batch, so manually iterate
for
i
in
range
(
input_ids
.
shape
[
0
]):
if
penalize_last
is
not
None
:
token_input_ids
,
counts
=
torch
.
unique
(
input_ids
[
i
,
-
penalize_last
:],
sorted
=
True
,
return_counts
=
True
,
dim
=-
1
)
else
:
token_input_ids
,
counts
=
torch
.
unique
(
input_ids
[
i
],
sorted
=
True
,
return_counts
=
True
,
dim
=-
1
)
c
[
i
]
.
scatter_
(
0
,
token_input_ids
,
counts
)
if
alpha_frequency
:
scores
-=
c
*
alpha_frequency
if
alpha_presence
:
scores
[
c
>
0
]
-=
alpha_presence
if
whitelist
is
not
None
:
scores
.
scatter_
(
1
,
whitelist
.
view
(
1
,
-
1
),
unpenalized
)
return
scores
def
func_multinomial
(
x
):
torch
.
manual_seed
(
69
)
return
torch
.
multinomial
(
x
,
1
)
@
torch
.
no_grad
()
def
generate_greedy
(
forward
,
prompt_tokens
,
tokens_to_generate
=
50
):
in_tokens
=
prompt_tokens
context
=
prompt_tokens
generated
=
torch
.
tensor
([[]],
dtype
=
torch
.
long
)
.
to
(
in_tokens
.
device
)
kv
=
None
for
_
in
range
(
tokens_to_generate
):
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
)
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
logits
=
logits
.
argmax
(
dim
=-
1
)
generated
=
torch
.
cat
([
generated
,
logits
],
dim
=-
1
)
in_tokens
=
logits
return
generated
@
torch
.
no_grad
()
def
generate
(
forward
,
prompt_tokens
,
tokens_to_generate
=
50
,
ops_list
=
[{
"temp"
:
0.9
}]):
in_tokens
=
prompt_tokens
context
=
prompt_tokens
generated
=
torch
.
zeros
(
prompt_tokens
.
shape
[
0
],
0
,
dtype
=
torch
.
long
)
.
to
(
in_tokens
.
device
)
kv
=
None
fully_deterministic
=
False
#soft_required = ["top_k", "top_p"]
op_map
=
{
"top_k"
:
apply_top_k
,
"top_p"
:
apply_top_p
,
"typical"
:
apply_typical
,
"temp"
:
apply_temp
,
"tfs"
:
apply_tfs
,
"rep_pen"
:
rep_pen
,
}
funcnomial
=
functorch
.
vmap
(
func_multinomial
,
randomness
=
"different"
)
for
_
in
range
(
tokens_to_generate
):
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
)
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
logits
=
torch
.
log_softmax
(
logits
,
dim
=-
1
)
#can save one softmax here by not applying softmax for the first op,
#need to take the softmax out of the necessary functions though
batch
=
[]
for
i
,
ops
in
enumerate
(
ops_list
):
item
=
logits
[
i
,
...
]
.
unsqueeze
(
0
)
ctx
=
context
[
i
,
...
]
.
unsqueeze
(
0
)
for
op
,
value
in
ops
.
items
():
if
op
==
"rep_pen"
:
item
=
op_map
[
op
](
ctx
,
item
,
**
value
)
else
:
item
=
op_map
[
op
](
item
,
value
)
batch
.
append
(
item
)
logits
=
torch
.
cat
(
batch
,
dim
=
0
)
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
#fully_deterministic makes it deterministic across the batch
if
fully_deterministic
:
logits
=
logits
.
split
(
1
,
dim
=
0
)
logit_list
=
[]
for
logit
in
logits
:
torch
.
manual_seed
(
69
)
logit_list
.
append
(
torch
.
multinomial
(
logit
,
1
))
logits
=
torch
.
cat
(
logit_list
,
dim
=
0
)
else
:
torch
.
manual_seed
(
69
)
logits
=
torch
.
multinomial
(
logits
,
1
)
generated
=
torch
.
cat
([
generated
,
logits
],
dim
=-
1
)
context
=
torch
.
cat
([
context
,
logits
],
dim
=-
1
)
in_tokens
=
logits
return
generated
def
generate_real_batched
(
forward
,
prompt_tokens
,
tokens_to_generate
=
50
,
ops
=
{
"temp"
:
0.9
}):
with
torch
.
no_grad
():
in_tokens
=
prompt_tokens
kv
=
None
fully_deterministic
=
False
tokens_generated
=
[]
op_map
=
{
"top_k"
:
apply_top_k
,
"top_p"
:
apply_top_p
,
"typical"
:
apply_typical
,
"temp"
:
apply_temp
,
"tfs"
:
apply_tfs
}
for
_
in
range
(
tokens_to_generate
):
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
)
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
logits
=
torch
.
log_softmax
(
logits
,
dim
=-
1
)
for
op
,
value
in
ops
.
items
():
logits
=
op_map
[
op
](
logits
,
value
)
.
float
()
logits
=
torch
.
softmax
(
logits
,
dim
=-
1
)
.
float
()
if
fully_deterministic
:
logits
=
logits
.
split
(
1
,
dim
=
0
)
logit_list
=
[]
for
logit
in
logits
:
torch
.
manual_seed
(
69
)
logit_list
.
append
(
torch
.
multinomial
(
logit
,
1
))
logits
=
torch
.
cat
(
logit_list
,
dim
=
0
)
else
:
torch
.
manual_seed
(
69
)
logits
=
torch
.
multinomial
(
logits
,
1
)
in_tokens
=
logits
tokens_generated
.
append
(
logits
)
tokens_generated
=
torch
.
cat
(
tokens_generated
,
dim
=-
1
)
return
tokens_generated
def
main
():
bsz
=
4
gen_len
=
250
torch
.
manual_seed
(
69
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
'gpt2'
)
prompt
=
"""I fucked her with my huge donut, when she seen my donut she went"""
prompt
=
"You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
tokens
=
tokenizer
.
encode
(
prompt
)
print
(
"Prompt:"
)
for
x
in
range
(
len
(
tokens
)):
print
(
tokenizer
.
decode
([
tokens
[
x
]]),
end
=
" | "
)
print
(
"
\n
Generation:"
)
tokens
=
torch
.
LongTensor
(
tokens
)
.
unsqueeze
(
0
)
.
cuda
()
tokens
=
[
tokens
]
*
bsz
#tokens = torch.cat([tokens, tokens], dim=0)
tokens
=
torch
.
cat
(
tokens
,
dim
=
0
)
t
=
time
.
perf_counter
()
model
=
gptj
.
load_gpt_j
()
.
cuda
()
.
half
()
.
eval
()
model
=
model
.
lm
ic
(
time
.
perf_counter
()
-
t
)
rep_pen
=
{
"penalty"
:
3
,
}
ops
=
{
"rep_pen"
:
rep_pen
,
"top_k"
:
50
,
"temp"
:
0.8
,
}
ops_list
=
[
ops
]
*
bsz
tokens_generated
=
generate
(
model
.
forward
,
tokens
,
gen_len
,
ops_list
=
ops_list
)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print
(
tokens_generated
.
shape
)
ic
(
prompt
)
tokens_generated
=
tokenizer
.
batch_decode
(
tokens_generated
.
cpu
()
.
numpy
())
for
gen
in
tokens_generated
:
print
(
str
(
gen
))
print
(
"==========================================================="
)
#ic(tokenizer.batch_decode(tokens_generated_batched.cpu().numpy()))
#timeit(lambda: generate(model.forward, tokens, 30, ops_list=ops_list), n=30)
#timeit(lambda: generate_real_batched(model.forward, tokens, 30, ops=ops), n=30)
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment