Commit 63f92eba authored by novelailab's avatar novelailab

fix resnet

parent ec79f2ea
from cmath import exp
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -26,28 +27,29 @@ class ResBlock(nn.Module): ...@@ -26,28 +27,29 @@ class ResBlock(nn.Module):
return F.relu(out) return F.relu(out)
class ResBlockBottleNeck(nn.Module): class ResBlockBottleNeck(nn.Module):
def __init__(self, in_channels, out_channels) -> None: def __init__(self, in_channels, out_channels, expansion, needs_downsample=False) -> None:
super().__init__() super().__init__()
downsample = True if in_channels != out_channels else False
self.residual = nn.Sequential() self.residual = nn.Sequential()
if downsample: self.expansion = expansion
if needs_downsample or in_channels != out_channels * self.expansion:
self.residual = nn.Sequential( self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2), nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=2 if needs_downsample else 1),
nn.BatchNorm2d(out_channels) nn.BatchNorm2d(out_channels * self.expansion)
) )
self.conv1 = nn.Conv2d(in_channels, out_channels//4, kernel_size=1, stride=1) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
self.conv2 = nn.Conv2d(out_channels//4, out_channels//4, kernel_size=3, stride=2 if downsample else 1, padding=1) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2 if needs_downsample else 1, padding=1)
self.conv3 = nn.Conv2d(out_channels//4, out_channels, kernel_size=1, stride=1) self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1)
self.bn1 = nn.BatchNorm2d(out_channels//4) self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels//4) self.bn2 = nn.BatchNorm2d(out_channels)
self.bn3 = nn.BatchNorm2d(out_channels) self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
def forward(self, x): def forward(self, x):
residual = self.residual(x)
out = F.relu((self.bn1(self.conv1(x)))) out = F.relu((self.bn1(self.conv1(x))))
out = F.relu((self.bn2(self.conv2(out)))) out = F.relu((self.bn2(self.conv2(out))))
out = F.relu((self.bn3(self.conv3(out)))) + self.residual(x) out = F.relu((self.bn3(self.conv3(out))))
return F.relu(out) return F.relu(out + residual)
class ResNet(base_image.BaseVisionModel): class ResNet(base_image.BaseVisionModel):
...@@ -76,15 +78,26 @@ class ResNet(base_image.BaseVisionModel): ...@@ -76,15 +78,26 @@ class ResNet(base_image.BaseVisionModel):
is_bottleneck = self.network_config[0] is_bottleneck = self.network_config[0]
curr_chan = 64 curr_chan = 64
prev_chan = curr_chan prev_chan = curr_chan
#dirty hack for downscaling at bottleneck layers
firstlayer = True
for i in self.network_config[1]: for i in self.network_config[1]:
for _ in range(i): for _ in range(i):
resblock = ResBlockBottleNeck(prev_chan, curr_chan) if is_bottleneck else ResBlock(prev_chan, curr_chan) needs_downsample = True
if is_bottleneck:
if firstlayer:
resblock = ResBlockBottleNeck(prev_chan, curr_chan, 4)
firstlayer = False
else:
resblock = ResBlockBottleNeck(prev_chan * 4, curr_chan, 4, needs_downsample)
needs_downsample = False
else:
resblock = ResBlock(prev_chan, curr_chan)
self.resblocks.append(resblock) self.resblocks.append(resblock)
prev_chan = curr_chan prev_chan = curr_chan
curr_chan *= 2 curr_chan *= 2
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(prev_chan, self.config.n_class) self.fc = nn.Linear(prev_chan * 4 if is_bottleneck else prev_chan, self.config.n_class)
def forward(self, x): def forward(self, x):
out = self.layerin(x) out = self.layerin(x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment