Commit 92b7b187 authored by novelailab's avatar novelailab

fix names

parent fa2a771c
...@@ -147,7 +147,7 @@ class GPT2Layer(nn.Module): ...@@ -147,7 +147,7 @@ class GPT2Layer(nn.Module):
return x return x
class GPTJModel(base_lm.BaseModel): class GPT2Model(base_lm.BaseModel):
def __init__(self, user_config, **kwargs): def __init__(self, user_config, **kwargs):
self.default_config = { self.default_config = {
'n_layer': 6, 'n_layer': 6,
...@@ -158,7 +158,7 @@ class GPTJModel(base_lm.BaseModel): ...@@ -158,7 +158,7 @@ class GPTJModel(base_lm.BaseModel):
'eps': 1e-5, 'eps': 1e-5,
'device': torch.device('cuda'), 'device': torch.device('cuda'),
'dtype': torch.float16, 'dtype': torch.float16,
'Layer': GPTJLayer, 'Layer': GPT2Layer,
'activation': gelu_new, 'activation': gelu_new,
'SelfAttention': SelfAttention, 'SelfAttention': SelfAttention,
'FeedForward': FeedForward, 'FeedForward': FeedForward,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment