Commit 6fdad291 authored by ssysm's avatar ssysm

Merge branch 'master' of...

Merge branch 'master' of https://github.com/AUTOMATIC1111/stable-diffusion-webui into upstream-master
parents cc92dc1f 45fbd1c5
...@@ -66,6 +66,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web ...@@ -66,6 +66,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- separate prompts using uppercase `AND` - separate prompts using uppercase `AND`
- also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2` - also supports weights for prompts: `a cat :1.2 AND a dog AND a penguin :2.2`
- No token limit for prompts (original stable diffusion lets you use up to 75 tokens) - No token limit for prompts (original stable diffusion lets you use up to 75 tokens)
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
## Installation and Running ## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
...@@ -123,4 +124,5 @@ The documentation was moved from this README over to the project's [wiki](https: ...@@ -123,4 +124,5 @@ The documentation was moved from this README over to the project's [wiki](https:
- Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot
- CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- DeepDanbooru - interrogator for anime diffusors https://github.com/KichangKim/DeepDanbooru
- (You) - (You)
// A full size 'lightbox' preview modal shown when left clicking on gallery previews // A full size 'lightbox' preview modal shown when left clicking on gallery previews
function closeModal() { function closeModal() {
gradioApp().getElementById("lightboxModal").style.display = "none"; gradioApp().getElementById("lightboxModal").style.display = "none";
} }
function showModal(event) { function showModal(event) {
const source = event.target || event.srcElement; const source = event.target || event.srcElement;
const modalImage = gradioApp().getElementById("modalImage") const modalImage = gradioApp().getElementById("modalImage")
const lb = gradioApp().getElementById("lightboxModal") const lb = gradioApp().getElementById("lightboxModal")
modalImage.src = source.src modalImage.src = source.src
if (modalImage.style.display === 'none') { if (modalImage.style.display === 'none') {
lb.style.setProperty('background-image', 'url(' + source.src + ')'); lb.style.setProperty('background-image', 'url(' + source.src + ')');
} }
lb.style.display = "block"; lb.style.display = "block";
lb.focus() lb.focus()
event.stopPropagation() event.stopPropagation()
} }
function negmod(n, m) { function negmod(n, m) {
return ((n % m) + m) % m; return ((n % m) + m) % m;
} }
function modalImageSwitch(offset){ function updateOnBackgroundChange() {
var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all") const modalImage = gradioApp().getElementById("modalImage")
var galleryButtons = [] if (modalImage && modalImage.offsetParent) {
allgalleryButtons.forEach(function(elem){ let allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
if(elem.parentElement.offsetParent){ let currentButton = null
galleryButtons.push(elem); allcurrentButtons.forEach(function(elem) {
if (elem.parentElement.offsetParent) {
currentButton = elem;
}
})
if (modalImage.src != currentButton.children[0].src) {
modalImage.src = currentButton.children[0].src;
if (modalImage.style.display === 'none') {
modal.style.setProperty('background-image', `url(${modalImage.src})`)
}
}
} }
}) }
if(galleryButtons.length>1){ function modalImageSwitch(offset) {
var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2") var allgalleryButtons = gradioApp().querySelectorAll(".gallery-item.transition-all")
var currentButton = null var galleryButtons = []
allcurrentButtons.forEach(function(elem){ allgalleryButtons.forEach(function(elem) {
if(elem.parentElement.offsetParent){ if (elem.parentElement.offsetParent) {
currentButton = elem; galleryButtons.push(elem);
} }
}) })
var result = -1 if (galleryButtons.length > 1) {
galleryButtons.forEach(function(v, i){ if(v==currentButton) { result = i } }) var allcurrentButtons = gradioApp().querySelectorAll(".gallery-item.transition-all.\\!ring-2")
var currentButton = null
if(result != -1){ allcurrentButtons.forEach(function(elem) {
nextButton = galleryButtons[negmod((result+offset),galleryButtons.length)] if (elem.parentElement.offsetParent) {
nextButton.click() currentButton = elem;
const modalImage = gradioApp().getElementById("modalImage"); }
const modal = gradioApp().getElementById("lightboxModal"); })
modalImage.src = nextButton.children[0].src;
if (modalImage.style.display === 'none') { var result = -1
modal.style.setProperty('background-image', `url(${modalImage.src})`) galleryButtons.forEach(function(v, i) {
if (v == currentButton) {
result = i
}
})
if (result != -1) {
nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
nextButton.click()
const modalImage = gradioApp().getElementById("modalImage");
const modal = gradioApp().getElementById("lightboxModal");
modalImage.src = nextButton.children[0].src;
if (modalImage.style.display === 'none') {
modal.style.setProperty('background-image', `url(${modalImage.src})`)
}
setTimeout(function() {
modal.focus()
}, 10)
} }
setTimeout( function(){modal.focus()},10) }
}
}
} }
function modalNextImage(event){ function modalNextImage(event) {
modalImageSwitch(1) modalImageSwitch(1)
event.stopPropagation() event.stopPropagation()
} }
function modalPrevImage(event){ function modalPrevImage(event) {
modalImageSwitch(-1) modalImageSwitch(-1)
event.stopPropagation() event.stopPropagation()
} }
function modalKeyHandler(event){ function modalKeyHandler(event) {
switch (event.key) { switch (event.key) {
case "ArrowLeft": case "ArrowLeft":
modalPrevImage(event) modalPrevImage(event)
...@@ -80,24 +105,22 @@ function modalKeyHandler(event){ ...@@ -80,24 +105,22 @@ function modalKeyHandler(event){
} }
} }
function showGalleryImage(){ function showGalleryImage() {
setTimeout(function() { setTimeout(function() {
fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain') fullImg_preview = gradioApp().querySelectorAll('img.w-full.object-contain')
if(fullImg_preview != null){ if (fullImg_preview != null) {
fullImg_preview.forEach(function function_name(e) { fullImg_preview.forEach(function function_name(e) {
if (e.dataset.modded) if (e.dataset.modded)
return; return;
e.dataset.modded = true; e.dataset.modded = true;
if(e && e.parentElement.tagName == 'DIV'){ if(e && e.parentElement.tagName == 'DIV'){
e.style.cursor='pointer' e.style.cursor='pointer'
e.addEventListener('click', function (evt) { e.addEventListener('click', function (evt) {
if(!opts.js_modal_lightbox) return; if(!opts.js_modal_lightbox) return;
modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed) modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed)
showModal(evt) showModal(evt)
},true); }, true);
} }
}); });
} }
...@@ -105,21 +128,21 @@ function showGalleryImage(){ ...@@ -105,21 +128,21 @@ function showGalleryImage(){
}, 100); }, 100);
} }
function modalZoomSet(modalImage, enable){ function modalZoomSet(modalImage, enable) {
if( enable ){ if (enable) {
modalImage.classList.add('modalImageFullscreen'); modalImage.classList.add('modalImageFullscreen');
} else{ } else {
modalImage.classList.remove('modalImageFullscreen'); modalImage.classList.remove('modalImageFullscreen');
} }
} }
function modalZoomToggle(event){ function modalZoomToggle(event) {
modalImage = gradioApp().getElementById("modalImage"); modalImage = gradioApp().getElementById("modalImage");
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen')) modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
event.stopPropagation() event.stopPropagation()
} }
function modalTileImageToggle(event){ function modalTileImageToggle(event) {
const modalImage = gradioApp().getElementById("modalImage"); const modalImage = gradioApp().getElementById("modalImage");
const modal = gradioApp().getElementById("lightboxModal"); const modal = gradioApp().getElementById("lightboxModal");
const isTiling = modalImage.style.display === 'none'; const isTiling = modalImage.style.display === 'none';
...@@ -134,17 +157,18 @@ function modalTileImageToggle(event){ ...@@ -134,17 +157,18 @@ function modalTileImageToggle(event){
event.stopPropagation() event.stopPropagation()
} }
function galleryImageHandler(e){ function galleryImageHandler(e) {
if(e && e.parentElement.tagName == 'BUTTON'){ if (e && e.parentElement.tagName == 'BUTTON') {
e.onclick = showGalleryImage; e.onclick = showGalleryImage;
} }
} }
onUiUpdate(function(){ onUiUpdate(function() {
fullImg_preview = gradioApp().querySelectorAll('img.w-full') fullImg_preview = gradioApp().querySelectorAll('img.w-full')
if(fullImg_preview != null){ if (fullImg_preview != null) {
fullImg_preview.forEach(galleryImageHandler); fullImg_preview.forEach(galleryImageHandler);
} }
updateOnBackgroundChange();
}) })
document.addEventListener("DOMContentLoaded", function() { document.addEventListener("DOMContentLoaded", function() {
...@@ -152,13 +176,13 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -152,13 +176,13 @@ document.addEventListener("DOMContentLoaded", function() {
const modal = document.createElement('div') const modal = document.createElement('div')
modal.onclick = closeModal; modal.onclick = closeModal;
modal.id = "lightboxModal"; modal.id = "lightboxModal";
modal.tabIndex=0 modal.tabIndex = 0
modal.addEventListener('keydown', modalKeyHandler, true) modal.addEventListener('keydown', modalKeyHandler, true)
const modalControls = document.createElement('div') const modalControls = document.createElement('div')
modalControls.className = 'modalControls gradio-container'; modalControls.className = 'modalControls gradio-container';
modal.append(modalControls); modal.append(modalControls);
const modalZoom = document.createElement('span') const modalZoom = document.createElement('span')
modalZoom.className = 'modalZoom cursor'; modalZoom.className = 'modalZoom cursor';
modalZoom.innerHTML = '⤡' modalZoom.innerHTML = '⤡'
...@@ -183,30 +207,30 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -183,30 +207,30 @@ document.addEventListener("DOMContentLoaded", function() {
const modalImage = document.createElement('img') const modalImage = document.createElement('img')
modalImage.id = 'modalImage'; modalImage.id = 'modalImage';
modalImage.onclick = closeModal; modalImage.onclick = closeModal;
modalImage.tabIndex=0 modalImage.tabIndex = 0
modalImage.addEventListener('keydown', modalKeyHandler, true) modalImage.addEventListener('keydown', modalKeyHandler, true)
modal.appendChild(modalImage) modal.appendChild(modalImage)
const modalPrev = document.createElement('a') const modalPrev = document.createElement('a')
modalPrev.className = 'modalPrev'; modalPrev.className = 'modalPrev';
modalPrev.innerHTML = '❮' modalPrev.innerHTML = '❮'
modalPrev.tabIndex=0 modalPrev.tabIndex = 0
modalPrev.addEventListener('click',modalPrevImage,true); modalPrev.addEventListener('click', modalPrevImage, true);
modalPrev.addEventListener('keydown', modalKeyHandler, true) modalPrev.addEventListener('keydown', modalKeyHandler, true)
modal.appendChild(modalPrev) modal.appendChild(modalPrev)
const modalNext = document.createElement('a') const modalNext = document.createElement('a')
modalNext.className = 'modalNext'; modalNext.className = 'modalNext';
modalNext.innerHTML = '❯' modalNext.innerHTML = '❯'
modalNext.tabIndex=0 modalNext.tabIndex = 0
modalNext.addEventListener('click',modalNextImage,true); modalNext.addEventListener('click', modalNextImage, true);
modalNext.addEventListener('keydown', modalKeyHandler, true) modalNext.addEventListener('keydown', modalKeyHandler, true)
modal.appendChild(modalNext) modal.appendChild(modalNext)
gradioApp().getRootNode().appendChild(modal) gradioApp().getRootNode().appendChild(modal)
document.body.appendChild(modalFragment); document.body.appendChild(modalFragment);
}); });
...@@ -7,38 +7,14 @@ import shlex ...@@ -7,38 +7,14 @@ import shlex
import platform import platform
dir_repos = "repositories" dir_repos = "repositories"
dir_tmp = "tmp"
python = sys.executable python = sys.executable
git = os.environ.get('GIT', "git") git = os.environ.get('GIT', "git")
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
args = shlex.split(commandline_args)
def extract_arg(args, name): def extract_arg(args, name):
return [x for x in args if x != name], name in args return [x for x in args if x != name], name in args
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
xformers = '--xformers' in args
def repo_dir(name):
return os.path.join(dir_repos, name)
def run(command, desc=None, errdesc=None): def run(command, desc=None, errdesc=None):
if desc is not None: if desc is not None:
print(desc) print(desc)
...@@ -58,23 +34,11 @@ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.st ...@@ -58,23 +34,11 @@ stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.st
return result.stdout.decode(encoding="utf8", errors="ignore") return result.stdout.decode(encoding="utf8", errors="ignore")
def run_python(code, desc=None, errdesc=None):
return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(args, desc=None):
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
def check_run(command): def check_run(command):
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
return result.returncode == 0 return result.returncode == 0
def check_run_python(code):
return check_run(f'"{python}" -c "{code}"')
def is_installed(package): def is_installed(package):
try: try:
spec = importlib.util.find_spec(package) spec = importlib.util.find_spec(package)
...@@ -84,6 +48,22 @@ def is_installed(package): ...@@ -84,6 +48,22 @@ def is_installed(package):
return spec is not None return spec is not None
def repo_dir(name):
return os.path.join(dir_repos, name)
def run_python(code, desc=None, errdesc=None):
return run(f'"{python}" -c "{code}"', desc, errdesc)
def run_pip(args, desc=None):
return run(f'"{python}" -m pip {args} --prefer-binary', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
def check_run_python(code):
return check_run(f'"{python}" -c "{code}"')
def git_clone(url, dir, name, commithash=None): def git_clone(url, dir, name, commithash=None):
# TODO clone into temporary dir and move if successful # TODO clone into temporary dir and move if successful
...@@ -105,56 +85,81 @@ def git_clone(url, dir, name, commithash=None): ...@@ -105,56 +85,81 @@ def git_clone(url, dir, name, commithash=None):
run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") run(f'"{git}" -C {dir} checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
try: def prepare_enviroment():
commit = run(f"{git} rev-parse HEAD").strip() torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
except Exception: requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
commit = "<none>" commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
print(f"Python {sys.version}") gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
print(f"Commit hash: {commit}") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1")
stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc")
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "f4e99857772fc3a126ba886aadf795a332774878")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
if not is_installed("torch") or not is_installed("torchvision"): args = shlex.split(commandline_args)
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
args, skip_torch_cuda_test = extract_arg(args, '--skip-torch-cuda-test')
xformers = '--xformers' in args
deepdanbooru = '--deepdanbooru' in args
try:
commit = run(f"{git} rev-parse HEAD").strip()
except Exception:
commit = "<none>"
if not skip_torch_cuda_test: print(f"Python {sys.version}")
run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'") print(f"Commit hash: {commit}")
if not is_installed("gfpgan"): if not is_installed("torch") or not is_installed("torchvision"):
run_pip(f"install {gfpgan_package}", "gfpgan") run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch")
if not is_installed("clip"): if not skip_torch_cuda_test:
run_pip(f"install {clip_package}", "clip") run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")
if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"): if not is_installed("gfpgan"):
if platform.system() == "Windows": run_pip(f"install {gfpgan_package}", "gfpgan")
run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
elif platform.system() == "Linux":
run_pip("install xformers", "xformers")
os.makedirs(dir_repos, exist_ok=True) if not is_installed("clip"):
run_pip(f"install {clip_package}", "clip")
git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash) if not is_installed("xformers") and xformers and platform.python_version().startswith("3.10"):
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) if platform.system() == "Windows":
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) run_pip("install https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/a/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl", "xformers")
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) elif platform.system() == "Linux":
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash) run_pip("install xformers", "xformers")
if not is_installed("lpips"): if not is_installed("deepdanbooru") and deepdanbooru:
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer") run_pip("install git+https://github.com/KichangKim/DeepDanbooru.git@edf73df4cdaeea2cf00e9ac08bd8a9026b7a7b26#egg=deepdanbooru[tensorflow] tensorflow==2.10.0 tensorflow-io==0.27.0", "deepdanbooru")
run_pip(f"install -r {requirements_file}", "requirements for Web UI") os.makedirs(dir_repos, exist_ok=True)
sys.argv += args git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-diffusion'), "Stable Diffusion", stable_diffusion_commit_hash)
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone("https://github.com/crowsonkb/k-diffusion.git", repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
if not is_installed("lpips"):
run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")
run_pip(f"install -r {requirements_file}", "requirements for Web UI")
sys.argv += args
if "--exit" in args:
print("Exiting because of --exit argument")
exit(0)
if "--exit" in args:
print("Exiting because of --exit argument")
exit(0)
def start_webui(): def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}") print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
import webui import webui
webui.webui() webui.webui()
if __name__ == "__main__": if __name__ == "__main__":
prepare_enviroment()
start_webui() start_webui()
...@@ -10,13 +10,11 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -10,13 +10,11 @@ from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader from modules import devices, modelloader
from modules.bsrgan_model_arch import RRDBNet from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path
class UpscalerBSRGAN(modules.upscaler.Upscaler): class UpscalerBSRGAN(modules.upscaler.Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self.name = "BSRGAN" self.name = "BSRGAN"
self.model_path = os.path.join(models_path, self.name)
self.model_name = "BSRGAN 4x" self.model_name = "BSRGAN 4x"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth" self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
self.user_path = dirname self.user_path = dirname
......
import os.path
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import get_context
def _load_tf_and_return_tags(pil_image, threshold):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
this_folder = os.path.dirname(__file__)
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
if not os.path.exists(os.path.join(model_path, 'project.json')):
# there is no point importing these every time
import zipfile
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
model_path)
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
zip_ref.extractall(model_path)
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
tags = dd.project.load_tags_from_project(model_path)
model = dd.project.load_model_from_project(
model_path, compile_model=True
)
width = model.input_shape[2]
height = model.input_shape[1]
image = np.array(pil_image)
image = tf.image.resize(
image,
size=(height, width),
method=tf.image.ResizeMethod.AREA,
preserve_aspect_ratio=True,
)
image = image.numpy() # EagerTensor to np.array
image = dd.image.transform_and_pad_image(image, width, height)
image = image / 255.0
image_shape = image.shape
image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))
y = model.predict(image)[0]
result_dict = {}
for i, tag in enumerate(tags):
result_dict[tag] = y[i]
result_tags_out = []
result_tags_print = []
for tag in tags:
if result_dict[tag] >= threshold:
if tag.startswith("rating:"):
continue
result_tags_out.append(tag)
result_tags_print.append(f'{result_dict[tag]} {tag}')
print('\n'.join(sorted(result_tags_print, reverse=True)))
return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
def subprocess_init_no_cuda():
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
def get_deepbooru_tags(pil_image, threshold=0.5):
context = get_context('spawn')
with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, )
ret = f.result() # will rethrow any exceptions
return ret
\ No newline at end of file
...@@ -5,9 +5,8 @@ import torch ...@@ -5,9 +5,8 @@ import torch
from PIL import Image from PIL import Image
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
import modules.esrgam_model_arch as arch import modules.esrgan_model_arch as arch
from modules import shared, modelloader, images, devices from modules import shared, modelloader, images, devices
from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts from modules.shared import opts
...@@ -76,7 +75,6 @@ class UpscalerESRGAN(Upscaler): ...@@ -76,7 +75,6 @@ class UpscalerESRGAN(Upscaler):
self.model_name = "ESRGAN_4x" self.model_name = "ESRGAN_4x"
self.scalers = [] self.scalers = []
self.user_path = dirname self.user_path = dirname
self.model_path = os.path.join(models_path, self.name)
super().__init__() super().__init__()
model_paths = self.find_models(ext_filter=[".pt", ".pth"]) model_paths = self.find_models(ext_filter=[".pt", ".pth"])
scalers = [] scalers = []
......
...@@ -29,7 +29,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v ...@@ -29,7 +29,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
if extras_mode == 1: if extras_mode == 1:
#convert file to pillow image #convert file to pillow image
for img in image_folder: for img in image_folder:
image = Image.fromarray(np.array(Image.open(img))) image = Image.open(img)
imageArr.append(image) imageArr.append(image)
imageNameArr.append(os.path.splitext(img.orig_name)[0]) imageNameArr.append(os.path.splitext(img.orig_name)[0])
else: else:
...@@ -98,6 +98,10 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v ...@@ -98,6 +98,10 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo,
forced_filename=image_name if opts.use_original_name_batch else None) forced_filename=image_name if opts.use_original_name_batch else None)
if opts.enable_pnginfo:
image.info = existing_pnginfo
image.info["extras"] = info
outputs.append(image) outputs.append(image)
devices.torch_gc() devices.torch_gc()
...@@ -169,9 +173,9 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int ...@@ -169,9 +173,9 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
print(f"Loading {secondary_model_info.filename}...") print(f"Loading {secondary_model_info.filename}...")
secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')
theta_0 = primary_model['state_dict'] theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
theta_1 = secondary_model['state_dict'] theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)
theta_funcs = { theta_funcs = {
"Weighted Sum": weighted_sum, "Weighted Sum": weighted_sum,
......
...@@ -40,18 +40,28 @@ class Hypernetwork: ...@@ -40,18 +40,28 @@ class Hypernetwork:
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1])) self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
def load_hypernetworks(path): def list_hypernetworks(path):
res = {} res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True): for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
res[name] = filename
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
if path is not None:
print(f"Loading hypernetwork {filename}")
try: try:
hn = Hypernetwork(filename) shared.loaded_hypernetwork = Hypernetwork(path)
res[hn.name] = hn
except Exception: except Exception:
print(f"Error loading hypernetwork {filename}", file=sys.stderr) print(f"Error loading hypernetwork {path}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
else:
if shared.loaded_hypernetwork is not None:
print(f"Unloading hypernetwork")
return res shared.loaded_hypernetwork = None
def attention_CrossAttention_forward(self, x, context=None, mask=None): def attention_CrossAttention_forward(self, x, context=None, mask=None):
...@@ -60,7 +70,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): ...@@ -60,7 +70,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
q = self.to_q(x) q = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.selected_hypernetwork() hypernetwork = shared.loaded_hypernetwork
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None: if hypernetwork_layers is not None:
......
...@@ -349,6 +349,38 @@ def get_next_sequence_number(path, basename): ...@@ -349,6 +349,38 @@ def get_next_sequence_number(path, basename):
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None): def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
'''Save an image.
Args:
image (`PIL.Image`):
The image to be saved.
path (`str`):
The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
basename (`str`):
The base filename which will be applied to `filename pattern`.
seed, prompt, short_filename,
extension (`str`):
Image file extension, default is `png`.
pngsectionname (`str`):
Specify the name of the section which `info` will be saved in.
info (`str` or `PngImagePlugin.iTXt`):
PNG info chunks.
existing_info (`dict`):
Additional PNG info. `existing_info == {pngsectionname: info, ...}`
no_prompt:
TODO I don't know its meaning.
p (`StableDiffusionProcessing`)
forced_filename (`str`):
If specified, `basename` and filename pattern will be ignored.
save_to_dirs (bool):
If true, the image will be saved into a subdirectory of `path`.
Returns: (fullfn, txt_fullfn)
fullfn (`str`):
The full path of the saved imaged.
txt_fullfn (`str` or None):
If a text file is saved for this image, this will be its full path. Otherwise None.
'''
if short_filename or prompt is None or seed is None: if short_filename or prompt is None or seed is None:
file_decoration = "" file_decoration = ""
elif opts.save_to_dirs: elif opts.save_to_dirs:
...@@ -424,7 +456,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -424,7 +456,10 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg") piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
if opts.save_txt and info is not None: if opts.save_txt and info is not None:
with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file: txt_fullfn = f"{fullfn_without_extension}.txt"
with open(txt_fullfn, "w", encoding="utf8") as file:
file.write(info + "\n") file.write(info + "\n")
else:
txt_fullfn = None
return fullfn return fullfn, txt_fullfn
...@@ -7,13 +7,11 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -7,13 +7,11 @@ from basicsr.utils.download_util import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.ldsr_model_arch import LDSR from modules.ldsr_model_arch import LDSR
from modules import shared from modules import shared
from modules.paths import models_path
class UpscalerLDSR(Upscaler): class UpscalerLDSR(Upscaler):
def __init__(self, user_path): def __init__(self, user_path):
self.name = "LDSR" self.name = "LDSR"
self.model_path = os.path.join(models_path, self.name)
self.user_path = user_path self.user_path = user_path
self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1" self.model_url = "https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1"
self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1" self.yaml_url = "https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1"
......
import argparse import argparse
import os import os
import sys import sys
import modules.safe
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
models_path = os.path.join(script_path, "models") models_path = os.path.join(script_path, "models")
......
...@@ -46,6 +46,12 @@ def apply_color_correction(correction, image): ...@@ -46,6 +46,12 @@ def apply_color_correction(correction, image):
return image return image
def get_correct_sampler(p):
if isinstance(p, modules.processing.StableDiffusionProcessingTxt2Img):
return sd_samplers.samplers
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img
class StableDiffusionProcessing: class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
self.sd_model = sd_model self.sd_model = sd_model
...@@ -123,7 +129,7 @@ class Processed: ...@@ -123,7 +129,7 @@ class Processed:
self.index_of_first_image = index_of_first_image self.index_of_first_image = index_of_first_image
self.styles = p.styles self.styles = p.styles
self.job_timestamp = state.job_timestamp self.job_timestamp = state.job_timestamp
self.clip_skip = opts.CLIP_ignore_last_layers self.clip_skip = opts.CLIP_stop_at_last_layers
self.eta = p.eta self.eta = p.eta
self.ddim_discretize = p.ddim_discretize self.ddim_discretize = p.ddim_discretize
...@@ -268,16 +274,18 @@ def fix_seed(p): ...@@ -268,16 +274,18 @@ def fix_seed(p):
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_ignore_last_layers) clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
generation_params = { generation_params = {
"Steps": p.steps, "Steps": p.steps,
"Sampler": sd_samplers.samplers[p.sampler_index].name, "Sampler": get_correct_sampler(p)[p.sampler_index].name,
"CFG scale": p.cfg_scale, "CFG scale": p.cfg_scale,
"Seed": all_seeds[index], "Seed": all_seeds[index],
"Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
"Size": f"{p.width}x{p.height}", "Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name.replace(',', '').replace(':', '')),
"Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch), "Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
...@@ -285,7 +293,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -285,7 +293,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None), "Denoising strength": getattr(p, 'denoising_strength', None),
"Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta), "Eta": (None if p.sampler is None or p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
"Clip skip": None if clip_skip==0 else clip_skip, "Clip skip": None if clip_skip <= 1 else clip_skip,
} }
generation_params.update(p.extra_generation_params) generation_params.update(p.extra_generation_params)
...@@ -445,7 +453,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -445,7 +453,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
text = infotext(n, i) text = infotext(n, i)
infotexts.append(text) infotexts.append(text)
image.info["parameters"] = text if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image) output_images.append(image)
del x_samples_ddim del x_samples_ddim
...@@ -464,7 +473,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -464,7 +473,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if opts.return_grid: if opts.return_grid:
text = infotext() text = infotext()
infotexts.insert(0, text) infotexts.insert(0, text)
grid.info["parameters"] = text if opts.enable_pnginfo:
grid.info["parameters"] = text
output_images.insert(0, grid) output_images.insert(0, grid)
index_of_first_image = 1 index_of_first_image = 1
......
...@@ -8,14 +8,12 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -8,14 +8,12 @@ from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
from modules.paths import models_path
from modules.shared import cmd_opts, opts from modules.shared import cmd_opts, opts
class UpscalerRealESRGAN(Upscaler): class UpscalerRealESRGAN(Upscaler):
def __init__(self, path): def __init__(self, path):
self.name = "RealESRGAN" self.name = "RealESRGAN"
self.model_path = os.path.join(models_path, self.name)
self.user_path = path self.user_path = path
super().__init__() super().__init__()
try: try:
......
# this code is adapted from the script contributed by anon from /h/
import io
import pickle
import collections
import sys
import traceback
import torch
import numpy
import _codecs
import zipfile
def encode(*args):
out = _codecs.encode(*args)
return out
class RestrictedUnpickler(pickle.Unpickler):
def persistent_load(self, saved_id):
assert saved_id[0] == 'storage'
return torch.storage._TypedStorage()
def find_class(self, module, name):
if module == 'collections' and name == 'OrderedDict':
return getattr(collections, name)
if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
return getattr(torch._utils, name)
if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
return getattr(torch, name)
if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
return getattr(torch.nn.modules.container, name)
if module == 'numpy.core.multiarray' and name == 'scalar':
return numpy.core.multiarray.scalar
if module == 'numpy' and name == 'dtype':
return numpy.dtype
if module == '_codecs' and name == 'encode':
return encode
if module == "pytorch_lightning.callbacks" and name == 'model_checkpoint':
import pytorch_lightning.callbacks
return pytorch_lightning.callbacks.model_checkpoint
if module == "pytorch_lightning.callbacks.model_checkpoint" and name == 'ModelCheckpoint':
import pytorch_lightning.callbacks.model_checkpoint
return pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint
if module == "__builtin__" and name == 'set':
return set
# Forbid everything else.
raise pickle.UnpicklingError(f"global '{module}/{name}' is forbidden")
def check_pt(filename):
try:
# new pytorch format is a zip file
with zipfile.ZipFile(filename) as z:
with z.open('archive/data.pkl') as file:
unpickler = RestrictedUnpickler(file)
unpickler.load()
except zipfile.BadZipfile:
# if it's not a zip file, it's an olf pytorch format, with five objects written to pickle
with open(filename, "rb") as file:
unpickler = RestrictedUnpickler(file)
for i in range(5):
unpickler.load()
def load(filename, *args, **kwargs):
from modules import shared
try:
if not shared.cmd_opts.disable_safe_unpickle:
check_pt(filename)
except Exception:
print(f"Error verifying pickled file from {filename}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr)
print(f"You can skip this check with --disable-safe-unpickle commandline argument.", file=sys.stderr)
return None
return unsafe_torch_load(filename, *args, **kwargs)
unsafe_torch_load = torch.load
torch.load = load
...@@ -9,14 +9,12 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -9,14 +9,12 @@ from basicsr.utils.download_util import load_file_from_url
import modules.upscaler import modules.upscaler
from modules import devices, modelloader from modules import devices, modelloader
from modules.paths import models_path
from modules.scunet_model_arch import SCUNet as net from modules.scunet_model_arch import SCUNet as net
class UpscalerScuNET(modules.upscaler.Upscaler): class UpscalerScuNET(modules.upscaler.Upscaler):
def __init__(self, dirname): def __init__(self, dirname):
self.name = "ScuNET" self.name = "ScuNET"
self.model_path = os.path.join(models_path, self.name)
self.model_name = "ScuNET GAN" self.model_name = "ScuNET GAN"
self.model_name2 = "ScuNET PSNR" self.model_name2 = "ScuNET PSNR"
self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth" self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/scunet_color_real_gan.pth"
......
...@@ -282,14 +282,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -282,14 +282,12 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens] remade_batch_tokens_of_same_length = [x + [self.wrapped.tokenizer.eos_token_id] * (target_token_count - len(x)) for x in remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device) tokens = torch.asarray(remade_batch_tokens_of_same_length).to(device)
tmp = -opts.CLIP_ignore_last_layers outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if (opts.CLIP_ignore_last_layers == 0): if opts.CLIP_stop_at_last_layers > 1:
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids) z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers]
z = outputs.last_hidden_state
else:
outputs = self.wrapped.transformer(input_ids=tokens, position_ids=position_ids, output_hidden_states=tmp)
z = outputs.hidden_states[tmp]
z = self.wrapped.transformer.text_model.final_layer_norm(z) z = self.wrapped.transformer.text_model.final_layer_norm(z)
else:
z = outputs.last_hidden_state
# restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers] batch_multipliers_of_same_length = [x + [1.0] * (target_token_count - len(x)) for x in batch_multipliers]
......
...@@ -28,7 +28,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): ...@@ -28,7 +28,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.selected_hypernetwork() hypernetwork = shared.loaded_hypernetwork
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None: if hypernetwork_layers is not None:
...@@ -68,7 +68,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): ...@@ -68,7 +68,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.selected_hypernetwork() hypernetwork = shared.loaded_hypernetwork
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None: if hypernetwork_layers is not None:
...@@ -132,7 +132,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): ...@@ -132,7 +132,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads h = self.heads
q_in = self.to_q(x) q_in = self.to_q(x)
context = default(context, x) context = default(context, x)
hypernetwork = shared.selected_hypernetwork() hypernetwork = shared.loaded_hypernetwork
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
if hypernetwork_layers is not None: if hypernetwork_layers is not None:
k_in = self.to_k(hypernetwork_layers[0](context)) k_in = self.to_k(hypernetwork_layers[0](context))
......
...@@ -5,7 +5,6 @@ from collections import namedtuple ...@@ -5,7 +5,6 @@ from collections import namedtuple
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from modules import shared, modelloader, devices from modules import shared, modelloader, devices
...@@ -122,6 +121,13 @@ def select_checkpoint(): ...@@ -122,6 +121,13 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd:
return pl_sd["state_dict"]
return pl_sd
def load_model_weights(model, checkpoint_info): def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash sd_model_hash = checkpoint_info.hash
...@@ -131,11 +137,8 @@ def load_model_weights(model, checkpoint_info): ...@@ -131,11 +137,8 @@ def load_model_weights(model, checkpoint_info):
pl_sd = torch.load(checkpoint_file, map_location="cpu") pl_sd = torch.load(checkpoint_file, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd: sd = get_state_dict_from_checkpoint(pl_sd)
sd = pl_sd["state_dict"]
else:
sd = pl_sd
model.load_state_dict(sd, strict=False) model.load_state_dict(sd, strict=False)
...@@ -165,7 +168,7 @@ def load_model(): ...@@ -165,7 +168,7 @@ def load_model():
checkpoint_info = select_checkpoint() checkpoint_info = select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config: if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {shared.cmd_opts.config}") print(f"Loading config from: {checkpoint_info.config}")
sd_config = OmegaConf.load(checkpoint_info.config) sd_config = OmegaConf.load(checkpoint_info.config)
sd_model = instantiate_from_config(sd_config.model) sd_model = instantiate_from_config(sd_config.model)
...@@ -192,7 +195,8 @@ def reload_model_weights(sd_model, info=None): ...@@ -192,7 +195,8 @@ def reload_model_weights(sd_model, info=None):
return return
if sd_model.sd_checkpoint_info.config != checkpoint_info.config: if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
return load_model() shared.sd_model = load_model()
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
lowvram.send_everything_to_cpu() lowvram.send_everything_to_cpu()
......
...@@ -45,6 +45,7 @@ parser.add_argument("--swinir-models-path", type=str, help="Path to directory wi ...@@ -45,6 +45,7 @@ parser.add_argument("--swinir-models-path", type=str, help="Path to directory wi
parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR')) parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(models_path, 'LDSR'))
parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers") parser.add_argument("--xformers", action='store_true', help="enable xformers for cross attention layers")
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work") parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--deepdanbooru", action='store_true', help="enable deepdanbooru interrogator")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.") parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
...@@ -65,6 +66,8 @@ parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox ...@@ -65,6 +66,8 @@ parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
...@@ -78,11 +81,8 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram ...@@ -78,11 +81,8 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False xformers_available = False
config_filename = cmd_opts.ui_settings_file config_filename = cmd_opts.ui_settings_file
hypernetworks = hypernetwork.load_hypernetworks(os.path.join(models_path, 'hypernetworks')) hypernetworks = hypernetwork.list_hypernetworks(os.path.join(models_path, 'hypernetworks'))
loaded_hypernetwork = None
def selected_hypernetwork():
return hypernetworks.get(opts.sd_hypernetwork, None)
class State: class State:
...@@ -132,13 +132,14 @@ def realesrgan_models_names(): ...@@ -132,13 +132,14 @@ def realesrgan_models_names():
class OptionInfo: class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None): def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False):
self.default = default self.default = default
self.label = label self.label = label
self.component = component self.component = component
self.component_args = component_args self.component_args = component_args
self.onchange = onchange self.onchange = onchange
self.section = None self.section = None
self.show_on_main_page = show_on_main_page
def options_section(section_identifier, options_dict): def options_section(section_identifier, options_dict):
...@@ -215,7 +216,7 @@ options_templates.update(options_section(('system', "System"), { ...@@ -215,7 +216,7 @@ options_templates.update(options_section(('system', "System"), {
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True),
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
...@@ -225,7 +226,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { ...@@ -225,7 +226,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."),
"enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
"filter_nsfw": OptionInfo(False, "Filter NSFW content"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_ignore_last_layers': OptionInfo(0, "Ignore last layers of CLIP model", gr.Slider, {"minimum": 0, "maximum": 5, "step": 1}), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
})) }))
...@@ -240,10 +241,11 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), ...@@ -240,10 +241,11 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "User interface"), {
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}), "show_progress_every_n_steps": OptionInfo(0, "Show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
"return_grid": OptionInfo(True, "Show grid in results for web"), "return_grid": OptionInfo(True, "Show grid in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
......
...@@ -8,7 +8,6 @@ from basicsr.utils.download_util import load_file_from_url ...@@ -8,7 +8,6 @@ from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm from tqdm import tqdm
from modules import modelloader from modules import modelloader
from modules.paths import models_path
from modules.shared import cmd_opts, opts, device from modules.shared import cmd_opts, opts, device
from modules.swinir_model_arch import SwinIR as net from modules.swinir_model_arch import SwinIR as net
from modules.upscaler import Upscaler, UpscalerData from modules.upscaler import Upscaler, UpscalerData
...@@ -25,7 +24,6 @@ class UpscalerSwinIR(Upscaler): ...@@ -25,7 +24,6 @@ class UpscalerSwinIR(Upscaler):
"/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \ "/003_realSR_BSRGAN_DFOWMFC_s64w8_SwinIR" \
"-L_x4_GAN.pth " "-L_x4_GAN.pth "
self.model_name = "SwinIR 4x" self.model_name = "SwinIR 4x"
self.model_path = os.path.join(models_path, self.name)
self.user_path = dirname self.user_path = dirname
super().__init__() super().__init__()
scalers = [] scalers = []
......
This diff is collapsed.
...@@ -36,10 +36,11 @@ class Upscaler: ...@@ -36,10 +36,11 @@ class Upscaler:
self.half = not modules.shared.cmd_opts.no_half self.half = not modules.shared.cmd_opts.no_half
self.pre_pad = 0 self.pre_pad = 0
self.mod_scale = None self.mod_scale = None
if self.name is not None and create_dirs:
if self.model_path is None and self.name:
self.model_path = os.path.join(models_path, self.name) self.model_path = os.path.join(models_path, self.name)
if not os.path.exists(self.model_path): if self.model_path and create_dirs:
os.makedirs(self.model_path) os.makedirs(self.model_path, exist_ok=True)
try: try:
import cv2 import cv2
......
...@@ -10,7 +10,6 @@ from modules.processing import Processed, process_images ...@@ -10,7 +10,6 @@ from modules.processing import Processed, process_images
from PIL import Image from PIL import Image
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
class Script(scripts.Script): class Script(scripts.Script):
def title(self): def title(self):
return "Prompts from file or textbox" return "Prompts from file or textbox"
...@@ -29,6 +28,9 @@ class Script(scripts.Script): ...@@ -29,6 +28,9 @@ class Script(scripts.Script):
checkbox_txt.change(fn=lambda x: [gr.File.update(visible = not x), gr.TextArea.update(visible = x)], inputs=[checkbox_txt], outputs=[file, prompt_txt]) checkbox_txt.change(fn=lambda x: [gr.File.update(visible = not x), gr.TextArea.update(visible = x)], inputs=[checkbox_txt], outputs=[file, prompt_txt])
return [checkbox_txt, file, prompt_txt] return [checkbox_txt, file, prompt_txt]
def on_show(self, checkbox_txt, file, prompt_txt):
return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ]
def run(self, p, checkbox_txt, data: bytes, prompt_txt: str): def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
if (checkbox_txt): if (checkbox_txt):
lines = [x.strip() for x in prompt_txt.splitlines()] lines = [x.strip() for x in prompt_txt.splitlines()]
......
...@@ -10,8 +10,8 @@ import numpy as np ...@@ -10,8 +10,8 @@ import numpy as np
import modules.scripts as scripts import modules.scripts as scripts
import gradio as gr import gradio as gr
from modules import images from modules import images, hypernetwork
from modules.processing import process_images, Processed from modules.processing import process_images, Processed, get_correct_sampler
from modules.shared import opts, cmd_opts, state from modules.shared import opts, cmd_opts, state
import modules.shared as shared import modules.shared as shared
import modules.sd_samplers import modules.sd_samplers
...@@ -56,15 +56,17 @@ def apply_order(p, x, xs): ...@@ -56,15 +56,17 @@ def apply_order(p, x, xs):
p.prompt = prompt_tmp + p.prompt p.prompt = prompt_tmp + p.prompt
samplers_dict = {} def build_samplers_dict(p):
for i, sampler in enumerate(modules.sd_samplers.samplers): samplers_dict = {}
samplers_dict[sampler.name.lower()] = i for i, sampler in enumerate(get_correct_sampler(p)):
for alias in sampler.aliases: samplers_dict[sampler.name.lower()] = i
samplers_dict[alias.lower()] = i for alias in sampler.aliases:
samplers_dict[alias.lower()] = i
return samplers_dict
def apply_sampler(p, x, xs): def apply_sampler(p, x, xs):
sampler_index = samplers_dict.get(x.lower(), None) sampler_index = build_samplers_dict(p).get(x.lower(), None)
if sampler_index is None: if sampler_index is None:
raise RuntimeError(f"Unknown sampler: {x}") raise RuntimeError(f"Unknown sampler: {x}")
...@@ -78,8 +80,11 @@ def apply_checkpoint(p, x, xs): ...@@ -78,8 +80,11 @@ def apply_checkpoint(p, x, xs):
def apply_hypernetwork(p, x, xs): def apply_hypernetwork(p, x, xs):
hn = shared.hypernetworks.get(x, None) hypernetwork.load_hypernetwork(x)
opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None'
def apply_clip_skip(p, x, xs):
opts.data["CLIP_stop_at_last_layers"] = x
def format_value_add_label(p, opt, x): def format_value_add_label(p, opt, x):
...@@ -133,6 +138,7 @@ axis_options = [ ...@@ -133,6 +138,7 @@ axis_options = [
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label), AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label), AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
] ]
...@@ -143,7 +149,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ...@@ -143,7 +149,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
ver_texts = [[images.GridAnnotation(y)] for y in y_labels] ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
hor_texts = [[images.GridAnnotation(x)] for x in x_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
first_pocessed = None first_processed = None
state.job_count = len(xs) * len(ys) * p.n_iter state.job_count = len(xs) * len(ys) * p.n_iter
...@@ -152,8 +158,8 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ...@@ -152,8 +158,8 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
processed = cell(x, y) processed = cell(x, y)
if first_pocessed is None: if first_processed is None:
first_pocessed = processed first_processed = processed
try: try:
res.append(processed.images[0]) res.append(processed.images[0])
...@@ -164,9 +170,9 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ...@@ -164,9 +170,9 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
if draw_legend: if draw_legend:
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
first_pocessed.images = [grid] first_processed.images = [grid]
return first_pocessed return first_processed
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
...@@ -196,10 +202,11 @@ class Script(scripts.Script): ...@@ -196,10 +202,11 @@ class Script(scripts.Script):
return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds] return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds]
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds): def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds):
modules.processing.fix_seed(p) if not no_fixed_seeds:
p.batch_size = 1 modules.processing.fix_seed(p)
initial_hn = opts.sd_hypernetwork p.batch_size = 1
CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
def process_axis(opt, vals): def process_axis(opt, vals):
if opt.label == 'Nothing': if opt.label == 'Nothing':
...@@ -214,7 +221,6 @@ class Script(scripts.Script): ...@@ -214,7 +221,6 @@ class Script(scripts.Script):
m = re_range.fullmatch(val) m = re_range.fullmatch(val)
mc = re_range_count.fullmatch(val) mc = re_range_count.fullmatch(val)
if m is not None: if m is not None:
start = int(m.group(1)) start = int(m.group(1))
end = int(m.group(2))+1 end = int(m.group(2))+1
step = int(m.group(3)) if m.group(3) is not None else 1 step = int(m.group(3)) if m.group(3) is not None else 1
...@@ -256,6 +262,17 @@ class Script(scripts.Script): ...@@ -256,6 +262,17 @@ class Script(scripts.Script):
valslist = list(permutations(valslist)) valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist] valslist = [opt.type(x) for x in valslist]
# Confirm options are valid before starting
if opt.label == "Sampler":
samplers_dict = build_samplers_dict(p)
for sampler_val in valslist:
if sampler_val.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {sampler_val}")
elif opt.label == "Checkpoint name":
for ckpt_val in valslist:
if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
raise RuntimeError(f"Checkpoint for {ckpt_val} not found")
return valslist return valslist
...@@ -308,6 +325,8 @@ class Script(scripts.Script): ...@@ -308,6 +325,8 @@ class Script(scripts.Script):
# restore checkpoint in case it was changed by axes # restore checkpoint in case it was changed by axes
modules.sd_models.reload_model_weights(shared.sd_model) modules.sd_models.reload_model_weights(shared.sd_model)
opts.data["sd_hypernetwork"] = initial_hn hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
return processed return processed
...@@ -103,7 +103,12 @@ ...@@ -103,7 +103,12 @@
#style_apply, #style_create, #interrogate{ #style_apply, #style_create, #interrogate{
margin: 0.75em 0.25em 0.25em 0.25em; margin: 0.75em 0.25em 0.25em 0.25em;
min-width: 3em; min-width: 5em;
}
#style_apply, #style_create, #deepbooru{
margin: 0.75em 0.25em 0.25em 0.25em;
min-width: 5em;
} }
#style_pos_col, #style_neg_col{ #style_pos_col, #style_neg_col{
...@@ -448,3 +453,13 @@ input[type="range"]{ ...@@ -448,3 +453,13 @@ input[type="range"]{
.context-menu-items a:hover{ .context-menu-items a:hover{
background: #a55000; background: #a55000;
} }
#quicksettings > div{
border: none;
background: none;
}
#quicksettings > div > div{
max-width: 32em;
padding: 0;
}
txt2img_Screenshot.png

526 KB | W: | H:

txt2img_Screenshot.png

329 KB | W: | H:

txt2img_Screenshot.png
txt2img_Screenshot.png
txt2img_Screenshot.png
txt2img_Screenshot.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -82,6 +82,9 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts")) ...@@ -82,6 +82,9 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
shared.sd_model = modules.sd_models.load_model() shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
loaded_hypernetwork = modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
def webui(): def webui():
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment