Commit 8939ff4f authored by Arda Cihaner's avatar Arda Cihaner

changing ResNet class to inherit BaseVisionModel

parent 558c30dc
...@@ -3,6 +3,8 @@ from . import gpt2 ...@@ -3,6 +3,8 @@ from . import gpt2
from . import fairseq from . import fairseq
from . import gptneo from . import gptneo
from . import alibi from . import alibi
from . import vit
from . import resnet
from . import fast from . import fast
MODEL_MAP = { MODEL_MAP = {
...@@ -11,6 +13,8 @@ MODEL_MAP = { ...@@ -11,6 +13,8 @@ MODEL_MAP = {
"gpt-fairseq": fairseq.GPTFairModel, "gpt-fairseq": fairseq.GPTFairModel,
"gpt-neo": gptneo.GPTNeoModel, "gpt-neo": gptneo.GPTNeoModel,
"alibi": alibi.AlibiModel, "alibi": alibi.AlibiModel,
"vit": vit.VisionTransformer,
"resnet": resnet.ResNet
} }
def get_model(model_name: str): def get_model(model_name: str):
......
...@@ -2,6 +2,8 @@ import torch ...@@ -2,6 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from basedformer.models import base_image
class ResBlock(nn.Module): class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels):
super().__init__() super().__init__()
...@@ -48,10 +50,14 @@ class ResBlockBottleNeck(nn.Module): ...@@ -48,10 +50,14 @@ class ResBlockBottleNeck(nn.Module):
return F.relu(out) return F.relu(out)
class ResNet(nn.Module): class ResNet(base_image.BaseVisionModel):
def __init__(self, in_channels, out_size=1000, network_layers=18) -> None: def __init__(self) -> None:
super().__init__() self.default_config = {
base_chan = 64 'in_channels': 3,
'network_size': 18, #ResNet18/34/50/101/152
'n_class': 100
}
super().__init__(self.default_config)
network_config_dict = { network_config_dict = {
18: (False, (2, 2, 2, 2)), 18: (False, (2, 2, 2, 2)),
34: (False, (3, 4, 6, 3)), 34: (False, (3, 4, 6, 3)),
...@@ -60,15 +66,15 @@ class ResNet(nn.Module): ...@@ -60,15 +66,15 @@ class ResNet(nn.Module):
152: (True, (3, 4, 36, 3)) 152: (True, (3, 4, 36, 3))
} }
self.layerin = nn.Sequential( self.layerin = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3), nn.Conv2d(self.config.in_channels, 64, kernel_size=7, stride=2, padding=3),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1), nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64), nn.BatchNorm2d(64),
nn.ReLU() nn.ReLU()
) )
self.resblocks = nn.ModuleList() self.resblocks = nn.ModuleList()
network_config = network_config_dict[network_layers] network_config = network_config_dict[self.config.network_layers]
is_bottleneck = network_config[0] is_bottleneck = network_config[0]
curr_chan = base_chan curr_chan = 64
prev_chan = curr_chan prev_chan = curr_chan
for i in network_config[1]: for i in network_config[1]:
for _ in range(i): for _ in range(i):
...@@ -78,7 +84,7 @@ class ResNet(nn.Module): ...@@ -78,7 +84,7 @@ class ResNet(nn.Module):
curr_chan *= 2 curr_chan *= 2
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(prev_chan, out_size) self.fc = nn.Linear(prev_chan, self.config.out_size)
def forward(self, x): def forward(self, x):
out = self.layerin(x) out = self.layerin(x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment