Commit 7eebf8ad authored by novelailab's avatar novelailab

add more stuff

parent 4a13de3a
from main import *
state_dict = SplitCheckpoint("/home/xuser/models/j6b_ckpt_14001", device="cpu")
# ORIGINAL
'''
transformer.ln_f.weight
transformer.ln_f.bias
lm_head.weight
lm_head.bias
transformer.h.9.ln_1.weight
transformer.h.9.ln_1.bias
transformer.h.9.mlp.c_proj.weight
transformer.h.9.mlp.c_proj.bias
transformer.h.9.mlp.c_fc.weight
transformer.h.9.mlp.c_fc.bias
transformer.h.9.attn.attention.out_proj.weight
transformer.h.9.attn.attention.k_proj.weight
transformer.h.9.attn.attention.v_proj.weight
transformer.h.9.attn.attention.q_proj.weight
transformer.wte.weight
'''
new_state_dict = {}
module_map = {
"ln_1": "ln_preattn",
"mlp.c_proj": "ff.ff2",
"mlp.c_fc": "ff.ff1",
"attn.attention.out_proj": "attn.out_proj",
"attn.attention.k_proj": "attn.k_proj",
"attn.attention.v_proj": "attn.v_proj",
"attn.attention.q_proj": "attn.q_proj",
"wte": "vocab_embed",
'ln_f': 'ln_final',
'lm_head': 'lm_head',
}
print(type(state_dict))
for key in state_dict.keys():
dotlist = key.split('.')
if len(dotlist) > 3:
layer = dotlist[2]
for x in module_map:
if x in key:
new_state_dict[f"layers.{layer}.{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
print(f"{key} -> layers.{layer}.{module_map[x]}.{dotlist[-1]}")
else:
for x in module_map:
if x in key:
new_state_dict[f"{module_map[x]}.{dotlist[-1]}"] = state_dict[key]
print(f"{key} -> {module_map[x]}.{dotlist[-1]}")
#print(new_state_dict)
def save(state_dict, path):
try: os.mkdir(path)
except: pass
checkpoint = {}
for i, x in enumerate(state_dict.items()):
checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt")
save(new_state_dict, "models/6b")
\ No newline at end of file
This diff is collapsed.
from main import *
import time
state_dict = SplitCheckpoint("'/home/xuser/models/j6b_ckpt_14001", device="cuda")
for x in state_dict:
print(x)
from time import perf_counter, perf_counter_ns
import numpy as np
from tqdm import tqdm
from contextlib import contextmanager
#replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
precision = 'ns'
r_arr = np.empty([2, r]) # [0] = mean, [1] = std
if function:
func.__name__ = function.__name__
for i in tqdm(range(r)) if do_tqdm else range(r):
n_arr = np.empty(n)
for k in range(n):
start = perf_counter_ns()
func()
n_arr[k] = perf_counter_ns() - start
if not first:
# delete the first element from n_arr numpy array
n_arr = np.delete(n_arr, 0)
r_arr[0, i] = np.mean(n_arr)
r_arr[1, i] = np.std(n_arr)
best = r_arr[:, np.argmin(r_arr[0])] # [0] = mean, [1] = std
#check if best[0] bigger than 1ms in numpy
if best[0] < 1e3:
precision = 'ns'
elif best[0] >= 1e9:
print('b')
best[0] = best[0] * 1e-9
best[1] = best[1] * 1e-9
precision = 's'
elif best[0] >= 1e6:
best[0] = best[0] * 1e-6
best[1] = best[1] * 1e-6
precision = 'ms'
elif best[0] >= 1e3:
precision = 'μs'
best[0] = best[0] * 1e-3
best[1] = best[1] * 1e-3
if not quiet:
if precision == 'ns':
print(f"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
if precision == 'μs':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 'ms':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 's':
print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
with torch.no_grad():
model = init_125m().cuda().half()
'''
timeit(lambda: model(torch.zeros((1, 2048)).long().cuda()), n=20, first=False)
module = torch.jit.trace(model, torch.zeros((1, 2048)).long().cuda())
torch.jit.optimize_for_inference(module)
timeit(lambda: module(torch.zeros((1, 2048)).long().cuda()), n=20, first=False)
timeit(lambda: model(torch.zeros((1, 1000)).long().cuda()), n=20, first=False)
module = torch.jit.trace(model, torch.zeros((1, 1000)).long().cuda())
torch.jit.optimize_for_inference(module)
timeit(lambda: module(torch.zeros((1, 1000)).long().cuda()), n=20, first=False)
'''
module = torch.jit.trace(model, torch.zeros((1, 2048)).long().cuda())
torch.jit.optimize_for_inference(module)
static_input = torch.zeros((1, 2048), device='cuda').long()
static_out = torch.randn((1, 2048, 2048), device='cuda').half()
timeit(lambda: module(static_input), n=20, first=False)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(3):
output = module(torch.randint(0, 50000, (1, 2048), device='cuda').long())
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g, stream=s):
static_out = module(static_input)
real_inputs = [torch.randint(0, 50000, (1, 2048), device='cuda').long() for _ in range(100)]
t = time.perf_counter()
for data in real_inputs:
#print(data[0, :20])
static_input.copy_(data)
#timeit(lambda: g.replay(), n=100, first=True)
g.replay()
#print(static_out[0, 0, :20])
torch.cuda.synchronize()
print(f"{perf_counter() - t}s")
#for data in real_inputs:
# print(model(data)[0, 0, :20])
......@@ -7,14 +7,14 @@ dry = False
config_obj = KubeConfig()
config_obj.set_name(name)
config_obj.set_gpu(gpu_name=GPU.A40, amount=1)
config_obj.set_gpu(gpu_name=GPU.RTX_A5000, amount=1)
config_obj.set_ram(16)
config_obj.set_cpu(4)
#config_obj.set_cpu_only()
config_obj.dry_run(dry)
config_obj.print_information()
config_obj.create_deployment(overwrite=True)
config_obj.create_service(overwrite=True)
#config_obj.create_deployment(overwrite=True)
#config_obj.create_service(overwrite=True)
remote = config_obj.get_pyfra_remote()
env1 = remote.env('noname', python_version=None)
......@@ -25,5 +25,6 @@ models = {'6b': '/home/xuser/models/j6b_ckpt_14001', '20b': '/home/xuser/diffusi
path = env1.path('/home/xuser/diffusionstorage/workspace/kuru/basedformer')
env1.sh('pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl')
env1.sh('pip install einops numpy')
env1.sh('pip install tqdm')
with always_rerun():
path.sh(f'python3 test.py')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment