Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
8073ccfc
Commit
8073ccfc
authored
Jul 13, 2022
by
Wes Brown
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add argument handling, a closured `hypernetwork_saver`, and save the final result.
parent
704947b4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
92 additions
and
28 deletions
+92
-28
hypertrain.py
hypertrain.py
+92
-28
No files found.
hypertrain.py
View file @
8073ccfc
...
...
@@ -8,6 +8,7 @@ from basedformer.utils import *
from
transformers
import
AutoTokenizer
from
basedformer
import
sampling
from
termcolor
import
colored
import
argparse
gpu
=
"cuda"
amp
=
torch
.
cuda
.
amp
...
...
@@ -15,8 +16,18 @@ if gpu != "cuda":
amp
=
torch
.
amp
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
prompts
=
[
"<|endoftext|>"
]
prompts
=
[
"<|endoftext|>"
,
"The year was"
,
"I grabbed my"
,
"She lifted the"
,
"He was known as the"
,
"The tavern was full again, so I ended up sharing a table with three very different creatures: a"
,
"I had been hiking in the wilderness when suddenly a"
,
"She spread her"
,
"The mercurial and beautiful woman laughed"
,
"[ Author:"
,
"[ Tags:"
,
"***"
]
def
_init_weights
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
...
...
@@ -158,39 +169,91 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None, step=0):
return
data
def
report_wandb
(
data
):
columns
=
[
"Step"
,
"Prompt"
,
"Generated Text"
,
"Vanilla Model"
]
wandb
.
log
({
"Generations"
:
wandb
.
Table
(
data
=
data
,
columns
=
columns
)})
def
report_console
(
data
):
for
gen
in
data
[
3
]:
for
gen
in
data
[
2
]:
print
(
colored
(
"======================================================"
,
"red"
))
print
(
colored
(
gen
,
"green"
))
print
(
colored
(
"======================================================"
,
"red"
))
def
make_hypernet_saver
(
train_config
,
hypernetwork
):
def
hypernet_saver
(
id
:
str
):
save_folder
=
Path
(
train_config
[
"save_path"
])
/
id
save_folder
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
torch
.
save
(
hypernetwork
.
state_dict
(),
save_folder
/
"hyper.pt"
)
opt
.
save
(
save_folder
/
"opt"
)
return
hypernet_saver
parser
=
argparse
.
ArgumentParser
(
description
=
'Hypernetwork Finetuner'
)
parser
.
add_argument
(
'--run_name'
,
type
=
str
,
help
=
'the run name to use'
,
required
=
True
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
help
=
'the model to train against'
,
required
=
True
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
help
=
'pre-tokenized dataset to use'
,
required
=
True
)
parser
.
add_argument
(
"--output"
,
type
=
str
,
help
=
'output path'
,
default
=
''
)
parser
.
add_argument
(
'--optimizer'
,
type
=
str
,
help
=
'the optimizer to use'
,
default
=
'adamw'
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
help
=
'learning rate'
,
default
=
2e-4
)
parser
.
add_argument
(
'--end_lr'
,
type
=
float
,
help
=
'end learning rate'
,
default
=
2e-4
)
parser
.
add_argument
(
'--warmup'
,
type
=
int
,
help
=
'warmup steps'
)
parser
.
add_argument
(
'--bs'
,
type
=
int
,
help
=
'batch size'
,
default
=
4
)
parser
.
add_argument
(
'--gas'
,
type
=
int
,
help
=
'gas'
,
default
=
1
)
parser
.
add_argument
(
'--seed'
,
type
=
int
,
help
=
"Random seed value"
,
default
=
42
)
parser
.
add_argument
(
"--save_steps"
,
type
=
int
,
help
=
'# of steps between checkpoint saves'
,
default
=
300
)
parser
.
add_argument
(
"--amp"
,
type
=
bool
,
help
=
'enable amp'
,
default
=
False
)
parser
.
add_argument
(
'--loss_scale'
,
type
=
bool
,
help
=
'whether to scale loss'
,
default
=
False
)
parser
.
add_argument
(
"--eval_every"
,
type
=
int
,
help
=
'evaluate hypernetwork every x steps'
,
default
=
100
)
parser
.
add_argument
(
'--output_path'
,
type
=
str
,
help
=
"Root path of all output"
,
default
=
"./"
)
parser
.
add_argument
(
'--no_resume'
,
type
=
bool
,
default
=
False
,
help
=
"Do not resume from last checkpoint"
)
parser
.
add_argument
(
"--context_size"
,
type
=
int
,
help
=
"Dataset context sizes"
,
default
=
2048
)
parser
.
add_argument
(
"--project_id"
,
type
=
str
,
help
=
"Project ID for reporting"
,
default
=
"hypernetwork-training"
)
parser
.
add_argument
(
"--logs"
,
type
=
str
,
help
=
"log directory location"
,
default
=
"./logs"
)
parser
.
add_argument
(
"--masked"
,
type
=
bool
,
help
=
"masked softmax fusion"
)
parser
.
set_defaults
(
loss_scale
=
False
,
amp
=
False
,
no_resume
=
False
,
masked
=
False
)
args
=
parser
.
parse_args
()
if
args
.
output
==
''
:
args
.
output
=
f
'./{args.run_name}'
# we need 250 batch size to train the small GPT.
train_config
=
{
"data_path"
:
"dataset/cassandra.map"
,
"save_path"
:
"models/sigurdv4-cassandra-hypernet2"
,
"lm_path"
:
"pretrained/sigurdv4"
,
"optimizer"
:
"adamw"
,
"masked_softmax_fusion"
:
False
,
"do_save"
:
True
,
"run_name"
:
"sigurdv4-cassandra-6b-postln-bf16-2e-4-4bsz-every5layer"
,
"lr"
:
2e-4
,
"end_lr"
:
2e-4
,
"warmup_steps"
:
50
,
"bs"
:
4
,
"gas"
:
1
,
"seed"
:
69
,
"save_every"
:
30
0
,
"amp"
:
False
,
"loss_scale"
:
Fals
e
,
"eval_every"
:
100
,
"data_path"
:
args
.
dataset
,
"save_path"
:
args
.
model
,
"lm_path"
:
args
.
model
,
"optimizer"
:
args
.
optimizer
,
"masked_softmax_fusion"
:
args
.
masked
,
"do_save"
:
args
.
save_steps
!=
0
,
"run_name"
:
args
.
run_name
,
"lr"
:
args
.
lr
,
"end_lr"
:
args
.
end_lr
,
"warmup_steps"
:
args
.
warmup
,
"bs"
:
args
.
bs
,
"gas"
:
args
.
gas
,
"seed"
:
args
.
seed
,
"save_every"
:
args
.
save_steps
0
,
"amp"
:
args
.
amp
,
"loss_scale"
:
args
.
loss_scal
e
,
"eval_every"
:
args
.
eval_every
,
}
torch
.
manual_seed
(
train_config
[
"seed"
])
bs
=
train_config
[
"bs"
]
...
...
@@ -209,6 +272,7 @@ for name, p in model.named_parameters():
hypernetwork
=
HyperNetworkSingle
(
model
.
config
)
.
to
(
gpu
)
.
float
()
for
param
in
hypernetwork
.
parameters
():
param
.
requires_grad
=
True
hypernetwork_saver
=
make_hypernet_saver
(
train_config
,
hypernetwork
)
cp_list
=
sorted
(
os
.
listdir
(
train_config
[
"save_path"
]),
key
=
lambda
x
:
int
(
x
.
split
(
"_"
)[
-
1
]))
...
...
@@ -216,7 +280,7 @@ last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(
cp_list
)
>
0
else
None
print
(
last_cp
)
if
last_cp
:
if
last_cp
and
not
args
.
no_resume
:
print
(
"Loading from step {}"
.
format
(
cp_list
[
-
1
]
.
split
(
"_"
)[
-
1
]))
hypernetwork
.
load_state_dict
(
torch
.
load
(
last_cp
/
"hyper.pt"
))
opt
=
optimizer
.
BasedOptimizer
.
load
(
hypernetwork
.
parameters
(),
...
...
@@ -303,12 +367,10 @@ for input_ids, labels in t:
},
step
=
curr_step
)
if
train_config
[
"do_save"
]
and
curr_step
%
train_config
[
"save_every"
]
==
0
and
curr_step
!=
0
:
save_folder
=
Path
(
train_config
[
"save_path"
])
/
f
"step_{curr_step}"
save_folder
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
torch
.
save
(
hypernetwork
.
state_dict
(),
save_folder
/
"hyper.pt"
)
opt
.
save
(
save_folder
/
"opt"
)
if
train_config
[
"do_save"
]
and
\
curr_step
%
train_config
[
"save_every"
]
==
0
and
\
curr_step
!=
0
:
hypernetwork_saver
(
f
"step_{curr_step}"
)
print
(
f
"
\n
Saved model at step {curr_step}"
)
if
curr_step
%
train_config
[
"eval_every"
]
==
0
and
curr_step
!=
0
:
...
...
@@ -320,3 +382,5 @@ for input_ids, labels in t:
report_wandb
(
sample_data
)
curr_step
+=
1
hypernetwork_saver
(
"final"
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment