Commit ad1fbbae authored by Jairo Correa's avatar Jairo Correa

Merge branch 'master' into fix-vram

parents c2d5b290 84e97a98
__pycache__ __pycache__
/ESRGAN *.ckpt
*.pth
/ESRGAN/*
/SwinIR/*
/repositories /repositories
/venv /venv
/tmp /tmp
/model.ckpt /model.ckpt
/models/**/*.ckpt /models/**/*
/GFPGANv1.3.pth /GFPGANv1.3.pth
/gfpgan/weights/*.pth /gfpgan/weights/*.pth
/ui-config.json /ui-config.json
......
...@@ -3,50 +3,64 @@ A browser interface based on Gradio library for Stable Diffusion. ...@@ -3,50 +3,64 @@ A browser interface based on Gradio library for Stable Diffusion.
![](txt2img_Screenshot.png) ![](txt2img_Screenshot.png)
Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) wiki page for extra scripts developed by users.
## Features ## Features
[Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features): [Detailed feature showcase with images](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features):
- Original txt2img and img2img modes - Original txt2img and img2img modes
- One click install and run script (but you still must install python and git) - One click install and run script (but you still must install python and git)
- Outpainting - Outpainting
- Inpainting - Inpainting
- Prompt matrix - Prompt
- Stable Diffusion upscale - Stable Diffusion upscale
- Attention - Attention, specify parts of text that the model should pay more attention to
- Loopback - a man in a ((txuedo)) - will pay more attentinoto tuxedo
- X/Y plot - a man in a (txuedo:1.21) - alternative syntax
- Loopback, run img2img procvessing multiple times
- X/Y plot, a way to draw a 2 dimensional plot of images with different parameters
- Textual Inversion - Textual Inversion
- have as many embeddings as you want and use any names you like for them
- use multiple embeddings with different numbers of vectors per token
- works with half precision floating point numbers
- Extras tab with: - Extras tab with:
- GFPGAN, neural network that fixes faces - GFPGAN, neural network that fixes faces
- CodeFormer, face restoration tool as an alternative to GFPGAN - CodeFormer, face restoration tool as an alternative to GFPGAN
- RealESRGAN, neural network upscaler - RealESRGAN, neural network upscaler
- ESRGAN, neural network with a lot of third party models - ESRGAN, neural network upscaler with a lot of third party models
- SwinIR, neural network upscaler - SwinIR, neural network upscaler
- LDSR, Latent diffusion super resolution upscaling - LDSR, Latent diffusion super resolution upscaling
- Resizing aspect ratio options - Resizing aspect ratio options
- Sampling method selection - Sampling method selection
- Interrupt processing at any time - Interrupt processing at any time
- 4GB video card support - 4GB video card support (also reports of 2GB working)
- Correct seeds for batches - Correct seeds for batches
- Prompt length validation - Prompt length validation
- Generation parameters added as text to PNG - get length of prompt in tokensas you type
- Tab to view an existing picture's generation parameters - get a warning after geenration if some text was truncated
- Generation parameters
- parameters you used to generate images are saved with that image
- in PNG chunks for PNG, in EXIF for JPEG
- can drag the image to PNG info tab to restore generation parameters and automatically copy them into UI
- can be disabled in settings
- Settings page - Settings page
- Running custom code from UI - Running arbitrary python code from UI (must run with commandline flag to enable)
- Mouseover hints for most UI elements - Mouseover hints for most UI elements
- Possible to change defaults/mix/max/step values for UI elements via text config - Possible to change defaults/mix/max/step values for UI elements via text config
- Random artist button - Random artist button
- Tiling support: UI checkbox to create images that can be tiled like textures - Tiling support, a checkbox to create images that can be tiled like textures
- Progress bar and live image generation preview - Progress bar and live image generation preview
- Negative prompt - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image
- Styles - Styles, a way to save part of prompt and easily apply them via dropdown later
- Variations - Variations, a way to generate same image but with tiny differences
- Seed resizing - Seed resizing, a way to generate same image but at slightly different resolution
- CLIP interrogator - CLIP interrogator, a button that tries to guess prompt from an image
- Prompt Editing - Prompt Editing, a way to change prompt mid-generation, say to start making a watermelon and switch to anime girl midway
- Batch Processing - Batch Processing, process a group of files using img2img
- Img2img Alternative - Img2img Alternative
- Highres Fix - Highres Fix, a convenience option to produce high resolution pictures in one click without usual distortions
- LDSR Upscaling - Reloading checkpoints on the fly
- Checkpoint Merger, a tab that allows you to merge two checkpoints into one
- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
...@@ -83,6 +97,9 @@ bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusio ...@@ -83,6 +97,9 @@ bash <(wget -qO- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusio
Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon). Find the instructions [here](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Installation-on-Apple-Silicon).
## Contributing
Here's how to add code to this repo: [Contributing](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing)
## Documentation ## Documentation
The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki). The documentation was moved from this README over to the project's [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki).
......
...@@ -359,7 +359,6 @@ Antanas Sutkus,0.7369492,black-white ...@@ -359,7 +359,6 @@ Antanas Sutkus,0.7369492,black-white
Leonora Carrington,0.73726475,scribbles Leonora Carrington,0.73726475,scribbles
Hieronymus Bosch,0.7369955,scribbles Hieronymus Bosch,0.7369955,scribbles
A. J. Casson,0.73666203,scribbles A. J. Casson,0.73666203,scribbles
A.J.Casson,0.73666203,scribbles
Chaim Soutine,0.73662066,scribbles Chaim Soutine,0.73662066,scribbles
Artur Bordalo,0.7364549,weird Artur Bordalo,0.7364549,weird
Thomas Allom,0.68792284,fineart Thomas Allom,0.68792284,fineart
...@@ -1907,7 +1906,6 @@ Alex Schomburg,0.46614102,digipa-low-impact ...@@ -1907,7 +1906,6 @@ Alex Schomburg,0.46614102,digipa-low-impact
Bastien L. Deharme,0.583349,special Bastien L. Deharme,0.583349,special
František Jakub Prokyš,0.58782333,fineart František Jakub Prokyš,0.58782333,fineart
Jesper Ejsing,0.58782053,fineart Jesper Ejsing,0.58782053,fineart
Jesper Ejsing,0.58782053,fineart
Odd Nerdrum,0.53551745,digipa-high-impact Odd Nerdrum,0.53551745,digipa-high-impact
Tom Lovell,0.5877577,fineart Tom Lovell,0.5877577,fineart
Ayami Kojima,0.5877416,fineart Ayami Kojima,0.5877416,fineart
......
...@@ -15,6 +15,7 @@ titles = { ...@@ -15,6 +15,7 @@ titles = {
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.", "\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt into user interface.",
"\uD83D\uDCC2": "Open images output directory",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
"SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
...@@ -57,8 +58,8 @@ titles = { ...@@ -57,8 +58,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.", "Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.", "Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.", "Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [prompt_words], [date], [datetime], [job_timestamp]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle", "Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Process an image, use it as an input, repeat.", "Loopback": "Process an image, use it as an input, repeat.",
......
...@@ -186,10 +186,12 @@ onUiUpdate(function(){ ...@@ -186,10 +186,12 @@ onUiUpdate(function(){
if (!txt2img_textarea) { if (!txt2img_textarea) {
txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea"); txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button")); txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
txt2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "txt2img_generate"));
} }
if (!img2img_textarea) { if (!img2img_textarea) {
img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea"); img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button")); img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
img2img_textarea?.addEventListener("keyup", (event) => submit_prompt(event, "img2img_generate"));
} }
}) })
...@@ -197,6 +199,14 @@ let txt2img_textarea, img2img_textarea = undefined; ...@@ -197,6 +199,14 @@ let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800 let wait_time = 800
let token_timeout; let token_timeout;
function submit_prompt(event, generate_button_id) {
if (event.altKey && event.keyCode === 13) {
event.preventDefault();
gradioApp().getElementById(generate_button_id).click();
return;
}
}
function update_token_counter(button_id) { function update_token_counter(button_id) {
if (token_timeout) if (token_timeout)
clearTimeout(token_timeout); clearTimeout(token_timeout);
......
# this scripts installs necessary requirements and launches main program in webui.py # this scripts installs necessary requirements and launches main program in webui.py
import subprocess import subprocess
import os import os
import sys import sys
...@@ -19,10 +18,9 @@ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/Tencen ...@@ -19,10 +18,9 @@ gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/Tencen
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc") stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "9e3002b7cd64df7870e08527b7664eb2f2f5f3f5") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "a7ec1974d4ccb394c2dca275f42cd97490618924")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH', "abf33e7002d59d9085081bce93ec798dcabd49af")
args = shlex.split(commandline_args) args = shlex.split(commandline_args)
...@@ -120,8 +118,6 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming ...@@ -120,8 +118,6 @@ git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash)
if not is_installed("lpips"): if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
...@@ -130,6 +126,9 @@ run_pip(f"install -r {requirements_file}", "requirements for Web UI") ...@@ -130,6 +126,9 @@ run_pip(f"install -r {requirements_file}", "requirements for Web UI")
sys.argv += args sys.argv += args
if "--exit" in args:
print("Exiting because of --exit argument")
exit(0)
def start_webui(): def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}") print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
......
import os.path
import sys
import traceback
import PIL.Image
import numpy as np
import torch
from basicsr.utils.download_util import load_file_from_url
import modules.upscaler
from modules import shared, modelloader
from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path
class UpscalerBSRGAN(modules.upscaler.Upscaler):
def __init__(self, dirname):
self.name = "BSRGAN"
self.model_path = os.path.join(models_path, self.name)
self.model_name = "BSRGAN 4x"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
self.user_path = dirname
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
if len(model_paths) == 0:
scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
try:
scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
scalers.append(scaler_data)
except Exception:
print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
self.scalers = scalers
def do_upscale(self, img: PIL.Image, selected_file):
torch.cuda.empty_cache()
model = self.load_model(selected_file)
if model is None:
return img
model.to(shared.device)
torch.cuda.empty_cache()
img = np.array(img)
img = img[:, :, ::-1]
img = np.moveaxis(img, 2, 0) / 255
img = torch.from_numpy(img).float()
img = img.unsqueeze(0).to(shared.device)
with torch.no_grad():
output = model(img)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255. * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
torch.cuda.empty_cache()
return PIL.Image.fromarray(output, 'RGB')
def load_model(self, path: str):
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
progress=True)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print(f"BSRGAN: Unable to load model from {filename}", file=sys.stderr)
return None
model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4) # define network
model.load_state_dict(torch.load(filename), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
return model
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
def initialize_weights(net_l, scale=1):
if not isinstance(net_l, list):
net_l = [net_l]
for net in net_l:
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale # for residual block
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
init.kaiming_normal_(m.weight, a=0, mode='fan_in')
m.weight.data *= scale
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias.data, 0.0)
def make_layer(block, n_layers):
layers = []
for _ in range(n_layers):
layers.append(block())
return nn.Sequential(*layers)
class ResidualDenseBlock_5C(nn.Module):
def __init__(self, nf=64, gc=32, bias=True):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
# initialization
initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
def forward(self, x):
x1 = self.lrelu(self.conv1(x))
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
'''Residual in Residual Dense Block'''
def __init__(self, nf, gc=32):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nf, gc)
self.RDB2 = ResidualDenseBlock_5C(nf, gc)
self.RDB3 = ResidualDenseBlock_5C(nf, gc)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class RRDBNet(nn.Module):
def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
super(RRDBNet, self).__init__()
RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
self.sf = sf
self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
self.RRDB_trunk = make_layer(RRDB_block_f, nb)
self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
#### upsampling
self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
if self.sf==4:
self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
fea = self.conv_first(x)
trunk = self.trunk_conv(self.RRDB_trunk(fea))
fea = fea + trunk
fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
if self.sf==4:
fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
out = self.conv_last(self.lrelu(self.HRconv(fea)))
return out
\ No newline at end of file
...@@ -5,31 +5,31 @@ import traceback ...@@ -5,31 +5,31 @@ import traceback
import cv2 import cv2
import torch import torch
from modules import shared, devices
from modules.paths import script_path
import modules.shared
import modules.face_restoration import modules.face_restoration
from importlib import reload import modules.shared
from modules import shared, devices, modelloader
from modules.paths import script_path, models_path
# codeformer people made a choice to include modified basicsr librry to their projectwhich makes # codeformer people made a choice to include modified basicsr library to their project which makes
# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN. # it utterly impossible to use it alongside with other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue. # I am making a choice to include some files from codeformer to work around this issue.
model_dir = "Codeformer"
pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
have_codeformer = False have_codeformer = False
codeformer = None codeformer = None
def setup_codeformer():
def setup_model(dirname):
global model_path
if not os.path.exists(model_path):
os.makedirs(model_path)
path = modules.paths.paths.get("CodeFormer", None) path = modules.paths.paths.get("CodeFormer", None)
if path is None: if path is None:
return return
# both GFPGAN and CodeFormer use bascisr, one has it installed from pip the other uses its own
#stored_sys_path = sys.path
#sys.path = [path] + sys.path
try: try:
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
from modules.codeformer.codeformer_arch import CodeFormer from modules.codeformer.codeformer_arch import CodeFormer
...@@ -44,18 +44,23 @@ def setup_codeformer(): ...@@ -44,18 +44,23 @@ def setup_codeformer():
def name(self): def name(self):
return "CodeFormer" return "CodeFormer"
def __init__(self): def __init__(self, dirname):
self.net = None self.net = None
self.face_helper = None self.face_helper = None
self.cmd_dir = dirname
def create_models(self): def create_models(self):
if self.net is not None and self.face_helper is not None: if self.net is not None and self.face_helper is not None:
self.net.to(devices.device_codeformer) self.net.to(devices.device_codeformer)
return self.net, self.face_helper return self.net, self.face_helper
model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir, download_name='codeformer-v0.1.0.pth')
if len(model_paths) != 0:
ckpt_path = model_paths[0]
else:
print("Unable to load codeformer model.")
return None, None
net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer) net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
checkpoint = torch.load(ckpt_path)['params_ema'] checkpoint = torch.load(ckpt_path)['params_ema']
net.load_state_dict(checkpoint) net.load_state_dict(checkpoint)
net.eval() net.eval()
...@@ -74,6 +79,9 @@ def setup_codeformer(): ...@@ -74,6 +79,9 @@ def setup_codeformer():
original_resolution = np_image.shape[0:2] original_resolution = np_image.shape[0:2]
self.create_models() self.create_models()
if self.net is None or self.face_helper is None:
return np_image
self.face_helper.clean_all() self.face_helper.clean_all()
self.face_helper.read_image(np_image) self.face_helper.read_image(np_image)
self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5) self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
...@@ -116,7 +124,7 @@ def setup_codeformer(): ...@@ -116,7 +124,7 @@ def setup_codeformer():
have_codeformer = True have_codeformer = True
global codeformer global codeformer
codeformer = FaceRestorerCodeFormer() codeformer = FaceRestorerCodeFormer(dirname)
shared.face_restorers.append(codeformer) shared.face_restorers.append(codeformer)
except Exception: except Exception:
......
import os import os
import sys
import traceback
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch import modules.esrgam_model_arch as arch
from modules import shared from modules import shared, modelloader, images
from modules.shared import opts
from modules.devices import has_mps from modules.devices import has_mps
import modules.images from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
def load_model(filename): def fix_model_layers(crt_model, pretrained_net):
# this code is adapted from https://github.com/xinntao/ESRGAN # this code is adapted from https://github.com/xinntao/ESRGAN
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
if 'conv_first.weight' in pretrained_net: if 'conv_first.weight' in pretrained_net:
crt_model.load_state_dict(pretrained_net) return pretrained_net
return crt_model
if 'model.0.weight' not in pretrained_net: if 'model.0.weight' not in pretrained_net:
is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"] is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
...@@ -72,9 +68,59 @@ def load_model(filename): ...@@ -72,9 +68,59 @@ def load_model(filename):
crt_net['conv_last.weight'] = pretrained_net['model.10.weight'] crt_net['conv_last.weight'] = pretrained_net['model.10.weight']
crt_net['conv_last.bias'] = pretrained_net['model.10.bias'] crt_net['conv_last.bias'] = pretrained_net['model.10.bias']
crt_model.load_state_dict(crt_net) return crt_net
crt_model.eval()
return crt_model class UpscalerESRGAN(Upscaler):
def __init__(self, dirname):
self.name = "ESRGAN"
self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
self.model_name = "ESRGAN 4x"
self.scalers = []
self.user_path = dirname
self.model_path = os.path.join(models_path, self.name)
super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = []
if len(model_paths) == 0:
scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
if "http" in file:
name = self.model_name
else:
name = modelloader.friendly_name(file)
scaler_data = UpscalerData(name, file, self, 4)
self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
model = self.load_model(selected_model)
if model is None:
return img
model.to(shared.device)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="%s.pth" % self.model_name,
progress=True)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (self.model_path, filename))
return None
pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
pretrained_net = fix_model_layers(crt_model, pretrained_net)
crt_model.load_state_dict(pretrained_net)
crt_model.eval()
return crt_model
def upscale_without_tiling(model, img): def upscale_without_tiling(model, img):
img = np.array(img) img = np.array(img)
...@@ -95,7 +141,7 @@ def esrgan_upscale(model, img): ...@@ -95,7 +141,7 @@ def esrgan_upscale(model, img):
if opts.ESRGAN_tile == 0: if opts.ESRGAN_tile == 0:
return upscale_without_tiling(model, img) return upscale_without_tiling(model, img)
grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap) grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
newtiles = [] newtiles = []
scale_factor = 1 scale_factor = 1
...@@ -110,32 +156,6 @@ def esrgan_upscale(model, img): ...@@ -110,32 +156,6 @@ def esrgan_upscale(model, img):
newrow.append([x * scale_factor, w * scale_factor, output]) newrow.append([x * scale_factor, w * scale_factor, output])
newtiles.append([y * scale_factor, h * scale_factor, newrow]) newtiles.append([y * scale_factor, h * scale_factor, newrow])
newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor) newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
output = modules.images.combine_grid(newgrid) output = images.combine_grid(newgrid)
return output return output
class UpscalerESRGAN(modules.images.Upscaler):
def __init__(self, filename, title):
self.name = title
self.model = load_model(filename)
def do_upscale(self, img):
model = self.model.to(shared.device)
img = esrgan_upscale(model, img)
return img
def load_models(dirname):
for file in os.listdir(dirname):
path = os.path.join(dirname, file)
model_name, extension = os.path.splitext(file)
if extension != '.pt' and extension != '.pth':
continue
try:
modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
except Exception:
print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
...@@ -40,6 +40,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v ...@@ -40,6 +40,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
outputs = [] outputs = []
for image, image_name in zip(imageArr, imageNameArr): for image, image_name in zip(imageArr, imageNameArr):
if image is None:
return outputs, "Please select an input image.", ''
existing_pnginfo = image.info or {} existing_pnginfo = image.info or {}
image = image.convert("RGB") image = image.convert("RGB")
...@@ -74,7 +76,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v ...@@ -74,7 +76,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
c = cached_images.get(key) c = cached_images.get(key)
if c is None: if c is None:
upscaler = shared.sd_upscalers[scaler_index] upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.upscale(image, image.width * resize, image.height * resize) c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
cached_images[key] = c cached_images[key] = c
return c return c
...@@ -143,7 +145,7 @@ def run_pnginfo(image): ...@@ -143,7 +145,7 @@ def run_pnginfo(image):
return '', geninfo, info return '', geninfo, info
def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half): def run_modelmerger(primary_model_name, secondary_model_name, interp_method, interp_amount, save_as_half, custom_name):
# Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation) # Linear interpolation (https://en.wikipedia.org/wiki/Linear_interpolation)
def weighted_sum(theta0, theta1, alpha): def weighted_sum(theta0, theta1, alpha):
return ((1 - alpha) * theta0) + (alpha * theta1) return ((1 - alpha) * theta0) + (alpha * theta1)
...@@ -191,8 +193,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int ...@@ -191,8 +193,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
if save_as_half: if save_as_half:
theta_0[key] = theta_0[key].half() theta_0[key] = theta_0[key].half()
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' filename = primary_model_info.model_name + '_' + str(round(interp_amount, 2)) + '-' + secondary_model_info.model_name + '_' + str(round((float(1.0) - interp_amount), 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt'
output_modelname = os.path.join(shared.cmd_opts.ckpt_dir, filename) filename = filename if custom_name == '' else (custom_name + '.ckpt')
output_modelname = os.path.join(ckpt_dir, filename)
print(f"Saving to {output_modelname}...") print(f"Saving to {output_modelname}...")
torch.save(primary_model, output_modelname) torch.save(primary_model, output_modelname)
......
import os import os
import sys import sys
import traceback import traceback
from glob import glob
from modules import shared, devices import facexlib
from modules.shared import cmd_opts import gfpgan
from modules.paths import script_path
import modules.face_restoration
def gfpgan_model_path():
from modules.shared import cmd_opts
filemask = 'GFPGAN*.pth'
if cmd_opts.gfpgan_model is not None:
return cmd_opts.gfpgan_model
places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
filename = None
for place in places:
filename = next(iter(glob(os.path.join(place, filemask))), None)
if filename is not None:
break
return filename
import modules.face_restoration
from modules import shared, devices, modelloader
from modules.paths import models_path
model_dir = "GFPGAN"
user_path = None
model_path = os.path.join(models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
have_gfpgan = False
loaded_gfpgan_model = None loaded_gfpgan_model = None
def gfpgan(): def gfpgann():
global loaded_gfpgan_model global loaded_gfpgan_model
global model_path
if loaded_gfpgan_model is not None: if loaded_gfpgan_model is not None:
loaded_gfpgan_model.gfpgan.to(shared.device) loaded_gfpgan_model.gfpgan.to(shared.device)
return loaded_gfpgan_model return loaded_gfpgan_model
...@@ -41,7 +27,16 @@ def gfpgan(): ...@@ -41,7 +27,16 @@ def gfpgan():
if gfpgan_constructor is None: if gfpgan_constructor is None:
return None return None
model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None) models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
if len(models) == 1 and "http" in models[0]:
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
model_file = latest_file
else:
print("Unable to load gfpgan model!")
return None
model = gfpgan_constructor(model_path=model_file, upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
model.gfpgan.to(shared.device) model.gfpgan.to(shared.device)
loaded_gfpgan_model = model loaded_gfpgan_model = model
...@@ -50,8 +45,9 @@ def gfpgan(): ...@@ -50,8 +45,9 @@ def gfpgan():
def gfpgan_fix_faces(np_image): def gfpgan_fix_faces(np_image):
global loaded_gfpgan_model global loaded_gfpgan_model
model = gfpgan() model = gfpgann()
if model is None:
return np_image
np_image_bgr = np_image[:, :, ::-1] np_image_bgr = np_image[:, :, ::-1]
cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True) cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
np_image = gfpgan_output_bgr[:, :, ::-1] np_image = gfpgan_output_bgr[:, :, ::-1]
...@@ -64,21 +60,39 @@ def gfpgan_fix_faces(np_image): ...@@ -64,21 +60,39 @@ def gfpgan_fix_faces(np_image):
return np_image return np_image
have_gfpgan = False
gfpgan_constructor = None gfpgan_constructor = None
def setup_gfpgan():
try:
gfpgan_model_path()
if os.path.exists(cmd_opts.gfpgan_dir): def setup_model(dirname):
sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir)) global model_path
from gfpgan import GFPGANer if not os.path.exists(model_path):
os.makedirs(model_path)
try:
from gfpgan import GFPGANer
from facexlib import detection, parsing
global user_path
global have_gfpgan global have_gfpgan
have_gfpgan = True
global gfpgan_constructor global gfpgan_constructor
load_file_from_url_orig = gfpgan.utils.load_file_from_url
facex_load_file_from_url_orig = facexlib.detection.load_file_from_url
facex_load_file_from_url_orig2 = facexlib.parsing.load_file_from_url
def my_load_file_from_url(**kwargs):
return load_file_from_url_orig(**dict(kwargs, model_dir=model_path))
def facex_load_file_from_url(**kwargs):
return facex_load_file_from_url_orig(**dict(kwargs, save_dir=model_path, model_dir=None))
def facex_load_file_from_url2(**kwargs):
return facex_load_file_from_url_orig2(**dict(kwargs, save_dir=model_path, model_dir=None))
gfpgan.utils.load_file_from_url = my_load_file_from_url
facexlib.detection.load_file_from_url = facex_load_file_from_url
facexlib.parsing.load_file_from_url = facex_load_file_from_url2
user_path = dirname
have_gfpgan = True
gfpgan_constructor = GFPGANer gfpgan_constructor = GFPGANer
class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration): class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
......
...@@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin ...@@ -11,7 +11,6 @@ from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
from fonts.ttf import Roboto from fonts.ttf import Roboto
import string import string
import modules.shared
from modules import sd_samplers, shared from modules import sd_samplers, shared
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
...@@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): ...@@ -52,8 +51,8 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
cols = math.ceil((w - overlap) / non_overlap_width) cols = math.ceil((w - overlap) / non_overlap_width)
rows = math.ceil((h - overlap) / non_overlap_height) rows = math.ceil((h - overlap) / non_overlap_height)
dx = (w - tile_w) / (cols-1) if cols > 1 else 0 dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
dy = (h - tile_h) / (rows-1) if rows > 1 else 0 dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
grid = Grid([], tile_w, tile_h, w, h, overlap) grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows): for row in range(rows):
...@@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64): ...@@ -67,7 +66,7 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
for col in range(cols): for col in range(cols):
x = int(col * dx) x = int(col * dx)
if x+tile_w >= w: if x + tile_w >= w:
x = w - tile_w x = w - tile_w
tile = image.crop((x, y, x + tile_w, y + tile_h)) tile = image.crop((x, y, x + tile_w, y + tile_h))
...@@ -132,7 +131,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ...@@ -132,7 +131,7 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center") drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
if not line.is_active: if not line.is_active:
drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4) drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
draw_y += line.size[1] + line_spacing draw_y += line.size[1] + line_spacing
...@@ -171,7 +170,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts): ...@@ -171,7 +170,8 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1]) line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts] hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts] ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
ver_texts]
pad_top = max(hor_text_heights) + line_spacing * 2 pad_top = max(hor_text_heights) + line_spacing * 2
...@@ -213,8 +213,19 @@ def resize_image(resize_mode, im, width, height): ...@@ -213,8 +213,19 @@ def resize_image(resize_mode, im, width, height):
if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
return im.resize((w, h), resample=LANCZOS) return im.resize((w, h), resample=LANCZOS)
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0] scale = max(w / im.width, h / im.height)
return upscaler.upscale(im, w, h)
if scale > 1.0:
upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"
upscaler = upscalers[0]
im = upscaler.scaler.upscale(im, scale, upscaler.data_path)
if im.width != w or im.height != h:
im = im.resize((w, h), resample=LANCZOS)
return im
if resize_mode == 0: if resize_mode == 0:
res = resize(im, width, height) res = resize(im, width, height)
...@@ -256,7 +267,7 @@ def resize_image(resize_mode, im, width, height): ...@@ -256,7 +267,7 @@ def resize_image(resize_mode, im, width, height):
invalid_filename_chars = '<>:"/\\|?*\n' invalid_filename_chars = '<>:"/\\|?*\n'
invalid_filename_prefix = ' ' invalid_filename_prefix = ' '
invalid_filename_postfix = ' .' invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s'+string.punctuation+']+') re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
max_filename_part_length = 128 max_filename_part_length = 128
...@@ -278,6 +289,16 @@ def apply_filename_pattern(x, p, seed, prompt): ...@@ -278,6 +289,16 @@ def apply_filename_pattern(x, p, seed, prompt):
if prompt is not None: if prompt is not None:
x = x.replace("[prompt]", sanitize_filename_part(prompt)) x = x.replace("[prompt]", sanitize_filename_part(prompt))
if "[prompt_no_styles]" in x:
prompt_no_style = prompt
for style in shared.prompt_styles.get_style_prompts(p.styles):
if len(style) > 0:
style_parts = [y for y in style.split("{prompt}")]
for part in style_parts:
prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')
prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()
x = x.replace("[prompt_no_styles]", sanitize_filename_part(prompt_no_style, replace_spaces=False))
x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False)) x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
if "[prompt_words]" in x: if "[prompt_words]" in x:
words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0] words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
...@@ -290,7 +311,7 @@ def apply_filename_pattern(x, p, seed, prompt): ...@@ -290,7 +311,7 @@ def apply_filename_pattern(x, p, seed, prompt):
x = x.replace("[cfg]", str(p.cfg_scale)) x = x.replace("[cfg]", str(p.cfg_scale))
x = x.replace("[width]", str(p.width)) x = x.replace("[width]", str(p.width))
x = x.replace("[height]", str(p.height)) x = x.replace("[height]", str(p.height))
x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False)) x = x.replace("[styles]", sanitize_filename_part(", ".join([x for x in p.styles if not x == "None"]), replace_spaces=False))
x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False)) x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
x = x.replace("[model_hash]", shared.sd_model.sd_model_hash) x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
...@@ -303,6 +324,7 @@ def apply_filename_pattern(x, p, seed, prompt): ...@@ -303,6 +324,7 @@ def apply_filename_pattern(x, p, seed, prompt):
return x return x
def get_next_sequence_number(path, basename): def get_next_sequence_number(path, basename):
""" """
Determines and returns the next sequence number to use when saving an image in the specified directory. Determines and returns the next sequence number to use when saving an image in the specified directory.
...@@ -316,7 +338,7 @@ def get_next_sequence_number(path, basename): ...@@ -316,7 +338,7 @@ def get_next_sequence_number(path, basename):
prefix_length = len(basename) prefix_length = len(basename)
for p in os.listdir(path): for p in os.listdir(path):
if p.startswith(basename): if p.startswith(basename):
l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element) l = os.path.splitext(p[prefix_length:])[0].split('-') # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
try: try:
result = max(int(l[0]), result) result = max(int(l[0]), result)
except ValueError: except ValueError:
...@@ -324,6 +346,7 @@ def get_next_sequence_number(path, basename): ...@@ -324,6 +346,7 @@ def get_next_sequence_number(path, basename):
return result + 1 return result + 1
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""): def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
if short_filename or prompt is None or seed is None: if short_filename or prompt is None or seed is None:
file_decoration = "" file_decoration = ""
...@@ -361,7 +384,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -361,7 +384,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
fullfn = "a.png" fullfn = "a.png"
fullfn_without_extension = "a" fullfn_without_extension = "a"
for i in range(500): for i in range(500):
fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}" fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}") fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}") fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
if not os.path.exists(fullfn): if not os.path.exists(fullfn):
...@@ -403,31 +426,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -403,31 +426,3 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
file.write(info + "\n") file.write(info + "\n")
class Upscaler:
name = "Lanczos"
def do_upscale(self, img):
return img
def upscale(self, img, w, h):
for i in range(3):
if img.width >= w and img.height >= h:
break
img = self.do_upscale(img)
if img.width != w or img.height != h:
img = img.resize((int(w), int(h)), resample=LANCZOS)
return img
class UpscalerNone(Upscaler):
name = "None"
def upscale(self, img, w, h):
return img
modules.shared.sd_upscalers.append(UpscalerNone())
modules.shared.sd_upscalers.append(Upscaler())
import os import os
import sys import sys
import traceback import traceback
from collections import namedtuple
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.images from modules.upscaler import Upscaler, UpscalerData
from modules.ldsr_model_arch import LDSR
from modules import shared from modules import shared
from modules.paths import script_path from modules.paths import models_path
LDSRModelInfo = namedtuple("LDSRModelInfo", ["name", "location", "model", "netscale"])
ldsr_models = [] class UpscalerLDSR(Upscaler):
have_ldsr = False def __init__(self, user_path):
LDSR_obj = None
class UpscalerLDSR(modules.images.Upscaler):
def __init__(self, steps):
self.steps = steps
self.name = "LDSR" self.name = "LDSR"
self.model_path = os.path.join(models_path, self.name)
def do_upscale(self, img): self.user_path = user_path
return upscale_with_ldsr(img) self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
super().__init__()
def add_lsdr(): scaler_data = UpscalerData("LDSR", None, self)
modules.shared.sd_upscalers.append(UpscalerLDSR(100)) self.scalers = [scaler_data]
def load_model(self, path: str):
def setup_ldsr(): # Remove incorrect project.yaml file if too big
path = modules.paths.paths.get("LDSR", None) yaml_path = os.path.join(self.model_path, "project.yaml")
if path is None: old_model_path = os.path.join(self.model_path, "model.pth")
return new_model_path = os.path.join(self.model_path, "model.ckpt")
global have_ldsr if os.path.exists(yaml_path):
global LDSR_obj statinfo = os.stat(yaml_path)
try: if statinfo.st_size >= 10485760:
from LDSR import LDSR print("Removing invalid LDSR YAML file.")
model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" os.remove(yaml_path)
yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" if os.path.exists(old_model_path):
repo_path = 'latent-diffusion/experiments/pretrained_models/' print("Renaming model from model.pth to model.ckpt")
model_path = load_file_from_url(url=model_url, model_dir=os.path.join("repositories", repo_path), os.rename(old_model_path, new_model_path)
progress=True, file_name="model.chkpt") model = load_file_from_url(url=self.model_url, model_dir=self.model_path,
yaml_path = load_file_from_url(url=yaml_url, model_dir=os.path.join("repositories", repo_path), file_name="model.ckpt", progress=True)
progress=True, file_name="project.yaml") yaml = load_file_from_url(url=self.yaml_url, model_dir=self.model_path,
have_ldsr = True file_name="project.yaml", progress=True)
LDSR_obj = LDSR(model_path, yaml_path)
try:
return LDSR(model, yaml)
except Exception:
print("Error importing LDSR:", file=sys.stderr) except Exception:
print(traceback.format_exc(), file=sys.stderr) print("Error importing LDSR:", file=sys.stderr)
have_ldsr = False print(traceback.format_exc(), file=sys.stderr)
return None
def upscale_with_ldsr(image): def do_upscale(self, img, path):
setup_ldsr() ldsr = self.load_model(path)
if not have_ldsr or LDSR_obj is None: if ldsr is None:
return image print("NO LDSR!")
return img
ddim_steps = shared.opts.ldsr_steps ddim_steps = shared.opts.ldsr_steps
pre_scale = shared.opts.ldsr_pre_down return ldsr.super_resolution(img, ddim_steps, self.scale)
post_scale = shared.opts.ldsr_post_down
image = LDSR_obj.super_resolution(image, ddim_steps, pre_scale, post_scale)
return image
import gc
import time
import warnings
import numpy as np
import torch
import torchvision
from PIL import Image
from einops import rearrange, repeat
from omegaconf import OmegaConf
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
warnings.filterwarnings("ignore", category=UserWarning)
# Create LDSR Class
class LDSR:
def load_model_from_config(self, half_attention):
print(f"Loading model from {self.modelPath}")
pl_sd = torch.load(self.modelPath, map_location="cpu")
sd = pl_sd["state_dict"]
config = OmegaConf.load(self.yamlPath)
model = instantiate_from_config(config.model)
model.load_state_dict(sd, strict=False)
model.cuda()
if half_attention:
model = model.half()
model.eval()
return {"model": model}
def __init__(self, model_path, yaml_path):
self.modelPath = model_path
self.yamlPath = yaml_path
@staticmethod
def run(model, selected_path, custom_steps, eta):
example = get_cond(selected_path)
n_runs = 1
guider = None
ckwargs = None
ddim_use_x0_pred = False
temperature = 1.
eta = eta
custom_shape = None
height, width = example["image"].shape[1:3]
split_input = height >= 128 and width >= 128
if split_input:
ks = 128
stride = 64
vqf = 4 #
model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
"vqf": vqf,
"patch_distributed_vq": True,
"tie_braker": False,
"clip_max_weight": 0.5,
"clip_min_weight": 0.01,
"clip_max_tie_weight": 0.5,
"clip_min_tie_weight": 0.01}
else:
if hasattr(model, "split_input_params"):
delattr(model, "split_input_params")
x_t = None
logs = None
for n in range(n_runs):
if custom_shape is not None:
x_t = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
x_t = repeat(x_t, '1 c h w -> b c h w', b=custom_shape[0])
logs = make_convolutional_sample(example, model,
custom_steps=custom_steps,
eta=eta, quantize_x0=False,
custom_shape=custom_shape,
temperature=temperature, noise_dropout=0.,
corrector=guider, corrector_kwargs=ckwargs, x_T=x_t,
ddim_use_x0_pred=ddim_use_x0_pred
)
return logs
def super_resolution(self, image, steps=100, target_scale=2, half_attention=False):
model = self.load_model_from_config(half_attention)
# Run settings
diffusion_steps = int(steps)
eta = 1.0
down_sample_method = 'Lanczos'
gc.collect()
torch.cuda.empty_cache()
im_og = image
width_og, height_og = im_og.size
# If we can adjust the max upscale size, then the 4 below should be our variable
down_sample_rate = target_scale / 4
wd = width_og * down_sample_rate
hd = height_og * down_sample_rate
width_downsampled_pre = int(wd)
height_downsampled_pre = int(hd)
if down_sample_rate != 1:
print(
f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
else:
print(f"Down sample rate is 1 from {target_scale} / 4 (Not downsampling)")
logs = self.run(model["model"], im_og, diffusion_steps, eta)
sample = logs["sample"]
sample = sample.detach().cpu()
sample = torch.clamp(sample, -1., 1.)
sample = (sample + 1.) / 2. * 255
sample = sample.numpy().astype(np.uint8)
sample = np.transpose(sample, (0, 2, 3, 1))
a = Image.fromarray(sample[0])
del model
gc.collect()
torch.cuda.empty_cache()
return a
def get_cond(selected_path):
example = dict()
up_f = 4
c = selected_path.convert('RGB')
c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]],
antialias=True)
c_up = rearrange(c_up, '1 c h w -> 1 h w c')
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.
c = c.to(torch.device("cuda"))
example["LR_image"] = c
example["image"] = c_up
return example
@torch.no_grad()
def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
mask=None, x0=None, quantize_x0=False, temperature=1., score_corrector=None,
corrector_kwargs=None, x_t=None
):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
print(f"Sampling with eta = {eta}; steps: {steps}")
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
mask=mask, x0=x0, temperature=temperature, verbose=False,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs, x_t=x_t)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(batch, model, custom_steps=None, eta=1.0, quantize_x0=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
corrector_kwargs=None, x_T=None, ddim_use_x0_pred=False):
log = dict()
z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
return_first_stage_outputs=True,
force_c_encode=not (hasattr(model, 'split_input_params')
and model.cond_stage_key == 'coordinates_bbox'),
return_original_cond=True)
if custom_shape is not None:
z = torch.randn(custom_shape)
print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
z0 = None
log["input"] = x
log["reconstruction"] = xrec
if ismap(xc):
log["original_conditioning"] = model.to_rgb(xc)
if hasattr(model, 'cond_stage_key'):
log[model.cond_stage_key] = model.to_rgb(xc)
else:
log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_model:
log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
if model.cond_stage_key == 'class_label':
log[model.cond_stage_key] = xc[model.cond_stage_key]
with model.ema_scope("Plotting"):
t0 = time.time()
sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
eta=eta,
quantize_x0=quantize_x0, mask=None, x0=z0,
temperature=temperature, score_corrector=corrector, corrector_kwargs=corrector_kwargs,
x_t=x_T)
t1 = time.time()
if ddim_use_x0_pred:
sample = intermediates['pred_x0'][-1]
x_sample = model.decode_first_stage(sample)
try:
x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
log["sample_noquant"] = x_sample_noquant
log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
except:
pass
log["sample"] = x_sample
log["time"] = t1 - t0
return log
import glob
import os
import shutil
import importlib
from urllib.parse import urlparse
from basicsr.utils.download_util import load_file_from_url
from modules import shared
from modules.upscaler import Upscaler
from modules.paths import script_path, models_path
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@param download_name: Specify to download from model_url immediately.
@param model_url: If no other models are found, this will be downloaded on upscale.
@param model_path: The location to store/find models in.
@param command_path: A command-line argument to search for models in first.
@param ext_filter: An optional list of filename extensions to filter by
@return: A list of paths containing the desired model(s)
"""
output = []
if ext_filter is None:
ext_filter = []
try:
places = []
if command_path is not None and command_path != model_path:
pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
if os.path.exists(pretrained_path):
print(f"Appending path: {pretrained_path}")
places.append(pretrained_path)
elif os.path.exists(command_path):
places.append(command_path)
places.append(model_path)
for place in places:
if os.path.exists(place):
for file in glob.iglob(place + '**/**', recursive=True):
full_path = os.path.join(place, file)
if os.path.isdir(full_path):
continue
if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file)
if extension not in ext_filter:
continue
if file not in output:
output.append(full_path)
if model_url is not None and len(output) == 0:
if download_name is not None:
dl = load_file_from_url(model_url, model_path, True, download_name)
output.append(dl)
else:
output.append(model_url)
except Exception:
pass
return output
def friendly_name(file: str):
if "http" in file:
file = urlparse(file).path
file = os.path.basename(file)
model_name, extension = os.path.splitext(file)
return model_name
def cleanup_models():
# This code could probably be more efficient if we used a tuple list or something to store the src/destinations
# and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
# somehow auto-register and just do these things...
root_path = script_path
src_path = models_path
dest_path = os.path.join(models_path, "Stable-diffusion")
move_files(src_path, dest_path, ".ckpt")
src_path = os.path.join(root_path, "ESRGAN")
dest_path = os.path.join(models_path, "ESRGAN")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "gfpgan")
dest_path = os.path.join(models_path, "GFPGAN")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "SwinIR")
dest_path = os.path.join(models_path, "SwinIR")
move_files(src_path, dest_path)
src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
dest_path = os.path.join(models_path, "LDSR")
move_files(src_path, dest_path)
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
try:
if not os.path.exists(dest_path):
os.makedirs(dest_path)
if os.path.exists(src_path):
for file in os.listdir(src_path):
fullpath = os.path.join(src_path, file)
if os.path.isfile(fullpath):
if ext_filter is not None:
if ext_filter not in file:
continue
print(f"Moving {file} from {src_path} to {dest_path}.")
try:
shutil.move(fullpath, dest_path)
except:
pass
if len(os.listdir(src_path)) == 0:
print(f"Removing empty folder: {src_path}")
shutil.rmtree(src_path, True)
except:
pass
def load_upscalers():
datas = []
for cls in Upscaler.__subclasses__():
name = cls.__name__
module_name = cls.__module__
module = importlib.import_module(module_name)
class_ = getattr(module, name)
cmd_name = f"{name.lower().replace('upscaler', '')}-models-path"
opt_string = None
try:
opt_string = shared.opts.__getattr__(cmd_name)
except:
pass
scaler = class_(opt_string)
for child in scaler.scalers:
datas.append(child)
shared.sd_upscalers = datas
...@@ -3,9 +3,10 @@ import os ...@@ -3,9 +3,10 @@ import os
import sys import sys
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models")
sys.path.insert(0, script_path) sys.path.insert(0, script_path)
# search for directory of stable diffsuion in following palces # search for directory of stable diffusion in following places
sd_path = None sd_path = None
possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)] possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion'), '.', os.path.dirname(script_path)]
for possible_sd_path in possible_sd_paths: for possible_sd_path in possible_sd_paths:
...@@ -15,21 +16,24 @@ for possible_sd_path in possible_sd_paths: ...@@ -15,21 +16,24 @@ for possible_sd_path in possible_sd_paths:
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths) assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
path_dirs = [ path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion'), (sd_path, 'ldm', 'Stable Diffusion', []),
(os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers'), (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []),
(os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer'), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),
(os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP'), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),
(os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR'), (os.path.join(sd_path, '../latent-diffusion'), 'LDSR.py', 'LDSR', []),
(os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion'), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),
] ]
paths = {} paths = {}
for d, must_exist, what in path_dirs: for d, must_exist, what, options in path_dirs:
must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist)) must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
if not os.path.exists(must_exist_path): if not os.path.exists(must_exist_path):
print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr) print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
else: else:
d = os.path.abspath(d) d = os.path.abspath(d)
sys.path.append(d) if "atstart" in options:
sys.path.insert(0, d)
else:
sys.path.append(d)
paths[what] = d paths[what] = d
...@@ -508,8 +508,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -508,8 +508,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8) x_sample = x_sample.astype(np.uint8)
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0] image = images.resize_image(0, image, self.width, self.height)
image = upscaler.upscale(image, self.width, self.height)
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0) image = np.moveaxis(image, 2, 0)
batch_images.append(image) batch_images.append(image)
......
...@@ -126,5 +126,93 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): ...@@ -126,5 +126,93 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step):
return res return res
re_attention = re.compile(r"""
\\\(|
\\\)|
\\\[|
\\]|
\\\\|
\\|
\(|
\[|
:([+-]?[.\d]+)\)|
\)|
]|
[^\\()\[\]:]+|
:
""", re.X)
def parse_prompt_attention(text):
"""
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight.
Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12
[abc] - decreases attention to abc by a multiplier of 1.1
\( - literal character '('
\[ - literal character '['
\) - literal character ')'
\] - literal character ']'
\\ - literal character '\'
anything else - just text
Example:
'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).'
produces:
[
['a ', 1.0],
['house', 1.5730000000000004],
[' ', 1.1],
['on', 1.0],
[' a ', 1.1],
['hill', 0.55],
[', sun, ', 1.1],
['sky', 1.4641000000000006],
['.', 1.1]
]
"""
#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100) res = []
round_brackets = []
square_brackets = []
round_bracket_multiplier = 1.1
square_bracket_multiplier = 1 / 1.1
def multiply_range(start_position, multiplier):
for p in range(start_position, len(res)):
res[p][1] *= multiplier
for m in re_attention.finditer(text):
text = m.group(0)
weight = m.group(1)
if text.startswith('\\'):
res.append([text[1:], 1.0])
elif text == '(':
round_brackets.append(len(res))
elif text == '[':
square_brackets.append(len(res))
elif weight is not None and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), float(weight))
elif text == ')' and len(round_brackets) > 0:
multiply_range(round_brackets.pop(), round_bracket_multiplier)
elif text == ']' and len(square_brackets) > 0:
multiply_range(square_brackets.pop(), square_bracket_multiplier)
else:
res.append([text, 1.0])
for pos in round_brackets:
multiply_range(pos, round_bracket_multiplier)
for pos in square_brackets:
multiply_range(pos, square_bracket_multiplier)
if len(res) == 0:
res = [["", 1.0]]
return res
import os
import sys import sys
import traceback import traceback
from collections import namedtuple
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
import modules.images from modules.upscaler import Upscaler, UpscalerData
from modules.paths import models_path
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
RealesrganModelInfo = namedtuple("RealesrganModelInfo", ["name", "location", "model", "netscale"])
realesrgan_models = []
have_realesrgan = False
class UpscalerRealESRGAN(Upscaler):
def __init__(self, path):
self.name = "RealESRGAN"
self.model_path = os.path.join(models_path, self.name)
self.user_path = path
super().__init__()
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
self.enable = True
self.scalers = []
scalers = self.load_models(path)
for scaler in scalers:
if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler)
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
self.enable = False
self.scalers = []
def do_upscale(self, img, path):
if not self.enable:
return img
info = self.load_model(path)
if not os.path.exists(info.data_path):
print("Unable to load RealESRGAN model: %s" % info.name)
return img
upsampler = RealESRGANer(
scale=info.scale,
model_path=info.data_path,
model=info.model(),
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
)
upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0]
image = Image.fromarray(upsampled)
return image
def load_model(self, path):
try:
info = None
for scaler in self.scalers:
if scaler.data_path == path:
info = scaler
if info is None:
print(f"Unable to find model info: {path}")
return None
model_file = load_file_from_url(url=info.data_path, model_dir=self.model_path, progress=True)
info.data_path = model_file
return info
except Exception as e:
print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
return None
def get_realesrgan_models(): def load_models(self, _):
return get_realesrgan_models(self)
def get_realesrgan_models(scaler):
try: try:
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
models = [ models = [
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN General x4x3", name="R-ESRGAN General 4xV3",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN General WDN x4x3", name="R-ESRGAN General WDN 4xV3",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN AnimeVideo", name="R-ESRGAN AnimeVideo",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu') model=lambda: SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN 4x plus", name="R-ESRGAN 4x+",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN 4x plus anime 6B", name="R-ESRGAN 4x+ Anime6B",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
netscale=4, scale=4,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4) model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
), ),
RealesrganModelInfo( UpscalerData(
name="Real-ESRGAN 2x plus", name="R-ESRGAN 2x+",
location="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
netscale=2, scale=2,
upscaler=scaler,
model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2) model=lambda: RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
), ),
] ]
return models return models
except Exception as e: except Exception as e:
print("Error makeing Real-ESRGAN midels list:", file=sys.stderr) print("Error making Real-ESRGAN models list:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
class UpscalerRealESRGAN(modules.images.Upscaler):
def __init__(self, upscaling, model_index):
self.upscaling = upscaling
self.model_index = model_index
self.name = realesrgan_models[model_index].name
def do_upscale(self, img):
return upscale_with_realesrgan(img, self.upscaling, self.model_index)
def setup_realesrgan():
global realesrgan_models
global have_realesrgan
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
realesrgan_models = get_realesrgan_models()
have_realesrgan = True
for i, model in enumerate(realesrgan_models):
if model.name in opts.realesrgan_enabled_models:
modules.shared.sd_upscalers.append(UpscalerRealESRGAN(model.netscale, i))
except Exception:
print("Error importing Real-ESRGAN:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
realesrgan_models = [RealesrganModelInfo('None', '', 0, None)]
have_realesrgan = False
def upscale_with_realesrgan(image, RealESRGAN_upscaling, RealESRGAN_model_index):
if not have_realesrgan:
return image
info = realesrgan_models[RealESRGAN_model_index]
model = info.model()
upsampler = RealESRGANer(
scale=info.netscale,
model_path=info.location,
model=model,
half=not cmd_opts.no_half,
tile=opts.ESRGAN_tile,
tile_pad=opts.ESRGAN_tile_overlap,
)
upsampled = upsampler.enhance(np.array(image), outscale=RealESRGAN_upscaling)[0]
image = Image.fromarray(upsampled)
return image
...@@ -55,7 +55,7 @@ def load_scripts(basedir): ...@@ -55,7 +55,7 @@ def load_scripts(basedir):
if not os.path.exists(basedir): if not os.path.exists(basedir):
return return
for filename in os.listdir(basedir): for filename in sorted(os.listdir(basedir)):
path = os.path.join(basedir, filename) path = os.path.join(basedir, filename)
if not os.path.isfile(path): if not os.path.isfile(path):
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import numpy as np import numpy as np
from torch import einsum from torch import einsum
from modules import prompt_parser
from modules.shared import opts, device, cmd_opts from modules.shared import opts, device, cmd_opts
from ldm.util import default from ldm.util import default
...@@ -204,6 +205,7 @@ class StableDiffusionModelHijack: ...@@ -204,6 +205,7 @@ class StableDiffusionModelHijack:
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
assert len(param_dict) == 1, 'embedding file has multiple terms in it' assert len(param_dict) == 1, 'embedding file has multiple terms in it'
emb = next(iter(param_dict.items()))[1] emb = next(iter(param_dict.items()))[1]
# diffuser concepts
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
assert len(data.keys()) == 1, 'embedding file has multiple terms in it' assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
...@@ -223,19 +225,25 @@ class StableDiffusionModelHijack: ...@@ -223,19 +225,25 @@ class StableDiffusionModelHijack:
for fn in os.listdir(dirname): for fn in os.listdir(dirname):
try: try:
process_file(os.path.join(dirname, fn), fn) fullfn = os.path.join(dirname, fn)
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
except Exception: except Exception:
print(f"Error loading emedding {fn}:", file=sys.stderr) print(f"Error loading emedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
continue continue
print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.") print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
def hijack(self, m): def hijack(self, m):
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
self.clip = m.cond_stage_model self.clip = m.cond_stage_model
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
...@@ -255,6 +263,14 @@ class StableDiffusionModelHijack: ...@@ -255,6 +263,14 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
def undo_hijack(self, m):
if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
m.cond_stage_model = m.cond_stage_model.wrapped
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped
def apply_circular(self, enable): def apply_circular(self, enable):
if self.circular_enabled == enable: if self.circular_enabled == enable:
return return
...@@ -269,6 +285,7 @@ class StableDiffusionModelHijack: ...@@ -269,6 +285,7 @@ class StableDiffusionModelHijack:
_, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
return remade_batch_tokens[0], token_count, max_length return remade_batch_tokens[0], token_count, max_length
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
super().__init__() super().__init__()
...@@ -294,7 +311,92 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -294,7 +311,92 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0: if mult != 1.0:
self.token_mults[ident] = mult self.token_mults[ident] = mult
def process_text(self, text):
def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
if opts.enable_emphasis:
parsed = prompt_parser.parse_prompt_attention(line)
else:
parsed = [[line, 1.0]]
tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"]
fixes = []
remade_tokens = []
multipliers = []
for tokens, (text, weight) in zip(tokenized, parsed):
i = 0
while i < len(tokens):
token = tokens[i]
possible_matches = self.hijack.ids_lookup.get(token, None)
if possible_matches is None:
remade_tokens.append(token)
multipliers.append(weight)
else:
found = False
for ids, word in possible_matches:
if tokens[i:i + len(ids)] == ids:
emb_len = int(self.hijack.word_embeddings[word].shape[0])
fixes.append((len(remade_tokens), word))
remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len
i += len(ids) - 1
found = True
used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word]))
break
if not found:
remade_tokens.append(token)
multipliers.append(weight)
i += 1
if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end]
multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers))
multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
return remade_tokens, fixes, multipliers, token_count
def process_text(self, texts):
used_custom_terms = []
remade_batch_tokens = []
hijack_comments = []
hijack_fixes = []
token_count = 0
cache = {}
batch_multipliers = []
for line in texts:
if line in cache:
remade_tokens, fixes, multipliers = cache[line]
else:
remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments)
cache[line] = (remade_tokens, fixes, multipliers)
remade_batch_tokens.append(remade_tokens)
hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def process_text_old(self, text):
id_start = self.wrapped.tokenizer.bos_token_id id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length maxlen = self.wrapped.max_length
...@@ -370,12 +472,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -370,12 +472,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
def forward(self, text): def forward(self, text):
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
if opts.use_old_emphasis_implementation:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
else:
batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
self.hijack.fixes = hijack_fixes self.hijack.fixes = hijack_fixes
self.hijack.comments = hijack_comments self.hijack.comments = hijack_comments
if len(used_custom_terms) > 0: if len(used_custom_terms) > 0:
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
tokens = torch.asarray(remade_batch_tokens).to(device) tokens = torch.asarray(remade_batch_tokens).to(device)
outputs = self.wrapped.transformer(input_ids=tokens) outputs = self.wrapped.transformer(input_ids=tokens)
......
...@@ -8,7 +8,14 @@ from omegaconf import OmegaConf ...@@ -8,7 +8,14 @@ from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared from modules import shared, modelloader
from modules.paths import models_path
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
model_name = "sd-v1-4.ckpt"
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
user_dir = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {} checkpoints_list = {}
...@@ -23,21 +30,30 @@ except Exception: ...@@ -23,21 +30,30 @@ except Exception:
pass pass
def setup_model(dirname):
global user_dir
user_dir = dirname
if not os.path.exists(model_path):
os.makedirs(model_path)
checkpoints_list.clear()
list_models()
def checkpoint_tiles(): def checkpoint_tiles():
print(sorted([x.title for x in checkpoints_list.values()]))
return sorted([x.title for x in checkpoints_list.values()]) return sorted([x.title for x in checkpoints_list.values()])
def list_models(): def list_models():
checkpoints_list.clear() checkpoints_list.clear()
model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) def modeltitle(path, shorthash):
def modeltitle(path, h):
abspath = os.path.abspath(path) abspath = os.path.abspath(path)
if abspath.startswith(model_dir): if user_dir is not None and abspath.startswith(user_dir):
name = abspath.replace(model_dir, '') name = abspath.replace(user_dir, '')
elif abspath.startswith(model_path):
name = abspath.replace(model_path, '')
else: else:
name = os.path.basename(path) name = os.path.basename(path)
...@@ -46,21 +62,27 @@ def list_models(): ...@@ -46,21 +62,27 @@ def list_models():
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
return f'{name} [{h}]', shortname return f'{name} [{shorthash}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt): if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt) h = model_hash(cmd_ckpt)
title, model_name = modeltitle(cmd_ckpt, h) title, short_model_name = modeltitle(cmd_ckpt, h)
checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name) checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
shared.opts.sd_model_checkpoint = title
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr) print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
h = model_hash(filename)
title, short_model_name = modeltitle(filename, h)
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
if os.path.exists(model_dir):
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True): def get_closet_checkpoint_match(searchString):
h = model_hash(filename) applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
title, model_name = modeltitle(filename, h) if len(applicable) > 0:
checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name) return applicable[0]
return None
def model_hash(filename): def model_hash(filename):
...@@ -138,7 +160,7 @@ def load_model(): ...@@ -138,7 +160,7 @@ def load_model():
def reload_model_weights(sd_model, info=None): def reload_model_weights(sd_model, info=None):
from modules import lowvram, devices from modules import lowvram, devices, sd_hijack
checkpoint_info = info or select_checkpoint() checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename: if sd_model.sd_model_checkpint == checkpoint_info.filename:
...@@ -149,8 +171,12 @@ def reload_model_weights(sd_model, info=None): ...@@ -149,8 +171,12 @@ def reload_model_weights(sd_model, info=None):
else: else:
sd_model.to(devices.cpu) sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(sd_model)
load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
sd_hijack.model_hijack.hijack(sd_model)
if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
sd_model.to(devices.device) sd_model.to(devices.device)
......
...@@ -4,7 +4,6 @@ import torch ...@@ -4,7 +4,6 @@ import torch
import tqdm import tqdm
from PIL import Image from PIL import Image
import inspect import inspect
import k_diffusion.sampling import k_diffusion.sampling
import ldm.models.diffusion.ddim import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms import ldm.models.diffusion.plms
...@@ -23,6 +22,8 @@ samplers_k_diffusion = [ ...@@ -23,6 +22,8 @@ samplers_k_diffusion = [
('Heun', 'sample_heun', ['k_heun']), ('Heun', 'sample_heun', ['k_heun']),
('DPM2', 'sample_dpm_2', ['k_dpm_2']), ('DPM2', 'sample_dpm_2', ['k_dpm_2']),
('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']), ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
('DPM fast', 'sample_dpm_fast', ['k_dpm_fast']),
('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
] ]
samplers_data_k_diffusion = [ samplers_data_k_diffusion = [
...@@ -36,7 +37,7 @@ samplers = [ ...@@ -36,7 +37,7 @@ samplers = [
SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []),
SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []),
] ]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS'] samplers_for_img2img = [x for x in samplers if x.name not in ['PLMS', 'DPM fast', 'DPM adaptive']]
sampler_extra_params = { sampler_extra_params = {
'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
...@@ -309,8 +310,13 @@ class KDiffusionSampler: ...@@ -309,8 +310,13 @@ class KDiffusionSampler:
x = x * sigmas[0] x = x * sigmas[0]
extra_params_kwargs = self.initialize(p) extra_params_kwargs = self.initialize(p)
if 'sigma_min' in inspect.signature(self.func).parameters:
samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs) extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
if 'n' in inspect.signature(self.func).parameters:
extra_params_kwargs['n'] = steps
else:
extra_params_kwargs['sigmas'] = sigmas
samples = self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples return samples
import sys
import argparse import argparse
import datetime
import json import json
import os import os
import sys
import gradio as gr import gradio as gr
import tqdm import tqdm
import datetime
import modules.artists import modules.artists
from modules.paths import script_path, sd_path
from modules.devices import get_optimal_device
import modules.styles
import modules.interrogate import modules.interrogate
import modules.memmon import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles
from modules.devices import get_optimal_device
from modules.paths import script_path, sd_path
sd_model_file = os.path.join(script_path, 'model.ckpt') sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file default_sd_model_file = sd_model_file
model_path = os.path.join(script_path, 'models')
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",) parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",) parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN')) parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None) parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats") parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
...@@ -34,8 +35,13 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis ...@@ -34,8 +35,13 @@ parser.add_argument("--always-batch-cond-uncond", action='store_true', help="dis
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN')) parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(model_path, 'Codeformer'))
parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR')) parser.add_argument("--gfpgan-models-path", type=str, help="Path to directory with GFPGAN model file(s).", default=os.path.join(model_path, 'GFPGAN'))
parser.add_argument("--esrgan-models-path", type=str, help="Path to directory with ESRGAN model file(s).", default=os.path.join(model_path, 'ESRGAN'))
parser.add_argument("--bsrgan-models-path", type=str, help="Path to directory with BSRGAN model file(s).", default=os.path.join(model_path, 'BSRGAN'))
parser.add_argument("--realesrgan-models-path", type=str, help="Path to directory with RealESRGAN model file(s).", default=os.path.join(model_path, 'RealESRGAN'))
parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(model_path, 'SwinIR'))
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(model_path, 'LDSR'))
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
...@@ -53,7 +59,6 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR ...@@ -53,7 +59,6 @@ parser.add_argument("--autolaunch", action='store_true', help="open the webui UR
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
device = get_optimal_device() device = get_optimal_device()
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
...@@ -61,6 +66,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram ...@@ -61,6 +66,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
config_filename = cmd_opts.ui_settings_file config_filename = cmd_opts.ui_settings_file
class State: class State:
interrupted = False interrupted = False
job = "" job = ""
...@@ -95,13 +101,13 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename) ...@@ -95,13 +101,13 @@ prompt_styles = modules.styles.StyleDatabase(styles_filename)
interrogator = modules.interrogate.InterrogateModels("interrogate") interrogator = modules.interrogate.InterrogateModels("interrogate")
face_restorers = [] face_restorers = []
# This was moved to webui.py with the other model "setup" calls.
modules.sd_models.list_models() # modules.sd_models.list_models()
def realesrgan_models_names(): def realesrgan_models_names():
import modules.realesrgan_model import modules.realesrgan_model
return [x.name for x in modules.realesrgan_model.get_realesrgan_models()] return [x.name for x in modules.realesrgan_model.get_realesrgan_models(None)]
class OptionInfo: class OptionInfo:
...@@ -167,13 +173,10 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo ...@@ -167,13 +173,10 @@ options_templates.update(options_section(('saving-to-dirs', "Saving to a directo
options_templates.update(options_section(('upscaling', "Upscaling"), { options_templates.update(options_section(('upscaling', "Upscaling"), {
"ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}), "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
"ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"realesrgan_enabled_models": OptionInfo(["Real-ESRGAN 4x plus", "Real-ESRGAN 4x plus anime 6B"], "Select which RealESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN x4+", "R-ESRGAN x4+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
"SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}), "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
"SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
"ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}), "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
"ldsr_pre_down": OptionInfo(1, "LDSR Pre-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
"upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
})) }))
...@@ -190,12 +193,13 @@ options_templates.update(options_section(('system', "System"), { ...@@ -190,12 +193,13 @@ options_templates.update(options_section(('system', "System"), {
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
"enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"),
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
......
...@@ -53,6 +53,12 @@ class StyleDatabase: ...@@ -53,6 +53,12 @@ class StyleDatabase:
negative_prompt = row.get("negative_prompt", "") negative_prompt = row.get("negative_prompt", "")
self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)
def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]
def get_negative_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).negative_prompt for x in styles]
def apply_styles_to_prompt(self, prompt, styles): def apply_styles_to_prompt(self, prompt, styles):
return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])
......
import sys import contextlib
import traceback import os
import cv2
import os import numpy as np
import contextlib import torch
import numpy as np from PIL import Image
from PIL import Image from basicsr.utils.download_util import load_file_from_url
import torch
import modules.images from modules import modelloader
from modules.shared import cmd_opts, opts, device from modules.paths import models_path
from modules.swinir_arch import SwinIR as net from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net
precision_scope = ( from modules.upscaler import Upscaler, UpscalerData
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
) precision_scope = (
torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
)
def load_model(filename, scale=4):
model = net(
upscale=scale, class UpscalerSwinIR(Upscaler):
in_chans=3, def __init__(self, dirname):
img_size=64, self.name = "SwinIR"
window_size=8, self.model_url = "https://github.com/JingyunLiang/SwinIR/releases/download/v0.0" \
img_range=1.0, "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6], "-L_x4_GAN.pth "
embed_dim=240, self.model_name = "SwinIR 4x"
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8], self.model_path = os.path.join(models_path, self.name)
mlp_ratio=2, self.user_path = dirname
upsampler="nearest+conv", super().__init__()
resi_connection="3conv", scalers = []
) model_files = self.find_models(ext_filter=[".pt", ".pth"])
for model in model_files:
pretrained_model = torch.load(filename) if "http" in model:
model.load_state_dict(pretrained_model["params_ema"], strict=True) name = self.model_name
if not cmd_opts.no_half: else:
model = model.half() name = modelloader.friendly_name(model)
return model model_data = UpscalerData(name, model, self)
scalers.append(model_data)
self.scalers = scalers
def load_models(dirname):
for file in os.listdir(dirname): def do_upscale(self, img, model_file):
path = os.path.join(dirname, file) model = self.load_model(model_file)
model_name, extension = os.path.splitext(file) if model is None:
return img
if extension != ".pt" and extension != ".pth": model = model.to(device)
continue img = upscale(img, model)
try:
try: torch.cuda.empty_cache()
modules.shared.sd_upscalers.append(UpscalerSwin(path, model_name)) except:
except Exception: pass
print(f"Error loading SwinIR model: {path}", file=sys.stderr) return img
print(traceback.format_exc(), file=sys.stderr)
def load_model(self, path, scale=4):
if "http" in path:
def upscale( dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
img, filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
model, else:
tile=opts.SWIN_tile, filename = path
tile_overlap=opts.SWIN_tile_overlap, if filename is None or not os.path.exists(filename):
window_size=8, return None
scale=4, model = net(
): upscale=scale,
img = np.array(img) in_chans=3,
img = img[:, :, ::-1] img_size=64,
img = np.moveaxis(img, 2, 0) / 255 window_size=8,
img = torch.from_numpy(img).float() img_range=1.0,
img = img.unsqueeze(0).to(device) depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
with torch.no_grad(), precision_scope("cuda"): embed_dim=240,
_, _, h_old, w_old = img.size() num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
h_pad = (h_old // window_size + 1) * window_size - h_old mlp_ratio=2,
w_pad = (w_old // window_size + 1) * window_size - w_old upsampler="nearest+conv",
img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :] resi_connection="3conv",
img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad] )
output = inference(img, model, tile, tile_overlap, window_size, scale)
output = output[..., : h_old * scale, : w_old * scale] pretrained_model = torch.load(filename)
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy() model.load_state_dict(pretrained_model["params_ema"], strict=True)
if output.ndim == 3: if not cmd_opts.no_half:
output = np.transpose( model = model.half()
output[[2, 1, 0], :, :], (1, 2, 0) return model
) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
return Image.fromarray(output, "RGB") def upscale(
img,
model,
def inference(img, model, tile, tile_overlap, window_size, scale): tile=opts.SWIN_tile,
# test the image tile by tile tile_overlap=opts.SWIN_tile_overlap,
b, c, h, w = img.size() window_size=8,
tile = min(tile, h, w) scale=4,
assert tile % window_size == 0, "tile size should be a multiple of window_size" ):
sf = scale img = np.array(img)
img = img[:, :, ::-1]
stride = tile - tile_overlap img = np.moveaxis(img, 2, 0) / 255
h_idx_list = list(range(0, h - tile, stride)) + [h - tile] img = torch.from_numpy(img).float()
w_idx_list = list(range(0, w - tile, stride)) + [w - tile] img = img.unsqueeze(0).to(device)
E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img) with torch.no_grad(), precision_scope("cuda"):
W = torch.zeros_like(E, dtype=torch.half, device=device) _, _, h_old, w_old = img.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
for h_idx in h_idx_list: w_pad = (w_old // window_size + 1) * window_size - w_old
for w_idx in w_idx_list: img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, : h_old + h_pad, :]
in_patch = img[..., h_idx : h_idx + tile, w_idx : w_idx + tile] img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, : w_old + w_pad]
out_patch = model(in_patch) output = inference(img, model, tile, tile_overlap, window_size, scale)
out_patch_mask = torch.ones_like(out_patch) output = output[..., : h_old * scale, : w_old * scale]
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
E[ if output.ndim == 3:
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf output = np.transpose(
].add_(out_patch) output[[2, 1, 0], :, :], (1, 2, 0)
W[ ) # CHW-RGB to HCW-BGR
..., h_idx * sf : (h_idx + tile) * sf, w_idx * sf : (w_idx + tile) * sf output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
].add_(out_patch_mask) return Image.fromarray(output, "RGB")
output = E.div_(W)
return output def inference(img, model, tile, tile_overlap, window_size, scale):
# test the image tile by tile
b, c, h, w = img.size()
class UpscalerSwin(modules.images.Upscaler): tile = min(tile, h, w)
def __init__(self, filename, title): assert tile % window_size == 0, "tile size should be a multiple of window_size"
self.name = title sf = scale
self.model = load_model(filename)
stride = tile - tile_overlap
def do_upscale(self, img): h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
model = self.model.to(device) w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
img = upscale(img, model) E = torch.zeros(b, c, h * sf, w * sf, dtype=torch.half, device=device).type_as(img)
return img W = torch.zeros_like(E, dtype=torch.half, device=device)
for h_idx in h_idx_list:
for w_idx in w_idx_list:
in_patch = img[..., h_idx: h_idx + tile, w_idx: w_idx + tile]
out_patch = model(in_patch)
out_patch_mask = torch.ones_like(out_patch)
E[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch)
W[
..., h_idx * sf: (h_idx + tile) * sf, w_idx * sf: (w_idx + tile) * sf
].add_(out_patch_mask)
output = E.div_(W)
return output
...@@ -9,10 +9,13 @@ import random ...@@ -9,10 +9,13 @@ import random
import sys import sys
import time import time
import traceback import traceback
import platform
import subprocess as sp
import numpy as np import numpy as np
import torch import torch
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
import piexif
import gradio as gr import gradio as gr
import gradio.utils import gradio.utils
...@@ -61,7 +64,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️ ...@@ -61,7 +64,7 @@ random_symbol = '\U0001f3b2\ufe0f' # 🎲️
reuse_symbol = '\u267b\ufe0f' # ♻️ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨 art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙ paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\uD83D\uDCC2'
def plaintext_to_html(text): def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>" text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
...@@ -111,18 +114,26 @@ def save_files(js_data, images, index): ...@@ -111,18 +114,26 @@ def save_files(js_data, images, index):
writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"]) writer.writerow(["prompt", "seed", "width", "height", "sampler", "cfgs", "steps", "filename", "negative_prompt"])
filename_base = str(int(time.time() * 1000)) filename_base = str(int(time.time() * 1000))
extension = opts.samples_format.lower()
for i, filedata in enumerate(images): for i, filedata in enumerate(images):
filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + ".png" filename = filename_base + ("" if len(images) == 1 else "-" + str(i + 1)) + f".{extension}"
filepath = os.path.join(opts.outdir_save, filename) filepath = os.path.join(opts.outdir_save, filename)
if filedata.startswith("data:image/png;base64,"): if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):] filedata = filedata[len("data:image/png;base64,"):]
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text('parameters', infotexts[i])
image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8')))) image = Image.open(io.BytesIO(base64.decodebytes(filedata.encode('utf-8'))))
image.save(filepath, quality=opts.jpeg_quality, pnginfo=pnginfo) if opts.enable_pnginfo and extension == 'png':
pnginfo = PngImagePlugin.PngInfo()
pnginfo.add_text('parameters', infotexts[i])
image.save(filepath, pnginfo=pnginfo)
else:
image.save(filepath, quality=opts.jpeg_quality)
if opts.enable_pnginfo and extension in ("jpg", "jpeg", "webp"):
piexif.insert(piexif.dump({"Exif": {
piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(infotexts[i], encoding="unicode")
}}), filepath)
filenames.append(filename) filenames.append(filename)
...@@ -369,7 +380,7 @@ def create_toprow(is_img2img): ...@@ -369,7 +380,7 @@ def create_toprow(is_img2img):
with gr.Column(scale=1): with gr.Column(scale=1):
with gr.Row(): with gr.Row():
interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt")
submit = gr.Button('Generate', elem_id="generate", variant='primary') submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
interrupt.click( interrupt.click(
fn=lambda: shared.state.interrupt(), fn=lambda: shared.state.interrupt(),
...@@ -461,6 +472,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -461,6 +472,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
send_to_img2img = gr.Button('Send to img2img') send_to_img2img = gr.Button('Send to img2img')
send_to_inpaint = gr.Button('Send to inpaint') send_to_inpaint = gr.Button('Send to inpaint')
send_to_extras = gr.Button('Send to extras') send_to_extras = gr.Button('Send to extras')
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_txt2img_folder = gr.Button(folder_symbol, elem_id=button_id)
with gr.Group(): with gr.Group():
html_info = gr.HTML() html_info = gr.HTML()
...@@ -586,7 +599,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -586,7 +599,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode") mask_mode = gr.Radio(label="Mask mode", show_label=False, choices=["Draw mask", "Upload mask"], type="index", value="Draw mask", elem_id="mask_mode")
inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index") inpainting_mask_invert = gr.Radio(label='Masking mode', show_label=False, choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index")
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index") inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='original', type="index")
with gr.Row(): with gr.Row():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False) inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False)
...@@ -637,6 +650,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -637,6 +650,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
img2img_send_to_img2img = gr.Button('Send to img2img') img2img_send_to_img2img = gr.Button('Send to img2img')
img2img_send_to_inpaint = gr.Button('Send to inpaint') img2img_send_to_inpaint = gr.Button('Send to inpaint')
img2img_send_to_extras = gr.Button('Send to extras') img2img_send_to_extras = gr.Button('Send to extras')
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder'
open_img2img_folder = gr.Button(folder_symbol, elem_id=button_id)
with gr.Group(): with gr.Group():
html_info = gr.HTML() html_info = gr.HTML()
...@@ -809,6 +824,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -809,6 +824,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
html_info = gr.HTML() html_info = gr.HTML()
extras_send_to_img2img = gr.Button('Send to img2img') extras_send_to_img2img = gr.Button('Send to img2img')
extras_send_to_inpaint = gr.Button('Send to inpaint') extras_send_to_inpaint = gr.Button('Send to inpaint')
button_id = "hidden_element" if shared.cmd_opts.hide_ui_dir_config else ''
open_extras_folder = gr.Button('Open output directory', elem_id=button_id)
submit.click( submit.click(
fn=run_extras, fn=run_extras,
...@@ -874,6 +891,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -874,6 +891,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Row(): with gr.Row():
primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name") primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary Model Name")
secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name") secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary Model Name")
custom_name = gr.Textbox(label="Custom Name (Optional)")
interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3) interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Interpolation Amount', value=0.3)
interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method") interp_method = gr.Radio(choices=["Weighted Sum", "Sigmoid", "Inverse Sigmoid"], value="Weighted Sum", label="Interpolation Method")
save_as_half = gr.Checkbox(value=False, label="Safe as float16") save_as_half = gr.Checkbox(value=False, label="Safe as float16")
...@@ -907,6 +925,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -907,6 +925,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
components = [] components = []
component_dict = {} component_dict = {}
def open_folder(f):
if not shared.cmd_opts.hide_ui_dir_config:
path = os.path.normpath(f)
if platform.system() == "Windows":
os.startfile(path)
elif platform.system() == "Darwin":
sp.Popen(["open", path])
else:
sp.Popen(["xdg-open", path])
def run_settings(*args): def run_settings(*args):
changed = 0 changed = 0
...@@ -1013,15 +1041,26 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -1013,15 +1041,26 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
inputs=components, inputs=components,
outputs=[result, text_settings], outputs=[result, text_settings],
) )
def modelmerger(*args):
try:
results = run_modelmerger(*args)
except Exception as e:
print("Error loading/saving model file:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
modules.sd_models.list_models() #To remove the potentially missing models from the list
return ["Error loading/saving model file. It doesn't exist or the name contains illegal characters"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(3)]
return results
modelmerger_merge.click( modelmerger_merge.click(
fn=run_modelmerger, fn=modelmerger,
inputs=[ inputs=[
primary_model_name, primary_model_name,
secondary_model_name, secondary_model_name,
interp_method, interp_method,
interp_amount, interp_amount,
save_as_half, save_as_half,
custom_name,
], ],
outputs=[ outputs=[
submit_result, submit_result,
...@@ -1068,6 +1107,24 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -1068,6 +1107,24 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
outputs=[extras_image], outputs=[extras_image],
) )
open_txt2img_folder.click(
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_txt2img_samples),
inputs=[],
outputs=[],
)
open_img2img_folder.click(
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_img2img_samples),
inputs=[],
outputs=[],
)
open_extras_folder.click(
fn=lambda: open_folder(opts.outdir_samples or opts.outdir_extras_samples),
inputs=[],
outputs=[],
)
img2img_send_to_extras.click( img2img_send_to_extras.click(
fn=lambda x: image_from_url_text(x), fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery_extras", _js="extract_image_from_gallery_extras",
......
import os
from abc import abstractmethod
import PIL
import numpy as np
import torch
from PIL import Image
import modules.shared
from modules import modelloader, shared
LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
from modules.paths import models_path
class Upscaler:
name = None
model_path = None
model_name = None
model_url = None
enable = True
filter = None
model = None
user_path = None
scalers: []
tile = True
def __init__(self, create_dirs=False):
self.mod_pad_h = None
self.tile_size = modules.shared.opts.ESRGAN_tile
self.tile_pad = modules.shared.opts.ESRGAN_tile_overlap
self.device = modules.shared.device
self.img = None
self.output = None
self.scale = 1
self.half = not modules.shared.cmd_opts.no_half
self.pre_pad = 0
self.mod_scale = None
if self.name is not None and create_dirs:
self.model_path = os.path.join(models_path, self.name)
if not os.path.exists(self.model_path):
os.makedirs(self.model_path)
try:
import cv2
self.can_tile = True
except:
pass
@abstractmethod
def do_upscale(self, img: PIL.Image, selected_model: str):
return img
def upscale(self, img: PIL.Image, scale: int, selected_model: str = None):
self.scale = scale
dest_w = img.width * scale
dest_h = img.height * scale
for i in range(3):
if img.width >= dest_w and img.height >= dest_h:
break
img = self.do_upscale(img, selected_model)
if img.width != dest_w or img.height != dest_h:
img = img.resize((int(dest_w), int(dest_h)), resample=LANCZOS)
return img
@abstractmethod
def load_model(self, path: str):
pass
def find_models(self, ext_filter=None) -> list:
return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path)
def update_status(self, prompt):
print(f"\nextras: {prompt}", file=shared.progress_print_out)
class UpscalerData:
name = None
data_path = None
scale: int = 4
scaler: Upscaler = None
model: None
def __init__(self, name: str, path: str, upscaler: Upscaler = None, scale: int = 4, model=None):
self.name = name
self.data_path = path
self.scaler = upscaler
self.scale = scale
self.model = model
class UpscalerNone(Upscaler):
name = "None"
scalers = []
def load_model(self, path):
pass
def do_upscale(self, img, selected_model=None):
return img
def __init__(self, dirname=None):
super().__init__(False)
self.scalers = [UpscalerData("None", None, self)]
class UpscalerLanczos(Upscaler):
scalers = []
def do_upscale(self, img, selected_model=None):
return img.resize((int(img.width * self.scale), int(img.height * self.scale)), resample=LANCZOS)
def load_model(self, _):
pass
def __init__(self, dirname=None):
super().__init__(False)
self.name = "Lanczos"
self.scalers = [UpscalerData("Lanczos", None, self)]
...@@ -11,46 +11,8 @@ from modules import images, processing, devices ...@@ -11,46 +11,8 @@ from modules import images, processing, devices
from modules.processing import Processed, process_images from modules.processing import Processed, process_images
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
# https://github.com/parlance-zz/g-diffuser-bot
def expand(x, dir, amount, power=0.75):
is_left = dir == 3
is_right = dir == 1
is_up = dir == 0
is_down = dir == 2
if is_left or is_right:
noise = np.zeros((x.shape[0], amount, 3), dtype=float)
indexes = np.random.random((x.shape[0], amount)) ** power * (1 - np.arange(amount) / amount)
if is_right:
indexes = 1 - indexes
indexes = (indexes * (x.shape[1] - 1)).astype(int)
for row in range(x.shape[0]):
if is_left:
noise[row] = x[row][indexes[row]]
else:
noise[row] = np.flip(x[row][indexes[row]], axis=0)
x = np.concatenate([noise, x] if is_left else [x, noise], axis=1)
return x
if is_up or is_down:
noise = np.zeros((amount, x.shape[1], 3), dtype=float)
indexes = np.random.random((x.shape[1], amount)) ** power * (1 - np.arange(amount) / amount)
if is_down:
indexes = 1 - indexes
indexes = (indexes * x.shape[0] - 1).astype(int)
for row in range(x.shape[1]):
if is_up:
noise[:, row] = x[:, row][indexes[row]]
else:
noise[:, row] = np.flip(x[:, row][indexes[row]], axis=0)
x = np.concatenate([noise, x] if is_up else [x, noise], axis=0)
return x
# this function is taken from https://github.com/parlance-zz/g-diffuser-bot
def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05): def get_matched_noise(_np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05):
# helper fft routines that keep ortho normalization and auto-shift before and after fft # helper fft routines that keep ortho normalization and auto-shift before and after fft
def _fft2(data): def _fft2(data):
......
...@@ -34,7 +34,7 @@ class Script(scripts.Script): ...@@ -34,7 +34,7 @@ class Script(scripts.Script):
seed = p.seed seed = p.seed
init_img = p.init_images[0] init_img = p.init_images[0]
img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2) img = upscaler.scaler.upscale(init_img, 2, upscaler.data_path)
devices.torch_gc() devices.torch_gc()
......
...@@ -45,11 +45,8 @@ def apply_sampler(p, x, xs): ...@@ -45,11 +45,8 @@ def apply_sampler(p, x, xs):
def apply_checkpoint(p, x, xs): def apply_checkpoint(p, x, xs):
applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title] info = modules.sd_models.get_closet_checkpoint_match(x)
assert len(applicable) > 0, f'Checkpoint {x} for found' assert info is not None, f'Checkpoint for {x} not found'
info = applicable[0]
modules.sd_models.reload_model_weights(shared.sd_model, info) modules.sd_models.reload_model_weights(shared.sd_model, info)
...@@ -159,6 +156,9 @@ class Script(scripts.Script): ...@@ -159,6 +156,9 @@ class Script(scripts.Script):
p.batch_size = 1 p.batch_size = 1
def process_axis(opt, vals): def process_axis(opt, vals):
if opt.label == 'Nothing':
return [0]
valslist = [x.strip() for x in vals.split(",")] valslist = [x.strip() for x in vals.split(",")]
if opt.type == int: if opt.type == int:
......
.output-html p {margin: 0 0.5em;} .output-html p {margin: 0 0.5em;}
.row > *,
.row > .gr-form > * {
min-width: min(120px, 100%);
flex: 1 1 0%;
}
.performance { .performance {
font-size: 0.85em; font-size: 0.85em;
color: #444; color: #444;
...@@ -17,7 +23,7 @@ ...@@ -17,7 +23,7 @@
text-align: right; text-align: right;
} }
#generate{ #txt2img_generate, #img2img_generate {
min-height: 4.5em; min-height: 4.5em;
} }
...@@ -43,13 +49,17 @@ ...@@ -43,13 +49,17 @@
margin-right: auto; margin-right: auto;
} }
#random_seed, #random_subseed, #reuse_seed, #reuse_subseed{ #random_seed, #random_subseed, #reuse_seed, #reuse_subseed, #open_folder{
min-width: auto; min-width: auto;
flex-grow: 0; flex-grow: 0;
padding-left: 0.25em; padding-left: 0.25em;
padding-right: 0.25em; padding-right: 0.25em;
} }
#hidden_element{
display: none;
}
#seed_row, #subseed_row{ #seed_row, #subseed_row{
gap: 0.5rem; gap: 0.5rem;
} }
......
...@@ -21,6 +21,9 @@ export COMMANDLINE_ARGS="" ...@@ -21,6 +21,9 @@ export COMMANDLINE_ARGS=""
# python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
#venv_dir="venv" #venv_dir="venv"
# script to launch to start the app
#export LAUNCH_SCRIPT="launch.py"
# install command for torch # install command for torch
#export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113" #export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113"
......
import os import os
import threading import threading
from modules import devices
from modules.paths import script_path from modules.paths import script_path
import signal import signal
import threading
from modules.shared import opts, cmd_opts, state import modules.paths
import modules.shared as shared import modules.codeformer_model as codeformer
import modules.ui
import modules.scripts
import modules.sd_hijack
import modules.codeformer_model
import modules.gfpgan_model
import modules.face_restoration
import modules.realesrgan_model as realesrgan
import modules.esrgan_model as esrgan import modules.esrgan_model as esrgan
import modules.ldsr_model as ldsr import modules.bsrgan_model as bsrgan
import modules.extras import modules.extras
import modules.lowvram import modules.face_restoration
import modules.txt2img import modules.gfpgan_model as gfpgan
import modules.img2img import modules.img2img
import modules.swinir as swinir import modules.ldsr_model as ldsr
import modules.lowvram
import modules.realesrgan_model as realesrgan
import modules.scripts
import modules.sd_hijack
import modules.sd_models import modules.sd_models
import modules.shared as shared
import modules.swinir_model as swinir
import modules.txt2img
import modules.ui
from modules import modelloader
from modules.paths import script_path
from modules.shared import cmd_opts
modelloader.cleanup_models()
modules.codeformer_model.setup_codeformer() modules.sd_models.setup_model(cmd_opts.ckpt_dir)
modules.gfpgan_model.setup_gfpgan() codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration()) shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers()
esrgan.load_models(cmd_opts.esrgan_models_path)
swinir.load_models(cmd_opts.swinir_models_path)
realesrgan.setup_realesrgan()
ldsr.add_lsdr()
queue_lock = threading.Lock() queue_lock = threading.Lock()
...@@ -47,6 +48,8 @@ def wrap_queued_call(func): ...@@ -47,6 +48,8 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func): def wrap_gradio_gpu_call(func):
def f(*args, **kwargs): def f(*args, **kwargs):
devices.torch_gc()
shared.state.sampling_step = 0 shared.state.sampling_step = 0
shared.state.job_count = -1 shared.state.job_count = -1
shared.state.job_no = 0 shared.state.job_no = 0
...@@ -62,6 +65,8 @@ def wrap_gradio_gpu_call(func): ...@@ -62,6 +65,8 @@ def wrap_gradio_gpu_call(func):
shared.state.job = "" shared.state.job = ""
shared.state.job_count = 0 shared.state.job_count = 0
devices.torch_gc()
return res return res
return modules.ui.wrap_gradio_call(f) return modules.ui.wrap_gradio_call(f)
......
...@@ -41,6 +41,11 @@ then ...@@ -41,6 +41,11 @@ then
venv_dir="venv" venv_dir="venv"
fi fi
if [[ -z "${LAUNCH_SCRIPT}" ]]
then
LAUNCH_SCRIPT="launch.py"
fi
# Disable sentry logging # Disable sentry logging
export ERROR_REPORTING=FALSE export ERROR_REPORTING=FALSE
...@@ -133,4 +138,4 @@ fi ...@@ -133,4 +138,4 @@ fi
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
printf "Launching launch.py..." printf "Launching launch.py..."
printf "\n%s\n" "${delimiter}" printf "\n%s\n" "${delimiter}"
"${python_cmd}" launch.py "${python_cmd}" "${LAUNCH_SCRIPT}"
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment