Commit d8ed6998 authored by C43H66N12O12S2's avatar C43H66N12O12S2 Committed by AUTOMATIC1111

Update swinir.py

parent 5f71ecfe
import sys import sys
import traceback import traceback
import cv2 import cv2
from collections import OrderedDict
import os import os
import requests import contextlib
from collections import namedtuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
import torch import torch
import modules.images import modules.images
from modules.shared import cmd_opts, opts, device from modules.shared import cmd_opts, opts, device
from modules.swinir_arch import SwinIR as net from modules.swinir_arch import SwinIR as net
precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
def load_model(task = "realsr", large_model = True, model_path="C:/sd/ESRGANn/4x-large.pth", scale=4): precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
try: )
modules.shared.sd_upscalers.append(UpscalerSwin("McSwinnySwin"))
except Exception:
print(f"Error loading ESRGAN model", file=sys.stderr) def load_model(filename, scale=4):
print(traceback.format_exc(), file=sys.stderr) model = net(
if not large_model: upscale=scale,
# use 'nearest+conv' to avoid block artifacts in_chans=3,
model = net(upscale=scale, in_chans=3, img_size=64, window_size=8, img_size=64,
img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180, num_heads=[6, 6, 6, 6, 6, 6], window_size=8,
mlp_ratio=2, upsampler='nearest+conv', resi_connection='1conv') img_range=1.0,
else: depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
# larger model size; use '3conv' to save parameters and memory; use ema for GAN training embed_dim=240,
model = net(upscale=scale, in_chans=3, img_size=64, window_size=8, num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
img_range=1., depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], embed_dim=240, mlp_ratio=2,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], upsampler="nearest+conv",
mlp_ratio=2, upsampler='nearest+conv', resi_connection='3conv') resi_connection="3conv",
)
pretrained_model = torch.load(model_path)
pretrained_model = torch.load(filename)
model.load_state_dict(pretrained_model["params_ema"], strict=True) model.load_state_dict(pretrained_model["params_ema"], strict=True)
if not cmd_opts.no_half:
model = model.half()
return model
def load_models(dirname):
for file in os.listdir(dirname):
path = os.path.join(dirname, file)
model_name, extension = os.path.splitext(file)
if extension != ".pt" and extension != ".pth":
continue
return model.half().to(device) try:
modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
def upscale(img, tile=opts.ESRGAN_tile, tile_overlap=opts.ESRGAN_tile_overlap, window_size = 8, scale = 4): except Exception:
print(f"Error loading SwinIR model: {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def upscale(
img,
model,
tile=opts.GAN_tile,
tile_overlap=opts.GAN_tile_overlap,
window_size=8,
scale=4,
):
img = np.array(img) img = np.array(img)
img = img[:, :, ::-1] img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255 img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float() img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device) img = img.unsqueeze(0).to(device)
model = load_model()
with torch.no_grad(), precision_scope("cuda"): with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size() _, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, :h_old + h_pad, :] img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, :w_old + w_pad] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference(img, model, tile, tile_overlap, window_size, scale) output = inference(img, model, tile, tile_overlap, window_size, scale)
output = output[..., :h_old * scale, :w_old * scale] output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3: if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR output = np.transpose(
output[[2, 1, 0], :, :], (1, 2, 0)
) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8 output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(output, 'RGB') return Image.fromarray(output, "RGB")
def inference(img, model, tile, tile_overlap, window_size, scale): def inference(img, model, tile, tile_overlap, window_size, scale):
# test the image tile by tile # test the image tile by tile
b, c, h, w = img.size() b, c, h, w = img.size()
...@@ -66,27 +90,34 @@ def inference(img, model, tile, tile_overlap, window_size, scale): ...@@ -66,27 +90,34 @@ def inference(img, model, tile, tile_overlap, window_size, scale):
sf = scale sf = scale
stride = tile - tile_overlap stride = tile - tile_overlap
h_idx_list = list(range(0, h-tile, stride)) + [h-tile] h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w-tile, stride)) + [w-tile] w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h*sf, w*sf, dtype=torch.half, device=device).type_as(img) E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=device) W = torch.zeros_like(E, dtype=torch.half, device=device)
for h_idx in h_idx_list: for h_idx in h_idx_list:
for w_idx in w_idx_list: for w_idx in w_idx_list:
in_patch = img[..., h_idx:h_idx+tile, w_idx:w_idx+tile] in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile]
out_patch = model(in_patch) out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch) out_patch_mask = torch.ones_like(out_patch)
E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch) E[
W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask) ..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
].add_(out_patch_mask)
output = E.div_(W) output = E.div_(W)
return output return output
class UpscalerSwin(modules.images.Upscaler): class UpscalerSwin(modules.images.Upscaler):
def __init__(self, title): def __init__(self, filename, title):
self.name = title self.name = title
self.model = load_model(filename)
def do_upscale(self, img): def do_upscale(self, img):
img = upscale(img) model = self.model.to(device)
img = upscale(img, model)
return img return img
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment