Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
8b26deda
Commit
8b26deda
authored
Jul 13, 2022
by
Wes Brown
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Revert mostly to `x=` assignment form.
parent
8073ccfc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
18 deletions
+27
-18
hypertrain.py
hypertrain.py
+27
-18
No files found.
hypertrain.py
View file @
8b26deda
...
...
@@ -29,6 +29,7 @@ prompts = ["<|endoftext|>",
"[ Tags:"
,
"***"
]
def
_init_weights
(
module
):
if
isinstance
(
module
,
nn
.
Linear
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
0.02
)
...
...
@@ -72,13 +73,19 @@ class HyperNetworkGRU(nn.Module):
param
.
data
.
normal_
(
mean
=
0.0
,
std
=
(
0.02
/
math
.
sqrt
(
2
*
config
[
"n_layer"
])))
self
.
linear_gru
=
nn
.
Sequential
(
self
.
linear1
,
self
.
gru
)
self
.
layernorm_linear
=
nn
.
Sequential
(
self
.
ln_1
,
self
.
linear2
)
def
forward
(
self
,
x
):
return
ck
(
self
.
activation
,
self
.
linear2
(
self
.
ln_1
(
self
.
gru
(
self
.
linear1
(
x
.
float
()))[
0
])))
.
bfloat16
()
x
=
x
.
float
()
x
=
self
.
linear_gru
.
forward
(
x
)[
0
]
x
=
ck
(
self
.
activation
,
self
.
layernorm_linear
.
forward
(
x
))
return
x
.
bfloat16
()
class
HyperNetwork
(
nn
.
Module
):
...
...
@@ -96,11 +103,12 @@ class HyperNetwork(nn.Module):
std
=
(
0.02
/
math
.
sqrt
(
2
*
config
[
"n_layer"
])))
def
forward
(
self
,
x
):
x
=
self
.
linear2
(
ck
(
self
.
activation
,
self
.
linear
(
x
.
float
())))
return
x
.
mul
(
torch
.
sigmoid
(
x
))
.
bfloat16
()
x
=
x
.
float
()
x
=
self
.
linear
(
x
)
x
=
ck
(
self
.
activation
,
x
)
x
=
self
.
linear2
(
x
)
x
=
x
.
mul
(
torch
.
sigmoid
(
x
))
return
x
.
bfloat16
()
class
HyperNetworkSingle
(
nn
.
Module
):
def
__init__
(
self
,
config
):
...
...
@@ -115,14 +123,12 @@ class HyperNetworkSingle(nn.Module):
for
param
in
self
.
linear
.
parameters
():
param
.
data
.
normal_
(
mean
=
0.0
,
std
=
(
0.02
/
math
.
sqrt
(
2
*
config
[
"n_layer"
])))
# state = self.state_dict()
# for k in state:
# state[k] = state[k] * 1 / math.sqrt(2 * config["n_layer"])
# self.load_state_dict(state)
def
forward
(
self
,
x
):
x
=
self
.
linear
(
x
.
float
())
return
x
.
mul
(
torch
.
sigmoid
(
x
))
.
bfloat16
()
x
=
x
.
float
()
x
=
self
.
linear
(
x
)
x
=
x
.
mul
(
torch
.
sigmoid
(
x
))
return
x
.
bfloat16
()
tokenizer
=
AutoTokenizer
.
from_pretrained
(
'gpt2'
)
...
...
@@ -183,14 +189,17 @@ def report_console(data):
print
(
colored
(
"======================================================"
,
"red"
))
def
make_hypernet_saver
(
train_config
,
hypernetwork
):
def
hypernet_saver
(
id
:
str
):
save_folder
=
Path
(
train_config
[
"save_path"
])
/
id
save_folder
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
torch
.
save
(
hypernetwork
.
state_dict
(),
save_folder
/
"hyper.pt"
)
opt
.
save
(
save_folder
/
"opt"
)
return
hypernet_saver
parser
=
argparse
.
ArgumentParser
(
description
=
'Hypernetwork Finetuner'
)
parser
.
add_argument
(
'--run_name'
,
type
=
str
,
help
=
'the run name to use'
,
required
=
True
)
...
...
@@ -383,4 +392,4 @@ for input_ids, labels in t:
curr_step
+=
1
hypernetwork_saver
(
"final"
)
\ No newline at end of file
hypernetwork_saver
(
"final"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment