Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
B
Basedformer
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Locked Files
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Security & Compliance
Security & Compliance
Dependency List
License Compliance
Packages
Packages
List
Container Registry
Analytics
Analytics
CI / CD
Code Review
Insights
Issues
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
novelai-storage
Basedformer
Commits
c58dfef8
Commit
c58dfef8
authored
Jun 13, 2022
by
FIRST_NAME LAST_NAME
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
push
parent
a28f0299
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
323 additions
and
39 deletions
+323
-39
basedformer/lm_utils.py
basedformer/lm_utils.py
+35
-9
basedformer/models/base_lm.py
basedformer/models/base_lm.py
+21
-1
basedformer/utils.py
basedformer/utils.py
+1
-1
contrastiverun.py
contrastiverun.py
+29
-0
convertfp16.py
convertfp16.py
+5
-0
dist_run.sh
dist_run.sh
+5
-0
finetune.py
finetune.py
+47
-28
finetunedeepspeed.py
finetunedeepspeed.py
+180
-0
No files found.
basedformer/lm_utils.py
View file @
c58dfef8
...
...
@@ -2,7 +2,7 @@ from basedformer import utils
from
basedformer
import
models
import
math
import
torch
from
torch
import
nn
from
torch
import
nn
,
distributed
import
os
import
json
from
dataclasses
import
dataclass
...
...
@@ -36,14 +36,41 @@ def no_init(config):
model
=
utils
.
no_init
(
lambda
:
model_class
(
config
))
return
model
def
save
(
model
,
path
):
try
:
os
.
mkdir
(
path
)
except
:
pass
def
serialize_config
(
config
):
serialized_dict
=
{
"model_class"
:
"gptj"
,
"model_path"
:
"."
,
'model_config'
:
{
'n_layer'
:
config
.
n_layer
,
'n_head'
:
config
.
n_head
,
'n_tokens'
:
config
.
n_tokens
,
'hidden_dim'
:
config
.
hidden_dim
,
'vocab_dim'
:
config
.
vocab_dim
,
'eps'
:
config
.
eps
,
}
}
return
serialized_dict
def
save
(
model
,
path
,
save_fp16
=
True
):
if
distributed
.
is_initialized
()
and
distributed
.
get_rank
()
!=
0
:
return
if
save_fp16
:
model
=
model
.
half
()
path
=
Path
(
path
)
lm_path
=
path
/
"lm"
#make folder
lm_path
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
checkpoint
=
{}
for
i
,
x
in
enumerate
(
model
.
state_dict
()
.
items
()):
checkpoint
[
x
[
0
]]
=
f
"{path}/b{i}.pt"
torch
.
save
(
x
[
1
],
f
"{path}/b{i}.pt"
)
torch
.
save
(
checkpoint
,
f
"{path}/m.pt"
)
checkpoint
[
x
[
0
]]
=
lm_path
/
f
"b{i}.pt"
torch
.
save
(
x
[
1
],
lm_path
/
f
"b{i}.pt"
)
torch
.
save
(
checkpoint
,
lm_path
/
"m.pt"
)
#write model.config to config.json inside path
with
open
(
path
/
"config.json"
,
"w"
)
as
f
:
json
.
dump
(
serialize_config
(
model
.
config
),
f
)
def
load_from_path
(
config_folder
=
None
,
strict
=
False
):
config_folder
=
Path
(
config_folder
)
...
...
@@ -51,13 +78,12 @@ def load_from_path(config_folder=None, strict=False):
model_class
=
models
.
get_model
(
config
[
"model_class"
])
model_path
=
config
[
"model_path"
]
model_config
=
config
[
"model_config"
]
print
(
model_config
)
if
model_path
==
"."
:
# model_path is the config_folder directory.
model_path
=
config_folder
model_path
=
Path
(
model_path
)
/
"lm"
model_path
=
str
(
Path
(
model_path
)
/
"lm"
)
model
=
_load_dict_model
(
model_class
,
model_config
,
model_path
,
strict
=
strict
)
return
model
...
...
basedformer/models/base_lm.py
View file @
c58dfef8
...
...
@@ -44,7 +44,7 @@ class BaseModel(nn.Module):
full_config
=
DotMap
(
full_config
)
return
full_config
def
forward
(
self
,
x
,
target
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
def
forward
_with_hidden_states
(
self
,
x
,
target
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
x
,
kv
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
,
kv
=
kv
,
cache
=
cache
)
x
=
self
.
lm_head
(
x
)
if
target
:
...
...
@@ -64,6 +64,26 @@ class BaseModel(nn.Module):
else
:
return
x
.
float
()
def
forward
(
self
,
x
,
target
=
None
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
hidden_states
,
kv
=
self
.
get_embeds
(
x
,
hypernetwork
=
hypernetwork
,
act_ck
=
act_ck
,
kv
=
kv
,
cache
=
cache
)
x
=
self
.
lm_head
(
hidden_states
)
if
target
:
logits
=
x
.
view
(
-
1
,
logits
.
shape
[
-
1
])
labels
=
target
.
view
(
-
1
)
loss
=
F
.
cross_entropy
(
logits
,
labels
)
#clean this mess later
if
cache
:
if
target
:
return
loss
,
x
.
float
(),
kv
else
:
return
x
.
float
(),
kv
else
:
if
target
:
return
loss
,
x
.
float
()
else
:
return
x
.
float
(),
hidden_states
def
get_embeds
(
self
,
x
,
hypernetwork
=
None
,
act_ck
=
False
,
kv
=
None
,
cache
=
False
):
if
kv
is
None
:
kv
=
[
None
]
*
self
.
n_layer
...
...
basedformer/utils.py
View file @
c58dfef8
...
...
@@ -98,7 +98,7 @@ class SplitCheckpoint(MutableMapping):
def
__len__
(
self
):
return
len
(
self
.
checkpoint
)
def
__getitem__
(
self
,
key
):
name
=
s
elf
.
checkpoint
[
key
]
name
=
s
tr
(
self
.
checkpoint
[
key
])
if
type
(
name
)
is
tuple
:
return
self
.
_load
(
name
[
0
]
.
split
(
'/'
)[
-
1
],
name
[
1
],
map_location
=
self
.
device
)
else
:
...
...
contrastiverun.py
0 → 100644
View file @
c58dfef8
import
os
import
torch
from
dotmap
import
DotMap
from
finetune
import
main
if
__name__
==
"__main__"
:
train_config
=
{
"data_path"
:
"dataset/sigurd-1G.map"
,
"save_path"
:
"models/gptj-sigurd-1G-contrastive-0.3weight"
,
"do_save"
:
True
,
"run_name"
:
"gptj-sigurd-1G-contrastive0.3weight"
,
"lr"
:
6e-5
,
"end_lr"
:
3e-5
,
"warmup_steps"
:
100
,
"anneal_steps"
:
7850
,
"bs"
:
2
,
"gas"
:
2
,
"seed"
:
69
,
"save_every"
:
500
,
"amp"
:
True
,
"loss_scale"
:
True
,
"cast_to"
:
torch
.
float16
,
"contrastive_loss"
:
0.3
,
}
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
global_rank
=
int
(
os
.
environ
[
"RANK"
])
torch
.
cuda
.
set_device
(
rank
)
main
(
rank
,
global_rank
,
world_size
,
DotMap
(
train_config
))
\ No newline at end of file
convertfp16.py
0 → 100644
View file @
c58dfef8
from
basedformer
import
lm_utils
as
lmu
import
torch
model
=
lmu
.
load_from_path
(
"models/gptj-sigurd-1G-vanilla/final"
)
lmu
.
save
(
model
,
"models/gptj-sigurd-1G-vanilla/final_fp16"
)
\ No newline at end of file
dist_run.sh
0 → 100644
View file @
c58dfef8
NCCL_DEBUG
=
INFO torchrun
\
--nnodes
=
2
\
--nproc_per_node
=
8
\
--rdzv_endpoint
=
10.0.155.233:29300
\
finetune.py
\ No newline at end of file
finetune.py
View file @
c58dfef8
...
...
@@ -18,6 +18,7 @@ import torch.distributed as dist
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.utils.data.distributed
import
DistributedSampler
from
dotmap
import
DotMap
import
argparse
from
torch.distributed.fsdp
import
(
FullyShardedDataParallel
,
...
...
@@ -51,19 +52,20 @@ def get_world():
def
get_flops
(
args
,
model
,
iter_time_s
):
ff
=
model
.
total_params
*
6
attn
=
2048
*
args
.
hidden_size
*
args
.
n_layers
*
60
attn
=
2048
*
model
.
config
.
hidden_dim
*
model
.
config
.
n_layer
*
60
flops
=
(
args
.
bs
*
args
.
gas
*
2048
*
(
ff
+
attn
)
/
(
iter_time_s
)
)
return
flops
return
flops
/
1e12
def
fsdp_train
(
args
,
model
,
train_loader
,
opt
):
bs
=
args
[
"bs"
]
gas
=
args
[
"gas"
]
rank
=
get_rank
()
global_rank
=
get_rank
()
rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
world_size
=
get_world
()
model
.
train
()
ddp_loss
=
torch
.
zeros
(
1
)
.
cuda
()
...
...
@@ -73,19 +75,28 @@ def fsdp_train(args, model, train_loader, opt):
t
=
train_loader
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
counter
=
0
for
input_ids
,
labels
in
t
:
timex
=
time
.
perf_counter
()
input_ids
=
input_ids
.
to
(
rank
)
labels
=
labels
.
to
(
rank
)
loss
=
0
for
x
in
range
(
args
[
"gas"
]):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
args
[
"amp"
],
dtype
=
torch
.
float16
):
logits
=
model
(
input_ids
[
x
*
bs
:(
x
+
1
)
*
bs
,
:
2048
]
.
to
(
rank
),
act_ck
=
True
)
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
args
[
"amp"
],
dtype
=
args
[
"cast_to"
]
):
logits
,
hidden_states
=
model
(
input_ids
[
x
*
bs
:(
x
+
1
)
*
bs
,
:
2048
]
.
to
(
rank
),
act_ck
=
True
)
logits
=
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
])
gas_labels
=
labels
[
x
*
bs
:(
x
+
1
)
*
bs
,
:
2048
]
.
contiguous
()
gas_labels
=
gas_labels
.
view
(
-
1
)
gas_loss
=
F
.
cross_entropy
(
logits
,
gas_labels
)
if
args
.
contrastive_loss
:
#print("contrastive enabled")
with
torch
.
no_grad
():
max
=
hidden_states
.
abs
()
.
amax
()
.
detach
()
hs
=
hidden_states
.
div
(
max
)
norm
=
hs
.
norm
(
dim
=-
1
,
keepdim
=
True
)
norm
=
norm
.
matmul
(
norm
.
transpose
(
-
1
,
-
2
))
contrastive_loss
=
torch
.
matmul
(
hs
,
hs
.
transpose
(
-
2
,
-
1
))
.
div
(
norm
)
.
abs
()
.
mean
()
gas_loss
+=
contrastive_loss
*
args
.
contrastive_loss
if
args
[
"loss_scale"
]:
scaler
.
scale
(
gas_loss
)
.
backward
()
else
:
...
...
@@ -108,12 +119,13 @@ def fsdp_train(args, model, train_loader, opt):
#opt.zero_grad()
model
.
zero_grad
(
set_to_none
=
True
)
sec_per_step
=
(
time
.
perf_counter
()
-
timex
)
flops
=
get_flops
(
args
,
model
.
module
,
sec_per_step
)
step_per_sec
=
(
1.
/
sec_per_step
)
tokens_per_sec
=
(
step_per_sec
*
2048
)
*
bs
*
gas
*
world_size
batch_size
=
bs
*
gas
*
world_size
ddp_loss
[
0
]
=
loss
dist
.
all_reduce
(
ddp_loss
,
op
=
dist
.
ReduceOp
.
SUM
)
if
rank
==
0
:
if
global_
rank
==
0
:
wandb
.
log
({
"train_loss"
:
ddp_loss
[
0
]
/
world_size
,
...
...
@@ -122,57 +134,64 @@ def fsdp_train(args, model, train_loader, opt):
"train/step_per_sec"
:
step_per_sec
,
"train/lr"
:
opt
.
curr_lr
,
"train/batch_size"
:
batch_size
,
"train/loss_scale"
:
scaler
.
get_scale
()
"train/loss_scale"
:
scaler
.
get_scale
(),
"train/flops"
:
flops
,
})
if
counter
!=
0
and
counter
%
args
[
"save_every"
]
==
0
:
if
global_rank
==
0
:
lm_utils
.
save
(
model
.
module
,
Path
(
args
[
"save_path"
])
/
f
"step_{str(counter)}"
)
dist
.
barrier
()
counter
+=
1
# we need 250 batch size to train the small GPT.
def
main
(
rank
,
world_size
,
args
):
def
main
(
rank
,
global_rank
,
world_size
,
args
):
bs
=
args
[
"bs"
]
gas
=
args
[
"gas"
]
torch
.
manual_seed
(
train_config
[
"seed"
])
torch
.
manual_seed
(
args
[
"seed"
])
setup
(
rank
,
world_size
)
Path
(
train_config
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
Path
(
args
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
model
=
lm_utils
.
load_from_path
(
"pretrained/gpt-j-base"
)
.
float
()
.
to
(
rank
)
#fsdp_model = FullyShardedDataParallel(
#model,
#fsdp_auto_wrap_policy=default_auto_wrap_policy,
#cpu_offload=CPUOffload(offload_params=True),
#)
fsdp_model
=
DDP
(
model
,
device_ids
=
[
rank
],
gradient_as_bucket_view
=
True
)
fsdp_model
=
DDP
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
gradient_as_bucket_view
=
True
)
utils
.
print_parameters
(
fsdp_model
)
ic
(
"model loaded"
)
opt
=
optimizer
.
BasedOptimizer
(
fsdp_model
.
parameters
(),
train_config
,
"zero1"
)
opt
=
optimizer
.
BasedOptimizer
(
fsdp_model
.
parameters
(),
args
,
"zero1"
)
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print
(
opt
.
curr_step
)
train_dataset
=
utils
.
ShardedDataset
(
2049
,
train_config
[
"data_path"
],
world_size
=
world_size
,
rank
=
rank
)
train_dataset
=
utils
.
ShardedDataset
(
2049
,
args
[
"data_path"
],
world_size
=
world_size
,
rank
=
global_
rank
)
train_loader
=
data
.
DataLoader
(
train_dataset
,
batch_size
=
bs
*
gas
,
shuffle
=
False
,
num_workers
=
0
,
)
if
rank
==
0
:
wandb
.
init
(
project
=
"basedformer-tests"
,
name
=
train_config
[
"run_name"
],
config
=
{
**
train_config
,
**
model
.
config
})
if
global_
rank
==
0
:
wandb
.
init
(
project
=
"basedformer-tests"
,
name
=
args
[
"run_name"
],
config
=
{
**
args
,
**
model
.
config
})
fsdp_train
(
args
,
fsdp_model
,
train_loader
,
opt
)
lm_utils
.
save
(
fsdp_model
.
module
,
Path
(
args
[
"save_path"
])
/
"final"
)
dist
.
barrier
()
cleanup
()
if
__name__
==
"__main__"
:
train_config
=
{
"data_path"
:
"dataset/sigurd-1G.map"
,
"save_path"
:
"
/home/xuser/diffusionstorage/workspace/kuru/basedformer/
models/gptj-sigurd-1G-vanilla"
,
"do_save"
:
Fals
e
,
"save_path"
:
"models/gptj-sigurd-1G-vanilla"
,
"do_save"
:
Tru
e
,
"run_name"
:
"gptj-sigurd-1G-vanilla"
,
"lr"
:
6e-5
,
"end_lr"
:
2
e-5
,
"end_lr"
:
3
e-5
,
"warmup_steps"
:
100
,
"anneal_steps"
:
1000
0
,
"anneal_steps"
:
785
0
,
"bs"
:
2
,
"gas"
:
1
,
"gas"
:
2
,
"seed"
:
69
,
"save_every"
:
500
,
"amp"
:
True
,
"loss_scale"
:
True
,
"cast_to"
:
torch
.
float16
,
"contrastive_loss"
:
False
,
}
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
global_rank
=
int
(
os
.
environ
[
"RANK"
])
torch
.
cuda
.
set_device
(
rank
)
main
(
rank
,
world_size
,
train_config
)
\ No newline at end of file
main
(
rank
,
global_rank
,
world_size
,
DotMap
(
train_config
))
\ No newline at end of file
finetunedeepspeed.py
0 → 100644
View file @
c58dfef8
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.cuda.amp
as
amp
import
torch.optim
as
optim
from
pathlib
import
Path
from
torch.utils
import
data
from
basedformer
import
optimizer
,
utils
,
lm_utils
import
yaml
import
sys
from
tqdm
import
tqdm
import
time
import
wandb
import
numpy
as
np
import
os
from
icecream
import
ic
import
torch.distributed
as
dist
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.utils.data.distributed
import
DistributedSampler
from
dotmap
import
DotMap
import
argparse
from
torch.distributed.fsdp
import
(
FullyShardedDataParallel
,
CPUOffload
,
)
from
torch.distributed.fsdp.wrap
import
(
default_auto_wrap_policy
,
)
def
setup
(
rank
,
world_size
):
#os.environ['MASTER_ADDR'] = 'localhost'
#os.environ['MASTER_PORT'] = '12355'
# initialize the process group
dist
.
init_process_group
(
backend
=
"nccl"
)
if
dist
.
is_initialized
():
print
(
"Initialized process group"
)
else
:
print
(
"Failed to initialize process group"
)
def
cleanup
():
dist
.
destroy_process_group
()
def
get_rank
():
if
dist
.
is_initialized
():
return
dist
.
get_rank
()
def
get_world
():
if
dist
.
is_initialized
():
return
dist
.
get_world_size
()
def
get_flops
(
args
,
model
,
iter_time_s
):
ff
=
model
.
total_params
*
6
attn
=
2048
*
model
.
config
.
hidden_dim
*
model
.
config
.
n_layer
*
60
flops
=
(
args
.
bs
*
args
.
gas
*
2048
*
(
ff
+
attn
)
/
(
iter_time_s
)
)
return
flops
/
1e12
def
fsdp_train
(
args
,
model
,
train_loader
,
opt
):
bs
=
args
[
"bs"
]
gas
=
args
[
"gas"
]
rank
=
get_rank
()
world_size
=
get_world
()
model
.
train
()
ddp_loss
=
torch
.
zeros
(
1
)
.
cuda
()
if
rank
==
0
:
t
=
tqdm
(
train_loader
)
else
:
t
=
train_loader
scaler
=
torch
.
cuda
.
amp
.
GradScaler
()
for
input_ids
,
labels
in
t
:
timex
=
time
.
perf_counter
()
input_ids
=
input_ids
.
to
(
rank
)
labels
=
labels
.
to
(
rank
)
loss
=
0
for
x
in
range
(
args
[
"gas"
]):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
args
[
"amp"
],
dtype
=
args
[
"cast_to"
]):
logits
=
model
(
input_ids
[
x
*
bs
:(
x
+
1
)
*
bs
,
:
2048
]
.
to
(
rank
),
act_ck
=
True
)
logits
=
logits
.
view
(
-
1
,
logits
.
shape
[
-
1
])
gas_labels
=
labels
[
x
*
bs
:(
x
+
1
)
*
bs
,
:
2048
]
.
contiguous
()
gas_labels
=
gas_labels
.
view
(
-
1
)
gas_loss
=
F
.
cross_entropy
(
logits
,
gas_labels
)
if
args
[
"loss_scale"
]:
scaler
.
scale
(
gas_loss
)
.
backward
()
else
:
gas_loss
.
backward
()
loss
+=
gas_loss
.
item
()
loss
=
loss
/
gas
if
args
[
"loss_scale"
]:
scaler
.
unscale_
(
opt
.
optimizer
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
1
)
if
args
[
"loss_scale"
]:
opt
.
step
(
scaler
=
scaler
)
else
:
opt
.
step
()
if
args
[
"loss_scale"
]:
scaler
.
update
()
#opt.zero_grad()
model
.
zero_grad
(
set_to_none
=
True
)
sec_per_step
=
(
time
.
perf_counter
()
-
timex
)
flops
=
get_flops
(
args
,
model
.
module
,
sec_per_step
)
step_per_sec
=
(
1.
/
sec_per_step
)
tokens_per_sec
=
(
step_per_sec
*
2048
)
*
bs
*
gas
*
world_size
batch_size
=
bs
*
gas
*
world_size
ddp_loss
[
0
]
=
loss
dist
.
all_reduce
(
ddp_loss
,
op
=
dist
.
ReduceOp
.
SUM
)
if
rank
==
0
:
wandb
.
log
({
"train_loss"
:
ddp_loss
[
0
]
/
world_size
,
"train/tokens_per_sec"
:
tokens_per_sec
,
"train/sec_per_step"
:
sec_per_step
,
"train/step_per_sec"
:
step_per_sec
,
"train/lr"
:
opt
.
curr_lr
,
"train/batch_size"
:
batch_size
,
"train/loss_scale"
:
scaler
.
get_scale
(),
"train/flops"
:
flops
,
})
# we need 250 batch size to train the small GPT.
def
main
(
rank
,
global_rank
,
world_size
,
args
):
bs
=
args
[
"bs"
]
gas
=
args
[
"gas"
]
torch
.
manual_seed
(
train_config
[
"seed"
])
setup
(
rank
,
world_size
)
Path
(
train_config
[
"save_path"
])
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
model
=
lm_utils
.
load_from_path
(
"pretrained/gpt-j-base"
)
.
float
()
.
to
(
rank
)
fsdp_model
=
DDP
(
model
,
device_ids
=
[
rank
],
output_device
=
rank
,
gradient_as_bucket_view
=
True
)
utils
.
print_parameters
(
fsdp_model
)
ic
(
"model loaded"
)
opt
=
optimizer
.
BasedOptimizer
(
fsdp_model
.
parameters
(),
train_config
,
"zero1"
)
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print
(
opt
.
curr_step
)
train_dataset
=
utils
.
ShardedDataset
(
2049
,
train_config
[
"data_path"
],
world_size
=
world_size
,
rank
=
global_rank
)
train_loader
=
data
.
DataLoader
(
train_dataset
,
batch_size
=
bs
*
gas
,
shuffle
=
False
,
num_workers
=
0
,
)
if
rank
==
0
:
wandb
.
init
(
project
=
"basedformer-tests"
,
name
=
train_config
[
"run_name"
],
config
=
{
**
train_config
,
**
model
.
config
})
fsdp_train
(
args
,
fsdp_model
,
train_loader
,
opt
)
dist
.
barrier
()
if
rank
==
0
:
fsdp_model
.
module
.
cleanup
()
if
__name__
==
"__main__"
:
train_config
=
{
"data_path"
:
"dataset/sigurd-1G.map"
,
"save_path"
:
"/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/gptj-sigurd-1G-vanilla"
,
"do_save"
:
False
,
"run_name"
:
"gptj-sigurd-1G-vanilla"
,
"lr"
:
6e-5
,
"end_lr"
:
3e-5
,
"warmup_steps"
:
100
,
"anneal_steps"
:
7861
,
"bs"
:
2
,
"gas"
:
2
,
"seed"
:
69
,
"save_every"
:
500
,
"amp"
:
True
,
"loss_scale"
:
False
,
"cast_to"
:
torch
.
bfloat16
,
}
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
global_rank
=
int
(
os
.
environ
[
"RANK"
])
torch
.
cuda
.
set_device
(
rank
)
main
(
rank
,
global_rank
,
world_size
,
DotMap
(
train_config
))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment