Commit 63f92eba authored by novelailab's avatar novelailab

fix resnet

parent ec79f2ea
from cmath import exp
import torch
import torch.nn as nn
import torch.nn.functional as F
......@@ -26,28 +27,29 @@ class ResBlock(nn.Module):
return F.relu(out)
class ResBlockBottleNeck(nn.Module):
def __init__(self, in_channels, out_channels) -> None:
def __init__(self, in_channels, out_channels, expansion, needs_downsample=False) -> None:
super().__init__()
downsample = True if in_channels != out_channels else False
self.residual = nn.Sequential()
if downsample:
self.expansion = expansion
if needs_downsample or in_channels != out_channels * self.expansion:
self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
nn.BatchNorm2d(out_channels)
nn.Conv2d(in_channels, out_channels * self.expansion, kernel_size=1, stride=2 if needs_downsample else 1),
nn.BatchNorm2d(out_channels * self.expansion)
)
self.conv1 = nn.Conv2d(in_channels, out_channels//4, kernel_size=1, stride=1)
self.conv2 = nn.Conv2d(out_channels//4, out_channels//4, kernel_size=3, stride=2 if downsample else 1, padding=1)
self.conv3 = nn.Conv2d(out_channels//4, out_channels, kernel_size=1, stride=1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2 if needs_downsample else 1, padding=1)
self.conv3 = nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, stride=1)
self.bn1 = nn.BatchNorm2d(out_channels//4)
self.bn2 = nn.BatchNorm2d(out_channels//4)
self.bn3 = nn.BatchNorm2d(out_channels)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
def forward(self, x):
residual = self.residual(x)
out = F.relu((self.bn1(self.conv1(x))))
out = F.relu((self.bn2(self.conv2(out))))
out = F.relu((self.bn3(self.conv3(out)))) + self.residual(x)
return F.relu(out)
out = F.relu((self.bn3(self.conv3(out))))
return F.relu(out + residual)
class ResNet(base_image.BaseVisionModel):
......@@ -76,15 +78,26 @@ class ResNet(base_image.BaseVisionModel):
is_bottleneck = self.network_config[0]
curr_chan = 64
prev_chan = curr_chan
#dirty hack for downscaling at bottleneck layers
firstlayer = True
for i in self.network_config[1]:
for _ in range(i):
resblock = ResBlockBottleNeck(prev_chan, curr_chan) if is_bottleneck else ResBlock(prev_chan, curr_chan)
needs_downsample = True
if is_bottleneck:
if firstlayer:
resblock = ResBlockBottleNeck(prev_chan, curr_chan, 4)
firstlayer = False
else:
resblock = ResBlockBottleNeck(prev_chan * 4, curr_chan, 4, needs_downsample)
needs_downsample = False
else:
resblock = ResBlock(prev_chan, curr_chan)
self.resblocks.append(resblock)
prev_chan = curr_chan
curr_chan *= 2
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(prev_chan, self.config.n_class)
self.fc = nn.Linear(prev_chan * 4 if is_bottleneck else prev_chan, self.config.n_class)
def forward(self, x):
out = self.layerin(x)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment