Commit 8939ff4f authored by Arda Cihaner's avatar Arda Cihaner

changing ResNet class to inherit BaseVisionModel

parent 558c30dc
......@@ -3,6 +3,8 @@ from . import gpt2
from . import fairseq
from . import gptneo
from . import alibi
from . import vit
from . import resnet
from . import fast
MODEL_MAP = {
......@@ -11,6 +13,8 @@ MODEL_MAP = {
"gpt-fairseq": fairseq.GPTFairModel,
"gpt-neo": gptneo.GPTNeoModel,
"alibi": alibi.AlibiModel,
"vit": vit.VisionTransformer,
"resnet": resnet.ResNet
}
def get_model(model_name: str):
......
......@@ -2,6 +2,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.models import base_image
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
......@@ -48,10 +50,14 @@ class ResBlockBottleNeck(nn.Module):
return F.relu(out)
class ResNet(nn.Module):
def __init__(self, in_channels, out_size=1000, network_layers=18) -> None:
super().__init__()
base_chan = 64
class ResNet(base_image.BaseVisionModel):
def __init__(self) -> None:
self.default_config = {
'in_channels': 3,
'network_size': 18, #ResNet18/34/50/101/152
'n_class': 100
}
super().__init__(self.default_config)
network_config_dict = {
18: (False, (2, 2, 2, 2)),
34: (False, (3, 4, 6, 3)),
......@@ -60,15 +66,15 @@ class ResNet(nn.Module):
152: (True, (3, 4, 36, 3))
}
self.layerin = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3),
nn.Conv2d(self.config.in_channels, 64, kernel_size=7, stride=2, padding=3),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.resblocks = nn.ModuleList()
network_config = network_config_dict[network_layers]
network_config = network_config_dict[self.config.network_layers]
is_bottleneck = network_config[0]
curr_chan = base_chan
curr_chan = 64
prev_chan = curr_chan
for i in network_config[1]:
for _ in range(i):
......@@ -78,7 +84,7 @@ class ResNet(nn.Module):
curr_chan *= 2
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(prev_chan, out_size)
self.fc = nn.Linear(prev_chan, self.config.out_size)
def forward(self, x):
out = self.layerin(x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment