Commit 92b7b187 authored by novelailab's avatar novelailab

fix names

parent fa2a771c
......@@ -147,7 +147,7 @@ class GPT2Layer(nn.Module):
return x
class GPTJModel(base_lm.BaseModel):
class GPT2Model(base_lm.BaseModel):
def __init__(self, user_config, **kwargs):
self.default_config = {
'n_layer': 6,
......@@ -158,7 +158,7 @@ class GPTJModel(base_lm.BaseModel):
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
'Layer': GPTJLayer,
'Layer': GPT2Layer,
'activation': gelu_new,
'SelfAttention': SelfAttention,
'FeedForward': FeedForward,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment