Commit 6b719c49 authored by evshiron's avatar evshiron

Merge branch 'master' into feat/progress-api

parents fddb4883 35c45df2
......@@ -29,4 +29,3 @@ notification.mp3
/textual_inversion
.vscode
/extensions
......@@ -75,6 +75,7 @@ titles = {
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
"Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
"Inpainting conditioning mask strength": "Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.",
"vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).",
......
......@@ -13,6 +13,15 @@ function showModal(event) {
}
lb.style.display = "block";
lb.focus()
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
const tabImg2Img = gradioApp().getElementById("tab_img2img")
// show the save button in modal only on txt2img or img2img tabs
if (tabTxt2Img.style.display != "none" || tabImg2Img.style.display != "none") {
gradioApp().getElementById("modal_save").style.display = "inline"
} else {
gradioApp().getElementById("modal_save").style.display = "none"
}
event.stopPropagation()
}
......@@ -81,6 +90,25 @@ function modalImageSwitch(offset) {
}
}
function saveImage(){
const tabTxt2Img = gradioApp().getElementById("tab_txt2img")
const tabImg2Img = gradioApp().getElementById("tab_img2img")
const saveTxt2Img = "save_txt2img"
const saveImg2Img = "save_img2img"
if (tabTxt2Img.style.display != "none") {
gradioApp().getElementById(saveTxt2Img).click()
} else if (tabImg2Img.style.display != "none") {
gradioApp().getElementById(saveImg2Img).click()
} else {
console.error("missing implementation for saving modal of this type")
}
}
function modalSaveImage(event) {
saveImage()
event.stopPropagation()
}
function modalNextImage(event) {
modalImageSwitch(1)
event.stopPropagation()
......@@ -93,6 +121,9 @@ function modalPrevImage(event) {
function modalKeyHandler(event) {
switch (event.key) {
case "s":
saveImage()
break;
case "ArrowLeft":
modalPrevImage(event)
break;
......@@ -198,6 +229,14 @@ document.addEventListener("DOMContentLoaded", function() {
modalTileImage.title = "Preview tiling";
modalControls.appendChild(modalTileImage)
const modalSave = document.createElement("span")
modalSave.className = "modalSave cursor"
modalSave.id = "modal_save"
modalSave.innerHTML = "🖫"
modalSave.addEventListener("click", modalSaveImage, true)
modalSave.title = "Save Image(s)"
modalControls.appendChild(modalSave)
const modalClose = document.createElement('span')
modalClose.className = 'modalClose cursor';
modalClose.innerHTML = '×'
......
......@@ -45,14 +45,14 @@ function switch_to_txt2img(){
return args_to_array(arguments);
}
function switch_to_img2img_img2img(){
function switch_to_img2img(){
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
return args_to_array(arguments);
}
function switch_to_img2img_inpaint(){
function switch_to_inpaint(){
gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
......@@ -65,26 +65,6 @@ function switch_to_extras(){
return args_to_array(arguments);
}
function extract_image_from_gallery_txt2img(gallery){
switch_to_txt2img()
return extract_image_from_gallery(gallery);
}
function extract_image_from_gallery_img2img(gallery){
switch_to_img2img_img2img()
return extract_image_from_gallery(gallery);
}
function extract_image_from_gallery_inpaint(gallery){
switch_to_img2img_inpaint()
return extract_image_from_gallery(gallery);
}
function extract_image_from_gallery_extras(gallery){
switch_to_extras()
return extract_image_from_gallery(gallery);
}
function get_tab_index(tabId){
var res = 0
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import time
# import time
# from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
# from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
# from modules.sd_samplers import all_samplers
# from modules.extras import run_pnginfo
# import modules.shared as shared
# from modules import devices
# import uvicorn
# from fastapi import Body, APIRouter, HTTPException
# from fastapi.responses import JSONResponse
# from pydantic import BaseModel, Field, Json
# from typing import List
# import json
# import io
# import base64
# from PIL import Image
# sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
# class TextToImageResponse(BaseModel):
# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
# parameters: Json
# info: Json
# class ImageToImageResponse(BaseModel):
# images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
# parameters: Json
# info: Json
from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, StableDiffusionImg2ImgProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import time
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, HTTPException
import modules.shared as shared
from modules import devices
import uvicorn
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
from typing import List
import json
import io
import base64
from PIL import Image
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class ImageToImageResponse(BaseModel):
images: List[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class ProgressResponse(BaseModel):
progress: float
eta_relative: float
state: Json
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras
# copy from wrap_gradio_gpu_call of webui.py
# because queue lock will be acquired in api handlers
......@@ -53,30 +59,39 @@ def before_gpu_call():
shared.state.textinfo = None
shared.state.time_start = time.time()
def after_gpu_call():
shared.state.job = ""
shared.state.job_count = 0
devices.torch_gc()
def upscaler_to_index(name: str):
try:
return [x.name.lower() for x in shared.sd_upscalers].index(name.lower())
except:
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
reqDict.pop('upscaler_1')
reqDict.pop('upscaler_2')
return reqDict
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"])
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"], response_model=TextToImageResponse)
self.app.add_api_route("/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)
self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)
self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)
self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"])
def __base64_to_image(self, base64_string):
# if has a comma, deal with prefix
if "," in base64_string:
base64_string = base64_string.split(",")[1]
imgdata = base64.b64decode(base64_string)
# convert base64 to PIL image
return Image.open(io.BytesIO(imgdata))
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
......@@ -97,15 +112,9 @@ class Api:
processed = process_images(p)
after_gpu_call()
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=processed.js())
b64images = list(map(encode_pil_to_base64, processed.images))
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
......@@ -120,7 +129,7 @@ class Api:
mask = img2imgreq.mask
if mask:
mask = self.__base64_to_image(mask)
mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
......@@ -135,7 +144,7 @@ class Api:
imgs = []
for img in init_images:
img = self.__base64_to_image(img)
img = decode_base64_to_image(img)
imgs = [img] * p.batch_size
p.init_images = imgs
......@@ -145,17 +154,39 @@ class Api:
processed = process_images(p)
after_gpu_call()
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
b64images = list(map(encode_pil_to_base64, processed.images))
if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=json.dumps(vars(img2imgreq)), info=processed.js())
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
reqDict = setUpscalers(req)
reqDict['image'] = decode_base64_to_image(reqDict['image'])
with self.queue_lock:
result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", **reqDict)
return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])
def extras_batch_images_api(self, req: ExtrasBatchImagesRequest):
reqDict = setUpscalers(req)
def prepareFiles(file):
file = decode_base64_to_file(file.data, file_path=file.name)
file.orig_name = file.name
return file
reqDict['image_folder'] = list(map(prepareFiles, reqDict['imageList']))
reqDict.pop('imageList')
with self.queue_lock:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def progressapi(self):
# copy from check_progress_call of ui.py
......@@ -179,9 +210,6 @@ class Api:
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js())
def extrasapi(self):
raise NotImplementedError
def pnginfoapi(self):
raise NotImplementedError
......
from array import array
from inflection import underscore
from typing import Any, Dict, Optional
import inspect
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
from inflection import underscore
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
import inspect
from modules.shared import sd_upscalers
API_NOT_ALLOWED = [
"self",
......@@ -51,17 +51,17 @@ class PydanticModelGenerator:
# field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default)
field_type = v.annotation
return Optional[field_type]
def merge_class_params(class_):
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
parameters = {}
for classes in all_classes:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters
self._model_name = model_name
self._class_data = merge_class_params(class_instance)
self._model_def = [
......@@ -73,11 +73,11 @@ class PydanticModelGenerator:
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
for fields in additional_fields:
self._model_def.append(ModelDef(
field=underscore(fields["key"]),
field_alias=fields["key"],
field=underscore(fields["key"]),
field_alias=fields["key"],
field_type=fields["type"],
field_value=fields["default"],
field_exclude=fields["exclude"] if "exclude" in fields else False))
......@@ -94,15 +94,63 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_population_by_field_name = True
DynamicModel.__config__.allow_mutation = True
return DynamicModel
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}]
).generate_model()
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingImg2Img",
"StableDiffusionProcessingImg2Img",
StableDiffusionProcessingImg2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
).generate_model()
\ No newline at end of file
).generate_model()
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ImageToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: dict
info: str
class ExtrasBaseRequest(BaseModel):
resize_mode: Literal[0, 1] = Field(default=0, title="Resize Mode", description="Sets the resize mode: 0 to upscale by upscaling_resize amount, 1 to upscale up to upscaling_resize_h x upscaling_resize_w.")
show_extras_results: bool = Field(default=True, title="Show results", description="Should the backend return the generated image?")
gfpgan_visibility: float = Field(default=0, title="GFPGAN Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of GFPGAN, values should be between 0 and 1.")
codeformer_visibility: float = Field(default=0, title="CodeFormer Visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of CodeFormer, values should be between 0 and 1.")
codeformer_weight: float = Field(default=0, title="CodeFormer Weight", ge=0, le=1, allow_inf_nan=False, description="Sets the weight of CodeFormer, values should be between 0 and 1.")
upscaling_resize: float = Field(default=2, title="Upscaling Factor", ge=1, le=4, description="By how much to upscale the image, only used when resize_mode=0.")
upscaling_resize_w: int = Field(default=512, title="Target Width", ge=1, description="Target width for the upscaler to hit. Only used when resize_mode=1.")
upscaling_resize_h: int = Field(default=512, title="Target Height", ge=1, description="Target height for the upscaler to hit. Only used when resize_mode=1.")
upscaling_crop: bool = Field(default=True, title="Crop to fit", description="Should the upscaler crop the image to fit in the choosen size?")
upscaler_1: str = Field(default="None", title="Main upscaler", description=f"The name of the main upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
upscaler_2: str = Field(default="None", title="Secondary upscaler", description=f"The name of the secondary upscaler to use, it has to be one of this list: {' , '.join([x.name for x in sd_upscalers])}")
extras_upscaler_2_visibility: float = Field(default=0, title="Secondary upscaler visibility", ge=0, le=1, allow_inf_nan=False, description="Sets the visibility of secondary upscaler, values should be between 0 and 1.")
class ExtraBaseResponse(BaseModel):
html_info: str = Field(title="HTML info", description="A series of HTML tags containing the process info.")
class ExtrasSingleImageRequest(ExtrasBaseRequest):
image: str = Field(default="", title="Image", description="Image to work on, must be a Base64 string containing the image's data.")
class ExtrasSingleImageResponse(ExtraBaseResponse):
image: str = Field(default=None, title="Image", description="The generated image in base64 format.")
class FileData(BaseModel):
data: str = Field(title="File data", description="Base64 representation of the file")
name: str = Field(title="File name")
class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class ProgressResponse(BaseModel):
progress: float
eta_relative: float
state: dict
This diff is collapsed.
import base64
import io
import os
import re
import gradio as gr
from modules.shared import script_path
from modules import shared
import tempfile
from PIL import Image
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update())
paste_fields = {}
bind_list = []
def quote(text):
......@@ -20,6 +26,110 @@ def quote(text):
text = text.replace('"', '\\"')
return f'"{text}"'
def image_from_url_text(filedata):
if type(filedata) == dict and filedata["is_file"]:
filename = filedata["name"]
tempdir = os.path.normpath(tempfile.gettempdir())
normfn = os.path.normpath(filename)
assert normfn.startswith(tempdir), 'trying to open image file not in temporary directory'
return Image.open(filename)
if type(filedata) == list:
if len(filedata) == 0:
return None
filedata = filedata[0]
if filedata.startswith("data:image/png;base64,"):
filedata = filedata[len("data:image/png;base64,"):]
filedata = base64.decodebytes(filedata.encode('utf-8'))
image = Image.open(io.BytesIO(filedata))
return image
def add_paste_fields(tabname, init_img, fields):
paste_fields[tabname] = {"init_img": init_img, "fields": fields}
# backwards compatibility for existing extensions
import modules.ui
if tabname == 'txt2img':
modules.ui.txt2img_paste_fields = fields
elif tabname == 'img2img':
modules.ui.img2img_paste_fields = fields
def integrate_settings_paste_fields(component_dict):
from modules import ui
settings_map = {
'sd_hypernetwork': 'Hypernet',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
}
settings_paste_fields = [
(component_dict[k], lambda d, k=k, v=v: ui.apply_setting(k, d.get(v, None)))
for k, v in settings_map.items()
]
for tabname, info in paste_fields.items():
if info["fields"] is not None:
info["fields"] += settings_paste_fields
def create_buttons(tabs_list):
buttons = {}
for tab in tabs_list:
buttons[tab] = gr.Button(f"Send to {tab}")
return buttons
#if send_generate_info is a tab name, mean generate_info comes from the params fields of the tab
def bind_buttons(buttons, send_image, send_generate_info):
bind_list.append([buttons, send_image, send_generate_info])
def run_bind():
for buttons, send_image, send_generate_info in bind_list:
for tab in buttons:
button = buttons[tab]
if send_image and paste_fields[tab]["init_img"]:
if type(send_image) == gr.Gallery:
button.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery",
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
else:
button.click(
fn=lambda x: x,
inputs=[send_image],
outputs=[paste_fields[tab]["init_img"]],
)
if send_generate_info and paste_fields[tab]["fields"] is not None:
if send_generate_info in paste_fields:
paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration', 'Size-1', 'Size-2'] + (["Seed"] if shared.opts.send_seed else [])
button.click(
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[send_generate_info]["fields"] if name in paste_field_names],
outputs=[field for field, name in paste_fields[tab]["fields"] if name in paste_field_names],
)
else:
connect_paste(button, paste_fields[tab]["fields"], send_generate_info)
button.click(
fn=None,
_js=f"switch_to_{tab}",
inputs=None,
outputs=None,
)
def parse_generation_parameters(x: str):
"""parses generation parameters string, the one you see in text field under the picture in UI:
```
......@@ -68,7 +178,7 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
return res
def connect_paste(button, paste_fields, input_comp, js=None):
def connect_paste(button, paste_fields, input_comp, jsfunc=None):
def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt")
......@@ -106,7 +216,9 @@ def connect_paste(button, paste_fields, input_comp, js=None):
button.click(
fn=paste_func,
_js=js,
_js=jsfunc,
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
)
......@@ -25,6 +25,7 @@ from statistics import stdev, mean
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
"linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
......@@ -428,7 +429,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
optimizer.step()
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
steps_done = hypernetwork.step + 1
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
if len(previous_mean_losses) > 1:
......@@ -438,9 +441,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
......@@ -449,8 +452,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
"learn_rate": scheduler.learn_rate
})
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
......
......@@ -8,7 +8,8 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack, shared
from modules.hypernetworks import hypernetwork
keys = list(hypernetwork.HypernetworkModule.activation_dict.keys())
not_available = ["hardswish", "multiheadattention"]
keys = ["linear"] + list(x for x in hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available)
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
# Remove illegal characters from name.
......
......@@ -300,8 +300,8 @@ class FilenameGenerator:
'seed': lambda self: self.seed if self.seed is not None else '',
'steps': lambda self: self.p and self.p.steps,
'cfg': lambda self: self.p and self.p.cfg_scale,
'width': lambda self: self.p and self.p.width,
'height': lambda self: self.p and self.p.height,
'width': lambda self: self.image.width,
'height': lambda self: self.image.height,
'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
......@@ -315,10 +315,11 @@ class FilenameGenerator:
}
default_time_format = '%Y%m%d%H%M%S'
def __init__(self, p, seed, prompt):
def __init__(self, p, seed, prompt, image):
self.p = p
self.seed = seed
self.prompt = prompt
self.image = image
def prompt_no_style(self):
if self.p is None or self.prompt is None:
......@@ -449,7 +450,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
txt_fullfn (`str` or None):
If a text file is saved for this image, this will be its full path. Otherwise None.
"""
namegen = FilenameGenerator(p, seed, prompt)
namegen = FilenameGenerator(p, seed, prompt, image)
if save_to_dirs is None:
save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
......
......@@ -19,7 +19,7 @@ import modules.scripts
def process_batch(p, input_dir, output_dir, args):
processing.fix_seed(p)
images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
images = shared.listfiles(input_dir)
print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")
......
......@@ -129,6 +129,73 @@ class StableDiffusionProcessing():
self.all_seeds = None
self.all_subseeds = None
def txt2img_image_conditioning(self, x, width=None, height=None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
return torch.zeros(
x.shape[0], 5, 1, 1,
dtype=x.dtype,
device=x.device
)
height = height or self.height
width = width or self.width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
return image_conditioning
def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
return torch.zeros(
latent_image.shape[0], 5, 1, 1,
dtype=latent_image.dtype,
device=latent_image.device
)
# Handle the different mask inputs
if image_mask is not None:
if torch.is_tensor(image_mask):
conditioning_mask = image_mask
else:
conditioning_mask = np.array(image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
conditioning_mask = conditioning_mask.to(source_image.device)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight)
)
# Encode the new masked image using first stage of network.
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
image_conditioning = image_conditioning.to(shared.device).type(self.sd_model.dtype)
return image_conditioning
def init(self, all_prompts, all_seeds, all_subseeds):
pass
......@@ -571,37 +638,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def create_dummy_mask(self, x, width=None, height=None):
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
height = height or self.height
width = width or self.width
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
image_conditioning = image_conditioning.to(x.dtype)
else:
# Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
image_conditioning = torch.zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
return image_conditioning
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
if not self.enable_hr:
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x))
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
return samples
x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.create_dummy_mask(x, self.firstphase_width, self.firstphase_height))
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height))
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2]
......@@ -638,7 +684,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=self.create_dummy_mask(samples))
image_conditioning = self.img2img_image_conditioning(
decoded_samples,
samples,
decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
)
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples
......@@ -770,33 +821,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
elif self.inpainting_fill == 3:
self.init_latent = self.init_latent * self.mask
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
if self.image_mask is not None:
conditioning_mask = np.array(self.image_mask.convert("L"))
conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
conditioning_mask = torch.from_numpy(conditioning_mask[None, None])
# Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
conditioning_mask = torch.ones(1, 1, *image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
conditioning_mask = conditioning_mask.to(image.device)
conditioning_image = image * (1.0 - conditioning_mask)
conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:])
conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1)
self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1)
self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype)
else:
self.image_conditioning = torch.zeros(
self.init_latent.shape[0], 5, 1, 1,
dtype=self.init_latent.dtype,
device=self.init_latent.device
)
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, self.image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
......
......@@ -236,7 +236,7 @@ class ScriptRunner:
with gr.Group():
create_script_ui(script, inputs, inputs_alwayson)
dropdown = gr.Dropdown(label="Script", choices=["None"] + self.titles, value="None", type="index")
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
dropdown.save_to_config = True
inputs[0] = dropdown
......
......@@ -3,6 +3,7 @@ import os.path
import sys
from collections import namedtuple
import torch
import re
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
......@@ -35,8 +36,10 @@ def setup_model():
list_models()
def checkpoint_tiles():
return sorted([x.title for x in checkpoints_list.values()])
def checkpoint_tiles():
convert = lambda name: int(name) if name.isdigit() else name.lower()
alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)
def list_models():
......@@ -170,7 +173,9 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd)
missing, extra = model.load_state_dict(sd, strict=False)
del pl_sd
model.load_state_dict(sd, strict=False)
del sd
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
......@@ -194,9 +199,10 @@ def load_model_weights(model, checkpoint_info):
model.first_stage_model.to(devices.dtype_vae)
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
if shared.opts.sd_checkpoint_cache > 0:
checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
checkpoints_loaded.popitem(last=False) # LRU
else:
print(f"Loading weights [{sd_model_hash}] from cache")
checkpoints_loaded.move_to_end(checkpoint_info)
......
......@@ -82,6 +82,7 @@ parser.add_argument("--api", action='store_true', help="use api=True to launch t
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI")
parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None)
parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False)
cmd_opts = parser.parse_args()
restricted_opts = {
......@@ -280,6 +281,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
......@@ -316,6 +318,7 @@ options_templates.update(options_section(('ui', "User interface"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(False, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(False, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
......@@ -462,3 +465,8 @@ total_tqdm = TotalTQDM()
mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start()
def listfiles(dirname):
filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")]
return [file for file in filenames if os.path.isfile(file)]
......@@ -86,12 +86,12 @@ class PersonalizedBase(Dataset):
assert len(self.dataset) > 0, "No images have been found in the dataset."
self.length = len(self.dataset) * repeats // batch_size
self.initial_indexes = np.arange(len(self.dataset))
self.dataset_length = len(self.dataset)
self.indexes = None
self.shuffle()
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0]).numpy()]
self.indexes = np.random.permutation(self.dataset_length)
def create_text(self, filename_text):
text = random.choice(self.lines)
......
......@@ -52,7 +52,7 @@ class LearnRateScheduler:
self.finished = False
def apply(self, optimizer, step_number):
if step_number <= self.end_step:
if step_number < self.end_step:
return
try:
......
......@@ -184,9 +184,8 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if shared.opts.training_write_csv_every == 0:
return
if step % shared.opts.training_write_csv_every != 0:
if (step + 1) % shared.opts.training_write_csv_every != 0:
return
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
......@@ -196,11 +195,11 @@ def write_loss(log_directory, filename, step, epoch_len, values):
csv_writer.writeheader()
epoch = step // epoch_len
epoch_step = step - epoch * epoch_len
epoch_step = step % epoch_len
csv_writer.writerow({
"step": step + 1,
"epoch": epoch + 1,
"epoch": epoch,
"epoch_step": epoch_step + 1,
**values,
})
......@@ -282,15 +281,16 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
loss.backward()
optimizer.step()
steps_done = embedding.step + 1
epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
epoch_step = embedding.step % len(ds)
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint.
embedding.name = f'{embedding_name}-{embedding.step}'
embedding.name = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
embedding.save(last_saved_file)
embedding_yet_to_be_embedded = True
......@@ -300,8 +300,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
"learn_rate": scheduler.learn_rate
})
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
forced_filename = f'{embedding_name}-{embedding.step}'
if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
......@@ -334,7 +334,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
......@@ -350,7 +350,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}v {}s'.format(vectorSize, embedding.step)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
......@@ -380,7 +380,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
embedding.cached_checksum = None
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
embedding.name = embedding_name
filename = os.path.join(shared.cmd_opts.embedding_dir, f'{embedding.name}.pt')
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
embedding.save(filename)
return embedding, filename
This diff is collapsed.
......@@ -153,7 +153,6 @@ def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
......@@ -178,6 +177,7 @@ axis_options = [
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
]
......
......@@ -314,8 +314,8 @@ input[type="range"]{
.modalControls {
display: grid;
grid-template-columns: 32px auto 1fr 32px;
grid-template-areas: "zoom tile space close";
grid-template-columns: 32px 32px 32px 1fr 32px;
grid-template-areas: "zoom tile save space close";
position: absolute;
top: 0;
left: 0;
......@@ -333,6 +333,10 @@ input[type="range"]{
grid-area: zoom;
}
.modalSave {
grid-area: save;
}
.modalTileImage {
grid-area: tile;
}
......@@ -346,8 +350,18 @@ input[type="range"]{
cursor: pointer;
}
.modalSave {
color: white;
font-size: 28px;
margin-top: 8px;
font-weight: bold;
cursor: pointer;
}
.modalClose:hover,
.modalClose:focus,
.modalSave:hover,
.modalSave:focus,
.modalZoom:hover,
.modalZoom:focus {
color: #999;
......@@ -522,18 +536,23 @@ If you change anything above, you need to make sure it is RTL compliant by just
your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/.
Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/
@media rtl {
/* this part was manualy added */
/* this part was added manually */
:host {
direction: rtl;
}
.output-html:has(.performance), .gr-text-input {
select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress {
direction: ltr;
}
#script_list > label > select,
#x_type > label > select,
#y_type > label > select {
direction: rtl;
}
.gr-radio, .gr-checkbox{
margin-left: 0.25em;
}
/* this part was automatically generated with few manual modifications */
/* automatically generated with few manual modifications */
.performance .time {
margin-right: unset;
margin-left: 0;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment