Commit 188c98ed authored by Arda Cihaner's avatar Arda Cihaner

resnet config fix

parent dcd0bba5
...@@ -2,11 +2,10 @@ import torch.nn as nn ...@@ -2,11 +2,10 @@ import torch.nn as nn
from dotmap import DotMap from dotmap import DotMap
class BaseVisionModel(nn.Module): class BaseVisionModel(nn.Module):
def __init__(self, user_config): def __init__(self, user_config, **kwargs):
super().__init__() super().__init__()
self.user_config = user_config self.user_config = user_config
self.config = self.configure_model() self.config = self.configure_model()
config = self.config
def configure_model(self): def configure_model(self):
full_config = {} full_config = {}
......
...@@ -51,13 +51,13 @@ class ResBlockBottleNeck(nn.Module): ...@@ -51,13 +51,13 @@ class ResBlockBottleNeck(nn.Module):
class ResNet(base_image.BaseVisionModel): class ResNet(base_image.BaseVisionModel):
def __init__(self) -> None: def __init__(self, user_config, **kwargs) -> None:
self.default_config = { self.default_config = {
'in_channels': 3, 'in_channels': 3,
'network_size': 18, #ResNet18/34/50/101/152 'network_size': 18, #ResNet18/34/50/101/152
'n_class': 100 'n_class': 100
} }
super().__init__(self.default_config) super().__init__(user_config, **kwargs)
network_config_dict = { network_config_dict = {
18: (False, (2, 2, 2, 2)), 18: (False, (2, 2, 2, 2)),
34: (False, (3, 4, 6, 3)), 34: (False, (3, 4, 6, 3)),
......
...@@ -127,8 +127,8 @@ class VisionTransformer(base_image.BaseVisionModel): ...@@ -127,8 +127,8 @@ class VisionTransformer(base_image.BaseVisionModel):
'activation': torch.nn.GELU(), 'activation': torch.nn.GELU(),
'image_size': (224, 224), 'image_size': (224, 224),
'eps': 1e-5, 'eps': 1e-5,
'device': torch.device('cpu'), 'device': torch.device('cuda'),
'dtype': torch.float32, 'dtype': torch.float16,
} }
super().__init__(self.default_config) super().__init__(self.default_config)
self.embed = ViTEmbeds(self.config) self.embed = ViTEmbeds(self.config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment