Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
c20db765
Commit
c20db765
authored
Aug 26, 2022
by
kurumuz
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make DSinference work with basedformer
parent
6ccb1e1f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
120 additions
and
5 deletions
+120
-5
basedformer/models/__init__.py
basedformer/models/__init__.py
+1
-0
basedformer/models/base_lm.py
basedformer/models/base_lm.py
+42
-2
basedformer/models/ds_strats.py
basedformer/models/ds_strats.py
+70
-0
basedformer/models/gptj.py
basedformer/models/gptj.py
+2
-1
basedformer/sampling.py
basedformer/sampling.py
+5
-2
No files found.
basedformer/models/__init__.py
View file @
c20db765
...
...
@@ -6,6 +6,7 @@ from . import alibi
from
.
import
vit
from
.
import
resnet
from
.
import
fast
from
.
import
ds_strats
MODEL_MAP
=
{
"gptj"
:
gptj
.
GPTJModel
,
...
...
basedformer/models/base_lm.py
View file @
c20db765
...
...
@@ -3,6 +3,13 @@ import torch.nn.functional as F
from
dataclasses
import
dataclass
from
dotmap
import
DotMap
import
math
from
basedformer
import
models
class
ConfigClass
:
def
__init__
(
self
,
config
):
#set all the key and values in config to attributes of this class
for
key
,
value
in
config
.
items
():
setattr
(
self
,
key
,
value
)
class
BaseModel
(
nn
.
Module
):
def
__init__
(
self
,
user_config
,
**
kwargs
):
...
...
@@ -61,7 +68,8 @@ class BaseModel(nn.Module):
for
k
,
v
in
self
.
user_config
.
items
():
full_config
[
k
]
=
v
full_config
=
DotMap
(
full_config
)
#full_config = DotMap(full_config)
full_config
=
ConfigClass
(
full_config
)
return
full_config
def
forward_with_hidden_states
(
self
,
x
,
target
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
...
...
@@ -119,4 +127,36 @@ class BaseModel(nn.Module):
if
cache
:
return
x
,
kv_new
else
:
return
x
,
None
\ No newline at end of file
return
x
,
None
def
get_embeds_ds
(
self
,
x
,
past_key_values
=
None
,
use_cache
=
True
):
if
past_key_values
is
None
:
past_key_values
=
[
None
]
*
self
.
n_layer
kv_new
=
[]
x
=
self
.
vocab_embed
(
x
)
for
layer_id
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
layer_past
=
past_key_values
[
layer_id
],
use_cache
=
use_cache
)
kv_new
.
append
(
x
[
1
])
x
=
x
[
0
]
x
=
self
.
ln_final
(
x
)
if
use_cache
:
return
x
,
kv_new
else
:
return
x
,
None
def
forward_ds
(
self
,
x
,
past_key_values
=
None
,
use_cache
=
True
):
x
,
kv
=
self
.
get_embeds_ds
(
x
,
past_key_values
=
past_key_values
,
use_cache
=
use_cache
)
x
=
self
.
lm_head
(
x
)
return
x
,
kv
def
convert_to_ds
(
self
):
convert_func
=
models
.
ds_strats
.
model_map
[
self
.
config
.
Layer
]
model
=
convert_func
(
self
)
return
model
basedformer/models/ds_strats.py
0 → 100644
View file @
c20db765
from
deepspeed.module_inject
import
DSPolicy
import
torch
from
torch.nn.parameter
import
Parameter
from
basedformer
import
models
class
BasedformerGPTJLayerPolicy
(
DSPolicy
):
_orig_layer_class
=
None
#can't have original layer class because in transformerfork all models are just one class
#needs some config from the model.config, including:
#rotary_dim, layer_norm_epsilon
def
__init__
(
self
,
client_module
,
inference
=
True
):
super
()
.
__init__
(
inference
,
scale_attention
=
True
)
self
.
client_module
=
client_module
def
get_hidden_heads
(
self
):
return
self
.
client_module
.
attn
.
q_proj
.
weight
.
shape
[
1
],
\
self
.
client_module
.
attn
.
n_head
def
attention
(
self
):
qw
=
self
.
client_module
.
attn
.
q_proj
.
weight
kw
=
self
.
client_module
.
attn
.
k_proj
.
weight
vw
=
self
.
client_module
.
attn
.
v_proj
.
weight
qkvw
=
Parameter
(
torch
.
cat
((
qw
,
kw
,
vw
),
dim
=
0
),
requires_grad
=
False
)
return
self
.
linear_layer
,
\
qkvw
,
\
None
,
\
self
.
client_module
.
attn
.
out_proj
.
weight
,
\
None
,
\
self
.
scale_attention
,
\
self
.
is_megatron_v2
def
mlp
(
self
):
return
self
.
linear_layer
,
\
self
.
client_module
.
ff
.
ff1
.
weight
,
\
self
.
client_module
.
ff
.
ff1
.
bias
,
\
self
.
client_module
.
ff
.
ff2
.
weight
,
\
self
.
client_module
.
ff
.
ff2
.
bias
def
layerNorm
(
self
):
return
None
,
\
None
,
\
self
.
client_module
.
ln_preattn
.
weight
,
\
self
.
client_module
.
ln_preattn
.
bias
def
GPTJTransform
(
model
):
model
.
config
.
rotary_dim
=
model
.
layers
[
0
]
.
attn
.
rotary_dim
model
.
config
.
layer_norm_epsilon
=
1e-5
model
.
forward
=
model
.
forward_ds
model
.
get_embeds
=
model
.
get_embeds_ds
import
deepspeed
model
=
deepspeed
.
init_inference
(
model
,
mp_size
=
1
,
dtype
=
torch
.
float16
,
replace_method
=
"auto"
,
injection_policy
=
{
models
.
gptj
.
GPTJLayer
:
BasedformerGPTJLayerPolicy
},
replace_with_kernel_inject
=
True
,
enable_cuda_graph
=
True
,
)
return
model
model_map
=
{
models
.
gptj
.
GPTJLayer
:
GPTJTransform
,
}
basedformer/models/gptj.py
View file @
c20db765
...
...
@@ -245,7 +245,8 @@ class GPTJModel(base_lm.BaseModel):
'activation'
:
gelu_new
,
'SelfAttention'
:
SelfAttention
,
'FeedForward'
:
FeedForward
,
'masked_softmax_fusion'
:
True
,
'q_only'
:
False
,
'masked_softmax_fusion'
:
False
,
}
base_lm
.
BaseModel
.
__init__
(
self
,
user_config
,
**
kwargs
)
if
self
.
config
.
masked_softmax_fusion
:
...
...
basedformer/sampling.py
View file @
c20db765
...
...
@@ -174,7 +174,7 @@ def generate_greedy(forward, prompt_tokens, tokens_to_generate=50, hypernetwork=
return
generated
@
torch
.
no_grad
()
def
generate
(
forward
,
prompt_tokens
,
tokens_to_generate
=
50
,
ops_list
=
[{
"temp"
:
0.9
}],
hypernetwork
=
None
,
non_deterministic
=
False
,
fully_deterministic
=
False
):
def
generate
(
forward
,
prompt_tokens
,
tokens_to_generate
=
50
,
ds
=
False
,
ops_list
=
[{
"temp"
:
0.9
}],
hypernetwork
=
None
,
non_deterministic
=
False
,
fully_deterministic
=
False
):
in_tokens
=
prompt_tokens
context
=
prompt_tokens
generated
=
torch
.
zeros
(
len
(
ops_list
),
0
,
dtype
=
torch
.
long
)
.
to
(
in_tokens
.
device
)
...
...
@@ -192,7 +192,10 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0
}
for
_
in
range
(
tokens_to_generate
):
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
,
hypernetwork
=
hypernetwork
)
if
ds
:
logits
,
kv
=
forward
(
in_tokens
,
past_key_values
=
kv
,
use_cache
=
True
)
else
:
logits
,
kv
=
forward
(
in_tokens
,
cache
=
True
,
kv
=
kv
,
hypernetwork
=
hypernetwork
)
logits
=
logits
[:,
-
1
,
:]
#get the last token in the seq
logits
=
torch
.
log_softmax
(
logits
,
dim
=-
1
)
#if kv[0][0].shape[0] == 1 and (kv[0][0].shape[0] != len(ops_list)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment