Commit a9d7eb72 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge branch 'master' into saving

parents f28ce3e3 4e72a1aa
__pycache__
/ESRGAN
*.ckpt
*.pth
/ESRGAN/*
/SwinIR/*
/repositories
/venv
/tmp
/model.ckpt
/models/**/*.ckpt
/models/**/*
/GFPGANv1.3.pth
/gfpgan/weights/*.pth
/ui-config.json
......@@ -22,3 +25,4 @@ __pycache__
/.idea
notification.mp3
/SwinIR
/textual_inversion
......@@ -3,50 +3,64 @@ A browser interface based on Gradio library for Stable Diffusion.
![](txt2img_Screenshot.png)
Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
## Features
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
- Original txt2img and img2img modes
- One click install and run script (but you still must install python and git)
- Outpainting
- Inpainting
- Prompt matrix
- Stable Diffusion upscale
- Attention
- Loopback
- X/Y plot
- Prompt Matrix
- Stable Diffusion Upscale
- Attention, specify parts of text that the model should pay more attention to
- a man in a ((tuxedo)) - will pay more attention to tuxedo
- a man in a (tuxedo:1.21) - alternative syntax
- Loopback, run img2img processing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
- Textual Inversion
- have as many embeddings as you want and use any names you like for them
- use multiple embeddings with different numbers of vectors per token
- works with half precision floating point numbers
- Extras tab with:
- GFPGAN, neural network that fixes faces
- CodeFormer, face restoration tool as an alternative to GFPGAN
- RealESRGAN, neural network upscaler
- ESRGAN, neural network with a lot of third party models
- ESRGAN, neural network upscaler with a lot of third party models
- SwinIR, neural network upscaler
- LDSR, Latent diffusion super resolution upscaling
- Resizing aspect ratio options
- Sampling method selection
- Interrupt processing at any time
- 4GB video card support
- Correct seeds for batches
- 4GB video card support (also reports of 2GB working)
- Correct seeds for batches
- Prompt length validation
- Generation parameters added as text to PNG
- Tab to view an existing picture's generation parameters
- get length of prompt in tokens as you type
- get a warning after generation if some text was truncated
- Generation parameters
- parameters you used to generate images are saved with that image
- in PNG chunks for PNG, in EXIF for JPEG
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
- can be disabled in settings
- Settings page
- Running custom code from UI
- Running arbitrary python code from UI (must run with --allow-code to enable)
- Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button
- Tiling support: UI checkbox to create images that can be tiled like textures
- Tiling support, a checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview
- Negative prompt
- Styles
- Variations
- Seed resizing
- CLIP interrogator
- Prompt Editing
- Batch Processing
- Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
- Styles, a way to save part of prompt and easily apply them via dropdown later
- Variations, a way to generate same image but with tiny differences
- Seed resizing, a way to generate same image but at slightly different resolution
- CLIP interrogator, a button that tries to guess prompt from an image
- Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
- Batch Processing, process a group of files using img2img
- Img2img Alternative
- Highres Fix
- LDSR Upscaling
- Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
- Reloading checkpoints on the fly
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
......@@ -83,6 +97,9 @@ bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusio
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
## Contributing
Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
## Documentation
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
......
......@@ -359,7 +359,6 @@ Antanas Sutkus,0.7369492,black-white
Leonora Carrington,0.73726475,scribbles
Hieronymus Bosch,0.7369955,scribbles
A. J. Casson,0.73666203,scribbles
A.J.Casson,0.73666203,scribbles
Chaim Soutine,0.73662066,scribbles
Artur Bordalo,0.7364549,weird
Thomas Allom,0.68792284,fineart
......@@ -1907,7 +1906,6 @@ Alex Schomburg,0.46614102,digipa-low-impact
Bastien L. Deharme,0.583349,special
František Jakub Prokyš,0.58782333,fineart
Jesper Ejsing,0.58782053,fineart
Jesper Ejsing,0.58782053,fineart
Odd Nerdrum,0.53551745,digipa-high-impact
Tom Lovell,0.5877577,fineart
Ayami Kojima,0.5877416,fineart
......
......@@ -58,8 +58,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Process an image, use it as an input, repeat.",
......
......@@ -30,6 +30,7 @@ function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_inte
onUiUpdate(function(){
check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery')
check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery')
check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', 'ti_interrupt', 'ti_preview', 'ti_gallery')
})
function requestMoreProgress(id_part, id_progressbar_span, id_interrupt){
......
function start_training_textual_inversion(){
requestProgress('ti')
gradioApp().querySelector('#ti_error').innerHTML=''
return args_to_array(arguments)
}
......@@ -186,10 +186,12 @@ onUiUpdate(function(){
if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
}
if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
}
})
......@@ -197,6 +199,14 @@ let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
let token_timeout;
function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) {
event.preventDefault();
gradioApp().getElementById(generate_button_id).click();
return;
}
}
function update_token_counter(button_id) {
if (token_timeout)
clearTimeout(token_timeout);
......
# this scripts installs necessary requirements and launches main program in webui.py
import subprocess
import os
import sys
......@@ -19,10 +18,9 @@ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/Tencen
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "9e3002b7cd64df7870e08527b7664eb2f2f5f3f5")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH', "abf33e7002d59d9085081bce93ec798dcabd49af")
args = shlex.split(commandline_args)
......@@ -120,8 +118,6 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash)
if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
......@@ -130,6 +126,9 @@ run_pip(f"install -r {requirements_file}", "requirements for Web UI")
sys.argv += args
if "--exit" in args:
print("Exiting because of --exit argument")
exit(0)
def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
......
import os.path
import sys
import traceback
import PIL.Image
import numpy as np
import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import shared, modelloader
from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path
class UpscalerBSRGAN(modules.upscaler.Upscaler):
def __init__(self, dirname):
self.name = "BSRGAN"
self.model_path = os.path.join(models_path, self.name)
self.model_name = "BSRGAN 4x"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
self.user_path = dirname
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
if len(model_paths) == 0:
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
try:
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
return img
model.to(shared.device)
torch.cuda.empty_cache()
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
return None
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
return model
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.sf = sf
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.sf==4:
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
if self.sf==4:
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
\ No newline at end of file
......@@ -5,31 +5,31 @@ import traceback
import cv2
import torch
from modules import shared, devices
from modules.paths import script_path
import modules.shared
import modules.face_restoration
from importlib import reload
import modules.shared
from modules import shared, devices, modelloader
from modules.paths import script_path, models_path
# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
# codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.
pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False
codeformer = None
def setup_codeformer():
def setup_model(dirname):
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
path = modules.paths.paths.get("CodeFormer", None)
if path is None:
return
# both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own
#stored_sys_path = sys.path
#sys.path = [path] + sys.path
try:
from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer
......@@ -44,18 +44,23 @@ def setup_codeformer():
def name(self):
return "CodeFormer"
def __init__(self):
def __init__(self, dirname):
self.net = None
self.face_helper = None
self.cmd_dir = dirname
def create_models(self):
if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer)
return self.net, self.face_helper
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
if len(model_paths) != 0:
ckpt_path = model_paths[0]
else:
print("Unable to load codeformer model.")
return None, None
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint)
net.eval()
......@@ -74,6 +79,9 @@ def setup_codeformer():
original_resolution = np_image.shape[0:2]
self.create_models()
if self.net is None or self.face_helper is None:
return np_image
self.face_helper.clean_all()
self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
......@@ -114,7 +122,7 @@ def setup_codeformer():
have_codeformer = True
global codeformer
codeformer = FaceRestorerCodeFormer()
codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer)
except Exception:
......
......@@ -32,10 +32,9 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
device = get_optimal_device()
device_codeformer = cpu if has_mps else device
dtype = torch.float16
def randn(seed, shape):
# Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
......
import os
import sys
import traceback
import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch
from modules import shared
from modules.shared import opts
from modules import shared, modelloader, images
from modules.devices import has_mps
import modules.images
from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
def load_model(filename):
def fix_model_layers(crt_model, pretrained_net):
# this code is adapted from https://github.com/xinntao/ESRGAN
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
if 'conv_first.weight' in pretrained_net:
crt_model.load_state_dict(pretrained_net)
return crt_model
return pretrained_net
if 'model.0.weight' not in pretrained_net:
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
......@@ -72,9 +68,59 @@ def load_model(filename):
crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
crt_model.load_state_dict(crt_net)
crt_model.eval()
return crt_model
return crt_net
class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
self.name = "ESRGAN"
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
self.model_name = "ESRGAN 4x"
self.scalers = []
self.user_path = dirname
self.model_path = os.path.join(models_path, self.name)
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
if len(model_paths) == 0:
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
scaler_data = UpscalerData(name, file, self, 4)
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
model = self.load_model(selected_model)
if model is None:
return img
model.to(shared.device)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="%s.pth" % self.model_name,
progress=True)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (self.model_path, filename))
return None
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net)
crt_model.load_state_dict(pretrained_net)
crt_model.eval()
return crt_model
def upscale_without_tiling(model, img):
img = np.array(img)
......@@ -95,7 +141,7 @@ def esrgan_upscale(model, img):
if opts.ESRGAN_tile == 0:
return upscale_without_tiling(model, img)
grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
newtiles = []
scale_factor = 1
......@@ -110,32 +156,6 @@ def esrgan_upscale(model, img):
newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow])
newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
output = modules.images.combine_grid(newgrid)
newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
output = images.combine_grid(newgrid)
return output
class UpscalerESRGAN(modules.images.Upscaler):
def __init__(self, filename, title):
self.name = title
self.model = load_model(filename)
def do_upscale(self, img):
model = self.model.to(shared.device)
img = esrgan_upscale(model, img)
return img
def load_models(dirname):
for file in os.listdir(dirname):
path = os.path.join(dirname, file)
model_name, extension = os.path.splitext(file)
if extension != '.pt' and extension != '.pth':
continue
try:
modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
except Exception:
print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
......@@ -40,6 +40,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
outputs = []
for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
existing_pnginfo = image.info or {}
image = image.convert("RGB")
......@@ -74,7 +76,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
c = cached_images.get(key)
if c is None:
upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.upscale(image, image.width * resize, image.height * resize)
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
cached_images[key] = c
return c
......@@ -189,9 +191,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
if save_as_half:
theta_0[key] = theta_0[key].half()
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename)
output_modelname = os.path.join(ckpt_dir, filename)
print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname)
......
import os
import sys
import traceback
from glob import glob
from modules import shared, devices
from modules.shared import cmd_opts
from modules.paths import script_path
import modules.face_restoration
def gfpgan_model_path():
from modules.shared import cmd_opts
filemask = 'GFPGAN*.pth'
if cmd_opts.gfpgan_model is not None:
return cmd_opts.gfpgan_model
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
filename = None
for place in places:
filename = next(iter(glob(os.path.join(place, filemask))), None)
if filename is not None:
break
return filename
import facexlib
import gfpgan
import modules.face_restoration
from modules import shared, devices, modelloader
from modules.paths import models_path
model_dir = "GFPGAN"
user_path = None
model_path = os.path.join(models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False
loaded_gfpgan_model = None
def gfpgan():
def gfpgann():
global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device)
return loaded_gfpgan_model
......@@ -41,7 +27,16 @@ def gfpgan():
if gfpgan_constructor is None:
return None
model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
if len(models) == 1 and "http" in models[0]:
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
model_file = latest_file
else:
print("Unable to load gfpgan model!")
return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device)
loaded_gfpgan_model = model
......@@ -49,8 +44,9 @@ def gfpgan():
def gfpgan_fix_faces(np_image):
model = gfpgan()
model = gfpgann()
if model is None:
return np_image
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
......@@ -61,21 +57,39 @@ def gfpgan_fix_faces(np_image):
return np_image
have_gfpgan = False
gfpgan_constructor = None
def setup_gfpgan():
try:
gfpgan_model_path()
if os.path.exists(cmd_opts.gfpgan_dir):
sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
from gfpgan import GFPGANer
def setup_model(dirname):
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
try:
from gfpgan import GFPGANer
from facexlib import detection, parsing
global user_path
global have_gfpgan
have_gfpgan = True
global gfpgan_constructor
load_file_from_url_orig = gfpgan.utils.load_file_from_url
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
def my_load_file_from_url(**kwargs):
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
def facex_load_file_from_url(**kwargs):
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
def facex_load_file_from_url2(**kwargs):
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
gfpgan.utils.load_file_from_url = my_load_file_from_url
facexlib.detection.load_file_from_url = facex_load_file_from_url
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
user_path = dirname
have_gfpgan = True
gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
......@@ -84,7 +98,7 @@ def setup_gfpgan():
def restore(self, np_image):
np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
cropped_faces, restored_faces, gfpgan_output_bgr = gfpgann().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1]
return np_image
......
......@@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto
import string
import modules.shared
from modules import sd_samplers, shared
from modules.shared import opts, cmd_opts
......@@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
cols = math.ceil((w - overlap) / non_overlap_width)
rows = math.ceil((h - overlap) / non_overlap_height)
dx = (w - tile_w) / (cols-1) if cols > 1 else 0
dy = (h - tile_h) / (rows-1) if rows > 1 else 0
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows):
......@@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
for col in range(cols):
x = int(col * dx)
if x+tile_w >= w:
if x + tile_w >= w:
x = w - tile_w
tile = image.crop((x, y, x + tile_w, y + tile_h))
......@@ -132,7 +131,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active:
drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)
drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
draw_y += line.size[1] + line_spacing
......@@ -171,7 +170,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
ver_texts]
pad_top = max(hor_text_heights) + line_spacing * 2
......@@ -213,8 +213,19 @@ def resize_image(resize_mode, im, width, height):
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS)
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
return upscaler.upscale(im, w, h)
scale = max(w / im.width, h / im.height)
if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
if im.width != w or im.height != h:
im = im.resize((w, h), resample=LANCZOS)
return im
if resize_mode == 0:
res = resize(im, width, height)
......@@ -256,7 +267,7 @@ def resize_image(resize_mode, im, width, height):
invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s'+string.punctuation+']+')
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
max_filename_part_length = 128
......@@ -278,6 +289,16 @@ def apply_filename_pattern(x, p, seed, prompt):
if prompt is not None:
x = x.replace("[prompt]", sanitize_filename_part(prompt))
if "[prompt_no_styles]" in x:
prompt_no_style = prompt
for style in shared.prompt_styles.get_style_prompts(p.styles):
if len(style) > 0:
style_parts = [y for y in style.split("{prompt}")]
for part in style_parts:
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
if "[prompt_words]" in x:
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
......@@ -290,10 +311,12 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height))
#currently disabled if using the save button, will work otherwise
# if enabled it will cause a bug because styles is not included in the save_files data dictionary
if hasattr(p, "styles"):
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
......@@ -306,6 +329,7 @@ def apply_filename_pattern(x, p, seed, prompt):
return x
def get_next_sequence_number(path, basename):
"""
Determines and returns the next sequence number to use when saving an image in the specified directory.
......@@ -319,7 +343,7 @@ def get_next_sequence_number(path, basename):
prefix_length = len(basename)
for p in os.listdir(path):
if p.startswith(basename):
l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try:
result = max(int(l[0]), result)
except ValueError:
......@@ -327,6 +351,7 @@ def get_next_sequence_number(path, basename):
return result + 1
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
if short_filename or prompt is None or seed is None:
file_decoration = ""
......@@ -364,7 +389,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
fullfn = "a.png"
fullfn_without_extension = "a"
for i in range(500):
fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}"
fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
if not os.path.exists(fullfn):
......@@ -406,31 +431,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
file.write(info + "\n")
class Upscaler:
name = "Lanczos"
def do_upscale(self, img):
return img
def upscale(self, img, w, h):
for i in range(3):
if img.width >= w and img.height >= h:
break
img = self.do_upscale(img)
if img.width != w or img.height != h:
img = img.resize((int(w), int(h)), resample=LANCZOS)
return img
class UpscalerNone(Upscaler):
name = "None"
def upscale(self, img, w, h):
return img
modules.shared.sd_upscalers.append(UpscalerNone())
modules.shared.sd_upscalers.append(Upscaler())
import os
import sys
import traceback
from collections import namedtuple
from basicsr.utils.download_util import load_file_from_url
import modules.images
from modules.upscaler import Upscaler, UpscalerData
from modules.ldsr_model_arch import LDSR
from modules import shared
from modules.paths import script_path
from modules.paths import models_path
LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
ldsr_models = []
have_ldsr = False
LDSR_obj = None
class UpscalerLDSR(modules.images.Upscaler):
def __init__(self, steps):
self.steps = steps
class UpscalerLDSR(Upscaler):
def __init__(self, user_path):
self.name = "LDSR"
def do_upscale(self, img):
return upscale_with_ldsr(img)
def add_lsdr():
modules.shared.sd_upscalers.append(UpscalerLDSR(100))
def setup_ldsr():
path = modules.paths.paths.get("LDSR", None)
if path is None:
return
global have_ldsr
global LDSR_obj
try:
from LDSR import LDSR
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
repo_path = 'latent-diffusion/experiments/pretrained_models/'
model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path),
progress=True, file_name="model.chkpt")
yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path),
progress=True, file_name="project.yaml")
have_ldsr = True
LDSR_obj = LDSR(model_path, yaml_path)
except Exception:
print("Error importing LDSR:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
have_ldsr = False
def upscale_with_ldsr(image):
setup_ldsr()
if not have_ldsr or LDSR_obj is None:
return image
ddim_steps = shared.opts.ldsr_steps
pre_scale = shared.opts.ldsr_pre_down
post_scale = shared.opts.ldsr_post_down
image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale)
return image
self.model_path = os.path.join(models_path, self.name)
self.user_path = user_path
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
super().__init__()
scaler_data = UpscalerData("LDSR", None, self)
self.scalers = [scaler_data]
def load_model(self, path: str):
# Remove incorrect project.yaml file if too big
yaml_path = os.path.join(self.model_path, "project.yaml")
old_model_path = os.path.join(self.model_path, "model.pth")
new_model_path = os.path.join(self.model_path, "model.ckpt")
if os.path.exists(yaml_path):
statinfo = os.stat(yaml_path)
if statinfo.st_size >= 10485760:
print("Removing invalid LDSR YAML file.")
os.remove(yaml_path)
if os.path.exists(old_model_path):
print("Renaming model from model.pth to model.ckpt")
os.rename(old_model_path, new_model_path)
model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="model.ckpt", progress=True)
yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
file_name="project.yaml", progress=True)
try:
return LDSR(model, yaml)
except Exception:
print("Error importing LDSR:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return None
def do_upscale(self, img, path):
ldsr = self.load_model(path)
if ldsr is None:
print("NO LDSR!")
return img
ddim_steps = shared.opts.ldsr_steps
return ldsr.super_resolution(img, ddim_steps, self.scale)
import gc
import time
import warnings
import numpy as np
import torch
import torchvision
from PIL import Image
from einops import rearrange, repeat
from omegaconf import OmegaConf
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
warnings.filterwarnings("ignore", category=UserWarning)
# Create LDSR Class
class LDSR:
def load_model_from_config(self, half_attention):
print(f"Loading model from {self.modelPath}")
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"]
config = OmegaConf.load(self.yamlPath)
model = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False)
model.cuda()
if half_attention:
model = model.half()
model.eval()
return {"model": model}
def __init__(self, model_path, yaml_path):
self.modelPath = model_path
self.yamlPath = yaml_path
@staticmethod
def run(model, selected_path, custom_steps, eta):
example = get_cond(selected_path)
n_runs = 1
guider = None
ckwargs = None
ddim_use_x0_pred = False
temperature = 1.
eta = eta
custom_shape = None
height, width = example["image"].shape[1:3]
split_input = height >= 128 and width >= 128
if split_input:
ks = 128
stride = 64
vqf = 4 #
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01}
else:
if hasattr(model, "split_input_params"):
delattr(model, "split_input_params")
x_t = None
logs = None
for n in range(n_runs):
if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
logs = make_convolutional_sample(example, model,
custom_steps=custom_steps,
eta=eta, quantize_x0=False,
custom_shape=custom_shape,
temperature=temperature, noise_dropout=0.,
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
ddim_use_x0_pred=ddim_use_x0_pred
)
return logs
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
model = self.load_model_from_config(half_attention)
# Run settings
diffusion_steps = int(steps)
eta = 1.0
down_sample_method = 'Lanczos'
gc.collect()
torch.cuda.empty_cache()
im_og = image
width_og, height_og = im_og.size
# If we can adjust the max upscale size, then the 4 below should be our variable
down_sample_rate = target_scale / 4
wd = width_og * down_sample_rate
hd = height_og * down_sample_rate
width_downsampled_pre = int(wd)
height_downsampled_pre = int(hd)
if down_sample_rate != 1:
print(
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
logs = self.run(model["model"], im_og, diffusion_steps, eta)
sample = logs["sample"]
sample = sample.detach().cpu()
sample = torch.clamp(sample, -1., 1.)
sample = (sample + 1.) / 2. * 255
sample = sample.numpy().astype(np.uint8)
sample = np.transpose(sample, (0, 2, 3, 1))
a = Image.fromarray(sample[0])
del model
gc.collect()
torch.cuda.empty_cache()
return a
def get_cond(selected_path):
example = dict()
up_f = 4
c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
antialias=True)
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.
c = c.to(torch.device("cuda"))
example["LR_image"] = c
example["image"] = c_up
return example
@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
corrector_kwargs=None, x_t=None
):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
print(f"Sampling with eta = {eta}; steps: {steps}")
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
mask=mask, x0=x0, temperature=temperature, verbose=False,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, x_t=x_t)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
log = dict()
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=not (hasattr(model, 'split_input_params')
and model.cond_stage_key == 'coordinates_bbox'),
return_original_cond=True)
if custom_shape is not None:
z = torch.randn(custom_shape)
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
z0 = None
log["input"] = x
log["reconstruction"] = xrec
if ismap(xc):
log["original_conditioning"] = model.to_rgb(xc)
if hasattr(model, 'cond_stage_key'):
log[model.cond_stage_key] = model.to_rgb(xc)
else:
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_model:
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_key == 'class_label':
log[model.cond_stage_key] = xc[model.cond_stage_key]
with model.ema_scope("Plotting"):
t0 = time.time()
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
eta=eta,
quantize_x0=quantize_x0, mask=None, x0=z0,
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
x_t=x_T)
t1 = time.time()
if ddim_use_x0_pred:
sample = intermediates['pred_x0'][-1]
x_sample = model.decode_first_stage(sample)
try:
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except:
pass
log["sample"] = x_sample
log["time"] = t1 - t0
return log
import glob
import os
import shutil
import importlib
from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
from modules import shared
from modules.upscaler import Upscaler
from modules.paths import script_path, models_path
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@param download_name: Specify to download from model_url immediately.
@param model_url: If no other models are found, this will be downloaded on upscale.
@param model_path: The location to store/find models in.
@param command_path: A command-line argument to search for models in first.
@param ext_filter: An optional list of filename extensions to filter by
@return: A list of paths containing the desired model(s)
"""
output = []
if ext_filter is None:
ext_filter = []
try:
places = []
if command_path is not None and command_path != model_path:
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
if os.path.exists(pretrained_path):
print(f"Appending path: {pretrained_path}")
places.append(pretrained_path)
elif os.path.exists(command_path):
places.append(command_path)
places.append(model_path)
for place in places:
if os.path.exists(place):
for file in glob.iglob(place + '**/**', recursive=True):
full_path = file
if os.path.isdir(full_path):
continue
if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file)
if extension not in ext_filter:
continue
if file not in output:
output.append(full_path)
if model_url is not None and len(output) == 0:
if download_name is not None:
dl = load_file_from_url(model_url, model_path, True, download_name)
output.append(dl)
else:
output.append(model_url)
except Exception:
pass
return output
def friendly_name(file: str):
if "http" in file:
file = urlparse(file).path
file = os.path.basename(file)
model_name, extension = os.path.splitext(file)
return model_name
def cleanup_models():
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
# somehow auto-register and just do these things...
root_path = script_path
src_path = models_path
dest_path = os.path.join(models_path, "Stable-diffusion")
move_files(src_path, dest_path, ".ckpt")
src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "gfpgan")
dest_path = os.path.join(models_path, "GFPGAN")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "SwinIR")
dest_path = os.path.join(models_path, "SwinIR")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
dest_path = os.path.join(models_path, "LDSR")
move_files(src_path, dest_path)
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
try:
if not os.path.exists(dest_path):
os.makedirs(dest_path)
if os.path.exists(src_path):
for file in os.listdir(src_path):
fullpath = os.path.join(src_path, file)
if os.path.isfile(fullpath):
if ext_filter is not None:
if ext_filter not in file:
continue
print(f"Moving {file} from {src_path} to {dest_path}.")
try:
shutil.move(fullpath, dest_path)
except:
pass
if len(os.listdir(src_path)) == 0:
print(f"Removing empty folder: {src_path}")
shutil.rmtree(src_path, True)
except:
pass
def load_upscalers():
datas = []
for cls in Upscaler.__subclasses__():
name = cls.__name__
module_name = cls.__module__
module = importlib.import_module(module_name)
class_ = getattr(module, name)
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
opt_string = None
try:
opt_string = shared.opts.__getattr__(cmd_name)
except:
pass
scaler = class_(opt_string)
for child in scaler.scalers:
datas.append(child)
shared.sd_upscalers = datas
......@@ -3,9 +3,10 @@ import os
import sys
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models")
sys.path.insert(0, script_path)
# search for directory of stable diffsuion in following palces
# search for directory of stable diffusion in following places
sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths:
......@@ -15,21 +16,24 @@ for possible_sd_path in possible_sd_paths:
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion'),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'),
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion'),
(sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
]
paths = {}
for d, must_exist, what in path_dirs:
for d, must_exist, what, options in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path):
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
else:
d = os.path.abspath(d)
sys.path.append(d)
if "atstart" in options:
sys.path.insert(0, d)
else:
sys.path.append(d)
paths[what] = d
......@@ -56,7 +56,7 @@ class StableDiffusionProcessing:
self.prompt: str = prompt
self.prompt_for_display: str = None
self.negative_prompt: str = (negative_prompt or "")
self.styles: str = styles
self.styles: list = styles or []
self.seed: int = seed
self.subseed: int = subseed
self.subseed_strength: float = subseed_strength
......@@ -79,7 +79,7 @@ class StableDiffusionProcessing:
self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = 0
self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn
self.s_tmin = opts.s_tmin
......@@ -130,7 +130,7 @@ class Processed:
self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax
self.s_noise = p.s_noise
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0])
......@@ -271,7 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
}
generation_params.update(p.extra_generation_params)
......@@ -295,8 +295,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
fix_seed(p)
os.makedirs(p.outpath_samples, exist_ok=True)
os.makedirs(p.outpath_grids, exist_ok=True)
if p.outpath_samples is not None:
os.makedirs(p.outpath_samples, exist_ok=True)
if p.outpath_grids is not None:
os.makedirs(p.outpath_grids, exist_ok=True)
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
......@@ -323,7 +326,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir):
model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = []
output_images = []
......@@ -492,8 +495,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample)
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
image = upscaler.upscale(image, self.width, self.height)
image = images.resize_image(0, image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
batch_images.append(image)
......
import os
import sys
import traceback
from collections import namedtuple
import numpy as np
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
import modules.images
from modules.upscaler import Upscaler, UpscalerData
from modules.paths import models_path
from modules.shared import cmd_opts, opts
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
realesrgan_models = []
have_realesrgan = False
class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
self.name = "RealESRGAN"
self.model_path = os.path.join(models_path, self.name)
self.user_path = path
super().__init__()
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
self.enable = True
self.scalers = []
scalers = self.load_models(path)
for scaler in scalers:
if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler)
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
self.enable = False
self.scalers = []
def do_upscale(self, img, path):
if not self.enable:
return img
info = self.load_model(path)
if not os.path.exists(info.data_path):
print("Unable to load RealESRGAN model: %s" % info.name)
return img
upsampler = RealESRGANer(
scale=info.scale,
model_path=info.data_path,
model=info.model(),
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
)
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
image = Image.fromarray(upsampled)
return image
def load_model(self, path):
try:
info = None
for scaler in self.scalers:
if scaler.data_path == path:
info = scaler
if info is None:
print(f"Unable to find model info: {path}")
return None
model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
info.data_path = model_file
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return None
def get_realesrgan_models():
def load_models(self, _):
return get_realesrgan_models(self)
def get_realesrgan_models(scaler):
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
models = [
RealesrganModelInfo(
name="Real-ESRGAN General x4x3",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
netscale=4,
UpscalerData(
name="R-ESRGAN General 4xV3",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
scale=4,
upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
),
RealesrganModelInfo(
name="Real-ESRGAN General WDN x4x3",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
netscale=4,
UpscalerData(
name="R-ESRGAN General WDN 4xV3",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
scale=4,
upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
),
RealesrganModelInfo(
name="Real-ESRGAN AnimeVideo",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
netscale=4,
UpscalerData(
name="R-ESRGAN AnimeVideo",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
scale=4,
upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
),
RealesrganModelInfo(
name="Real-ESRGAN 4x plus",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
netscale=4,
UpscalerData(
name="R-ESRGAN 4x+",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
scale=4,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
),
RealesrganModelInfo(
name="Real-ESRGAN 4x plus anime 6B",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
netscale=4,
UpscalerData(
name="R-ESRGAN 4x+ Anime6B",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
scale=4,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
),
RealesrganModelInfo(
name="Real-ESRGAN 2x plus",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
netscale=2,
UpscalerData(
name="R-ESRGAN 2x+",
path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
scale=2,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
),
]
return models
except Exception as e:
print("Error makeing Real-ESRGAN midels list:", file=sys.stderr)
print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
class UpscalerRealESRGAN(modules.images.Upscaler):
def __init__(self, upscaling, model_index):
self.upscaling = upscaling
self.model_index = model_index
self.name = realesrgan_models[model_index].name
def do_upscale(self, img):
return upscale_with_realesrgan(img, self.upscaling, self.model_index)
def setup_realesrgan():
global realesrgan_models
global have_realesrgan
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
realesrgan_models = get_realesrgan_models()
have_realesrgan = True
for i, model in enumerate(realesrgan_models):
if model.name in opts.realesrgan_enabled_models:
modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
have_realesrgan = False
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
if not have_realesrgan:
return image
info = realesrgan_models[RealESRGAN_model_index]
model = info.model()
upsampler = RealESRGANer(
scale=info.netscale,
model_path=info.location,
model=model,
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
)
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
image = Image.fromarray(upsampled)
return image
......@@ -6,253 +6,51 @@ import torch
import numpy as np
from torch import einsum
from modules import prompt_parser
import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules.shared import opts, device, cmd_opts
from ldm.util import default
from einops import rearrange
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
def apply_optimizations():
if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
def undo_optimizations():
ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def nonlinearity_hijack(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
class StableDiffusionModelHijack:
ids_lookup = {}
word_embeddings = {}
word_embeddings_checksums = {}
fixes = None
comments = []
dir_mtime = None
layers = None
circular_enabled = False
clip = None
def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
tokenizer = model.cond_stage_model.tokenizer
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
self.word_embeddings[name] = emb.detach().to(device)
self.word_embeddings_checksums[name] = f'{const_hash(emb.reshape(-1)*100)&0xffff:04x}'
ids = tokenizer([name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, name))
for fn in os.listdir(dirname):
try:
process_file(os.path.join(dirname, fn), fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
ldm.modules.diffusionmodules.model.nonlinearity = nonlinearity_hijack
ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward
apply_optimizations()
def flatten(el):
flattened = [flatten(children) for children in el.children()]
......@@ -263,6 +61,14 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
def apply_circular(self, enable):
if self.circular_enabled == enable:
return
......@@ -282,7 +88,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
super().__init__()
self.wrapped = wrapped
self.hijack = hijack
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
self.max_length = wrapped.max_length
self.token_mults = {}
......@@ -303,7 +109,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0:
self.token_mults[ident] = mult
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
......@@ -325,28 +130,19 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None)
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if possible_matches is None:
if embedding is None:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
else:
found = False
for ids, word in possible_matches:
if tokens[i:i + len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += emb_len
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
......@@ -417,32 +213,23 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens):
token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None)
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None:
mult *= mult_change
elif possible_matches is None:
i += 1
elif embedding is None:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
else:
found = False
for ids, word in possible_matches:
if tokens[i:i+len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(mult)
i += 1
emb_len = int(embedding.vec.shape[0])
fixes.append((len(remade_tokens), embedding))
remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum()))
i += emb_len
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
......@@ -450,6 +237,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
......@@ -470,7 +258,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments
......@@ -503,14 +290,19 @@ class EmbeddingsWithFixes(torch.nn.Module):
inputs_embeds = self.wrapped(input_ids)
if batch_fixes is not None:
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, word in fixes:
emb = self.embeddings.word_embeddings[word]
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor[offset+1:offset+1+emb_len] = self.embeddings.word_embeddings[word][0:emb_len]
if batch_fixes is None or len(batch_fixes) == 0 or max([len(x) for x in batch_fixes]) == 0:
return inputs_embeds
vecs = []
for fixes, tensor in zip(batch_fixes, inputs_embeds):
for offset, embedding in fixes:
emb = embedding.vec
emb_len = min(tensor.shape[0]-offset-1, emb.shape[0])
tensor = torch.cat([tensor[0:offset+1], emb[0:emb_len], tensor[offset+1+emb_len:]])
vecs.append(tensor)
return inputs_embeds
return torch.stack(vecs)
def add_circular_option_to_conv_2d():
......
import math
import torch
from torch import einsum
from ldm.util import default
from einops import rearrange
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
for i in range(0, q.shape[0], 2):
end = i + 2
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
# taken from https://github.com/Doggettx/stable-diffusion
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context) * self.scale
v_in = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1
if mem_required > mem_free_total:
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
def nonlinearity_hijack(x):
# swish
t = torch.sigmoid(x)
x *= t
del t
return x
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q1.shape
q2 = q1.reshape(b, c, h*w)
del q1
q = q2.permute(0, 2, 1) # b,hw,c
del q2
k = k1.reshape(b, c, h*w) # b,c,hw
del k1
h_ = torch.zeros_like(k, device=q.device)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4
h2 = h_.reshape(b, c, h, w)
del h_
h3 = self.proj_out(h2)
del h2
h3 += x
return h3
......@@ -8,7 +8,14 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from modules import shared
from modules import shared, modelloader, devices
from modules.paths import models_path
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
model_name = "sd-v1-4.ckpt"
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
user_dir = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
......@@ -23,20 +30,30 @@ except Exception:
pass
def setup_model(dirname):
global user_dir
user_dir = dirname
if not os.path.exists(model_path):
os.makedirs(model_path)
checkpoints_list.clear()
list_models()
def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()])
def list_models():
checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir)
def modeltitle(path, h):
def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
if abspath.startswith(model_dir):
name = abspath.replace(model_dir, '')
if user_dir is not None and abspath.startswith(user_dir):
name = abspath.replace(user_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else:
name = os.path.basename(path)
......@@ -45,21 +62,27 @@ def list_models():
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
return f'{name} [{h}]', shortname
return f'{name} [{shorthash}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title, model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.sd_model_checkpoint = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
if os.path.exists(model_dir):
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
h = model_hash(filename)
title, model_name = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
def get_closet_checkpoint_match(searchString):
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
if len(applicable) > 0:
return applicable[0]
return None
def model_hash(filename):
......@@ -111,6 +134,8 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):
if not shared.cmd_opts.no_half:
model.half()
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
model.sd_model_hash = sd_model_hash
model.sd_model_checkpint = checkpoint_file
......@@ -137,7 +162,7 @@ def load_model():
def reload_model_weights(sd_model, info=None):
from modules import lowvram, devices
from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename:
......@@ -148,8 +173,12 @@ def reload_model_weights(sd_model, info=None):
else:
sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
sd_hijack.model_hijack.hijack(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device)
......
......@@ -4,7 +4,6 @@ import torch
import tqdm
from PIL import Image
import inspect
import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
......@@ -23,6 +22,8 @@ samplers_k_diffusion = [
('Heun', 'sample_heun', ['k_heun']),
('DPM2', 'sample_dpm_2', ['k_dpm_2']),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
]
samplers_data_k_diffusion = [
......@@ -36,7 +37,7 @@ samplers = [
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
......@@ -289,7 +290,10 @@ class KDiffusionSampler:
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
steps, t_enc = setup_img2img_steps(p, steps)
sigmas = self.model_wrap.get_sigmas(steps)
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
else:
sigmas = self.model_wrap.get_sigmas(steps)
noise = noise * sigmas[steps - t_enc - 1]
xi = x + noise
......@@ -305,12 +309,20 @@ class KDiffusionSampler:
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps
sigmas = self.model_wrap.get_sigmas(steps)
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
else:
sigmas = self.model_wrap.get_sigmas(steps)
x = x * sigmas[0]
extra_params_kwargs = self.initialize(p)
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
if 'sigma_min' in inspect.signature(self.func).parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
if 'n' in inspect.signature(self.func).parameters:
extra_params_kwargs['n'] = steps
else:
extra_params_kwargs['sigmas'] = sigmas
samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples
import sys
import argparse
import datetime
import json
import os
import sys
import gradio as gr
import tqdm
import datetime
import modules.artists
from modules.paths import script_path, sd_path
from modules.devices import get_optimal_device
import modules.styles
import modules.interrogate
import modules.memmon
import modules.sd_models
import modules.styles
from modules.devices import get_optimal_device
from modules.paths import script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
model_path = os.path.join(script_path, 'models')
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
......@@ -34,8 +35,13 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR'))
parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer'))
parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
......@@ -53,7 +59,6 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
cmd_opts = parser.parse_args()
device = get_optimal_device()
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
......@@ -61,6 +66,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
config_filename = cmd_opts.ui_settings_file
class State:
interrupted = False
job = ""
......@@ -72,6 +78,7 @@ class State:
current_latent = None
current_image = None
current_image_sampling_step = 0
textinfo = None
def interrupt(self):
self.interrupted = True
......@@ -82,7 +89,7 @@ class State:
self.current_image_sampling_step = 0
def get_job_timestamp(self):
return datetime.datetime.now().strftime("%Y%m%d%H%M%S")
return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
state = State()
......@@ -95,13 +102,13 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = []
modules.sd_models.list_models()
# This was moved to webui.py with the other model "setup" calls.
# modules.sd_models.list_models()
def realesrgan_models_names():
import modules.realesrgan_model
return [x.name for x in modules.realesrgan_model.get_realesrgan_models()]
return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
class OptionInfo:
......@@ -167,13 +174,10 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["Real-ESRGAN 4x plus", "Real-ESRGAN 4x plus anime 6B"], "Select which RealESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"ldsr_pre_down": OptionInfo(1, "LDSR Pre-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
}))
......@@ -190,9 +194,9 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
......
......@@ -53,6 +53,12 @@ class StyleDatabase:
negative_prompt = row.get("negative_prompt", "")
self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]
def get_negative_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
def apply_styles_to_prompt(self, prompt, styles):
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
......
import sys
import traceback
import cv2
import os
import contextlib
import numpy as np
from PIL import Image
import torch
import modules.images
from modules.shared import cmd_opts, opts, device
from modules.swinir_arch import SwinIR as net
precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
)
def load_model(filename, scale=4):
model = net(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
pretrained_model = torch.load(filename)
model.load_state_dict(pretrained_model["params_ema"], strict=True)
if not cmd_opts.no_half:
model = model.half()
return model
def load_models(dirname):
for file in os.listdir(dirname):
path = os.path.join(dirname, file)
model_name, extension = os.path.splitext(file)
if extension != ".pt" and extension != ".pth":
continue
try:
modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name))
except Exception:
print(f"Error loading SwinIR model: {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def upscale(
img,
model,
tile=opts.SWIN_tile,
tile_overlap=opts.SWIN_tile_overlap,
window_size=8,
scale=4,
):
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device)
with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference(img, model, tile, tile_overlap, window_size, scale)
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(
output[[2, 1, 0], :, :], (1, 2, 0)
) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(output, "RGB")
def inference(img, model, tile, tile_overlap, window_size, scale):
# test the image tile by tile
b, c, h, w = img.size()
tile = min(tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
sf = scale
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=device)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf
].add_(out_patch_mask)
output = E.div_(W)
return output
class UpscalerSwin(modules.images.Upscaler):
def __init__(self, filename, title):
self.name = title
self.model = load_model(filename)
def do_upscale(self, img):
model = self.model.to(device)
img = upscale(img, model)
return img
import contextlib
import os
import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm
from modules import modelloader
from modules.paths import models_path
from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net
from modules.upscaler import Upscaler, UpscalerData
precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
)
class UpscalerSwinIR(Upscaler):
def __init__(self, dirname):
self.name = "SwinIR"
self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
"-L_x4_GAN.pth "
self.model_name = "SwinIR 4x"
self.model_path = os.path.join(models_path, self.name)
self.user_path = dirname
super().__init__()
scalers = []
model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files:
if "http" in model:
name = self.model_name
else:
name = modelloader.friendly_name(model)
model_data = UpscalerData(name, model, self)
scalers.append(model_data)
self.scalers = scalers
def do_upscale(self, img, model_file):
model = self.load_model(model_file)
if model is None:
return img
model = model.to(device)
img = upscale(img, model)
try:
torch.cuda.empty_cache()
except:
pass
return img
def load_model(self, path, scale=4):
if "http" in path:
dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
else:
filename = path
if filename is None or not os.path.exists(filename):
return None
model = net(
upscale=scale,
in_chans=3,
img_size=64,
window_size=8,
img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=240,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler="nearest+conv",
resi_connection="3conv",
)
pretrained_model = torch.load(filename)
model.load_state_dict(pretrained_model["params_ema"], strict=True)
if not cmd_opts.no_half:
model = model.half()
return model
def upscale(
img,
model,
tile=opts.SWIN_tile,
tile_overlap=opts.SWIN_tile_overlap,
window_size=8,
scale=4,
):
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(device)
with torch.no_grad(), precision_scope("cuda"):
_, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
output = inference(img, model, tile, tile_overlap, window_size, scale)
output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(
output[[2, 1, 0], :, :], (1, 2, 0)
) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(output, "RGB")
def inference(img, model, tile, tile_overlap, window_size, scale):
# test the image tile by tile
b, c, h, w = img.size()
tile = min(tile, h, w)
assert tile % window_size == 0, "tile size should be a multiple of window_size"
sf = scale
stride = tile - tile_overlap
h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
W = torch.zeros_like(E, dtype=torch.half, device=device)
with tqdm(total=len(h_idx_list) * len(w_idx_list), desc="SwinIR tiles") as pbar:
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
pbar.update(1)
output = E.div_(W)
return output
# -----------------------------------------------------------------------------------
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
# Originally Written by Ze Liu, Modified by Jingyun Liang.
# -----------------------------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
img_size=224, patch_size=4, resi_connection='1conv'):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
def forward(self, x, x_size):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
H, W = self.img_size
if self.norm is not None:
flops += H * W * self.embed_dim
return flops
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
class SwinIR(nn.Module):
r""" SwinIR
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(SwinIR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
#####################################################################################################
################################### 1, shallow feature extraction ###################################
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
#####################################################################################################
################################### 2, deep feature extraction ######################################
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(dim=embed_dim,
input_resolution=(patches_resolution[0],
patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
#####################################################################################################
################################ 3, high quality image reconstruction ################################
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
if self.upscale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
return x
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
if self.upscale == 4:
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
return x[:, :, :H*self.upscale, :W*self.upscale]
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
if __name__ == '__main__':
upscale = 4
window_size = 8
height = (1024 // upscale // window_size + 1) * window_size
width = (720 // upscale // window_size + 1) * window_size
model = SwinIR(upscale=2, img_size=(height, width),
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
print(model)
print(height, width, model.flops() / 1e9)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
# -----------------------------------------------------------------------------------
# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
# Originally Written by Ze Liu, Modified by Jingyun Liang.
# -----------------------------------------------------------------------------------
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
attn_mask = self.calculate_mask(self.input_resolution)
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
# assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
if self.input_resolution == x_size:
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RSTB(nn.Module):
"""Residual Swin Transformer Block (RSTB).
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
img_size: Input image size.
patch_size: Patch size.
resi_connection: The convolutional block before residual connection.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False,
img_size=224, patch_size=4, resi_connection='1conv'):
super(RSTB, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.residual_group = BasicLayer(dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
norm_layer=None)
def forward(self, x, x_size):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
H, W = self.img_size
if self.norm is not None:
flops += H * W * self.embed_dim
return flops
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops
class SwinIR(nn.Module):
r""" SwinIR
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
Args:
img_size (int | tuple(int)): Input image size. Default 64
patch_size (int | tuple(int)): Patch size. Default: 1
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
img_range: Image range. 1. or 255.
upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
def __init__(self, img_size=64, patch_size=1, in_chans=3,
embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv',
**kwargs):
super(SwinIR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
num_feat = 64
self.img_range = img_range
if in_chans == 3:
rgb_mean = (0.4488, 0.4371, 0.4040)
self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
else:
self.mean = torch.zeros(1, 1, 1, 1)
self.upscale = upscale
self.upsampler = upsampler
self.window_size = window_size
#####################################################################################################
################################### 1, shallow feature extraction ###################################
self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
#####################################################################################################
################################### 2, deep feature extraction ######################################
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build Residual Swin Transformer blocks (RSTB)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = RSTB(dim=embed_dim,
input_resolution=(patches_resolution[0],
patches_resolution[1]),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
resi_connection=resi_connection
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# build the last conv layer in deep feature extraction
if resi_connection == '1conv':
self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
#####################################################################################################
################################ 3, high quality image reconstruction ################################
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
(patches_resolution[0], patches_resolution[1]))
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
nn.LeakyReLU(inplace=True))
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
if self.upscale == 4:
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
else:
# for image denoising and JPEG compression artifact reduction
self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def check_image_size(self, x):
_, _, h, w = x.size()
mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
return x
def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
H, W = x.shape[2:]
x = self.check_image_size(x)
self.mean = self.mean.type_as(x)
x = (x - self.mean) * self.img_range
if self.upsampler == 'pixelshuffle':
# for classical SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.conv_last(self.upsample(x))
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.upsample(x)
elif self.upsampler == 'nearest+conv':
# for real-world SR
x = self.conv_first(x)
x = self.conv_after_body(self.forward_features(x)) + x
x = self.conv_before_upsample(x)
x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
if self.upscale == 4:
x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
x = self.conv_last(self.lrelu(self.conv_hr(x)))
else:
# for image denoising and JPEG compression artifact reduction
x_first = self.conv_first(x)
res = self.conv_after_body(self.forward_features(x_first)) + x_first
x = x + self.conv_last(res)
x = x / self.img_range + self.mean
return x[:, :, :H*self.upscale, :W*self.upscale]
def flops(self):
flops = 0
H, W = self.patches_resolution
flops += H * W * 3 * self.embed_dim * 9
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
flops += H * W * 3 * self.embed_dim * self.embed_dim
flops += self.upsample.flops()
return flops
if __name__ == '__main__':
upscale = 4
window_size = 8
height = (1024 // upscale // window_size + 1) * window_size
width = (720 // upscale // window_size + 1) * window_size
model = SwinIR(upscale=2, img_size=(height, width),
window_size=window_size, img_range=1., depths=[6, 6, 6, 6],
embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
print(model)
print(height, width, model.flops() / 1e9)
x = torch.randn((1, 3, height, width))
x = model(x)
print(x.shape)
import os
import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
import random
import tqdm
class PersonalizedBase(Dataset):
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
self.placeholder_token = placeholder_token
self.size = size
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
self.dataset = []
with open(template_file, "r") as file:
lines = [x.strip() for x in file.readlines()]
self.lines = lines
assert data_root, 'dataset directory not specified'
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
image = Image.open(path)
image = image.convert('RGB')
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
filename_tokens = [token for token in filename_tokens if token.isalpha()]
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
torchdata = torch.moveaxis(torchdata, 2, 0)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
self.dataset.append((init_latent, filename_tokens))
self.length = len(self.dataset) * repeats
self.initial_indexes = np.arange(self.length) % len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
def __len__(self):
return self.length
def __getitem__(self, i):
if i % len(self.dataset) == 0:
self.shuffle()
index = self.indexes[i % len(self.indexes)]
x, filename_tokens = self.dataset[index]
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens))
return x, text
import os
import sys
import traceback
import torch
import tqdm
import html
import datetime
from modules import shared, devices, sd_hijack, processing
import modules.textual_inversion.dataset
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
self.name = name
self.step = step
self.cached_checksum = None
def save(self, filename):
embedding_data = {
"string_to_token": {"*": 265},
"string_to_param": {"*": self.vec},
"name": self.name,
"step": self.step,
}
torch.save(embedding_data, filename)
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
def const_hash(a):
r = 0
for v in a:
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
return r
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum
class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
first_id = ids[0]
if first_id not in self.ids_lookup:
self.ids_lookup[first_id] = []
self.ids_lookup[first_id].append((ids, embedding))
return embedding
def load_textual_inversion_embeddings(self):
mt = os.path.getmtime(self.embeddings_dir)
if self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
param_dict = data['string_to_param']
if hasattr(param_dict, '_parameters'):
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
emb = next(iter(data.values()))
if len(emb.shape) == 1:
emb = emb.unsqueeze(0)
else:
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
self.register_embedding(embedding, shared.sd_model)
for fn in os.listdir(self.embeddings_dir):
try:
fullfn = os.path.join(self.embeddings_dir, fn)
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None:
return None
for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids:
return embedding
return None
def create_embedding(name, num_vectors_per_token):
init_text = '*'
cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token):
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
embedding.save(fn)
return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%d-%m"), embedding_name)
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
os.makedirs(embedding_dir, exist_ok=True)
else:
embedding_dir = None
if create_image_every > 0:
images_dir = os.path.join(log_directory, "images")
os.makedirs(images_dir, exist_ok=True)
else:
images_dir = None
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
losses = torch.zeros((32,))
last_saved_file = "<none>"
last_saved_image = "<none>"
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
if embedding.step > steps:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
c = cond_model([text])
loss = shared.sd_model(x.unsqueeze(0), c)[0]
losses[embedding.step % losses.shape[0]] = loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
pbar.set_description(f"loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
prompt=text,
steps=20,
do_not_save_grid=True,
do_not_save_samples=True,
)
processed = processing.process_images(p)
image = processed.images[0]
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {text}"
shared.state.job_no = embedding.step
shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
Last prompt: {html.escape(text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
embedding.cached_checksum = None
embedding.save(filename)
return embedding, filename
import html
import gradio as gr
import modules.textual_inversion.textual_inversion as ti
from modules import sd_hijack, shared
def create_embedding(name, nvpt):
filename = ti.create_embedding(name, nvpt)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
def train_embedding(*args):
try:
sd_hijack.undo_optimizations()
embedding, filename = ti.train_embedding(*args)
res = f"""
Training {'interrupted' if shared.state.interrupted else 'finished'} after {embedding.step} steps.
Embedding saved to {html.escape(filename)}
"""
return res, ""
except Exception:
raise
finally:
sd_hijack.apply_optimizations()
......@@ -15,11 +15,13 @@ import subprocess as sp
import numpy as np
import torch
from PIL import Image, PngImagePlugin
import piexif
import gradio as gr
import gradio.utils
import gradio.routes
from modules import sd_hijack
from modules.paths import script_path
from modules.shared import opts, cmd_opts
import modules.shared as shared
......@@ -32,6 +34,7 @@ import modules.codeformer_model
import modules.styles
import modules.generation_parameters_copypaste
from modules.images import apply_filename_pattern, get_next_sequence_number
import modules.textual_inversion.ui
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init()
......@@ -129,27 +132,37 @@ def save_files(js_data, images, index):
writer = csv.writer(file)
if at_start:
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
if file_decoration != "":
file_decoration = "-" + file_decoration.lower()
file_decoration = apply_filename_pattern(file_decoration, p, p.seed, p.prompt)
truncated = (file_decoration[:240] + '..') if len(file_decoration) > 240 else file_decoration
filename_base = truncated
extension = opts.samples_format.lower()
basecount = get_next_sequence_number(path, "")
for i, filedata in enumerate(images):
file_number = f"{basecount+i:05}"
filename = file_number + filename_base + ".png"
filename = file_number + filename_base + f".{extension}"
filepath = os.path.join(path, filename)
if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):]
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text('parameters', infotexts[i])
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
image.save(filepath, quality=opts.jpeg_quality, pnginfo=pnginfo)
if opts.enable_pnginfo and extension == 'png':
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text('parameters', infotexts[i])
image.save(filepath, pnginfo=pnginfo)
else:
image.save(filepath, quality=opts.jpeg_quality)
if opts.enable_pnginfo and extension in ("jpg", "jpeg", "webp"):
piexif.insert(piexif.dump({"Exif": {
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(infotexts[i], encoding="unicode")
}}), filepath)
filenames.append(filename)
......@@ -158,8 +171,8 @@ def save_files(js_data, images, index):
return '', '', plaintext_to_html(f"Saved: {filenames[0]}")
def wrap_gradio_call(func):
def f(*args, **kwargs):
def wrap_gradio_call(func, extra_outputs=None):
def f(*args, extra_outputs_array=extra_outputs, **kwargs):
run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled
if run_memmon:
shared.mem_mon.monitor()
......@@ -175,7 +188,10 @@ def wrap_gradio_call(func):
shared.state.job = ""
shared.state.job_count = 0
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t
......@@ -195,6 +211,7 @@ def wrap_gradio_call(func):
res[-1] += f"<div class='performance'><p class='time'>Time taken: <wbr>{elapsed:.2f}s</p>{vram_html}</div>"
shared.state.interrupted = False
shared.state.job_count = 0
return tuple(res)
......@@ -203,7 +220,7 @@ def wrap_gradio_call(func):
def check_progress_call(id_part):
if shared.state.job_count == 0:
return "", gr_show(False), gr_show(False)
return "", gr_show(False), gr_show(False), gr_show(False)
progress = 0
......@@ -235,13 +252,19 @@ def check_progress_call(id_part):
else:
preview_visibility = gr_show(True)
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
if shared.state.textinfo is not None:
textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True)
else:
textinfo_result = gr_show(False)
return f"<span id='{id_part}_progress_span' style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image, textinfo_result
def check_progress_call_initial(id_part):
shared.state.job_count = -1
shared.state.current_latent = None
shared.state.current_image = None
shared.state.textinfo = None
return check_progress_call(id_part)
......@@ -396,7 +419,7 @@ def create_toprow(is_img2img):
with gr.Column(scale=1):
with gr.Row():
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
submit = gr.Button('Generate', elem_id="generate", variant='primary')
submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
interrupt.click(
fn=lambda: shared.state.interrupt(),
......@@ -415,13 +438,16 @@ def create_toprow(is_img2img):
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste
def setup_progressbar(progressbar, preview, id_part):
def setup_progressbar(progressbar, preview, id_part, textinfo=None):
if textinfo is None:
textinfo = gr.HTML(visible=False)
check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False)
check_progress.click(
fn=lambda: check_progress_call(id_part),
show_progress=False,
inputs=[],
outputs=[progressbar, preview, preview],
outputs=[progressbar, preview, preview, textinfo],
)
check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False)
......@@ -429,11 +455,14 @@ def setup_progressbar(progressbar, preview, id_part):
fn=lambda: check_progress_call_initial(id_part),
show_progress=False,
inputs=[],
outputs=[progressbar, preview, preview],
outputs=[progressbar, preview, preview, textinfo],
)
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
def create_ui(wrap_gradio_gpu_call):
import modules.img2img
import modules.txt2img
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
......@@ -499,7 +528,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)
txt2img_args = dict(
fn=txt2img,
fn=wrap_gradio_gpu_call(modules.txt2img.txt2img),
_js="submit",
inputs=[
txt2img_prompt,
......@@ -615,7 +644,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index")
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
with gr.Row():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
......@@ -691,7 +720,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
)
img2img_args = dict(
fn=img2img,
fn=wrap_gradio_gpu_call(modules.img2img.img2img),
_js="submit_img2img",
inputs=[
dummy_component,
......@@ -844,7 +873,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click(
fn=run_extras,
fn=wrap_gradio_gpu_call(modules.extras.run_extras),
_js="get_extras_tab_index",
inputs=[
dummy_component,
......@@ -894,7 +923,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
pnginfo_send_to_img2img = gr.Button('Send to img2img')
image.change(
fn=wrap_gradio_call(run_pnginfo),
fn=wrap_gradio_call(modules.extras.run_pnginfo),
inputs=[image],
outputs=[html, generation_info, html2],
)
......@@ -903,7 +932,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
gr.HTML(value="<p>A merger of the two checkpoints will be generated in your <b>checkpoint</b> directory.</p>")
with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
......@@ -912,10 +941,96 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Safe as float16")
modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
with gr.Column(variant='panel'):
submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
with gr.Blocks() as textual_inversion_interface:
with gr.Row().style(equal_height=False):
with gr.Column():
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
new_embedding_name = gr.Textbox(label="Name")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
with gr.Row():
with gr.Column(scale=3):
gr.HTML(value="")
with gr.Column():
create_embedding = gr.Button(value="Create", variant='primary')
with gr.Group():
gr.HTML(value="<p style='margin-bottom: 0.7em'>Train an embedding; must specify a directory with a set of 512x512 images</p>")
train_embedding_name = gr.Dropdown(label='Embedding', choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys()))
learn_rate = gr.Number(label='Learning rate', value=5.0e-03)
dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images")
log_directory = gr.Textbox(label='Log directory', placeholder="Path to directory where to write outputs", value="textual_inversion")
template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"))
steps = gr.Number(label='Max steps', value=100000, precision=0)
create_image_every = gr.Number(label='Save an image to log directory every N steps, 0 to disable', value=1000, precision=0)
save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=1000, precision=0)
with gr.Row():
with gr.Column(scale=2):
gr.HTML(value="")
with gr.Column():
with gr.Row():
interrupt_training = gr.Button(value="Interrupt")
train_embedding = gr.Button(value="Train", variant='primary')
with gr.Column():
progressbar = gr.HTML(elem_id="ti_progressbar")
ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)
ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4)
ti_preview = gr.Image(elem_id='ti_preview', visible=False)
ti_progress = gr.HTML(elem_id="ti_progress", value="")
ti_outcome = gr.HTML(elem_id="ti_error", value="")
setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress)
create_embedding.click(
fn=modules.textual_inversion.ui.create_embedding,
inputs=[
new_embedding_name,
nvpt,
],
outputs=[
train_embedding_name,
ti_output,
ti_outcome,
]
)
train_embedding.click(
fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]),
_js="start_training_textual_inversion",
inputs=[
train_embedding_name,
learn_rate,
dataset_directory,
log_directory,
steps,
create_image_every,
save_embedding_every,
template_file,
],
outputs=[
ti_output,
ti_outcome,
]
)
interrupt_training.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
outputs=[],
)
def create_setting_component(key):
def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default
......@@ -1027,6 +1142,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(extras_interface, "Extras", "extras"),
(pnginfo_interface, "PNG Info", "pnginfo"),
(modelmerger_interface, "Checkpoint Merger", "modelmerger"),
(textual_inversion_interface, "Textual inversion", "ti"),
(settings_interface, "Settings", "settings"),
]
......@@ -1060,11 +1176,11 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
def modelmerger(*args):
try:
results = run_modelmerger(*args)
results = modules.extras.run_modelmerger(*args)
except Exception as e:
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() #To remove the potentially missing models from the list
modules.sd_models.list_models() # to remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
return results
......
import os
from abc import abstractmethod
import PIL
import numpy as np
import torch
from PIL import Image
import modules.shared
from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
from modules.paths import models_path
class Upscaler:
name = None
model_path = None
model_name = None
model_url = None
enable = True
filter = None
model = None
user_path = None
scalers: []
tile = True
def __init__(self, create_dirs=False):
self.mod_pad_h = None
self.tile_size = modules.shared.opts.ESRGAN_tile
self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
self.device = modules.shared.device
self.img = None
self.output = None
self.scale = 1
self.half = not modules.shared.cmd_opts.no_half
self.pre_pad = 0
self.mod_scale = None
if self.name is not None and create_dirs:
self.model_path = os.path.join(models_path, self.name)
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
try:
import cv2
self.can_tile = True
except:
pass
@abstractmethod
def do_upscale(self, img: PIL.Image, selected_model: str):
return img
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
self.scale = scale
dest_w = img.width * scale
dest_h = img.height * scale
for i in range(3):
if img.width >= dest_w and img.height >= dest_h:
break
img = self.do_upscale(img, selected_model)
if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
return img
@abstractmethod
def load_model(self, path: str):
pass
def find_models(self, ext_filter=None) -> list:
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
def update_status(self, prompt):
print(f"\nextras: {prompt}", file=shared.progress_print_out)
class UpscalerData:
name = None
data_path = None
scale: int = 4
scaler: Upscaler = None
model: None
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
self.name = name
self.data_path = path
self.scaler = upscaler
self.scale = scale
self.model = model
class UpscalerNone(Upscaler):
name = "None"
scalers = []
def load_model(self, path):
pass
def do_upscale(self, img, selected_model=None):
return img
def __init__(self, dirname=None):
super().__init__(False)
self.scalers = [UpscalerData("None", None, self)]
class UpscalerLanczos(Upscaler):
scalers = []
def do_upscale(self, img, selected_model=None):
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
def load_model(self, _):
pass
def __init__(self, dirname=None):
super().__init__(False)
self.name = "Lanczos"
self.scalers = [UpscalerData("Lanczos", None, self)]
......@@ -11,46 +11,8 @@ from modules import images, processing, devices
from modules.processing import Processed, process_images
from modules.shared import opts, cmd_opts, state
# https://github.com/parlance-zz/g-diffuser-bot
def expand(x, dir, amount, power=0.75):
is_left = dir == 3
is_right = dir == 1
is_up = dir == 0
is_down = dir == 2
if is_left or is_right:
noise = np.zeros((x.shape[0], amount, 3), dtype=float)
indexes = np.random.random((x.shape[0], amount)) ** power * (1 - np.arange(amount) / amount)
if is_right:
indexes = 1 - indexes
indexes = (indexes * (x.shape[1] - 1)).astype(int)
for row in range(x.shape[0]):
if is_left:
noise[row] = x[row][indexes[row]]
else:
noise[row] = np.flip(x[row][indexes[row]], axis=0)
x = np.concatenate([noise, x] if is_left else [x, noise], axis=1)
return x
if is_up or is_down:
noise = np.zeros((amount, x.shape[1], 3), dtype=float)
indexes = np.random.random((x.shape[1], amount)) ** power * (1 - np.arange(amount) / amount)
if is_down:
indexes = 1 - indexes
indexes = (indexes * x.shape[0] - 1).astype(int)
for row in range(x.shape[1]):
if is_up:
noise[:, row] = x[:, row][indexes[row]]
else:
noise[:, row] = np.flip(x[:, row][indexes[row]], axis=0)
x = np.concatenate([noise, x] if is_up else [x, noise], axis=0)
return x
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
# helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data):
......
......@@ -34,7 +34,7 @@ class Script(scripts.Script):
seed = p.seed
init_img = p.init_images[0]
img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
devices.torch_gc()
......
......@@ -45,11 +45,8 @@ def apply_sampler(p, x, xs):
def apply_checkpoint(p, x, xs):
applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title]
assert len(applicable) > 0, f'Checkpoint {x} for found'
info = applicable[0]
info = modules.sd_models.get_closet_checkpoint_match(x)
assert info is not None, f'Checkpoint for {x} not found'
modules.sd_models.reload_model_weights(shared.sd_model, info)
......@@ -159,6 +156,9 @@ class Script(scripts.Script):
p.batch_size = 1
def process_axis(opt, vals):
if opt.label == 'Nothing':
return [0]
valslist = [x.strip() for x in vals.split(",")]
if opt.type == int:
......
......@@ -23,7 +23,7 @@
text-align: right;
}
#generate{
#txt2img_generate, #img2img_generate {
min-height: 4.5em;
}
......@@ -157,7 +157,7 @@ button{
max-width: 10em;
}
#txt2img_preview, #img2img_preview{
#txt2img_preview, #img2img_preview, #ti_preview{
position: absolute;
width: 320px;
left: 0;
......@@ -172,18 +172,18 @@ button{
}
@media screen and (min-width: 768px) {
#txt2img_preview, #img2img_preview {
#txt2img_preview, #img2img_preview, #ti_preview {
position: absolute;
}
}
@media screen and (max-width: 767px) {
#txt2img_preview, #img2img_preview {
#txt2img_preview, #img2img_preview, #ti_preview {
position: relative;
}
}
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0, #ti_preview div.left-0.top-0{
display: none;
}
......@@ -247,7 +247,7 @@ input[type="range"]{
#txt2img_negative_prompt, #img2img_negative_prompt{
}
#txt2img_progressbar, #img2img_progressbar{
#txt2img_progressbar, #img2img_progressbar, #ti_progressbar{
position: absolute;
z-index: 1000;
right: 0;
......
a painting, art by [name]
a rendering, art by [name]
a cropped painting, art by [name]
the painting, art by [name]
a clean painting, art by [name]
a dirty painting, art by [name]
a dark painting, art by [name]
a picture, art by [name]
a cool painting, art by [name]
a close-up painting, art by [name]
a bright painting, art by [name]
a cropped painting, art by [name]
a good painting, art by [name]
a close-up painting, art by [name]
a rendition, art by [name]
a nice painting, art by [name]
a small painting, art by [name]
a weird painting, art by [name]
a large painting, art by [name]
a painting of [filewords], art by [name]
a rendering of [filewords], art by [name]
a cropped painting of [filewords], art by [name]
the painting of [filewords], art by [name]
a clean painting of [filewords], art by [name]
a dirty painting of [filewords], art by [name]
a dark painting of [filewords], art by [name]
a picture of [filewords], art by [name]
a cool painting of [filewords], art by [name]
a close-up painting of [filewords], art by [name]
a bright painting of [filewords], art by [name]
a cropped painting of [filewords], art by [name]
a good painting of [filewords], art by [name]
a close-up painting of [filewords], art by [name]
a rendition of [filewords], art by [name]
a nice painting of [filewords], art by [name]
a small painting of [filewords], art by [name]
a weird painting of [filewords], art by [name]
a large painting of [filewords], art by [name]
a photo of a [name]
a rendering of a [name]
a cropped photo of the [name]
the photo of a [name]
a photo of a clean [name]
a photo of a dirty [name]
a dark photo of the [name]
a photo of my [name]
a photo of the cool [name]
a close-up photo of a [name]
a bright photo of the [name]
a cropped photo of a [name]
a photo of the [name]
a good photo of the [name]
a photo of one [name]
a close-up photo of the [name]
a rendition of the [name]
a photo of the clean [name]
a rendition of a [name]
a photo of a nice [name]
a good photo of a [name]
a photo of the nice [name]
a photo of the small [name]
a photo of the weird [name]
a photo of the large [name]
a photo of a cool [name]
a photo of a small [name]
a photo of a [name], [filewords]
a rendering of a [name], [filewords]
a cropped photo of the [name], [filewords]
the photo of a [name], [filewords]
a photo of a clean [name], [filewords]
a photo of a dirty [name], [filewords]
a dark photo of the [name], [filewords]
a photo of my [name], [filewords]
a photo of the cool [name], [filewords]
a close-up photo of a [name], [filewords]
a bright photo of the [name], [filewords]
a cropped photo of a [name], [filewords]
a photo of the [name], [filewords]
a good photo of the [name], [filewords]
a photo of one [name], [filewords]
a close-up photo of the [name], [filewords]
a rendition of the [name], [filewords]
a photo of the clean [name], [filewords]
a rendition of a [name], [filewords]
a photo of a nice [name], [filewords]
a good photo of a [name], [filewords]
a photo of the nice [name], [filewords]
a photo of the small [name], [filewords]
a photo of the weird [name], [filewords]
a photo of the large [name], [filewords]
a photo of a cool [name], [filewords]
a photo of a small [name], [filewords]
......@@ -21,6 +21,9 @@ export COMMANDLINE_ARGS=""
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
#venv_dir="venv"
# script to launch to start the app
#export LAUNCH_SCRIPT="launch.py"
# install command for torch
#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
......
......@@ -3,36 +3,34 @@ import threading
from modules import devices
from modules.paths import script_path
import signal
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.ui
import modules.scripts
import modules.sd_hijack
import modules.codeformer_model
import modules.gfpgan_model
import modules.face_restoration
import modules.realesrgan_model as realesrgan
import threading
import modules.paths
import modules.codeformer_model as codeformer
import modules.esrgan_model as esrgan
import modules.ldsr_model as ldsr
import modules.bsrgan_model as bsrgan
import modules.extras
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.ldsr_model as ldsr
import modules.lowvram
import modules.txt2img
import modules.img2img
import modules.swinir as swinir
import modules.realesrgan_model as realesrgan
import modules.scripts
import modules.sd_hijack
import modules.sd_models
import modules.shared as shared
import modules.swinir_model as swinir
import modules.ui
from modules import modelloader
from modules.paths import script_path
from modules.shared import cmd_opts
modules.codeformer_model.setup_codeformer()
modules.gfpgan_model.setup_gfpgan()
modelloader.cleanup_models()
modules.sd_models.setup_model(cmd_opts.ckpt_dir)
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
esrgan.load_models(cmd_opts.esrgan_models_path)
swinir.load_models(cmd_opts.swinir_models_path)
realesrgan.setup_realesrgan()
ldsr.add_lsdr()
modelloader.load_upscalers()
queue_lock = threading.Lock()
......@@ -46,7 +44,7 @@ def wrap_queued_call(func):
return f
def wrap_gradio_gpu_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
devices.torch_gc()
......@@ -58,6 +56,7 @@ def wrap_gradio_gpu_call(func):
shared.state.current_image = None
shared.state.current_image_sampling_step = 0
shared.state.interrupted = False
shared.state.textinfo = None
with queue_lock:
res = func(*args, **kwargs)
......@@ -69,7 +68,7 @@ def wrap_gradio_gpu_call(func):
return res
return modules.ui.wrap_gradio_call(f)
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
......@@ -86,13 +85,7 @@ def webui():
signal.signal(signal.SIGINT, sigint_handler)
demo = modules.ui.create_ui(
txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
run_pnginfo=modules.extras.run_pnginfo,
run_modelmerger=modules.extras.run_modelmerger
)
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
demo.launch(
share=cmd_opts.share,
......
......@@ -41,6 +41,11 @@ then
venv_dir="venv"
fi
if [[ -z "${LAUNCH_SCRIPT}" ]]
then
LAUNCH_SCRIPT="launch.py"
fi
# Disable sentry logging
export ERROR_REPORTING=FALSE
......@@ -133,4 +138,4 @@ fi
printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}"
"${python_cmd}" launch.py
"${python_cmd}" "${LAUNCH_SCRIPT}"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment