Commit b47ef0e9 authored by Eren Doğan's avatar Eren Doğan Committed by GitHub

Merge pull request #6 from NovelAI/imagestuff

parents c58dfef8 cff3389c
......@@ -3,6 +3,8 @@ from . import gpt2
from . import fairseq
from . import gptneo
from . import alibi
from . import vit
from . import resnet
from . import fast
MODEL_MAP = {
......@@ -11,6 +13,8 @@ MODEL_MAP = {
"gpt-fairseq": fairseq.GPTFairModel,
"gpt-neo": gptneo.GPTNeoModel,
"alibi": alibi.AlibiModel,
"vit": vit.VisionTransformer,
"resnet": resnet.ResNet
}
def get_model(model_name: str):
......
import torch.nn as nn
from dotmap import DotMap
class BaseVisionModel(nn.Module):
def __init__(self, user_config, **kwargs):
super().__init__()
self.user_config = user_config
self.config = self.configure_model()
def configure_model(self):
full_config = {}
if not hasattr(self, 'default_config'):
raise ValueError("No default config found, add one for the model to function")
#apply defaults
for k, v in self.default_config.items():
full_config[k] = v
#apply user defined config if provided
for k, v in self.user_config.items():
full_config[k] = v
full_config = DotMap(full_config)
return full_config
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.models import base_image
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
downsample = True if in_channels != out_channels else False
self.residual = nn.Sequential()
if downsample:
self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
nn.BatchNorm2d(out_channels)
)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=2 if downsample else 1)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out)) + self.residual(x)
return F.relu(out)
class ResBlockBottleNeck(nn.Module):
def __init__(self, in_channels, out_channels) -> None:
super().__init__()
downsample = True if in_channels != out_channels else False
self.residual = nn.Sequential()
if downsample:
self.residual = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
nn.BatchNorm2d(out_channels)
)
self.conv1 = nn.Conv2d(in_channels, out_channels//4, kernel_size=1, stride=1)
self.conv2 = nn.Conv2d(out_channels//4, out_channels//4, kernel_size=3, stride=2 if downsample else 1, padding=1)
self.conv3 = nn.Conv2d(out_channels//4, out_channels, kernel_size=1, stride=1)
self.bn1 = nn.BatchNorm2d(out_channels//4)
self.bn2 = nn.BatchNorm2d(out_channels//4)
self.bn3 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = F.relu((self.bn1(self.conv1(x))))
out = F.relu((self.bn2(self.conv2(out))))
out = F.relu((self.bn3(self.conv3(out)))) + self.residual(x)
return F.relu(out)
class ResNet(base_image.BaseVisionModel):
def __init__(self, user_config, **kwargs) -> None:
self.default_config = {
'in_channels': 3,
'network_size': 18, #ResNet18/34/50/101/152
'n_class': 100
}
super().__init__(user_config, **kwargs)
network_config_dict = {
18: (False, (2, 2, 2, 2)),
34: (False, (3, 4, 6, 3)),
50: (True, (3, 4, 6, 3)),
101: (True, (3, 4, 23, 3)),
152: (True, (3, 4, 36, 3))
}
self.layerin = nn.Sequential(
nn.Conv2d(self.config.in_channels, 64, kernel_size=7, stride=2, padding=3),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.resblocks = nn.ModuleList()
self.network_config = network_config_dict[self.config.network_size]
is_bottleneck = self.network_config[0]
curr_chan = 64
prev_chan = curr_chan
for i in self.network_config[1]:
for _ in range(i):
resblock = ResBlockBottleNeck(prev_chan, curr_chan) if is_bottleneck else ResBlock(prev_chan, curr_chan)
self.resblocks.append(resblock)
prev_chan = curr_chan
curr_chan *= 2
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(prev_chan, self.config.n_class)
def forward(self, x):
out = self.layerin(x)
for layer in self.resblocks:
out = layer(out)
out = self.avgpool(out)
out = out.view(out.size(0), -1)
return self.fc(out)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from basedformer.utils import *
from basedformer.models import base_image
import einops
def _attn(query, key, value, attention_mask=None, scale_attn=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
attn_weights = attn_weights / scale_attn
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
attn_output = torch.matmul(attn_weights, value).to(value.dtype)
return attn_output
class SelfAttention(nn.Module):
# Code copied from HF, might want to sanity check later.
def __init__(self, config):
nn.Module.__init__(self)
self.head_dim = config.hidden_dim // config.n_head
self.rotary_dim = self.head_dim // 4
self.hidden_dim = config.hidden_dim
self.n_head = config.n_head
device = config.device
dtype = config.dtype
self.register_buffer("scale_attn", torch.sqrt(torch.tensor(self.head_dim, requires_grad=False).float()))
attn_bias = False
self.k_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.v_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.q_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
self.out_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=attn_bias, device=device, dtype=dtype)
def forward(self, x, kv=None, cache=False):
B, S, H = x.shape # batch, sequence, hidden_dim
# split heads into: [batch, head, sequence, head_dim]
query = self.q_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
key = self.k_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
value = self.v_proj(x).view(B, S, self.n_head, self.head_dim).transpose(1, 2)
if kv:
k, v = kv
# cat key and value (get the whole sequence, other than the last added token all are cached),
# so query can attend to it.
torch.cat([k, key], dim=-2) # cat key
torch.cat([v, value], dim=-2) # cat value
x = _attn(
query, key, value, None, self.scale_attn
)
x = x.transpose(1, 2).contiguous().view(B, S, H)
x = self.out_proj(x)
if cache:
return x, (key, value)
else:
return x, None
class FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.ff1 = nn.Linear(config.hidden_dim, config.hidden_dim*4, device=config.device, dtype=config.dtype)
self.ff2 = nn.Linear(config.hidden_dim*4, config.hidden_dim, device=config.device, dtype=config.dtype)
self.activation = config.activation
def forward(self, x):
x = self.ff1(x)
x = self.activation(x)
x = self.ff2(x)
return x
class ViTEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_dim
self.ln_preattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ln_postattn = nn.LayerNorm(config.hidden_dim, eps=config.eps, device=config.device, dtype=config.dtype)
self.ff = FeedForward(config)
self.attn = SelfAttention(config)
def forward(self, x):
residual = x
print(x.shape)
x = self.ln_preattn(x)
x = self.attn(x)[0]
x = residual + x
residual = x
x = self.ln_postattn(x)
x = self.ff(x)
return x + residual
class ViTEmbeds(nn.Module):
def __init__(self, config) -> None:
super().__init__()
p_size = config.patch_size
channels = config.channels
dim = config.hidden_dim
num_patches = (config.image_size[1] // p_size) * (config.image_size[0] // p_size)
self.lin_emb = nn.Linear((p_size ** 2) * channels, dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
self.pos = nn.Parameter(torch.zeros(1, num_patches + 1, dim))
def forward(self, x: torch.Tensor):
embed = self.lin_emb(x)
batch_size = x.size()[0]
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
embed = torch.cat((cls_tokens, embed), dim=1)
return embed + self.pos
class VisionTransformer(base_image.BaseVisionModel):
def __init__(self):
self.default_config = {
'n_layer': 12,
'n_head': 8,
'channels': 3,
'patch_size': 16,
'hidden_dim': 768,
'n_classes' : 1000,
'activation': torch.nn.GELU(),
'image_size': (224, 224),
'eps': 1e-5,
'device': torch.device('cuda'),
'dtype': torch.float16,
}
super().__init__(self.default_config)
self.embed = ViTEmbeds(self.config)
self.encoder_layers = nn.ModuleList()
for _ in range(self.config.n_layer):
self.encoder_layers.append(ViTEncoder(self.config))
self.mlp_head = nn.Linear(self.config.hidden_dim, self.config.n_classes)
def forward(self, x):
p_size = self.config.patch_size
patches = einops.rearrange(x, 'b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=p_size, s2=p_size)
patches = self.embed(patches)
for encoder in self.encoder_layers:
patches = encoder(patches)
return self.mlp_head(patches)
import os
from typing import OrderedDict
import torch
model_dir = 'pretrained/resnet/'
new_state_dict = {}
weights : OrderedDict = torch.load(model_dir + 'resnet_18.pth')
net_conf = (False, (2, 2, 2, 2))
counter = 0
new_state_dict['layerin.0.weight'] = weights['conv1.weight']
new_state_dict['layerin.2.weight'] = weights['bn1.weight']
new_state_dict['layerin.2.bias'] = weights['bn1.bias']
for i, j in enumerate(net_conf[1], 1):
for k in range(j):
curr_layer = f"layer{i}.{k}."
curr_state_dict_key = f"resblocks.{counter}."
new_state_dict[curr_state_dict_key + "conv1.weight"] = weights[curr_layer + "conv1.weight"]
new_state_dict[curr_state_dict_key + "conv2.weight"] = weights[curr_layer + "conv2.weight"]
new_state_dict[curr_state_dict_key + "bn1.weight"] = weights[curr_layer + "bn1.weight"]
new_state_dict[curr_state_dict_key + "bn1.bias"] = weights[curr_layer + "bn1.bias"]
new_state_dict[curr_state_dict_key + "bn2.weight"] = weights[curr_layer + "bn2.weight"]
new_state_dict[curr_state_dict_key + "bn2.bias"] = weights[curr_layer + "bn2.bias"]
if net_conf[0]:
new_state_dict[curr_state_dict_key + "conv3.weight"] = weights[curr_layer + "conv3.weight"]
new_state_dict[curr_state_dict_key + "bn3.weight"] = weights[curr_layer + "bn3.weight"]
new_state_dict[curr_state_dict_key + "bn3.bias"] = weights[curr_layer + "bn3.bias"]
counter += 1
torch.save(new_state_dict, "pretrained/resnet/modified/resnet_18.pth")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment