Commit cff3389c authored by Arda Cihaner's avatar Arda Cihaner

conversion script for resnet

parent 188c98ed
import os
from typing import OrderedDict
import torch
model_dir = 'pretrained/resnet/'
new_state_dict = {}
weights : OrderedDict = torch.load(model_dir + 'resnet_18.pth')
net_conf = (False, (2, 2, 2, 2))
counter = 0
new_state_dict['layerin.0.weight'] = weights['conv1.weight']
new_state_dict['layerin.2.weight'] = weights['bn1.weight']
new_state_dict['layerin.2.bias'] = weights['bn1.bias']
for i, j in enumerate(net_conf[1], 1):
for k in range(j):
curr_layer = f"layer{i}.{k}."
curr_state_dict_key = f"resblocks.{counter}."
new_state_dict[curr_state_dict_key + "conv1.weight"] = weights[curr_layer + "conv1.weight"]
new_state_dict[curr_state_dict_key + "conv2.weight"] = weights[curr_layer + "conv2.weight"]
new_state_dict[curr_state_dict_key + "bn1.weight"] = weights[curr_layer + "bn1.weight"]
new_state_dict[curr_state_dict_key + "bn1.bias"] = weights[curr_layer + "bn1.bias"]
new_state_dict[curr_state_dict_key + "bn2.weight"] = weights[curr_layer + "bn2.weight"]
new_state_dict[curr_state_dict_key + "bn2.bias"] = weights[curr_layer + "bn2.bias"]
if net_conf[0]:
new_state_dict[curr_state_dict_key + "conv3.weight"] = weights[curr_layer + "conv3.weight"]
new_state_dict[curr_state_dict_key + "bn3.weight"] = weights[curr_layer + "bn3.weight"]
new_state_dict[curr_state_dict_key + "bn3.bias"] = weights[curr_layer + "bn3.bias"]
counter += 1
torch.save(new_state_dict, "pretrained/resnet/modified/resnet_18.pth")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment