Commit dcd0bba5 authored by Arda Cihaner's avatar Arda Cihaner

fixes

parent 8939ff4f
......@@ -72,11 +72,11 @@ class ResNet(base_image.BaseVisionModel):
nn.ReLU()
)
self.resblocks = nn.ModuleList()
network_config = network_config_dict[self.config.network_layers]
is_bottleneck = network_config[0]
self.network_config = network_config_dict[self.config.network_size]
is_bottleneck = self.network_config[0]
curr_chan = 64
prev_chan = curr_chan
for i in network_config[1]:
for i in self.network_config[1]:
for _ in range(i):
resblock = ResBlockBottleNeck(prev_chan, curr_chan) if is_bottleneck else ResBlock(prev_chan, curr_chan)
self.resblocks.append(resblock)
......@@ -84,7 +84,7 @@ class ResNet(base_image.BaseVisionModel):
curr_chan *= 2
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(prev_chan, self.config.out_size)
self.fc = nn.Linear(prev_chan, self.config.n_class)
def forward(self, x):
out = self.layerin(x)
......
......@@ -124,7 +124,7 @@ class VisionTransformer(base_image.BaseVisionModel):
'patch_size': 16,
'hidden_dim': 768,
'n_classes' : 1000,
'activation': F.gelu,
'activation': torch.nn.GELU(),
'image_size': (224, 224),
'eps': 1e-5,
'device': torch.device('cpu'),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment