Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
c99ffa47
Commit
c99ffa47
authored
Mar 26, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix checkpointing
parent
7e65fc56
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
4 deletions
+46
-4
hypertrain.py
hypertrain.py
+45
-3
main.py
main.py
+1
-1
No files found.
hypertrain.py
View file @
c99ffa47
...
...
@@ -16,6 +16,8 @@ import wandb
from
lm_arch.gpt2
import
GPT2Model
import
numpy
as
np
from
transformers
import
AutoTokenizer
from
torch.utils.checkpoint
import
checkpoint
as
ck
from
math
import
log2
,
ceil
def
_init_weights
(
module
):
"""Initialize the weights."""
...
...
@@ -29,6 +31,43 @@ def _init_weights(module):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
def
shift_tokens
(
x
,
amt
,
eps
=
1e-5
):
n
,
device
=
x
.
shape
[
1
],
x
.
device
cumsum
=
x
.
cumsum
(
dim
=
1
)
*
x
,
x_pass
=
x
.
chunk
(
amt
+
1
,
dim
=
-
1
)
*
x_cumsum
,
_
=
cumsum
.
chunk
(
amt
+
1
,
dim
=
-
1
)
amts
=
2
**
torch
.
arange
(
amt
)
amts
=
amts
.
tolist
()
shifts
=
[]
denom
=
torch
.
arange
(
n
,
device
=
device
)
for
x_chunk
,
x_cumsum_chunk
,
amt
in
zip
(
x
,
x_cumsum
,
amts
):
shifted_chunk
=
shift
(
x_cumsum_chunk
,
amt
,
dim
=
-
2
)
-
shift
(
x_cumsum_chunk
,
2
*
amt
,
dim
=
-
2
)
shifted_denom
=
shift
(
denom
,
amt
,
dim
=
-
1
)
-
shift
(
denom
,
2
*
amt
,
dim
=
-
1
)
shifted_denom
=
rearrange
(
shifted_denom
,
'n -> () n ()'
)
normed_shifted_x
=
shifted_chunk
/
(
shifted_denom
+
eps
)
shifts
.
append
(
normed_shifted_x
)
return
torch
.
cat
((
*
shifts
,
x_pass
),
dim
=
-
1
)
def
discounted_cumsum
(
t
,
gamma
):
try
:
from
torch_discounted_cumsum
import
discounted_cumsum_left
except
ImportError
:
print
(
'unable to import torch_discounted_cumsum - please run `pip install torch-discounted-cumsum`'
)
b
,
n
,
d
=
t
.
shape
t
=
rearrange
(
t
,
'b n d -> (b d) n'
)
t
=
discounted_cumsum_left
(
t
,
gamma
)
t
=
rearrange
(
t
,
'(b d) n -> b n d'
,
b
=
b
)
return
t
def
shift
(
x
,
amt
,
dim
=
-
1
):
return
F
.
pad
(
x
,
(
*
((
0
,
0
)
*
(
-
dim
-
1
)),
amt
,
-
amt
),
value
=
0.
)
class
HyperNetwork
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
()
.
__init__
()
...
...
@@ -36,6 +75,7 @@ class HyperNetwork(nn.Module):
self
.
linear
=
nn
.
Linear
(
embed_dim
,
embed_dim
//
4
,
bias
=
True
)
self
.
linear2
=
nn
.
Linear
(
embed_dim
//
4
,
embed_dim
,
bias
=
True
)
self
.
activation
=
gelu_new
self
.
num_shifts
=
ceil
(
log2
(
2048
))
-
1
#self.linear.weight.data.normal_(mean=0.0, std=0.02)
for
module
in
self
.
modules
():
_init_weights
(
module
)
...
...
@@ -48,8 +88,10 @@ class HyperNetwork(nn.Module):
#self.load_state_dict(state)
def
forward
(
self
,
x
):
x
=
self
.
linear
(
x
.
float
())
x
=
self
.
activation
(
x
)
x
=
x
.
float
()
x
=
shift_tokens
(
x
,
self
.
num_shifts
)
x
=
self
.
linear
(
x
)
x
=
ck
(
self
.
activation
,
x
)
x
=
self
.
linear2
(
x
)
x
=
x
.
mul
(
torch
.
sigmoid
(
x
))
return
x
.
bfloat16
()
...
...
@@ -131,7 +173,7 @@ for input_ids, labels in t:
loss
=
0
for
x
in
range
(
train_config
[
"gas"
]):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
train_config
[
"amp"
],
dtype
=
torch
.
float16
):
logits
=
model
(
input_ids
[
x
*
bs
:(
x
+
1
)
*
bs
,
:]
.
cuda
(),
hypernetwork
=
hypernetwork
,
act_ck
=
Fals
e
)
logits
=
model
(
input_ids
[
x
*
bs
:(
x
+
1
)
*
bs
,
:]
.
cuda
(),
hypernetwork
=
hypernetwork
,
act_ck
=
Tru
e
)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits
=
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
])
gas_labels
=
labels
[
x
*
bs
:(
x
+
1
)
*
bs
,
:]
.
contiguous
()
...
...
main.py
View file @
c99ffa47
...
...
@@ -212,7 +212,7 @@ class FeedForward(nn.Module):
def
forward
(
self
,
x
,
act_ck
=
False
):
x
=
self
.
ff1
(
x
)
if
act_ck
:
ck
(
self
.
activation
,
x
)
x
=
ck
(
self
.
activation
,
x
)
else
:
x
=
self
.
activation
(
x
)
x
=
self
.
ff2
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment