Commit 6b719c49 authored by evshiron's avatar evshiron

Merge branch 'master' into feat/progress-api

parents fddb4883 35c45df2
...@@ -29,4 +29,3 @@ notification.mp3 ...@@ -29,4 +29,3 @@ notification.mp3
/textual_inversion /textual_inversion
.vscode .vscode
/extensions /extensions
...@@ -75,6 +75,7 @@ titles = { ...@@ -75,6 +75,7 @@ titles = {
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.", "Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.", "Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
"Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).", "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
......
...@@ -13,6 +13,15 @@ function showModal(event) { ...@@ -13,6 +13,15 @@ function showModal(event) {
} }
lb.style.display = "block"; lb.style.display = "block";
lb.focus() lb.focus()
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
const tabImg2Img = gradioApp().getElementById("tab_img2img")
// show the save button in modal only on txt2img or img2img tabs
if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
gradioApp().getElementById("modal_save").style.display = "inline"
} else {
gradioApp().getElementById("modal_save").style.display = "none"
}
event.stopPropagation() event.stopPropagation()
} }
...@@ -81,6 +90,25 @@ function modalImageSwitch(offset) { ...@@ -81,6 +90,25 @@ function modalImageSwitch(offset) {
} }
} }
function saveImage(){
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
const tabImg2Img = gradioApp().getElementById("tab_img2img")
const saveTxt2Img = "save_txt2img"
const saveImg2Img = "save_img2img"
if (tabTxt2Img.style.display != "none") {
gradioApp().getElementById(saveTxt2Img).click()
} else if (tabImg2Img.style.display != "none") {
gradioApp().getElementById(saveImg2Img).click()
} else {
console.error("missing implementation for saving modal of this type")
}
}
function modalSaveImage(event) {
saveImage()
event.stopPropagation()
}
function modalNextImage(event) { function modalNextImage(event) {
modalImageSwitch(1) modalImageSwitch(1)
event.stopPropagation() event.stopPropagation()
...@@ -93,6 +121,9 @@ function modalPrevImage(event) { ...@@ -93,6 +121,9 @@ function modalPrevImage(event) {
function modalKeyHandler(event) { function modalKeyHandler(event) {
switch (event.key) { switch (event.key) {
case "s":
saveImage()
break;
case "ArrowLeft": case "ArrowLeft":
modalPrevImage(event) modalPrevImage(event)
break; break;
...@@ -198,6 +229,14 @@ document.addEventListener("DOMContentLoaded", function() { ...@@ -198,6 +229,14 @@ document.addEventListener("DOMContentLoaded", function() {
modalTileImage.title = "Preview tiling"; modalTileImage.title = "Preview tiling";
modalControls.appendChild(modalTileImage) modalControls.appendChild(modalTileImage)
const modalSave = document.createElement("span")
modalSave.className = "modalSave cursor"
modalSave.id = "modal_save"
modalSave.innerHTML = "🖫"
modalSave.addEventListener("click", modalSaveImage, true)
modalSave.title = "Save Image(s)"
modalControls.appendChild(modalSave)
const modalClose = document.createElement('span') const modalClose = document.createElement('span')
modalClose.className = 'modalClose cursor'; modalClose.className = 'modalClose cursor';
modalClose.innerHTML = '×' modalClose.innerHTML = '×'
......
...@@ -45,14 +45,14 @@ function switch_to_txt2img(){ ...@@ -45,14 +45,14 @@ function switch_to_txt2img(){
return args_to_array(arguments); return args_to_array(arguments);
} }
function switch_to_img2img_img2img(){ function switch_to_img2img(){
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click(); gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
return args_to_array(arguments); return args_to_array(arguments);
} }
function switch_to_img2img_inpaint(){ function switch_to_inpaint(){
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click(); gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
...@@ -65,26 +65,6 @@ function switch_to_extras(){ ...@@ -65,26 +65,6 @@ function switch_to_extras(){
return args_to_array(arguments); return args_to_array(arguments);
} }
function extract_image_from_gallery_txt2img(gallery){
switch_to_txt2img()
return extract_image_from_gallery(gallery);
}
function extract_image_from_gallery_img2img(gallery){
switch_to_img2img_img2img()
return extract_image_from_gallery(gallery);
}
function extract_image_from_gallery_inpaint(gallery){
switch_to_img2img_inpaint()
return extract_image_from_gallery(gallery);
}
function extract_image_from_gallery_extras(gallery){
switch_to_extras()
return extract_image_from_gallery(gallery);
}
function get_tab_index(tabId){ function get_tab_index(tabId){
var res = 0 var res = 0
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import time # import time
# from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
# from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
# from modules.sd_samplers import all_samplers
# from modules.extras import run_pnginfo
# import modules.shared as shared
# from modules import devices
# import uvicorn
# from fastapi import Body, APIRouter, HTTPException
# from fastapi.responses import JSONResponse
# from pydantic import BaseModel, Field, Json
# from typing import List
# import json
# import io
# import base64
# from PIL import Image
# sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
# class TextToImageResponse(BaseModel):
# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
# parameters: Json
# info: Json
# class ImageToImageResponse(BaseModel):
# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
# parameters: Json
# info: Json
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI import time
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images import uvicorn
from modules.sd_samplers import all_samplers from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from modules.extras import run_pnginfo from fastapi import APIRouter, HTTPException
import modules.shared as shared import modules.shared as shared
from modules import devices from modules import devices
import uvicorn from modules.api.models import *
from fastapi import Body, APIRouter, HTTPException from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from fastapi.responses import JSONResponse from modules.sd_samplers import all_samplers
from pydantic import BaseModel, Field, Json from modules.extras import run_extras
from typing import List
import json
import io
import base64
from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class ImageToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class ProgressResponse(BaseModel):
progress: float
eta_relative: float
state: Json
# copy from wrap_gradio_gpu_call of webui.py # copy from wrap_gradio_gpu_call of webui.py
# because queue lock will be acquired in api handlers # because queue lock will be acquired in api handlers
...@@ -53,30 +59,39 @@ def before_gpu_call(): ...@@ -53,30 +59,39 @@ def before_gpu_call():
shared.state.textinfo = None shared.state.textinfo = None
shared.state.time_start = time.time() shared.state.time_start = time.time()
def after_gpu_call(): def after_gpu_call():
shared.state.job = "" shared.state.job = ""
shared.state.job_count = 0 shared.state.job_count = 0
devices.torch_gc() devices.torch_gc()
def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
reqDict.pop('upscaler_1')
reqDict.pop('upscaler_2')
return reqDict
class Api: class Api:
def __init__(self, app, queue_lock): def __init__(self, app, queue_lock):
self.router = APIRouter() self.router = APIRouter()
self.app = app self.app = app
self.queue_lock = queue_lock self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"]) self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix
if "," in base64_string:
base64_string = base64_string.split(",")[1]
imgdata = base64.b64decode(base64_string)
# convert base64 to PIL image
return Image.open(io.BytesIO(imgdata))
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index) sampler_index = sampler_to_index(txt2imgreq.sampler_index)
...@@ -97,15 +112,9 @@ class Api: ...@@ -97,15 +112,9 @@ class Api:
processed = process_images(p) processed = process_images(p)
after_gpu_call() after_gpu_call()
b64images = [] b64images = list(map(encode_pil_to_base64, processed.images))
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI): def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index) sampler_index = sampler_to_index(img2imgreq.sampler_index)
...@@ -120,7 +129,7 @@ class Api: ...@@ -120,7 +129,7 @@ class Api:
mask = img2imgreq.mask mask = img2imgreq.mask
if mask: if mask:
mask = self.__base64_to_image(mask) mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params populate = img2imgreq.copy(update={ # Override __init__ params
...@@ -135,7 +144,7 @@ class Api: ...@@ -135,7 +144,7 @@ class Api:
imgs = [] imgs = []
for img in init_images: for img in init_images:
img = self.__base64_to_image(img) img = decode_base64_to_image(img)
imgs = [img] * p.batch_size imgs = [img] * p.batch_size
p.init_images = imgs p.init_images = imgs
...@@ -145,17 +154,39 @@ class Api: ...@@ -145,17 +154,39 @@ class Api:
processed = process_images(p) processed = process_images(p)
after_gpu_call() after_gpu_call()
b64images = [] b64images = list(map(encode_pil_to_base64, processed.images))
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
if (not img2imgreq.include_init_images): if (not img2imgreq.include_init_images):
img2imgreq.init_images = None img2imgreq.init_images = None
img2imgreq.mask = None img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js()) return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
reqDict = setUpscalers(req)
reqDict['image'] = decode_base64_to_image(reqDict['image'])
with self.queue_lock:
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
def prepareFiles(file):
file = decode_base64_to_file(file.data, file_path=file.name)
file.orig_name = file.name
return file
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
reqDict.pop('imageList')
with self.queue_lock:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def progressapi(self): def progressapi(self):
# copy from check_progress_call of ui.py # copy from check_progress_call of ui.py
...@@ -179,9 +210,6 @@ class Api: ...@@ -179,9 +210,6 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js()) return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
def extrasapi(self):
raise NotImplementedError
def pnginfoapi(self): def pnginfoapi(self):
raise NotImplementedError raise NotImplementedError
......
from array import array import inspect
from inflection import underscore
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
import inspect from modules.shared import sd_upscalers
API_NOT_ALLOWED = [ API_NOT_ALLOWED = [
"self", "self",
...@@ -106,3 +106,51 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator( ...@@ -106,3 +106,51 @@ StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
StableDiffusionProcessingImg2Img, StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}] [{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model() ).generate_model()
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ExtrasBaseRequest(BaseModel):
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
class ExtraBaseResponse(BaseModel):
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
class ExtrasSingleImageRequest(ExtrasBaseRequest):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
class ExtrasSingleImageResponse(ExtraBaseResponse):
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
class FileData(BaseModel):
data: str = Field(title="File data", description="Base64 representation of the file")
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class ProgressResponse(BaseModel):
progress: float
eta_relative: float
state: dict
This diff is collapsed.
import base64
import io
import os import os
import re import re
import gradio as gr import gradio as gr
from modules.shared import script_path from modules.shared import script_path
from modules import shared from modules import shared
import tempfile
from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$") re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update()) type_of_gr_update = type(gr.update())
paste_fields = {}
bind_list = []
def quote(text): def quote(text):
...@@ -20,6 +26,110 @@ def quote(text): ...@@ -20,6 +26,110 @@ def quote(text):
text = text.replace('"', '\\"') text = text.replace('"', '\\"')
return f'"{text}"' return f'"{text}"'
def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]:
filename = filedata["name"]
tempdir = os.path.normpath(tempfile.gettempdir())
normfn = os.path.normpath(filename)
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
return Image.open(filename)
if type(filedata) == list:
if len(filedata) == 0:
return None
filedata = filedata[0]
if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):]
filedata = base64.decodebytes(filedata.encode('utf-8'))
image = Image.open(io.BytesIO(filedata))
return image
def add_paste_fields(tabname, init_img, fields):
paste_fields[tabname] = {"init_img": init_img, "fields": fields}
# backwards compatibility for existing extensions
import modules.ui
if tabname == 'txt2img':
modules.ui.txt2img_paste_fields = fields
elif tabname == 'img2img':
modules.ui.img2img_paste_fields = fields
def integrate_settings_paste_fields(component_dict):
from modules import ui
settings_map = {
'sd_hypernetwork': 'Hypernet',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
settings_paste_fields = [
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
for k, v in settings_map.items()
]
for tabname, info in paste_fields.items():
if info["fields"] is not None:
info["fields"] += settings_paste_fields
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
buttons[tab] = gr.Button(f"Send to {tab}")
return buttons
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info])
def run_bind():
for buttons, send_image, send_generate_info in bind_list:
for tab in buttons:
button = buttons[tab]
if send_image and paste_fields[tab]["init_img"]:
if type(send_image) == gr.Gallery:
button.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery",
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
else:
button.click(
fn=lambda x: x,
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
if send_generate_info and paste_fields[tab]["fields"] is not None:
if send_generate_info in paste_fields:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else [])
button.click(
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
)
else:
connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
button.click(
fn=None,
_js=f"switch_to_{tab}",
inputs=None,
outputs=None,
)
def parse_generation_parameters(x: str): def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI: """parses generation parameters string, the one you see in text field under the picture in UI:
``` ```
...@@ -68,7 +178,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -68,7 +178,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res return res
def connect_paste(button, paste_fields, input_comp, js=None): def connect_paste(button, paste_fields, input_comp, jsfunc=None):
def paste_func(prompt): def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config: if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt") filename = os.path.join(script_path, "params.txt")
...@@ -106,7 +216,9 @@ def connect_paste(button, paste_fields, input_comp, js=None): ...@@ -106,7 +216,9 @@ def connect_paste(button, paste_fields, input_comp, js=None):
button.click( button.click(
fn=paste_func, fn=paste_func,
_js=js, _js=jsfunc,
inputs=[input_comp], inputs=[input_comp],
outputs=[x[0] for x in paste_fields], outputs=[x[0] for x in paste_fields],
) )
...@@ -25,6 +25,7 @@ from statistics import stdev, mean ...@@ -25,6 +25,7 @@ from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0 multiplier = 1.0
activation_dict = { activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU, "relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU, "leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU, "elu": torch.nn.ELU,
...@@ -428,6 +429,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -428,6 +429,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
optimizer.step() optimizer.step()
steps_done = hypernetwork.step + 1
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.") raise RuntimeError("Loss diverged.")
...@@ -438,9 +441,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -438,9 +441,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})" dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info) pbar.set_description(dataset_loss_info)
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0: if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}' hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file) hypernetwork.save(last_saved_file)
...@@ -449,8 +452,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -449,8 +452,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
}) })
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0: if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{hypernetwork.step}' forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename) last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad() optimizer.zero_grad()
......
...@@ -8,7 +8,8 @@ import modules.textual_inversion.textual_inversion ...@@ -8,7 +8,8 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys()) not_available = ["hardswish", "multiheadattention"]
keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False): def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name. # Remove illegal characters from name.
......
...@@ -300,8 +300,8 @@ class FilenameGenerator: ...@@ -300,8 +300,8 @@ class FilenameGenerator:
'seed': lambda self: self.seed if self.seed is not None else '', 'seed': lambda self: self.seed if self.seed is not None else '',
'steps': lambda self: self.p and self.p.steps, 'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale, 'cfg': lambda self: self.p and self.p.cfg_scale,
'width': lambda self: self.p and self.p.width, 'width': lambda self: self.image.width,
'height': lambda self: self.p and self.p.height, 'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
...@@ -315,10 +315,11 @@ class FilenameGenerator: ...@@ -315,10 +315,11 @@ class FilenameGenerator:
} }
default_time_format = '%Y%m%d%H%M%S' default_time_format = '%Y%m%d%H%M%S'
def __init__(self, p, seed, prompt): def __init__(self, p, seed, prompt, image):
self.p = p self.p = p
self.seed = seed self.seed = seed
self.prompt = prompt self.prompt = prompt
self.image = image
def prompt_no_style(self): def prompt_no_style(self):
if self.p is None or self.prompt is None: if self.p is None or self.prompt is None:
...@@ -449,7 +450,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i ...@@ -449,7 +450,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
txt_fullfn (`str` or None): txt_fullfn (`str` or None):
If a text file is saved for this image, this will be its full path. Otherwise None. If a text file is saved for this image, this will be its full path. Otherwise None.
""" """
namegen = FilenameGenerator(p, seed, prompt) namegen = FilenameGenerator(p, seed, prompt, image)
if save_to_dirs is None: if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
......
...@@ -19,7 +19,7 @@ import modules.scripts ...@@ -19,7 +19,7 @@ import modules.scripts
def process_batch(p, input_dir, output_dir, args): def process_batch(p, input_dir, output_dir, args):
processing.fix_seed(p) processing.fix_seed(p)
images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)] images = shared.listfiles(input_dir)
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
......
...@@ -129,6 +129,73 @@ class StableDiffusionProcessing(): ...@@ -129,6 +129,73 @@ class StableDiffusionProcessing():
self.all_seeds = None self.all_seeds = None
self.all_subseeds = None self.all_subseeds = None
def txt2img_image_conditioning(self, x, width=None, height=None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return torch.zeros(
x.shape[0], 5, 1, 1,
dtype=x.dtype,
device=x.device
)
height = height or self.height
width = width or self.width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
return torch.zeros(
latent_image.shape[0], 5, 1, 1,
dtype=latent_image.dtype,
device=latent_image.device
)
# Handle the different mask inputs
if image_mask is not None:
if torch.is_tensor(image_mask):
conditioning_mask = image_mask
else:
conditioning_mask = np.array(image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
conditioning_mask = conditioning_mask.to(source_image.device)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
)
# Encode the new masked image using first stage of network.
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
return image_conditioning
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
pass pass
...@@ -571,37 +638,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -571,37 +638,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def create_dummy_mask(self, x, width=None, height=None):
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
height = height or self.height
width = width or self.width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
else:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
return image_conditioning
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr: if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x)) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height)) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
...@@ -638,7 +684,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -638,7 +684,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None x = None
devices.torch_gc() devices.torch_gc()
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples)) image_conditioning = self.img2img_image_conditioning(
decoded_samples,
samples,
decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples return samples
...@@ -770,33 +821,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -770,33 +821,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3: elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask self.init_latent = self.init_latent * self.mask
if self.sampler.conditioning_key in {'hybrid', 'concat'}: self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
if self.image_mask is not None:
conditioning_mask = np.array(self.image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
conditioning_mask = conditioning_mask.to(image.device)
conditioning_image = image * (1.0 - conditioning_mask)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
else:
self.image_conditioning = torch.zeros(
self.init_latent.shape[0], 5, 1, 1,
dtype=self.init_latent.dtype,
device=self.init_latent.device
)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
......
...@@ -236,7 +236,7 @@ class ScriptRunner: ...@@ -236,7 +236,7 @@ class ScriptRunner:
with gr.Group(): with gr.Group():
create_script_ui(script, inputs, inputs_alwayson) create_script_ui(script, inputs, inputs_alwayson)
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index") dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True dropdown.save_to_config = True
inputs[0] = dropdown inputs[0] = dropdown
......
...@@ -3,6 +3,7 @@ import os.path ...@@ -3,6 +3,7 @@ import os.path
import sys import sys
from collections import namedtuple from collections import namedtuple
import torch import torch
import re
from omegaconf import OmegaConf from omegaconf import OmegaConf
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
...@@ -36,7 +37,9 @@ def setup_model(): ...@@ -36,7 +37,9 @@ def setup_model():
def checkpoint_tiles(): def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()]) convert = lambda name: int(name) if name.isdigit() else name.lower()
alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
def list_models(): def list_models():
...@@ -170,7 +173,9 @@ def load_model_weights(model, checkpoint_info): ...@@ -170,7 +173,9 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd) sd = get_state_dict_from_checkpoint(pl_sd)
missing, extra = model.load_state_dict(sd, strict=False) del pl_sd
model.load_state_dict(sd, strict=False)
del sd
if shared.cmd_opts.opt_channelslast: if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last) model.to(memory_format=torch.channels_last)
...@@ -194,6 +199,7 @@ def load_model_weights(model, checkpoint_info): ...@@ -194,6 +199,7 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.to(devices.dtype_vae) model.first_stage_model.to(devices.dtype_vae)
if shared.opts.sd_checkpoint_cache > 0:
checkpoints_loaded[checkpoint_info] = model.state_dict().copy() checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU checkpoints_loaded.popitem(last=False) # LRU
......
...@@ -82,6 +82,7 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t ...@@ -82,6 +82,7 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui") parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
restricted_opts = { restricted_opts = {
...@@ -280,6 +281,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { ...@@ -280,6 +281,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
...@@ -316,6 +318,7 @@ options_templates.update(options_section(('ui', "User interface"), { ...@@ -316,6 +318,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"), "add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), "disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
...@@ -462,3 +465,8 @@ total_tqdm = TotalTQDM() ...@@ -462,3 +465,8 @@ total_tqdm = TotalTQDM()
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts) mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start() mem_mon.start()
def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
...@@ -86,12 +86,12 @@ class PersonalizedBase(Dataset): ...@@ -86,12 +86,12 @@ class PersonalizedBase(Dataset):
assert len(self.dataset) > 0, "No images have been found in the dataset." assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset)) self.dataset_length = len(self.dataset)
self.indexes = None self.indexes = None
self.shuffle() self.shuffle()
def shuffle(self): def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()] self.indexes = np.random.permutation(self.dataset_length)
def create_text(self, filename_text): def create_text(self, filename_text):
text = random.choice(self.lines) text = random.choice(self.lines)
......
...@@ -52,7 +52,7 @@ class LearnRateScheduler: ...@@ -52,7 +52,7 @@ class LearnRateScheduler:
self.finished = False self.finished = False
def apply(self, optimizer, step_number): def apply(self, optimizer, step_number):
if step_number <= self.end_step: if step_number < self.end_step:
return return
try: try:
......
...@@ -184,9 +184,8 @@ def write_loss(log_directory, filename, step, epoch_len, values): ...@@ -184,9 +184,8 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0: if shared.opts.training_write_csv_every == 0:
return return
if step % shared.opts.training_write_csv_every != 0: if (step + 1) % shared.opts.training_write_csv_every != 0:
return return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
with open(os.path.join(log_directory, filename), "a+", newline='') as fout: with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
...@@ -196,11 +195,11 @@ def write_loss(log_directory, filename, step, epoch_len, values): ...@@ -196,11 +195,11 @@ def write_loss(log_directory, filename, step, epoch_len, values):
csv_writer.writeheader() csv_writer.writeheader()
epoch = step // epoch_len epoch = step // epoch_len
epoch_step = step - epoch * epoch_len epoch_step = step % epoch_len
csv_writer.writerow({ csv_writer.writerow({
"step": step + 1, "step": step + 1,
"epoch": epoch + 1, "epoch": epoch,
"epoch_step": epoch_step + 1, "epoch_step": epoch_step + 1,
**values, **values,
}) })
...@@ -282,15 +281,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -282,15 +281,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward() loss.backward()
optimizer.step() optimizer.step()
steps_done = embedding.step + 1
epoch_num = embedding.step // len(ds) epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1 epoch_step = embedding.step % len(ds)
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}") pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0: if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint. # Before saving, change name to match current checkpoint.
embedding.name = f'{embedding_name}-{embedding.step}' embedding.name = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt') last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
embedding.save(last_saved_file) embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True embedding_yet_to_be_embedded = True
...@@ -300,8 +300,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -300,8 +300,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
"learn_rate": scheduler.learn_rate "learn_rate": scheduler.learn_rate
}) })
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0: if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{embedding.step}' forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename) last_saved_image = os.path.join(images_dir, forced_filename)
p = processing.StableDiffusionProcessingTxt2Img( p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model, sd_model=shared.sd_model,
...@@ -334,7 +334,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -334,7 +334,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png') last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file) data = torch.load(last_saved_file)
...@@ -350,7 +350,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc ...@@ -350,7 +350,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
checkpoint = sd_models.select_checkpoint() checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash) footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}v {}s'.format(vectorSize, embedding.step) footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data) captioned_image = insert_image_data_embed(captioned_image, data)
...@@ -380,7 +380,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -380,7 +380,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
embedding.cached_checksum = None embedding.cached_checksum = None
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention). # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
embedding.name = embedding_name embedding.name = embedding_name
filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt') filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
embedding.save(filename) embedding.save(filename)
return embedding, filename return embedding, filename
This diff is collapsed.
...@@ -153,7 +153,6 @@ def str_permutations(x): ...@@ -153,7 +153,6 @@ def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x return x
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"]) AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"]) AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
...@@ -178,6 +177,7 @@ axis_options = [ ...@@ -178,6 +177,7 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
] ]
......
...@@ -314,8 +314,8 @@ input[type="range"]{ ...@@ -314,8 +314,8 @@ input[type="range"]{
.modalControls { .modalControls {
display: grid; display: grid;
grid-template-columns: 32px auto 1fr 32px; grid-template-columns: 32px 32px 32px 1fr 32px;
grid-template-areas: "zoom tile space close"; grid-template-areas: "zoom tile save space close";
position: absolute; position: absolute;
top: 0; top: 0;
left: 0; left: 0;
...@@ -333,6 +333,10 @@ input[type="range"]{ ...@@ -333,6 +333,10 @@ input[type="range"]{
grid-area: zoom; grid-area: zoom;
} }
.modalSave {
grid-area: save;
}
.modalTileImage { .modalTileImage {
grid-area: tile; grid-area: tile;
} }
...@@ -346,8 +350,18 @@ input[type="range"]{ ...@@ -346,8 +350,18 @@ input[type="range"]{
cursor: pointer; cursor: pointer;
} }
.modalSave {
color: white;
font-size: 28px;
margin-top: 8px;
font-weight: bold;
cursor: pointer;
}
.modalClose:hover, .modalClose:hover,
.modalClose:focus, .modalClose:focus,
.modalSave:hover,
.modalSave:focus,
.modalZoom:hover, .modalZoom:hover,
.modalZoom:focus { .modalZoom:focus {
color: #999; color: #999;
...@@ -522,18 +536,23 @@ If you change anything above, you need to make sure it is RTL compliant by just ...@@ -522,18 +536,23 @@ If you change anything above, you need to make sure it is RTL compliant by just
your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/. your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/.
Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/ Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/
@media rtl { @media rtl {
/* this part was manualy added */ /* this part was added manually */
:host { :host {
direction: rtl; direction: rtl;
} }
.output-html:has(.performance), .gr-text-input { select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress {
direction: ltr; direction: ltr;
} }
#script_list > label > select,
#x_type > label > select,
#y_type > label > select {
direction: rtl;
}
.gr-radio, .gr-checkbox{ .gr-radio, .gr-checkbox{
margin-left: 0.25em; margin-left: 0.25em;
} }
/* this part was automatically generated with few manual modifications */ /* automatically generated with few manual modifications */
.performance .time { .performance .time {
margin-right: unset; margin-right: unset;
margin-left: 0; margin-left: 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment