Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
41b51369
Commit
41b51369
authored
Apr 07, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
not too bad
parent
6cbab785
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
53 additions
and
24 deletions
+53
-24
basedformer/gptj.py
basedformer/gptj.py
+19
-1
basedformer/lm_base.py
basedformer/lm_base.py
+29
-19
hypertrain.py
hypertrain.py
+3
-2
scripts/comparehf.py
scripts/comparehf.py
+2
-2
No files found.
basedformer/gptj.py
View file @
41b51369
...
...
@@ -11,6 +11,7 @@ except ImportError:
import
os
from
pathlib
import
Path
import
math
from
basedformer
import
lm_base
def
fixed_pos_embedding
(
dim
=
None
,
seq_len
=
None
,
x
=
None
):
if
x
is
None
:
...
...
@@ -224,3 +225,20 @@ class GPTJModel(nn.Module):
x
=
layer
(
x
,
layer_id
=
layer_id
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
)
x
=
self
.
ln_final
(
x
)
return
x
class
GPTJBaseLM
(
lm_base
.
BaseLM
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
nn
.
Module
.
__init__
(
self
)
lm_base
.
BaseLM
.
__init__
(
self
,
config
,
lm
)
self
.
model_class
=
GPTJModel
def
load_gpt_j
(
path
=
"models/6b"
,
state_dict
=
None
):
config
=
{
"n_layer"
:
28
,
"n_head"
:
16
,
"hidden_dim"
:
4096
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
}
model
=
GPTJBaseLM
.
load
(
config
,
path
,
state_dict
)
return
model
basedformer/lm_base.py
View file @
41b51369
...
...
@@ -5,6 +5,28 @@ from torch import nn
from
basedformer
import
gptj
import
os
import
json
from
dataclasses
import
dataclass
'''
BaseLM config dataclass:
model_config = {
"model_class":
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
}
'''
@
dataclass
class
BaseLMConfig
():
model_class
:
type
n_layer
:
int
n_head
:
int
hidden_dim
:
int
vocab_dim
:
int
eps
:
float
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class
BaseLM
(
nn
.
Module
):
...
...
@@ -12,6 +34,7 @@ class BaseLM(nn.Module):
nn
.
Module
.
__init__
(
self
)
self
.
config
=
config
self
.
lm
=
lm
self
.
model_class
=
None
def
init_weights
(
self
):
for
module
in
self
.
lm
.
modules
():
...
...
@@ -33,8 +56,8 @@ class BaseLM(nn.Module):
@
classmethod
def
init
(
cls
,
config
):
lm
=
config
[
"model_class"
](
**
config
)
model
=
cls
(
config
,
lm
)
model
=
cls
(
config
)
model
.
lm
=
model
.
model_class
(
**
config
)
model
.
init_weights
()
#make this modular later
...
...
@@ -42,8 +65,8 @@ class BaseLM(nn.Module):
@
classmethod
def
no_init
(
cls
,
config
):
lm
=
utils
.
no_init
(
lambda
:
config
.
model_class
(
**
config
)
)
model
=
cls
(
config
,
lm
)
model
=
cls
(
config
)
model
.
lm
=
utils
.
no_init
(
lambda
:
model
.
model_class
(
**
config
)
)
return
model
@
classmethod
...
...
@@ -53,8 +76,8 @@ class BaseLM(nn.Module):
if
path
:
state_dict
=
utils
.
SplitCheckpoint
(
path
,
device
=
"cuda"
)
lm
=
config
[
"model_class"
](
**
config
)
model
=
cls
(
config
,
lm
)
model
=
cls
(
config
)
model
.
lm
=
model
.
model_class
(
**
config
)
model
.
lm
.
load_state_dict
(
state_dict
,
strict
=
strict
)
return
model
...
...
@@ -70,16 +93,3 @@ class BaseLM(nn.Module):
checkpoint
[
x
[
0
]]
=
f
"{path}/b{i}.pt"
torch
.
save
(
x
[
1
],
f
"{path}/b{i}.pt"
)
torch
.
save
(
checkpoint
,
f
"{path}/m.pt"
)
def
load_gpt_j
(
path
=
"models/6b"
,
state_dict
=
None
):
config
=
{
"model_class"
:
gptj
.
GPTJModel
,
"n_layer"
:
28
,
"n_head"
:
16
,
"hidden_dim"
:
4096
,
"vocab_dim"
:
50400
,
"eps"
:
1e-5
}
model
=
BaseLM
.
load
(
config
,
path
,
state_dict
)
return
model
hypertrain.py
View file @
41b51369
...
...
@@ -146,6 +146,7 @@ class HyperNetworkSingle(nn.Module):
return
x
.
bfloat16
()
model_config
=
{
"model_class"
:
"n_layer"
:
28
,
"n_head"
:
16
,
"hidden_dim"
:
4096
,
...
...
@@ -178,7 +179,7 @@ gas = train_config["gas"]
Path
(
train_config
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model
=
l
oad_gpt_j
()
.
cuda
()
.
bfloat16
()
model
=
l
m_base
.
()
.
cuda
()
.
bfloat16
()
for
param
in
model
.
parameters
():
param
.
requires_grad
=
False
...
...
@@ -196,7 +197,7 @@ opt = optimizer.BasedOptimizer(hypernetwork.parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
train_dataset
=
utils
.
FbDataset
(
2049
,
train_config
[
"data_path"
])
train_dataset
=
FbDataset
(
2049
,
train_config
[
"data_path"
])
train_loader
=
data
.
DataLoader
(
train_dataset
,
batch_size
=
bs
*
gas
,
shuffle
=
False
,
num_workers
=
0
)
wandb
.
init
(
project
=
"hypernetwork-tests"
,
name
=
train_config
[
"run_name"
],
config
=
{
**
train_config
,
**
model_config
})
...
...
scripts/comparehf.py
View file @
41b51369
from
basedformer
import
lm_base
from
basedformer
import
gptj
from
basedformer.utils
import
*
import
time
...
...
@@ -67,7 +67,7 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with
torch
.
no_grad
():
based_model
=
lm_base
.
load_gpt_j
()
.
cuda
()
.
half
()
.
eval
()
based_model
=
gptj
.
load_gpt_j
()
.
cuda
()
.
half
()
.
eval
()
based_model
=
based_model
.
lm
print
(
"Loaded based model"
)
hf_model
=
no_init
(
lambda
:
AutoModelForCausalLM
.
from_pretrained
(
'/home/xuser/models/j6b_ckpt_14001'
))
.
cuda
()
.
half
()
.
eval
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment