Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
fb134d28
Commit
fb134d28
authored
Jul 06, 2022
by
novelailab
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
zero2 works
parent
9d27a5cc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
33 deletions
+41
-33
basedformer/models/gptj.py
basedformer/models/gptj.py
+13
-10
basedformer/optimizer.py
basedformer/optimizer.py
+4
-0
finetune.py
finetune.py
+24
-23
No files found.
basedformer/models/gptj.py
View file @
fb134d28
...
...
@@ -95,15 +95,18 @@ class SelfAttention(nn.Module):
sin
,
cos
=
fixed_pos_embedding
(
dim
=
self
.
rotary_dim
,
seq_len
=
max_positions
)
self
.
register_buffer
(
"sin"
,
sin
)
self
.
register_buffer
(
"cos"
,
cos
)
self
.
fused_softmax
=
FusedScaleMaskSoftmax
(
input_in_fp16
=
False
,
input_in_bf16
=
True
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
"causal"
,
scaled_masked_softmax_fusion
=
True
,
)
if
self
.
config
.
masked_softmax_fusion
:
self
.
fused_softmax
=
FusedScaleMaskSoftmax
(
input_in_fp16
=
False
,
input_in_bf16
=
True
,
mask_func
=
attention_mask_func
,
scale
=
None
,
softmax_in_fp32
=
False
,
attn_mask_type
=
"causal"
,
scaled_masked_softmax_fusion
=
True
,
)
else
:
self
.
fused_softmax
=
None
def
forward
(
self
,
x
,
kv
=
None
,
cache
=
False
):
B
,
S
,
H
=
x
.
shape
# batch, sequence, hidden_dim
...
...
@@ -242,7 +245,7 @@ class GPTJModel(base_lm.BaseModel):
'activation'
:
gelu_new
,
'SelfAttention'
:
SelfAttention
,
'FeedForward'
:
FeedForward
,
'masked_softmax_fusion'
:
Fals
e
,
'masked_softmax_fusion'
:
Tru
e
,
}
base_lm
.
BaseModel
.
__init__
(
self
,
user_config
,
**
kwargs
)
if
self
.
config
.
masked_softmax_fusion
:
...
...
basedformer/optimizer.py
View file @
fb134d28
...
...
@@ -77,6 +77,10 @@ class BasedOptimizer:
eps
=
self
.
eps
,
)
elif
self
.
optimizer_name
==
"zero2"
:
from
apex.contrib.optimizers.distributed_fused_adam
import
DistributedFusedAdam
self
.
optimizer
=
DistributedFusedAdam
(
self
.
parameters
,
lr
=
0
,
weight_decay
=
self
.
weight_decay
,
betas
=
(
self
.
beta1
,
self
.
beta2
),
eps
=
self
.
eps
,
grad_sync_dtype
=
torch
.
float32
)
elif
self
.
optimizer_name
==
"adafactor"
:
try
:
from
transformers.optimization
import
Adafactor
...
...
finetune.py
View file @
fb134d28
...
...
@@ -5,7 +5,7 @@ import torch.cuda.amp as amp
import
torch.optim
as
optim
from
pathlib
import
Path
from
torch.utils
import
data
from
basedformer
import
optimizer
,
utils
,
lm_utils
from
basedformer
import
optimizer
,
utils
,
lm_utils
,
dataset
import
yaml
import
sys
from
tqdm
import
tqdm
...
...
@@ -16,17 +16,11 @@ import os
from
icecream
import
ic
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
#from torch.nn.parallel import DistributedDataParallel as DDP
from
apex.parallel.distributed
import
DistributedDataParallel
as
DDP
from
torch.utils.data.distributed
import
DistributedSampler
from
dotmap
import
DotMap
import
argparse
from
torch.distributed.fsdp
import
(
FullyShardedDataParallel
,
CPUOffload
,
)
from
torch.distributed.fsdp.wrap
import
(
default_auto_wrap_policy
,
)
def
setup
(
rank
,
world_size
):
#os.environ['MASTER_ADDR'] = 'localhost'
...
...
@@ -97,14 +91,19 @@ def fsdp_train(args, model, train_loader, opt):
norm
=
norm
.
matmul
(
norm
.
transpose
(
-
1
,
-
2
))
contrastive_loss
=
torch
.
matmul
(
hs
,
hs
.
transpose
(
-
2
,
-
1
))
.
div
(
norm
)
.
abs
()
.
mean
()
gas_loss
+=
contrastive_loss
*
args
.
contrastive_loss
if
args
[
"loss_scale"
]:
scaler
.
scale
(
gas_loss
)
.
backward
()
with
opt
.
optimizer
.
no_sync
():
scaler
.
scale
(
gas_loss
)
.
backward
()
else
:
gas_loss
.
backward
()
with
opt
.
optimizer
.
no_sync
():
gas_loss
.
backward
()
loss
+=
gas_loss
.
item
()
loss
=
loss
/
gas
opt
.
optimizer
.
grad_sync
()
if
args
[
"loss_scale"
]:
scaler
.
unscale_
(
opt
.
optimizer
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1
)
...
...
@@ -116,10 +115,10 @@ def fsdp_train(args, model, train_loader, opt):
if
args
[
"loss_scale"
]:
scaler
.
update
()
#
opt.zero_grad()
model
.
zero_grad
(
set_to_none
=
True
)
opt
.
zero_grad
()
#
model.zero_grad(set_to_none=True)
sec_per_step
=
(
time
.
perf_counter
()
-
timex
)
flops
=
get_flops
(
args
,
model
.
module
,
sec_per_step
)
flops
=
get_flops
(
args
,
model
,
sec_per_step
)
step_per_sec
=
(
1.
/
sec_per_step
)
tokens_per_sec
=
(
step_per_sec
*
2048
)
*
bs
*
gas
*
world_size
batch_size
=
bs
*
gas
*
world_size
...
...
@@ -153,15 +152,17 @@ def main(rank, global_rank, world_size, args):
setup
(
rank
,
world_size
)
Path
(
args
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
model
=
lm_utils
.
load_from_path
(
"pretrained/gpt-j-base"
)
.
float
()
.
to
(
rank
)
fsdp_model
=
DDP
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
gradient_as_bucket_view
=
True
)
model
=
lm_utils
.
load_from_path
(
"/home/xuser/nvme1/pretrained/gpt-j-base"
)
.
half
()
.
to
(
rank
)
#fsdp_model = DDP(model, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
#fsdp_model = DDP(model)
fsdp_model
=
model
utils
.
print_parameters
(
fsdp_model
)
ic
(
"model loaded"
)
opt
=
optimizer
.
BasedOptimizer
(
fsdp_model
.
parameters
(),
args
,
"zero
1
"
)
opt
=
optimizer
.
BasedOptimizer
(
fsdp_model
.
parameters
(),
args
,
"zero
2
"
)
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print
(
opt
.
curr_step
)
train_dataset
=
utils
.
ShardedDataset
(
2049
,
args
[
"data_path"
],
world_size
=
world_size
,
rank
=
global_rank
)
train_dataset
=
dataset
.
ShardedDataset
(
2049
,
args
[
"data_path"
],
world_size
=
world_size
,
rank
=
global_rank
)
train_loader
=
data
.
DataLoader
(
train_dataset
,
batch_size
=
bs
*
gas
,
shuffle
=
False
,
num_workers
=
0
,
)
if
global_rank
==
0
:
wandb
.
init
(
project
=
"basedformer-tests"
,
name
=
args
[
"run_name"
],
config
=
{
**
args
,
**
model
.
config
})
...
...
@@ -172,21 +173,21 @@ def main(rank, global_rank, world_size, args):
if
__name__
==
"__main__"
:
train_config
=
{
"data_path"
:
"dataset/sigurd-1G.map"
,
"data_path"
:
"
/home/xuser/nvme1/
dataset/sigurd-1G.map"
,
"save_path"
:
"models/gptj-sigurd-1G-vanilla"
,
"do_save"
:
Tru
e
,
"do_save"
:
Fals
e
,
"run_name"
:
"gptj-sigurd-1G-vanilla"
,
"lr"
:
6e-5
,
"end_lr"
:
3e-5
,
"warmup_steps"
:
100
,
"anneal_steps"
:
7850
,
"bs"
:
2
,
"gas"
:
2
,
"gas"
:
8
,
"seed"
:
69
,
"save_every"
:
500
,
"amp"
:
Tru
e
,
"amp"
:
Fals
e
,
"loss_scale"
:
True
,
"cast_to"
:
torch
.
float16
,
"cast_to"
:
torch
.
b
float16
,
"contrastive_loss"
:
False
,
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment