Commit cc02ad48 authored by Wes Brown's avatar Wes Brown

Some cleanup, device agnosticism.

parent 91470a5a
......@@ -4,11 +4,8 @@ import mmap
import pickle
import concurrent
from torch.utils import data
from simplejpeg import decode_jpeg
import simplejpeg
import pickle
from pathlib import Path
from PIL import Image
from tqdm import tqdm
from concurrent.futures import as_completed
import requests
......@@ -54,6 +51,9 @@ class ShardedDataset(data.Dataset):
class ShardedImageDataset(data.Dataset):
def __init__(self, dataset_path: str, index_path: str, metadata_path=None, threads=None, inner_transform=None,
outer_transform=None, skip=0, bsz=256, world_size=1, local_rank=0, global_rank=0, device="cpu"):
from simplejpeg import decode_jpeg
import simplejpeg
from PIL import Image
self.skip = skip
self.threads = threads
......
......@@ -87,11 +87,13 @@ def load_from_path(config_folder=None, strict=False):
model = _load_dict_model(model_class, model_config, model_path, strict=strict)
return model
def _load_dict_model(model_class, config, path=None, state_dict=None, strict=False):
def _load_dict_model(model_class, config, path=None, state_dict=None,
strict=False, device="cuda"):
# I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions -- actually nope.
if path:
state_dict = utils.SplitCheckpoint(path, device="cuda")
state_dict = utils.SplitCheckpoint(path, device=device)
state_dict.device = device
model= utils.no_init(lambda: model_class(config))
model.load_state_dict(state_dict, strict=strict)
......
......@@ -20,6 +20,12 @@ from basedformer import sampling
from icecream import ic
from termcolor import colored
gpu = "cuda"
amp = torch.cuda.amp
if gpu != "cuda":
amp = torch.amp
scaler = torch.cuda.amp.GradScaler()
def _init_weights(module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
......@@ -158,7 +164,7 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None):
#print("Prompt:")
#for x in range(len(tokens)):
# print(tokenizer.decode([tokens[x]]), end=" | ")
tokens = torch.LongTensor(tokens).unsqueeze(0).cuda()
tokens = torch.LongTensor(tokens).unsqueeze(0).to(gpu)
tokens = [tokens] * bsz
tokens = torch.cat(tokens, dim=0)
......@@ -190,9 +196,9 @@ def sample(prompt, n_tokens, bsz, hypernetwork=None):
# we need 250 batch size to train the small GPT.
train_config = {
"data_path": "/home/xuser/nvme1/dataset/enwik9-gpt2-2049.map",
"save_path": "/home/xuser/models/enwik9-sigurdv4-hypernet2",
"lm_path": "/home/xuser/nvme1/pretrained/sigurdv4",
"data_path": "dataset/enwik9-gpt2-2049.map",
"save_path": "models/enwik9-sigurdv4-hypernet2",
"lm_path": "pretrained/sigurdv4",
"optimizer": "adamw",
"masked_softmax_fusion": False,
"do_save": True,
......@@ -214,8 +220,8 @@ gas = train_config["gas"]
Path(train_config["save_path"]).mkdir(parents=True, exist_ok=True)
#model = GPTModel.gpt2_init(model_config).cuda().float()
model = lm_utils.load_from_path("/home/xuser/nvme1/pretrained/sigurdv4").cuda().bfloat16()
model = lm_utils.load_from_path("pretrained/sigurdv4").to(gpu).bfloat16()
for param in model.parameters():
param.requires_grad = False
......@@ -223,9 +229,7 @@ for name, p in model.named_parameters():
if ("ln" in name or "vocab_embed" in name):
p.requires_grad = True
hypernetwork = HyperNetworkSingle(model.config).cuda().float()
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(model_config["n_layer"] // 5)])
#hypernetwork = nn.ModuleList([HyperNetworkSingle(model_config).cuda().float() for _ in range(2)])
hypernetwork = HyperNetworkSingle(model.config).to(gpu).float()
for param in hypernetwork.parameters():
param.requires_grad = True
......@@ -257,17 +261,17 @@ else:
t = tqdm(train_loader, initial=curr_step)
scaler = torch.cuda.amp.GradScaler()
#sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork)
for input_ids, labels in t:
timex = time.perf_counter()
input_ids = input_ids.cuda()
labels = labels.cuda()
input_ids = input_ids.to(gpu)
labels = labels.to(gpu)
loss = 0
for x in range(train_config["gas"]):
with torch.cuda.amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits, _ = model(input_ids[x*bs:(x+1)*bs, :].cuda(), hypernetwork=hypernetwork, act_ck=True)
with amp.autocast(enabled=train_config["amp"], dtype=torch.float16):
logits, _ = model(input_ids[x*bs:(x+1)*bs, :].to(gpu), hypernetwork=hypernetwork, act_ck=True)
#print(tokenizer.decode(input_ids[x*bs:(x+1)*bs, :][0]))
logits = logits.view(-1, logits.shape[-1])
gas_labels = labels[x*bs:(x+1)*bs, :].contiguous()
......@@ -317,6 +321,7 @@ for input_ids, labels in t:
print(f"Saved model at step {curr_step}")
if curr_step % train_config["eval_every"] == 0 and curr_step != 0:
print("")
sample("<|endoftext|>", 500, 3, hypernetwork=hypernetwork)
curr_step += 1
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment