Commit 44751bc6 authored by novelailab's avatar novelailab

config almost done

parent 42870e7b
from typing import KeysView
from typing import Callable, KeysView
from regex import D
import torch
import torch.nn as nn
......@@ -148,7 +148,7 @@ class FeedForward(nn.Module):
class GPTJLayer(nn.Module):
def __init__(self, attn, ff, config):
nn.Module.__init__(self)
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.type)
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ff = ff(config)
self.attn = attn(config)
self.tick = True
......@@ -253,12 +253,8 @@ class GPTJConfig:
eps: float = 1e-5
device: torch.device = torch.device('cuda')
dtype: torch.dtype = torch.float16
Layer = GPTJLayer
activation = gelu_new
def from_dict(self, config_dict):
for k, v in config_dict.items():
setattr(self, k, v)
Layer: nn.Module = GPTJLayer
activation: Callable = gelu_new
def load_gpt_j(path="models/6b", state_dict=None):
config = {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment