Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
5ca754a6
Commit
5ca754a6
authored
Jul 03, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
init retro
parent
62baf4ad
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
225 additions
and
0 deletions
+225
-0
basedformer/models/retro.py
basedformer/models/retro.py
+225
-0
No files found.
basedformer/models/retro.py
0 → 100644
View file @
5ca754a6
from
typing
import
Callable
,
KeysView
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
basedformer.utils
import
*
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
einops
import
rearrange
,
repeat
try
:
from
collections.abc
import
MutableMapping
except
ImportError
:
from
collections
import
MutableMapping
import
os
from
pathlib
import
Path
import
math
from
basedformer.models
import
base_lm
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
if
x
is
None
:
x
=
torch
.
empty
(
0
)
inv_freq
=
1.
/
(
10000
**
(
torch
.
arange
(
0
,
dim
,
2
)
/
dim
))
.
to
(
x
.
dtype
)
.
to
(
x
.
device
)
sinusoid_inp
=
torch
.
einsum
(
'i , j -> i j'
,
torch
.
arange
(
seq_len
)
.
to
(
x
.
device
),
inv_freq
)
.
float
()
return
torch
.
sin
(
sinusoid_inp
),
torch
.
cos
(
sinusoid_inp
)
def
rotate_every_two
(
x
):
x1
=
x
[:,
:,
:,
::
2
]
x2
=
x
[:,
:,
:,
1
::
2
]
x
=
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
)
return
rearrange
(
x
,
'... d j -> ... (d j)'
)
def
apply_rotary_pos_emb
(
x
,
sincos
,
offset
=
0
):
sin
,
cos
=
map
(
lambda
t
:
repeat
(
t
[
offset
:
x
.
shape
[
1
]
+
offset
,:],
"n d -> () n () (d j)"
,
j
=
2
),
sincos
)
return
(
x
*
cos
)
+
(
rotate_every_two
(
x
)
*
sin
)
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights
.
to
(
value
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value
)
.
to
(
value
.
dtype
)
return
attn_output
class
Attention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
config
,
causal
=
True
,
null_kv
=
False
):
nn
.
Module
.
__init__
(
self
)
max_positions
=
2049
self
.
head_dim
=
config
.
hidden_dim
//
config
.
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
config
.
hidden_dim
self
.
n_head
=
config
.
n_head
self
.
q_only
=
config
.
q_only
self
.
causal
=
causal
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
if
self
.
causal
:
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
if
config
.
q_only
:
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
head_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
head_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
else
:
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
out_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"cos"
,
cos
)
# allowing for attending to nothing (null function)
# and to save attention from breaking if all retrieved chunks are padded out
self
.
null_k
=
nn
.
Parameter
(
torch
.
randn
(
self
.
hidden_dim
))
if
null_kv
else
None
self
.
null_v
=
nn
.
Parameter
(
torch
.
randn
(
self
.
hidden_dim
))
if
null_kv
else
None
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
if
self
.
q_only
:
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
1
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
1
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
else
:
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
if
kv
:
offset
=
kv
[
0
]
.
shape
[
-
2
]
else
:
offset
=
0
if
self
.
rotary_dim
<
self
.
head_dim
:
k_rot
=
key
[:,
:,
:,
:
self
.
rotary_dim
]
k_pass
=
key
[:,
:,
:,
self
.
rotary_dim
:]
q_rot
=
query
[:,
:,
:,
:
self
.
rotary_dim
]
q_pass
=
query
[:,
:,
:,
self
.
rotary_dim
:]
k_rot
=
apply_rotary_pos_emb
(
k_rot
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
k_rot
.
dtype
)
q_rot
=
apply_rotary_pos_emb
(
q_rot
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
q_rot
.
dtype
)
key
=
torch
.
cat
([
k_rot
,
k_pass
],
dim
=-
1
)
query
=
torch
.
cat
([
q_rot
,
q_pass
],
dim
=-
1
)
else
:
key
=
apply_rotary_pos_emb
(
key
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
key
.
dtype
)
query
=
apply_rotary_pos_emb
(
query
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
query
.
dtype
)
query
=
query
.
transpose
(
1
,
2
)
key
=
key
.
transpose
(
1
,
2
)
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
key
=
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
value
=
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
#causal mask with generation in mind
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
if
cache
:
return
x
,
[
key
,
value
]
else
:
return
x
,
None
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
ff1
=
nn
.
Linear
(
config
.
hidden_dim
,
config
.
hidden_dim
*
4
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff2
=
nn
.
Linear
(
config
.
hidden_dim
*
4
,
config
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
activation
=
config
.
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
if
act_ck
:
x
=
ck
(
self
.
activation
,
x
)
else
:
x
=
self
.
activation
(
x
)
x
=
self
.
ff2
(
x
)
return
x
class
GPTJLayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
ln_preattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
#self.ln_preattn = nn.LogSoftmax(dim=-2)
self
.
ff
=
ff
(
config
)
self
.
attn
=
attn
(
config
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
diff_hypernets
=
False
,
interleaving_layers
=
False
,
every_n
=
5
,
cache
=
False
,
kv
=
None
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
,
kv
=
ck
(
self
.
attn
,
x
,
kv
,
cache
)
#attn_out, kv = self.attn(x, kv=kv, cache=cache)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
,
kv
=
self
.
attn
(
x
,
kv
=
kv
,
cache
=
cache
)
if
hypernetwork
:
if
diff_hypernets
:
if
interleaving_layers
and
layer_id
%
every_n
==
0
:
if
self
.
tick
:
hyper_out
=
hypernetwork
[
0
](
x
)
self
.
tick
=
False
else
:
hyper_out
=
hypernetwork
[
1
](
x
)
self
.
tick
=
True
elif
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
[(
layer_id
//
every_n
)
-
1
](
x
)
else
:
if
layer_id
%
every_n
==
0
:
hyper_out
=
hypernetwork
(
x
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
#order of addition matters, i had no idea... fixed a bug here.
x
=
attn_out
+
ff_out
+
residual
#x = residual + attn_out + ff_out -> doesn't match.
if
hypernetwork
and
layer_id
%
every_n
==
0
:
x
=
x
+
hyper_out
return
x
,
kv
class
GPTJModel
(
base_lm
.
BaseModel
):
def
__init__
(
self
,
user_config
,
**
kwargs
):
self
.
default_config
=
{
'n_layer'
:
6
,
'n_head'
:
8
,
'n_tokens'
:
2048
,
'hidden_dim'
:
512
,
'vocab_dim'
:
50400
,
'eps'
:
1e-5
,
'device'
:
torch
.
device
(
'cuda'
),
'dtype'
:
torch
.
float16
,
'Layer'
:
GPTJLayer
,
'activation'
:
gelu_new
,
'SelfAttention'
:
SelfAttention
,
'FeedForward'
:
FeedForward
,
}
base_lm
.
BaseModel
.
__init__
(
self
,
user_config
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment