Commit b39e6d1b authored by novelailab's avatar novelailab

update

parent 40e90836
from typing import Callable, KeysView from typing import Callable, KeysView
from regex import D
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -156,7 +155,8 @@ class GPTJLayer(nn.Module): ...@@ -156,7 +155,8 @@ class GPTJLayer(nn.Module):
residual = x residual = x
if act_ck: if act_ck:
x = ck(self.ln_preattn, x) x = ck(self.ln_preattn, x)
attn_out, kv = ck(self.attn, x, kv=kv, cache=cache) attn_out, kv = ck(self.attn, x, kv, cache)
#attn_out, kv = self.attn(x, kv=kv, cache=cache)
else: else:
x = self.ln_preattn(x) x = self.ln_preattn(x)
......
...@@ -121,5 +121,9 @@ class BasedOptimizer: ...@@ -121,5 +121,9 @@ class BasedOptimizer:
metadata = pickle.load(f) metadata = pickle.load(f)
based_optimizer = cls(parameters, metadata, metadata["optimizer_name"]) based_optimizer = cls(parameters, metadata, metadata["optimizer_name"])
based_optimizer.optimizer.load_state_dict(torch.load(path / "opt_states.pt")) try:
based_optimizer.optimizer.load_state_dict(torch.load(path / "opt_states.pt"))
except:
print("Couldn't load the optimizer, initializing the optimizer states. Honk!!!")
pass
return based_optimizer return based_optimizer
\ No newline at end of file
...@@ -156,13 +156,13 @@ def func_multinomial(x): ...@@ -156,13 +156,13 @@ def func_multinomial(x):
return torch.multinomial(x, 1) return torch.multinomial(x, 1)
@torch.no_grad() @torch.no_grad()
def generate_greedy(forward, prompt_tokens, tokens_to_generate=50): def generate_greedy(forward, prompt_tokens, tokens_to_generate=50, hypernetwork=None):
in_tokens = prompt_tokens in_tokens = prompt_tokens
padding_token = 50256 padding_token = 50256
generated = torch.zeros(prompt_tokens.shape[0], 0, dtype=torch.long).to(in_tokens.device) generated = torch.zeros(prompt_tokens.shape[0], 0, dtype=torch.long).to(in_tokens.device)
kv = None kv = None
for i in range(tokens_to_generate): for i in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv) logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
logits = logits[:, -1, :] #get the last token in the seq logits = logits[:, -1, :] #get the last token in the seq
# get the token before the padding_token in the seq # get the token before the padding_token in the seq
logits = logits.argmax(dim=-1).unsqueeze(-1) logits = logits.argmax(dim=-1).unsqueeze(-1)
...@@ -173,7 +173,7 @@ def generate_greedy(forward, prompt_tokens, tokens_to_generate=50): ...@@ -173,7 +173,7 @@ def generate_greedy(forward, prompt_tokens, tokens_to_generate=50):
return generated return generated
@torch.no_grad() @torch.no_grad()
def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0.9}]): def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0.9}], hypernetwork=None):
in_tokens = prompt_tokens in_tokens = prompt_tokens
context = prompt_tokens context = prompt_tokens
generated = torch.zeros(prompt_tokens.shape[0], 0, dtype=torch.long).to(in_tokens.device) generated = torch.zeros(prompt_tokens.shape[0], 0, dtype=torch.long).to(in_tokens.device)
...@@ -190,7 +190,7 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0 ...@@ -190,7 +190,7 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0
} }
for _ in range(tokens_to_generate): for _ in range(tokens_to_generate):
logits, kv = forward(in_tokens, cache=True, kv=kv) logits, kv = forward(in_tokens, cache=True, kv=kv, hypernetwork=hypernetwork)
logits = logits[:, -1, :] #get the last token in the seq logits = logits[:, -1, :] #get the last token in the seq
logits = torch.log_softmax(logits, dim=-1) logits = torch.log_softmax(logits, dim=-1)
#can save one softmax here by not applying softmax for the first op, #can save one softmax here by not applying softmax for the first op,
...@@ -222,7 +222,7 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0 ...@@ -222,7 +222,7 @@ def generate(forward, prompt_tokens, tokens_to_generate=50, ops_list=[{"temp": 0
logits = torch.cat(logit_list, dim=0) logits = torch.cat(logit_list, dim=0)
else: else:
torch.manual_seed(69) #torch.manual_seed(69)
logits = torch.multinomial(logits, 1) logits = torch.multinomial(logits, 1)
generated = torch.cat([generated, logits], dim=-1) generated = torch.cat([generated, logits], dim=-1)
......
...@@ -54,6 +54,11 @@ def no_init(loading_code): ...@@ -54,6 +54,11 @@ def no_init(loading_code):
def count_parameters(model, only_trainable=False): def count_parameters(model, only_trainable=False):
return sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable) return sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)
def print_parameters(model, only_trainable=False):
params = sum(p.numel() for p in model.parameters() if p.requires_grad or not only_trainable)
params = params / 1e6
print(f"{params:.2f}M parameters")
SPLIT_WEIGHTS_NAME = "m.pt" SPLIT_WEIGHTS_NAME = "m.pt"
class SplitCheckpoint(MutableMapping): class SplitCheckpoint(MutableMapping):
def __init__(self, name_or_path, device="cpu", subfolder=None): def __init__(self, name_or_path, device="cpu", subfolder=None):
......
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp
import torch.optim as optim
from pathlib import Path
from torch.utils import data
from basedformer import optimizer, utils, lm_utils
import yaml
import sys
from tqdm import tqdm
import time
import wandb
import numpy as np
import os
from icecream import ic
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/gptj-finetune",
"do_save": False,
"run_name": "gptj-finetune",
"lr": 2e-5,
"end_lr": 6e-5,
"warmup_steps": 100,
"anneal_steps": 10000,
"bs": 1,
"gas": 4,
"seed": 69,
"save_every": 500,
"amp": True,
"loss_scale": True,
}
torch.manual_seed(train_config["seed"])
bs = train_config["bs"]
gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = lm_utils.load_from_path("pretrained/gptj-6b").cuda().float()
model.train()
ic("model loaded")
cp_list = sorted(os.listdir(train_config["save_path"]), key=lambda x: int(x.split("_")[-1]))
last_cp = Path(train_config["save_path"]) / cp_list[-1] if len(cp_list) > 0 else None
print(last_cp)
trainable_params = [
model.vocab_embed,
model.lm_head,
model.layers[10],
]
for param in model.parameters():
param.requires_grad = False
for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True
for module in trainable_params:
for p in module.parameters():
module.requires_grad = True
'''
if last_cp:
print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
model.load(model_config, last_cp / "lm", strict=True)
opt = optimizer.BasedOptimizer.load(model.parameters(), last_cp / "opt")
else:
opt = optimizer.BasedOptimizer(model.parameters(), train_config, "adamw")
'''
opt = optimizer.BasedOptimizer(model.layers[10].parameters(), train_config, "adamw")
# TODO: Add load, add evals, add FP16 AMP, and Data Parallel, outputting hidden states from the get_logits function.
print(opt.curr_step)
train_dataset = utils.FbDataset(2049, train_config["data_path"])
if last_cp:
train_dataset.skip = opt.curr_step * bs * gas
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
wandb.init(project="basedformer-tests", name=train_config["run_name"], config={**train_config, **model.config})
if last_cp:
curr_step = opt.curr_step
else:
curr_step = 0
t = tqdm(train_loader, initial=curr_step)
scaler = torch.cuda.amp.GradScaler()
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.cuda()
labels = labels.cuda()
loss = 0
for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits = model(input_ids[x*bs:(x+1)*bs, :2048].cuda(), act_ck=True)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
#roll down the sequence
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :2048].contiguous()
gas_labels = gas_labels.view(-1)
gas_loss = F.cross_entropy(logits, gas_labels)
if train_config["loss_scale"]:
scaler.scale(gas_loss).backward()
else:
gas_loss.backward()
loss += gas_loss.item()
loss = loss / gas
if train_config["loss_scale"]:
scaler.unscale_(opt.optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
if train_config["loss_scale"]:
opt.step(scaler=scaler)
else:
opt.step()
if train_config["loss_scale"]:
scaler.update()
opt.zero_grad()
sec_per_step = (time.perf_counter() - timex)
step_per_sec = (1. / sec_per_step)
tokens_per_sec = (step_per_sec * 2048) * bs * gas
t.set_description(f"{step_per_sec:.2f} steps/s, {sec_per_step:.2f}s/step, {tokens_per_sec:.2f}tokens/s, loss={loss:.4f}")
wandb.log(
{
"train/loss": loss,
"train/tokens_per_sec": tokens_per_sec,
"train/sec_per_step": sec_per_step,
"train/step_per_sec": step_per_sec,
"train/lr": opt.curr_lr,
"train/loss_scale": scaler.get_scale()
},
step=curr_step)
if train_config["do_save"]:
if curr_step % train_config["save_every"] == 0:
save_folder = Path(train_config["save_path"]) / f"step_{curr_step}"
save_folder.mkdir(parents=True, exist_ok=True)
model.save(save_folder / "lm")
opt.save(save_folder / "opt")
print(f"Saved model at step {curr_step}")
curr_step += 1
\ No newline at end of file
...@@ -12,9 +12,10 @@ import wandb ...@@ -12,9 +12,10 @@ import wandb
import numpy as np import numpy as np
from torch.utils.checkpoint import checkpoint as ck from torch.utils.checkpoint import checkpoint as ck
from math import log2, ceil from math import log2, ceil
from basedformer import gptj, optimizer from basedformer import gptj, optimizer, lm_utils
from basedformer.utils import * from basedformer.utils import *
import glob import glob
from icecream import ic
def _init_weights(module): def _init_weights(module):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
...@@ -146,27 +147,20 @@ class HyperNetworkSingle(nn.Module): ...@@ -146,27 +147,20 @@ class HyperNetworkSingle(nn.Module):
x = x.mul(torch.sigmoid(x)) x = x.mul(torch.sigmoid(x))
return x.bfloat16() return x.bfloat16()
model_config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
}
# we need 250 batch size to train the small GPT. # we need 250 batch size to train the small GPT.
train_config = { train_config = {
"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map", #"data_path": "/home/xuser/diffusionstorage/datasets/enwik9-gpt2-2049.map",
"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v7_infilling.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map", #"data_path": "/home/xuser/diffusionstorage/datasets/OWT2-gpt2-full.map",
#"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map", #"data_path": "/home/xuser/diffusionstorage/datasets/sigurd/map/sigurd_v5_fs_2049.map",
"save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-bs16-save", "save_path": "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-infilling",
"do_save": True, "do_save": True,
"run_name": "gpt-j-enwik9-6b-postln-bf16-2e-4-4bsz-every5layersavetest", "run_name": "gpt-j-6b-2e-4-infilling",
"lr": 2e-4, "lr": 2e-4,
"end_lr": 2e-4, "end_lr": 2e-4,
"warmup_steps": 50, "warmup_steps": 50,
"bs": 1, "bs": 4,
"gas": 4, "gas": 1,
"seed": 69, "seed": 69,
"save_every": 300, "save_every": 300,
"amp": False, "amp": False,
...@@ -179,7 +173,7 @@ gas = train_config["gas"] ...@@ -179,7 +173,7 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True) Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float() #model = GPTModel.gpt2_init(model_config).cuda().float()
model = gptj.load_gpt_j().lm.cuda().bfloat16() model = lm_utils.load_from_path("pretrained/gptj-6b").cuda().bfloat16()
for param in model.parameters(): for param in model.parameters():
param.requires_grad = False param.requires_grad = False
...@@ -187,7 +181,7 @@ for name, p in model.named_parameters(): ...@@ -187,7 +181,7 @@ for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name): if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True p.requires_grad = True
hypernetwork = HyperNetworkSingle(model_config).cuda().float() hypernetwork = HyperNetworkSingle(model.config).cuda().float()
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)]) #hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)])
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(2)]) #hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(2)])
for param in hypernetwork.parameters(): for param in hypernetwork.parameters():
...@@ -212,7 +206,7 @@ if last_cp: ...@@ -212,7 +206,7 @@ if last_cp:
train_dataset.skip = opt.curr_step * bs * gas train_dataset.skip = opt.curr_step * bs * gas
train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, ) train_loader = data.DataLoader(train_dataset, batch_size=bs*gas, shuffle=False, num_workers=0, )
wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model_config}) wandb.init(project="hypernetwork-tests", name=train_config["run_name"], config={**train_config, **model.config})
if last_cp: if last_cp:
curr_step = opt.curr_step curr_step = opt.curr_step
...@@ -279,7 +273,5 @@ for input_ids, labels in t: ...@@ -279,7 +273,5 @@ for input_ids, labels in t:
torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt") torch.save(hypernetwork.state_dict(), save_folder / "hyper.pt")
opt.save(save_folder / "opt") opt.save(save_folder / "opt")
print(f"Saved model at step {curr_step}") print(f"Saved model at step {curr_step}")
sys.exit(0)
curr_step += 1 curr_step += 1
\ No newline at end of file
...@@ -51,9 +51,9 @@ path = env1.path('/home/xuser/diffusionstorage/workspace/kuru/basedformer') ...@@ -51,9 +51,9 @@ path = env1.path('/home/xuser/diffusionstorage/workspace/kuru/basedformer')
#env1.sh('pip3 install git+https://github.com/pytorch/fairseq') #env1.sh('pip3 install git+https://github.com/pytorch/fairseq')
env1.sh('pip3 install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl') env1.sh('pip3 install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl')
with always_rerun(): #with always_rerun():
env1.sh('pip uninstall transformers') #env1.sh('pip uninstall transformers')
env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo') #env1.sh('pip install /home/xuser/diffusionstorage/workspace/finetune/pokepls/transformers-repo')
with always_rerun(): with always_rerun():
......
...@@ -7,14 +7,14 @@ import argparse ...@@ -7,14 +7,14 @@ import argparse
# run command: default # run command: default
# kill # kill
name = 'pyfra-basedformer' name = 'basedformer'
dry = False dry = False
bash = False bash = False
config_obj = KubeConfig() config_obj = KubeConfig()
config_obj.set_name(name) config_obj.set_name(name)
config_obj.set_gpu(gpu_name=GPU.RTX_A6000, amount=1) config_obj.set_gpu(gpu_name=GPU.A100_NVLINK, amount=1)
config_obj.set_ram(64) config_obj.set_ram(24)
config_obj.set_cpu(4) config_obj.set_cpu(4)
config_obj.dry_run(dry) config_obj.dry_run(dry)
config_obj.print_information() config_obj.print_information()
...@@ -27,7 +27,7 @@ env1 = remote.env('noname', python_version=None) ...@@ -27,7 +27,7 @@ env1 = remote.env('noname', python_version=None)
path = env1.path('/home/xuser/diffusionstorage/workspace/kuru/basedformer') path = env1.path('/home/xuser/diffusionstorage/workspace/kuru/basedformer')
if False: if True:
env1.sh('pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl') env1.sh('pip install /home/xuser/hugessd/pytorch/torch-1.10.1+cu113-cp38-cp38-linux_x86_64.whl')
env1.sh('pip install einops numpy') env1.sh('pip install einops numpy')
env1.sh('pip install tqdm') env1.sh('pip install tqdm')
...@@ -37,7 +37,7 @@ if False: ...@@ -37,7 +37,7 @@ if False:
env1.sh('pip3 install dotmap icecream') env1.sh('pip3 install dotmap icecream')
path.sh("pip3 install --editable .") path.sh("pip3 install --editable .")
with always_rerun(): with always_rerun():
if True: if False:
#env1.sh('pip3 install transformers') #env1.sh('pip3 install transformers')
#path.sh('pip3 install --editable ../lm-evaluation-harness/.') #path.sh('pip3 install --editable ../lm-evaluation-harness/.')
#env1.sh('pip3 install pytest') #env1.sh('pip3 install pytest')
......
from basedformer import lm_utils as lmu
from basedformer.models import hypernet
from basedformer import sampling
import os
from pathlib import Path
from basedformer.utils import *
from transformers import AutoTokenizer
from icecream import ic
import time
import sys
def main():
#save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-enwik9-bs4-2e-4-catchup"
save_path = "/home/xuser/diffusionstorage/workspace/kuru/basedformer/models/hypernetwork-gptj-2048-infilling"
cp_list = sorted(os.listdir(save_path), key=lambda x: int(x.split("_")[-1]))
last_cp = Path(save_path) / cp_list[-1] if len(cp_list) > 0 else None
print(last_cp)
bsz = 1
gen_len = 400
#torch.manual_seed(69)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
mask = "████████"
prompt = "You hated the elves enough that if you seen one of them in the forest you would just slice their throats."
prompt = """'''Kurumuz''' is the founder of tech company [["""
promptnomask = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window,{mask} holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}"""
prompt = f"""The room was lit now by a dozen candles. The door had been locked, and the windows barred; but there were still some faint glimmers of moonlight on the floor outside. For a moment the figure stood motionless in its doorway to look about it with an air of keen and nervous expectancy. Then he came forward into the chamber and moved{mask}, where he remained standing for an instant upon his toes like one listening intently before starting to rummage among the books and papers. He selected a large volume from among them and turned back to the window, holding it between himself and the rest of the room until he could feel the warm breath of the night creeping through the curtains.{mask}"""
tokens = tokenizer.encode(promptnomask)
print(tokens)
print("Prompt:")
for x in range(len(tokens)):
print(tokenizer.decode([tokens[x]]), end=" | ")
print("\n Generation:")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
tokens = [tokens] * bsz
#tokens = torch.cat([tokens, tokens], dim=0)
tokens = torch.cat(tokens, dim=0)
t = time.perf_counter()
model = lmu.load_from_path('pretrained/gptj-6b').cuda().bfloat16().eval()
hypernetwork = hypernet.HyperNetworkSingle(model.config).cuda().float()
print("Loading from step {}".format(cp_list[-1].split("_")[-1]))
hypernetwork.load_state_dict(torch.load(last_cp / "hyper.pt"))
ic(time.perf_counter() - t)
rep_pen = {
"penalty": 5,
}
ops = {
"rep_pen": rep_pen,
"tfs": 0.86,
"temp": 0.8,
}
ops_list = [ops] * bsz
#tokens_generated = sampling.generate(model.forward, tokens, gen_len, ops_list=ops_list, hypernetwork=hypernetwork)
tokens_generated = sampling.generate_greedy(model.forward, tokens, gen_len, hypernetwork=hypernetwork)
#tokens_generated_batched = generate_real_batched(model.forward, tokens, gen_len, ops=ops)
print(tokens_generated.shape)
tokens_generated[tokens_generated == 48585] = 35625
ic(prompt)
tokens_generated = tokenizer.batch_decode(tokens_generated.cpu().numpy())
for gen in tokens_generated:
print(str(gen.split("*****")[0]))
print("++++++++++++")
print(str(gen.split("*****")[1]))
print("===========================================================")
#ic(tokenizer.batch_decode(tokens_generated_batched.cpu().numpy()))
#timeit(lambda: generate(model.forward, tokens, 30, ops_list=ops_list), n=30)
#timeit(lambda: generate_real_batched(model.forward, tokens, 30, ops=ops), n=30)
if __name__ == "__main__":
main()
\ No newline at end of file
from basedformer import models, utils
import torch
config = {
"n_layer": 6,
"n_head": 8,
"hidden_dim": 4096,
}
#init param matched GPT
gpt = models.gptj.GPTJModel(config).cuda().float()
utils.print_parameters(gpt)
#init param matched LSTM
lstm = torch.nn.LSTM(batch_first=True, input_size=4096, hidden_size=4096, num_layers=12).cuda().float()
utils.print_parameters(lstm)
x = torch.randint(0, 50256, (1, 1)).long().cuda()
y = torch.rand(1, 1, 4096).cuda().float()
with torch.no_grad():
print("GPT:")
utils.timeit(func=lambda: gpt(x), r=10, n=10)
print("LSTM:")
utils.timeit(func=lambda: lstm(y), r=10, n=10)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment