Commit 188c98ed authored by Arda Cihaner's avatar Arda Cihaner

resnet config fix

parent dcd0bba5
......@@ -2,11 +2,10 @@ import torch.nn as nn
from dotmap import DotMap
class BaseVisionModel(nn.Module):
def __init__(self, user_config):
def __init__(self, user_config, **kwargs):
super().__init__()
self.user_config = user_config
self.config = self.configure_model()
config = self.config
def configure_model(self):
full_config = {}
......
......@@ -51,13 +51,13 @@ class ResBlockBottleNeck(nn.Module):
class ResNet(base_image.BaseVisionModel):
def __init__(self) -> None:
def __init__(self, user_config, **kwargs) -> None:
self.default_config = {
'in_channels': 3,
'network_size': 18, #ResNet18/34/50/101/152
'n_class': 100
}
super().__init__(self.default_config)
super().__init__(user_config, **kwargs)
network_config_dict = {
18: (False, (2, 2, 2, 2)),
34: (False, (3, 4, 6, 3)),
......
......@@ -127,8 +127,8 @@ class VisionTransformer(base_image.BaseVisionModel):
'activation': torch.nn.GELU(),
'image_size': (224, 224),
'eps': 1e-5,
'device': torch.device('cpu'),
'dtype': torch.float32,
'device': torch.device('cuda'),
'dtype': torch.float16,
}
super().__init__(self.default_config)
self.embed = ViTEmbeds(self.config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment