Commit 42870e7b authored by novelailab's avatar novelailab

oops, fix

parent 41f39980
......@@ -260,12 +260,6 @@ class GPTJConfig:
for k, v in config_dict.items():
setattr(self, k, v)
class GPTJBaseLM(lm_base.BaseLM):
def __init__(self, config=None, lm=None):
nn.Module.__init__(self)
lm_base.BaseLM.__init__(self, config, lm)
self.model_class=GPTJModel
def load_gpt_j(path="models/6b", state_dict=None):
config = {
"n_layer": 28,
......@@ -275,5 +269,5 @@ def load_gpt_j(path="models/6b", state_dict=None):
"eps": 1e-5
}
config = GPTJConfig(**config)
model = GPTJBaseLM.load(config, path, state_dict)
model = lm_base.load(GPTJModel, config, path)
return model
......@@ -26,25 +26,22 @@ def init_weights(model, n_layer):
if ("ff2" in name or "out_proj" in name) and "weight" in name:
p.data.normal_(mean=0.0, std=(0.02 / math.sqrt(2 * n_layer)))
@classmethod
def init(model_class, config):
model = model_class(config)
model.init_weights()
return model
@classmethod
def no_init(model_class, config):
model = utils.no_init(lambda: model_class(config))
return model
@classmethod
def load(config, model_class, path=None, state_dict=None, strict=False):
def load(model_class, config, path=None, state_dict=None, strict=False):
# I am kinda sad that we will not have a load function in lm object itself.
# might be better to add load functions -- actually nope.
if path:
state_dict = utils.SplitCheckpoint(path, device="cuda")
model= utils.no_init(lambda: model_class(**config))
model= utils.no_init(lambda: model_class(config))
model.load_state_dict(state_dict, strict=strict)
return model
......
......@@ -68,7 +68,6 @@ def timeit(func, r=1, n=5, quiet=False, function=None, do_tqdm=False, first=True
with torch.no_grad():
based_model = gptj.load_gpt_j().cuda().half().eval()
based_model = based_model.lm
print("Loaded based model")
hf_model = no_init(lambda: AutoModelForCausalLM.from_pretrained('/home/xuser/models/j6b_ckpt_14001')).cuda().half().eval()
print("Loaded hf model")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment