Commit 89ceb109 authored by novelailab's avatar novelailab

dont use no_init when we need to init

parent 9f01112b
from main import *
import time
from time import perf_counter, perf_counter_ns
import numpy as np
from tqdm import tqdm
from contextlib import contextmanager
#replicating timeit magic function of ipython
def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True):
precision = 'ns'
r_arr = np.empty([2, r]) # [0] = mean, [1] = std
if function:
func.__name__ = function.__name__
for i in tqdm(range(r)) if do_tqdm else range(r):
n_arr = np.empty(n)
for k in range(n):
start = perf_counter_ns()
func()
n_arr[k] = perf_counter_ns() - start
if not first:
# delete the first element from n_arr numpy array
n_arr = np.delete(n_arr, 0)
r_arr[0, i] = np.mean(n_arr)
r_arr[1, i] = np.std(n_arr)
best = r_arr[:, np.argmin(r_arr[0])] # [0] = mean, [1] = std
#check if best[0] bigger than 1ms in numpy
if best[0] < 1e3:
precision = 'ns'
elif best[0] >= 1e9:
print('b')
best[0] = best[0] * 1e-9
best[1] = best[1] * 1e-9
precision = 's'
elif best[0] >= 1e6:
best[0] = best[0] * 1e-6
best[1] = best[1] * 1e-6
precision = 'ms'
elif best[0] >= 1e3:
precision = 'μs'
best[0] = best[0] * 1e-3
best[1] = best[1] * 1e-3
if not quiet:
if precision == 'ns':
print(f"{func.__name__}: {best[0]:.0f}{precision} ± {best[1]:.0f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
if precision == 'μs':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 'ms':
print(f"{func.__name__}: {best[0]:.2f}{precision} ± {best[1]:.2f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
elif precision == 's':
print(f"{func.__name__}: {best[0]:.4f}{precision} ± {best[1]:.4f}{precision} per loop (mean ± std. dev. of {str(r)} runs, {str(n)} loops each)")
def test_thing(graph, input):
torch.cuda.synchronize()
static_input.copy_(input)
graph.replay()
torch.cuda.synchronize()
with torch.no_grad():
model = init_1_3b().cuda().half()
shape = (1, 512)
x = torch.zeros(shape).cuda().long()
print(shape)
print("PyTorch Eager")
timeit(r=1, n=100, func=lambda: model(x), do_tqdm=False, first=False)
print("PyTorch CUDAGraph+JIT")
module = torch.jit.trace(model, torch.zeros(shape).long().cuda())
torch.jit.optimize_for_inference(module)
static_input = torch.randint(0, 50256, shape, device='cuda')
fake_inputs = [torch.randint(0, 50256, shape, device="cuda") for _ in range(100)]
real_inputs = [torch.randint(0, 50256, shape, device="cuda") for _ in range(100)]
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for y in fake_inputs:
static_output = module(y)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
static_output = module(static_input)
timeit(func=lambda: test_thing(g, static_input), r=1, n=100, do_tqdm=False, first=False)
......@@ -59,12 +59,12 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with torch.no_grad():
model = load_gpt_j().cuda().half()
x = torch.zeros(1, 1024).cuda().long()
model = init_6b().cuda().half()
x = torch.zeros(50, 1).cuda().long()
print(model(x).shape)
print("PyTorch Eager")
timeit(r=1, n=100, func=lambda: model(x), do_tqdm=False, first=False)
module = torch.jit.trace(model, torch.zeros((1, 1024)).long().cuda())
module = torch.jit.trace(model, torch.zeros((50, 1)).long().cuda())
torch.jit.optimize_for_inference(module)
print("PyTorch JIT")
timeit(r=1, n=100, func=lambda: module(x), do_tqdm=False, first=False)
\ No newline at end of file
......@@ -10,13 +10,6 @@ import os
from pathlib import Path
import math
def defaults():
# Easily accessible defaults
D_LAYER = GPTLayer
D_ATTN = SelfAttention
D_FF = FeedForward
D_ACT = gelu_new
def no_init(loading_code):
def dummy(self):
return
......@@ -278,7 +271,7 @@ class GPTModel(nn.Module):
@classmethod
def init(cls, config):
model = no_init(lambda: cls(**config))
model = cls(**config)
return model
def save(self, path):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment