Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
24459438
Commit
24459438
authored
May 13, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fairseq works, start neo impl
parent
92b7b187
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
120 additions
and
90 deletions
+120
-90
basedformer/models/gptneo.py
basedformer/models/gptneo.py
+120
-90
No files found.
basedformer/models/gptneo.py
View file @
24459438
...
...
@@ -11,39 +11,20 @@ except ImportError:
import
os
from
pathlib
import
Path
import
math
from
basedformer
import
lm_base
from
basedformer.models
import
base_lm
from
typing
import
Optional
,
Any
def
shift_tokens
(
x
,
amt
,
eps
=
1e-5
):
n
,
device
=
x
.
shape
[
1
],
x
.
device
cumsum
=
x
.
cumsum
(
dim
=
1
)
*
x
,
x_pass
=
x
.
chunk
(
amt
+
1
,
dim
=
-
1
)
*
x_cumsum
,
_
=
cumsum
.
chunk
(
amt
+
1
,
dim
=
-
1
)
amts
=
2
**
torch
.
arange
(
amt
)
amts
=
amts
.
tolist
()
shifts
=
[]
denom
=
torch
.
arange
(
n
,
device
=
device
)
for
x_chunk
,
x_cumsum_chunk
,
amt
in
zip
(
x
,
x_cumsum
,
amts
):
shifted_chunk
=
shift
(
x_cumsum_chunk
,
amt
,
dim
=
-
2
)
-
shift
(
x_cumsum_chunk
,
2
*
amt
,
dim
=
-
2
)
shifted_denom
=
shift
(
denom
,
amt
,
dim
=
-
1
)
-
shift
(
denom
,
2
*
amt
,
dim
=
-
1
)
shifted_denom
=
rearrange
(
shifted_denom
,
'n -> () n ()'
)
normed_shifted_x
=
shifted_chunk
/
(
shifted_denom
+
eps
)
shifts
.
append
(
normed_shifted_x
)
return
torch
.
cat
((
*
shifts
,
x_pass
),
dim
=
-
1
)
def
shift
(
x
,
amt
,
dim
=
-
1
):
return
F
.
pad
(
x
,
(
*
((
0
,
0
)
*
(
-
dim
-
1
)),
amt
,
-
amt
),
value
=
0.
)
def
_attn
(
query
,
key
,
value
,
causal_mask
,
masked_bias
,
attention_mask
=
None
,
scale_attn
=
None
):
attention_mask
=
None
,
scale_attn
=
None
,
fp32_attn
=
True
):
if
fp32_attn
:
attn_weights
=
torch
.
matmul
(
query
.
float
(),
key
.
transpose
(
-
1
,
-
2
)
.
float
())
else
:
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
matmul
(
query
,
key
.
transpose
(
-
1
,
-
2
))
attn_weights
=
torch
.
where
(
causal_mask
,
attn_weights
,
masked_bias
.
to
(
attn_weights
.
dtype
))
attn_weights
=
attn_weights
/
scale_attn
attn_weights
=
attn_weights
/
scale_attn
.
to
(
attn_weights
.
dtype
)
if
attention_mask
is
not
None
:
attn_weights
=
attn_weights
+
attention_mask
...
...
@@ -57,19 +38,34 @@ def _attn(query, key, value, causal_mask, masked_bias,
class
SelfAttention
(
nn
.
Module
):
# Code copied from HF, might want to sanity check later.
def
__init__
(
self
,
hidden_dim
,
n_head
,
device
,
d
type
):
def
__init__
(
self
,
config
,
attention_
type
):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
max_positions
=
2049
bias
=
torch
.
tril
(
torch
.
ones
((
max_positions
,
max_positions
),
dtype
=
torch
.
uint8
,
requires_grad
=
False
))
.
view
(
1
,
1
,
max_positions
,
max_positions
)
.
bool
()
self
.
head_dim
=
hidden_dim
//
n_head
if
attention_type
==
"local"
:
self
.
register_buffer
(
"bias"
,
bias
^
torch
.
tril
(
bias
,
-
config
.
window_size
),
)
else
:
self
.
register_buffer
(
"bias"
,
bias
,
)
self
.
head_dim
=
config
.
hidden_dim
//
config
.
n_head
self
.
rotary_dim
=
self
.
head_dim
//
4
self
.
hidden_dim
=
hidden_dim
self
.
n_head
=
n_head
self
.
hidden_dim
=
config
.
hidden_dim
self
.
n_head
=
config
.
n_head
device
=
config
.
device
dtype
=
config
.
dtype
self
.
register_buffer
(
"scale_attn"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
head_dim
,
requires_grad
=
False
)
.
float
()))
self
.
register_buffer
(
"bias"
,
bias
)
self
.
register_buffer
(
"masked_bias"
,
torch
.
tensor
(
-
1e9
,
requires_grad
=
False
))
#-1e10 is what mtj uses.
attn_bias
=
False
attn_bias
=
True
#fairseq has attn_bias
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_dim
,
self
.
hidden_dim
,
bias
=
attn_bias
,
device
=
device
,
dtype
=
dtype
)
...
...
@@ -93,7 +89,7 @@ class SelfAttention(nn.Module):
causal_mask
=
self
.
bias
[:,
:,
key_length
-
query_length
:
key_length
,
:
key_length
]
x
=
_attn
(
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
query
,
key
,
value
,
causal_mask
,
self
.
masked_bias
,
None
,
self
.
scale_attn
,
self
.
config
.
fp32_attn
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
...
...
@@ -101,14 +97,14 @@ class SelfAttention(nn.Module):
if
cache
:
return
x
,
(
key
,
value
)
else
:
return
x
return
x
,
None
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
hidden_dim
,
activation
,
device
,
dtype
):
def
__init__
(
self
,
config
):
nn
.
Module
.
__init__
(
self
)
self
.
ff1
=
nn
.
Linear
(
dim
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ff2
=
nn
.
Linear
(
hidden_dim
,
dim
,
device
=
device
,
dtype
=
dtype
)
self
.
activation
=
activation
self
.
ff1
=
nn
.
Linear
(
config
.
hidden_dim
,
config
.
hidden_dim
*
4
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff2
=
nn
.
Linear
(
config
.
hidden_dim
*
4
,
config
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
activation
=
config
.
activation
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
...
...
@@ -120,70 +116,104 @@ class FeedForward(nn.Module):
return
x
class
GPTNeoLayer
(
nn
.
Module
):
def
__init__
(
self
,
attn
,
ff
,
hidden_dim
,
n_head
,
eps
,
activation
,
device
,
dtype
):
def
__init__
(
self
,
attn
,
ff
,
config
,
layer_idx
):
nn
.
Module
.
__init__
(
self
)
self
.
hidden_dim
=
hidden_dim
self
.
ln_preattn
=
nn
.
LayerNorm
(
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
ln_postattn
=
nn
.
LayerNorm
(
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
ff
=
ff
(
dim
=
hidden_dim
,
hidden_dim
=
hidden_dim
*
4
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
)
self
.
attn
=
attn
(
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
device
=
device
,
dtype
=
dtype
)
self
.
tick
=
True
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
):
self
.
hidden_dim
=
config
.
hidden_dim
self
.
ln_preattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ln_postattn
=
nn
.
LayerNorm
(
config
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ff
=
ff
(
config
)
if
layer_idx
%
2
==
0
:
attn_type
=
"global"
else
:
attn_type
=
"local"
self
.
attn
=
attn
(
config
,
attn_type
)
def
forward
(
self
,
x
,
layer_id
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
cache
=
False
,
kv
=
None
):
residual
=
x
if
act_ck
:
x
=
ck
(
self
.
ln_preattn
,
x
)
attn_out
=
ck
(
self
.
attn
,
x
)
attn_out
,
kv
=
ck
(
self
.
attn
,
x
,
kv
=
kv
,
cache
=
cache
)
else
:
x
=
self
.
ln_preattn
(
x
)
attn_out
=
self
.
attn
(
x
)
attn_out
,
kv
=
self
.
attn
(
x
,
kv
=
kv
,
cache
=
cache
)
residual
=
residual
+
attn_out
x
=
residual
+
attn_out
residual
=
x
x
=
self
.
ln_postattn
(
x
)
ff_out
=
self
.
ff
(
x
,
act_ck
)
x
=
residual
+
ff_out
return
x
return
x
,
kv
class
GPTNeoModel
(
base_lm
.
BaseModel
):
def
__init__
(
self
,
user_config
,
**
kwargs
):
self
.
default_config
=
{
'n_layer'
:
6
,
'n_head'
:
8
,
'n_tokens'
:
2049
,
'hidden_dim'
:
512
,
'vocab_dim'
:
50400
,
'fp32_attn'
:
True
,
#fairseq models are trained with fp32 attn
'eps'
:
1e-5
,
'device'
:
torch
.
device
(
'cuda'
),
'dtype'
:
torch
.
float16
,
'Layer'
:
GPTNeoLayer
,
'activation'
:
gelu_new
,
'SelfAttention'
:
SelfAttention
,
'FeedForward'
:
FeedForward
,
'window_size'
:
256
,
}
def
__init__
(
self
,
user_config
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
#configuration
self
.
user_config
=
user_config
self
.
config
=
self
.
configure_model
()
config
=
self
.
config
#modeling
self
.
n_layer
=
config
.
n_layer
self
.
hidden_dim
=
config
.
hidden_dim
self
.
vocab_embed
=
nn
.
Embedding
(
config
.
vocab_dim
,
self
.
hidden_dim
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
ln_final
=
nn
.
LayerNorm
(
self
.
hidden_dim
,
eps
=
config
.
eps
,
device
=
config
.
device
,
dtype
=
config
.
dtype
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_dim
,
config
.
vocab_dim
,
bias
=
True
)
for
_
in
range
(
config
.
n_layer
):
self
.
layers
.
append
(
config
.
Layer
(
attn
=
config
.
SelfAttention
,
ff
=
config
.
FeedForward
,
config
=
config
,
)
)
# returns sinusoidal embeddings of shape: (1, n_tokens, 768)
self
.
register_buffer
(
"embed_scale"
,
torch
.
sqrt
(
torch
.
tensor
(
self
.
config
.
hidden_dim
,
requires_grad
=
False
)))
self
.
pos_embed
=
nn
.
Embedding
(
self
.
config
.
n_tokens
,
self
.
config
.
hidden_dim
)
self
.
lm_head
=
nn
.
Linear
(
self
.
config
.
hidden_dim
,
self
.
config
.
vocab_dim
,
bias
=
False
)
#bias=False for fairseq models
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
if
kv
is
None
:
kv
=
[
None
]
*
self
.
n_layer
past_length
=
0
else
:
past_length
=
kv
[
0
][
0
]
.
size
(
-
2
)
#get sequence dim of key
kv_new
=
[]
position_ids
=
torch
.
arange
(
past_length
,
x
[
-
1
]
+
past_length
,
dtype
=
torch
.
long
,
device
=
x
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
)
.
view
(
-
1
,
x
[
-
1
])
class
GPTNeoModel
(
nn
.
Module
):
def
__init__
(
self
,
hidden_dim
,
n_layer
,
n_head
,
vocab_dim
,
eps
,
activation
=
gelu_new
,
Layer
=
GPTNeoLayer
,
device
=
"cuda"
,
dtype
=
torch
.
float16
,
**
kwargs
):
nn
.
Module
.
__init__
(
self
)
self
.
n_layer
=
n_layer
self
.
hidden_dim
=
hidden_dim
self
.
vocab_embed
=
nn
.
Embedding
(
vocab_dim
,
self
.
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
self
.
ln_final
=
nn
.
LayerNorm
(
self
.
hidden_dim
,
eps
=
eps
,
device
=
device
,
dtype
=
dtype
)
self
.
layers
=
nn
.
ModuleList
([])
self
.
lm_head
=
nn
.
Linear
(
hidden_dim
,
vocab_dim
,
bias
=
True
)
for
_
in
range
(
n_layer
):
self
.
layers
.
append
(
Layer
(
attn
=
SelfAttention
,
ff
=
FeedForward
,
hidden_dim
=
hidden_dim
,
n_head
=
n_head
,
eps
=
eps
,
activation
=
activation
,
device
=
device
,
dtype
=
dtype
))
def
forward
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
lm_head
(
x
)
return
x
.
float
()
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
):
x
=
self
.
vocab_embed
(
x
)
x
=
x
+
self
.
pos_embed
(
position_ids
)
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
layer_id
=
layer_id
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
ln_final
(
x
)
return
x
x
,
kvi
=
layer
(
x
,
layer_id
=
layer_id
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
,
kv
=
kv
[
layer_id
],
cache
=
cache
)
kv_new
.
append
(
kvi
)
class
GPTNeoBaseLM
(
lm_base
.
BaseLM
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
nn
.
Module
.
__init__
(
self
)
lm_base
.
BaseLM
.
__init__
(
self
,
config
,
lm
)
self
.
model_class
=
GPTNeoModel
def
load_gpt_j
(
path
=
"models/6b"
,
state_dict
=
None
):
config
=
{
"n_layer"
:
28
,
"n_head"
:
16
,
"hidden_dim"
:
4096
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
}
model
=
GPTNeoBaseLM
.
load
(
config
,
path
,
state_dict
)
return
model
x
=
self
.
ln_final
(
x
)
if
cache
:
return
x
,
kv_new
else
:
return
x
,
None
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment