Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
fb25b47c
Commit
fb25b47c
authored
Mar 26, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
set seed, everything works
parent
4f87dce5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
1 deletion
+5
-1
train.py
train.py
+5
-1
No files found.
train.py
View file @
fb25b47c
...
...
@@ -13,6 +13,7 @@ from tqdm import tqdm
import
time
import
wandb
from
lm_arch.gpt2
import
GPT2Model
import
numpy
as
np
model_config
=
{
"n_layer"
:
12
,
...
...
@@ -38,12 +39,13 @@ train_config = {
"save_every"
:
500
,
"amp"
:
True
,
}
torch
.
manual_seed
(
train_config
[
"seed"
])
bs
=
train_config
[
"bs"
]
gas
=
train_config
[
"gas"
]
Path
(
train_config
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
model
=
GPT
2
Model
.
gpt2_init
(
model_config
)
.
cuda
()
.
float
()
model
=
GPTModel
.
gpt2_init
(
model_config
)
.
cuda
()
.
float
()
opt
=
optimizer
.
BasedOptimizer
(
model
.
parameters
(),
train_config
,
"adamw"
)
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
...
...
@@ -65,6 +67,8 @@ for input_ids, labels in t:
for
x
in
range
(
train_config
[
"gas"
]):
if
train_config
[
"amp"
]:
with
torch
.
cuda
.
amp
.
autocast
():
#with torch.jit.fuser("fuser2"):
# module = torch.jit.trace(model, torch.randint(0, 50256, (12, 1024)).long().cuda())
logits
=
model
(
input_ids
[
x
*
bs
:(
x
+
1
)
*
bs
,
:
1024
]
.
cuda
(),
hypernetwork
=
None
,
act_ck
=
False
)
logits
=
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
])
gas_labels
=
labels
[
x
*
bs
:(
x
+
1
)
*
bs
,
:
1024
]
.
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment