Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
42870e7b
Commit
42870e7b
authored
May 09, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
oops, fix
parent
41f39980
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
13 deletions
+3
-13
basedformer/gptj.py
basedformer/gptj.py
+1
-7
basedformer/lm_base.py
basedformer/lm_base.py
+2
-5
scripts/comparehf.py
scripts/comparehf.py
+0
-1
No files found.
basedformer/gptj.py
View file @
42870e7b
...
...
@@ -260,12 +260,6 @@ class GPTJConfig:
for
k
,
v
in
config_dict
.
items
():
setattr
(
self
,
k
,
v
)
class
GPTJBaseLM
(
lm_base
.
BaseLM
):
def
__init__
(
self
,
config
=
None
,
lm
=
None
):
nn
.
Module
.
__init__
(
self
)
lm_base
.
BaseLM
.
__init__
(
self
,
config
,
lm
)
self
.
model_class
=
GPTJModel
def
load_gpt_j
(
path
=
"models/6b"
,
state_dict
=
None
):
config
=
{
"n_layer"
:
28
,
...
...
@@ -275,5 +269,5 @@ def load_gpt_j(path="models/6b", state_dict=None):
"eps"
:
1e-5
}
config
=
GPTJConfig
(
**
config
)
model
=
GPTJBaseLM
.
load
(
config
,
path
,
state_dict
)
model
=
lm_base
.
load
(
GPTJModel
,
config
,
path
)
return
model
basedformer/lm_base.py
View file @
42870e7b
...
...
@@ -26,25 +26,22 @@ def init_weights(model, n_layer):
if
(
"ff2"
in
name
or
"out_proj"
in
name
)
and
"weight"
in
name
:
p
.
data
.
normal_
(
mean
=
0.0
,
std
=
(
0.02
/
math
.
sqrt
(
2
*
n_layer
)))
@
classmethod
def
init
(
model_class
,
config
):
model
=
model_class
(
config
)
model
.
init_weights
()
return
model
@
classmethod
def
no_init
(
model_class
,
config
):
model
=
utils
.
no_init
(
lambda
:
model_class
(
config
))
return
model
@
classmethod
def
load
(
config
,
model_class
,
path
=
None
,
state_dict
=
None
,
strict
=
False
):
def
load
(
model_class
,
config
,
path
=
None
,
state_dict
=
None
,
strict
=
False
):
# I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions -- actually nope.
if
path
:
state_dict
=
utils
.
SplitCheckpoint
(
path
,
device
=
"cuda"
)
model
=
utils
.
no_init
(
lambda
:
model_class
(
**
config
))
model
=
utils
.
no_init
(
lambda
:
model_class
(
config
))
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
return
model
...
...
scripts/comparehf.py
View file @
42870e7b
...
...
@@ -68,7 +68,6 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with
torch
.
no_grad
():
based_model
=
gptj
.
load_gpt_j
()
.
cuda
()
.
half
()
.
eval
()
based_model
=
based_model
.
lm
print
(
"Loaded based model"
)
hf_model
=
no_init
(
lambda
:
AutoModelForCausalLM
.
from_pretrained
(
'/home/xuser/models/j6b_ckpt_14001'
))
.
cuda
()
.
half
()
.
eval
()
print
(
"Loaded hf model"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment