Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
3c7bd057
Commit
3c7bd057
authored
Apr 14, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
clean gptj attention
parent
24e93cbf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
7 deletions
+23
-7
basedformer/gptj.py
basedformer/gptj.py
+23
-7
No files found.
basedformer/gptj.py
View file @
3c7bd057
...
@@ -70,13 +70,20 @@ class SelfAttention(nn.Module):
...
@@ -70,13 +70,20 @@ class SelfAttention(nn.Module):
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"cos"
,
cos
)
self
.
register_buffer
(
"cos"
,
cos
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
# split heads into: [batch, head, sequence, head_dim]
# other than v because some rotary bs?
# transpose q, k after rotary as rotary code accepts [b, s, h, h_d]
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
query
=
self
.
q_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
key
=
self
.
k_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
value
=
self
.
v_proj
(
x
)
.
view
(
B
,
S
,
self
.
n_head
,
self
.
head_dim
)
.
transpose
(
1
,
2
)
if
kv
:
k
,
v
=
kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch
.
cat
([
k
,
key
],
dim
=-
2
)
# cat key
torch
.
cat
([
v
,
value
],
dim
=-
2
)
# cat value
offset
=
0
offset
=
0
if
self
.
rotary_dim
<
self
.
head_dim
:
if
self
.
rotary_dim
<
self
.
head_dim
:
...
@@ -95,7 +102,13 @@ class SelfAttention(nn.Module):
...
@@ -95,7 +102,13 @@ class SelfAttention(nn.Module):
else
:
else
:
key
=
apply_rotary_pos_emb
(
key
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
key
.
dtype
)
key
=
apply_rotary_pos_emb
(
key
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
key
.
dtype
)
query
=
apply_rotary_pos_emb
(
query
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
query
.
dtype
)
query
=
apply_rotary_pos_emb
(
query
,
(
self
.
sin
,
self
.
cos
),
offset
=
offset
)
.
to
(
query
.
dtype
)
if
cache
:
# doing this to avoid transposing key again after loading it as transposed.
cache
=
(
key
,
)
query
=
query
.
transpose
(
1
,
2
)
key
=
key
.
transpose
(
1
,
2
)
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
query_length
,
key_length
=
query
.
size
(
-
2
),
key
.
size
(
-
2
)
#causal mask with generation in mind
#causal mask with generation in mind
...
@@ -108,7 +121,10 @@ class SelfAttention(nn.Module):
...
@@ -108,7 +121,10 @@ class SelfAttention(nn.Module):
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
x
.
transpose
(
1
,
2
)
.
contiguous
()
.
view
(
B
,
S
,
H
)
x
=
self
.
out_proj
(
x
)
x
=
self
.
out_proj
(
x
)
return
x
if
cache
:
return
x
,
(
cache
[
0
],
value
)
else
:
return
x
class
FeedForward
(
nn
.
Module
):
class
FeedForward
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
hidden_dim
,
activation
,
device
,
dtype
):
def
__init__
(
self
,
dim
,
hidden_dim
,
activation
,
device
,
dtype
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment