Commit 09e4a19b authored by novelailab's avatar novelailab

refactor kinda done, yeet main

parent 840dd7f4
from main import * from lm_arch import lm_class
from lm_arch.utils import *
import time import time
import torch
from time import perf_counter, perf_counter_ns from time import perf_counter, perf_counter_ns
import numpy as np import numpy as np
from tqdm import tqdm from tqdm import tqdm
...@@ -65,7 +67,8 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True ...@@ -65,7 +67,8 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with torch.no_grad(): with torch.no_grad():
based_model = load_gpt_j().cuda().half().eval() based_model = lm_class.load_gpt_j().cuda().half().eval()
based_model = based_model.lm
print("Loaded based model") print("Loaded based model")
hf_model = no_init(lambda: AutoModelForCausalLM.from_pretrained('/home/xuser/models/j6b_ckpt_14001')).cuda().half().eval() hf_model = no_init(lambda: AutoModelForCausalLM.from_pretrained('/home/xuser/models/j6b_ckpt_14001')).cuda().half().eval()
print("Loaded hf model") print("Loaded hf model")
......
...@@ -79,7 +79,7 @@ def _attn(query, key, value, causal_mask, masked_bias, ...@@ -79,7 +79,7 @@ def _attn(query, key, value, causal_mask, masked_bias,
class SelfAttention(nn.Module): class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later. # Code copied from HF, might want to sanity check later.
def __init__(self, hidden_dim, n_head, device, dtype): def __init__(self, hidden_dim, n_head, device, dtype):
super().__init__(self) nn.Module.__init__(self)
max_positions = 2049 max_positions = 2049
bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view( bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8, requires_grad=False)).view(
1, 1, max_positions, max_positions).bool() 1, 1, max_positions, max_positions).bool()
...@@ -143,7 +143,7 @@ class SelfAttention(nn.Module): ...@@ -143,7 +143,7 @@ class SelfAttention(nn.Module):
class FeedForward(nn.Module): class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, activation, device, dtype): def __init__(self, dim, hidden_dim, activation, device, dtype):
super().__init__(self) nn.Module.__init__(self)
self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype) self.ff1 = nn.Linear(dim, hidden_dim, device=device, dtype=dtype)
self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype) self.ff2 = nn.Linear(hidden_dim, dim, device=device, dtype=dtype)
self.activation = activation self.activation = activation
...@@ -159,7 +159,7 @@ class FeedForward(nn.Module): ...@@ -159,7 +159,7 @@ class FeedForward(nn.Module):
class GPTJLayer(nn.Module): class GPTJLayer(nn.Module):
def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype): def __init__(self, attn, ff, hidden_dim, n_head, eps, activation, device, dtype):
super().__init__(self) nn.Module.__init__(self)
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype) self.ln_preattn = nn.LayerNorm(hidden_dim, eps=eps, device=device, dtype=dtype)
self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype) self.ff = ff(dim=hidden_dim, hidden_dim=hidden_dim*4, activation=activation, device=device, dtype=dtype)
...@@ -204,8 +204,8 @@ class GPTJLayer(nn.Module): ...@@ -204,8 +204,8 @@ class GPTJLayer(nn.Module):
return x return x
class GPTJModel(nn.Module): class GPTJModel(nn.Module):
def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation, Layer, device, dtype): def __init__(self, hidden_dim, n_layer, n_head, vocab_dim, eps, activation=gelu_new, Layer=GPTJLayer, device="cuda", dtype=torch.float16):
super().__init__(self) nn.Module.__init__(self)
self.n_layer = n_layer self.n_layer = n_layer
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype) self.vocab_embed = nn.Embedding(vocab_dim, self.hidden_dim, device=device, dtype=dtype)
...@@ -225,4 +225,4 @@ class GPTJModel(nn.Module): ...@@ -225,4 +225,4 @@ class GPTJModel(nn.Module):
for layer_id, layer in enumerate(self.layers): for layer_id, layer in enumerate(self.layers):
x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck) x = layer(x, layer_id=layer_id, hypernetwork=hypernetwork, act_ck=act_ck)
x = self.ln_final(x) x = self.ln_final(x)
return x return x
\ No newline at end of file
...@@ -2,11 +2,13 @@ from lm_arch import utils ...@@ -2,11 +2,13 @@ from lm_arch import utils
import math import math
import torch import torch
from torch import nn from torch import nn
from lm_arch import gptj
import os import os
#Having common BaseLM functionality in this class instead of the torch LM itself makes sense. #Having common BaseLM functionality in this class instead of the torch LM itself makes sense.
class BaseLM(nn.Module): class BaseLM(nn.Module):
def __init__(self, config=None, lm=None): def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
self.config = config self.config = config
self.lm = lm self.lm = lm
...@@ -44,13 +46,13 @@ class BaseLM(nn.Module): ...@@ -44,13 +46,13 @@ class BaseLM(nn.Module):
return model return model
@classmethod @classmethod
def load(cls, config, path=None, state_dict=None, strict=False): def load(cls, model_class, config, path=None, state_dict=None, strict=False):
# I am kinda sad that we will not have a load function in lm object itself. # I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions to that as well but not sure. # might be better to add load functions to that as well but not sure.
if path: if path:
state_dict = utils.SplitCheckpoint(path, device="cuda") state_dict = utils.SplitCheckpoint(path, device="cuda")
lm = config.model_class(**config) lm = model_class(**config)
model = cls(config, lm) model = cls(config, lm)
model.lm.load_state_dict(state_dict, strict=strict) model.lm.load_state_dict(state_dict, strict=strict)
return model return model
...@@ -59,11 +61,25 @@ class BaseLM(nn.Module): ...@@ -59,11 +61,25 @@ class BaseLM(nn.Module):
if self.lm is None: if self.lm is None:
print("No LM object to save. Please first init a model.") print("No LM object to save. Please first init a model.")
return return
try: os.mkdir(path) try: os.mkdir(path)
except: pass except: pass
checkpoint = {} checkpoint = {}
for i, x in enumerate(self.lm.state_dict().items()): for i, x in enumerate(self.lm.state_dict().items()):
checkpoint[x[0]] = f"{path}/b{i}.pt" checkpoint[x[0]] = f"{path}/b{i}.pt"
torch.save(x[1], f"{path}/b{i}.pt") torch.save(x[1], f"{path}/b{i}.pt")
torch.save(checkpoint, f"{path}/m.pt") torch.save(checkpoint, f"{path}/m.pt")
\ No newline at end of file
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
"n_head": 16,
"hidden_dim": 4096,
"vocab_dim": 50400,
"eps": 1e-5,
"activation": gptj.gelu_new,
"Layer": gptj.GPTJLayer
}
model = BaseLM.load(gptj.GPTJModel, config, path, state_dict)
return model
This diff is collapsed.
# run bash: -b
# run command: default
# kill: -k name
# start pod: -s name
# gpu: -g --gpu
# amount: -n
# cpu cores: -c
# amount of ram: -r
# image: -i
from novelutils.novelfra import * from novelutils.novelfra import *
from pyfra import * from pyfra import *
import argparse import argparse
import sys import sys
parser = argparse.ArgumentParser(description='Novelfra utility tool for launching pods and deployments on kubernetes with pyfra.') parser = argparse.ArgumentParser(description='Novelfra utility tool for launching pods and deployments on kubernetes with pyfra.')
parser.add_argument('name', nargs="?", type=str, help='Deployment name') parser.add_argument('name', nargs="?", type=str, help='Deployment name')
# Make the default the last one we used. #TODO: Make the default the last one we used.
parser.add_argument('--service', action="store_true", help="""Create a service with the deployment. If a service is not parser.add_argument('--service', action="store_true", help="""Create a service with the deployment. If a service is not
created you won't be able to access the pod outside from the kube network.""") created you won't be able to access the pod outside from the kube network.""")
parser.add_argument('-b', '--bash', action="store_true", help='Run bash instead of python3.') parser.add_argument('-b', '--bash', action="store_true", help='Run bash instead of python3.')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment