Commit cd4d59c0 authored by Muhammad Rizqi Nur's avatar Muhammad Rizqi Nur

Merge master

parents 05e2e405 17a2076f
......@@ -29,3 +29,5 @@ notification.mp3
/localizations/ar_AR.json @xmodar @blackneoo
/localizations/de_DE.json @LunixWasTaken
/localizations/es_ES.json @innovaciones
/localizations/fr_FR.json @tumbly
/localizations/it_IT.json @EugenioBuffo
/localizations/ja_JP.json @yuuki76
/localizations/ko_KR.json @36DB
/localizations/pt_BR.json @M-art-ucci
/localizations/ru_RU.json @kabachuha
/localizations/tr_TR.json @camenduru
/localizations/zh_CN.json @dtlnor @bgluminous
/localizations/zh_TW.json @benlisquare
......@@ -128,10 +128,12 @@ def prepare_enviroment():
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
sys.argv += shlex.split(commandline_args)
test_argv = [x for x in sys.argv if x != '--tests']
sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
sys.argv, update_check = extract_arg(sys.argv, '--update-check')
sys.argv, run_tests = extract_arg(sys.argv, '--tests')
xformers = '--xformers' in sys.argv
deepdanbooru = '--deepdanbooru' in sys.argv
ngrok = '--ngrok' in sys.argv
......@@ -194,6 +196,26 @@ def prepare_enviroment():
print("Exiting because of --exit argument")
if run_tests:
def tests(argv):
if "--api" not in argv:
print(f"Launching Web UI in another process for testing with arguments: {' '.join(argv[1:])}")
with open('test/stdout.txt', "w", encoding="utf8") as stdout, open('test/stderr.txt', "w", encoding="utf8") as stderr:
proc = subprocess.Popen([sys.executable, *argv], stdout=stdout, stderr=stderr)
import test.server_poll
print(f"Stopping Web UI process with id {}")
def start_webui():
print(f"Launching Web UI with arguments: {' '.join(sys.argv[1:])}")
This diff is collapsed.
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -202,6 +202,7 @@
"Inpaint at full resolution padding, pixels": "전체 해상도로 인페인트시 패딩값(픽셀 단위)",
"Inpaint masked": "마스크만 처리",
"Inpaint not masked": "마스크 이외만 처리",
"Inpainting conditioning mask strength": "인페인팅 조절 마스크 강도",
"Input directory": "인풋 이미지 경로",
"Input images directory": "이미지 경로 입력",
"Interpolation Method": "보간 방법",
......@@ -218,6 +219,7 @@
"Interrogate: use artists from artists.csv": "분석 : artists.csv의 작가들 사용하기",
"Interrupt": "중단",
"Is negative text": "네거티브 텍스트일시 체크",
"Iterate seed every line": "줄마다 시드 반복하기",
"Just resize": "리사이징",
"Keep -1 for seeds": "시드값 -1로 유지",
"keep whatever was there originally": "이미지 원본 유지",
......@@ -234,6 +236,7 @@
"Leave blank to save images to the default path.": "기존 저장 경로에 이미지들을 저장하려면 비워두세요.",
"left": "왼쪽",
"linear": "linear",
"List of prompt inputs": "프롬프트 입력 리스트",
"List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/ for setting names. Requires restarting to apply.": "설정 탭이 아니라 상단의 빠른 설정 바에 위치시킬 설정 이름을 쉼표로 분리해서 입력하십시오. 설정 이름은 modules/shared.py에서 찾을 수 있습니다. 재시작이 필요합니다.",
"LMS": "LMS",
"LMS Karras": "LMS Karras",
......@@ -261,7 +264,7 @@
"Multiplier (M) - set to 0 to get model A": "배율 (M) - 0으로 적용하면 모델 A를 얻게 됩니다",
"Name": "이름",
"Negative prompt": "네거티브 프롬프트",
"Negative prompt (press Ctrl+Enter or Alt+Enter to generate)": "네거티브 프롬프트 입력(Ctrl+Enter나 Alt+Enter로 생성 시작)",
"Negative prompt (press Ctrl+Enter or Alt+Enter to generate)": "네거티브 프롬프트(Prompt) 입력(Ctrl+Enter나 Alt+Enter로 생성 시작)",
"Next batch": "다음 묶음",
"Next Page": "다음 페이지",
"None": "없음",
......@@ -274,6 +277,7 @@
"Number of repeats for a single input image per epoch; used only for displaying epoch number": "세대(Epoch)당 단일 인풋 이미지의 반복 횟수 - 세대(Epoch) 숫자를 표시하는 데에만 사용됩니다. ",
"Number of rows on the page": "각 페이지마다 표시할 세로줄 수",
"Number of vectors per token": "토큰별 벡터 수",
"Only applies to inpainting models. Determines how strongly to mask off the original image for inpainting and img2img. 1.0 means fully masked, which is the default behaviour. 0.0 means a fully unmasked conditioning. Lower values will help preserve the overall composition of the image, but will struggle with large changes.": "인페인팅 모델에만 적용됩니다. 인페인팅과 이미지→이미지에서 원본 이미지를 얼마나 마스킹 처리할지 결정하는 값입니다. 1.0은 완전히 마스킹함(기본 설정)을 의미하고, 0.0은 완전히 언마스킹된 이미지를 의미합니다. 낮은 값일수록 이미지의 전체적인 구성을 유지하는 데에 도움되겠지만, 변화량이 많을수록 불안정해집니다.",
"Open for Clip Aesthetic!": "클립 스타일 기능을 활성화하려면 클릭!",
"Open images output directory": "이미지 저장 경로 열기",
"Open output directory": "저장 경로 열기",
......@@ -319,7 +323,7 @@
"Process images in a directory on the same machine where the server is running.": "WebUI 서버가 돌아가고 있는 디바이스에 존재하는 디렉토리의 이미지들을 처리합니다.",
"Produce an image that can be tiled.": "타일링 가능한 이미지를 생성합니다.",
"Prompt": "프롬프트",
"Prompt (press Ctrl+Enter or Alt+Enter to generate)": "프롬프트 입력(Ctrl+Enter나 Alt+Enter로 생성 시작)",
"Prompt (press Ctrl+Enter or Alt+Enter to generate)": "프롬프트(Prompt) 입력(Ctrl+Enter나 Alt+Enter로 생성 시작)",
"Prompt matrix": "프롬프트 매트릭스",
"Prompt order": "프롬프트 순서",
"Prompt S/R": "프롬프트 스타일 변경",
......@@ -388,6 +392,7 @@
"Select activation function of hypernetwork": "하이퍼네트워크 활성화 함수 선택",
"Select Layer weights initialization. relu-like - Kaiming, sigmoid-like - Xavier is recommended": "레이어 가중치 초기화 방식 선택 - relu류 : Kaiming 추천, sigmoid류 : Xavier 추천",
"Select which Real-ESRGAN models to show in the web UI. (Requires restart)": "WebUI에 표시할 Real-ESRGAN 모델을 선택하십시오. (재시작 필요)",
"Send seed when sending prompt or image to other interface": "다른 화면으로 프롬프트나 이미지를 보낼 때 시드도 함께 보내기",
"Send to extras": "부가기능으로 전송",
"Send to img2img": "이미지→이미지로 전송",
"Send to inpaint": "인페인트로 전송",
......@@ -464,6 +469,8 @@
"uniform": "uniform",
"up": "위쪽",
"Upload mask": "마스크 업로드하기",
"Upload prompt inputs": "입력할 프롬프트를 업로드하십시오",
"Upscale Before Restoring Faces": "얼굴 보정을 진행하기 전에 업스케일링 먼저 진행하기",
"Upscale latent space image when doing hires. fix": "고해상도 보정 사용시 잠재 공간 이미지 업스케일하기",
"Upscale masked region to target resolution, do inpainting, downscale back and paste into original image": "마스크된 부분을 설정된 해상도로 업스케일하고, 인페인팅을 진행한 뒤, 다시 다운스케일 후 원본 이미지에 붙여넣습니다.",
"Upscaler": "업스케일러",
This diff is collapsed.
import time
import uvicorn
from gradio.processing_utils import encode_pil_to_base64, decode_base64_to_file, decode_base64_to_image
from fastapi import APIRouter, HTTPException
from fastapi import APIRouter, Depends, HTTPException
import modules.shared as shared
from modules import devices
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras
from modules.extras import run_extras, run_pnginfo
def upscaler_to_index(name: str):
......@@ -13,8 +16,10 @@ def upscaler_to_index(name: str):
raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([ for x in sd_upscalers])}")
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
def setUpscalers(req: dict):
reqDict = vars(req)
reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
......@@ -23,6 +28,7 @@ def setUpscalers(req: dict):
return reqDict
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
......@@ -32,15 +38,17 @@ class Api:"/sdapi/v1/img2img", self.img2imgapi, methods=["POST"], response_model=ImageToImageResponse)"/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse)"/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse)"/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse)"/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse)
def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
"do_not_save_grid": True
......@@ -48,34 +56,39 @@ class Api:
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
with self.queue_lock:
processed = process_images(p)
b64images = list(map(encode_pil_to_base64, processed.images))
return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js())
def img2imgapi(self, img2imgreq: StableDiffusionImg2ImgProcessingAPI):
sampler_index = sampler_to_index(img2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
raise HTTPException(status_code=404, detail="Sampler not found")
init_images = img2imgreq.init_images
if init_images is None:
raise HTTPException(status_code=404, detail="Init image not found")
raise HTTPException(status_code=404, detail="Init image not found")
mask = img2imgreq.mask
if mask:
mask = decode_base64_to_image(mask)
populate = img2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
"do_not_save_grid": True,
"do_not_save_grid": True,
"mask": mask
......@@ -87,16 +100,20 @@ class Api:
imgs = [img] * p.batch_size
p.init_images = imgs
# Override object param
with self.queue_lock:
processed = process_images(p)
b64images = list(map(encode_pil_to_base64, processed.images))
if (not img2imgreq.include_init_images):
img2imgreq.init_images = None
img2imgreq.mask = None
return ImageToImageResponse(images=b64images, parameters=vars(img2imgreq), info=processed.js())
def extras_single_image_api(self, req: ExtrasSingleImageRequest):
......@@ -124,9 +141,40 @@ class Api:
result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", **reqDict)
return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])
def pnginfoapi(self):
raise NotImplementedError
def pnginfoapi(self, req: PNGInfoRequest):
if(not req.image.strip()):
return PNGInfoResponse(info="")
result = run_pnginfo(decode_base64_to_image(req.image.strip()))
return PNGInfoResponse(info=result[1])
def progressapi(self, req: ProgressRequest = Depends()):
# copy from check_progress_call of
if shared.state.job_count == 0:
return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict())
# avoid dividing zero
progress = 0.01
if shared.state.job_count > 0:
progress += shared.state.job_no / shared.state.job_count
if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
progress = min(progress, 1)
current_image = None
if shared.state.current_image and not req.skip_current_image:
current_image = encode_pil_to_base64(shared.state.current_image)
return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict(), current_image=current_image)
def launch(self, server_name, port):
import inspect
from click import prompt
from pydantic import BaseModel, Field, create_model
from typing import Any, Optional
from typing_extensions import Literal
......@@ -51,17 +52,17 @@ class PydanticModelGenerator:
# field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default)
field_type = v.annotation
return Optional[field_type]
def merge_class_params(class_):
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
parameters = {}
for classes in all_classes:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters
self._model_name = model_name
self._class_data = merge_class_params(class_instance)
self._model_def = [
......@@ -73,11 +74,11 @@ class PydanticModelGenerator:
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
for fields in additional_fields:
field_exclude=fields["exclude"] if "exclude" in fields else False))
......@@ -94,15 +95,15 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_population_by_field_name = True
DynamicModel.__config__.allow_mutation = True
return DynamicModel
StableDiffusionTxt2ImgProcessingAPI = PydanticModelGenerator(
[{"key": "sampler_index", "type": str, "default": "Euler"}]
StableDiffusionImg2ImgProcessingAPI = PydanticModelGenerator(
[{"key": "sampler_index", "type": str, "default": "Euler"}, {"key": "init_images", "type": list, "default": None}, {"key": "denoising_strength", "type": float, "default": 0.75}, {"key": "mask", "type": str, "default": None}, {"key": "include_init_images", "type": bool, "default": False, "exclude" : True}]
......@@ -148,4 +149,19 @@ class ExtrasBatchImagesRequest(ExtrasBaseRequest):
imageList: list[FileData] = Field(title="Images", description="List of images to work on. Must be Base64 strings")
class ExtrasBatchImagesResponse(ExtraBaseResponse):
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
\ No newline at end of file
images: list[str] = Field(title="Images", description="The generated images in base64 format.")
class PNGInfoRequest(BaseModel):
image: str = Field(title="Image", description="The base64 encoded PNG image")
class PNGInfoResponse(BaseModel):
info: str = Field(title="Image info", description="A string with all the info the image had")
class ProgressRequest(BaseModel):
skip_current_image: bool = Field(default=False, title="Skip current image", description="Skip current image serialization")
class ProgressResponse(BaseModel):
progress: float = Field(title="Progress", description="The progress with a range of 0 to 1")
eta_relative: float = Field(title="ETA in secs")
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
......@@ -66,6 +66,7 @@ def integrate_settings_paste_fields(component_dict):
settings_map = {
'sd_hypernetwork': 'Hypernet',
'sd_hypernetwork_strength': 'Hypernet strength',
'CLIP_stop_at_last_layers': 'Clip skip',
'sd_model_checkpoint': 'Model hash',
......@@ -209,13 +209,16 @@ def list_hypernetworks(path):
res = {}
for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
name = os.path.splitext(os.path.basename(filename))[0]
res[name] = filename
# Prevent a hypothetical "" from being listed.
if name != "None":
res[name] = filename
return res
def load_hypernetwork(filename):
path = shared.hypernetworks.get(filename, None)
if path is not None:
# Prevent any file named "" from being loaded.
if path is not None and filename != "None":
print(f"Loading hypernetwork {filename}")
shared.loaded_hypernetwork = Hypernetwork()
......@@ -332,7 +335,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
assert hypernetwork_name, 'hypernetwork not selected'
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
......@@ -358,18 +363,25 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
images_dir = None
hypernetwork = shared.loaded_hypernetwork
checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0
if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload:
hypernetwork = shared.loaded_hypernetwork
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
size = len(ds.indexes)
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
losses = torch.zeros((size,))
......@@ -393,11 +405,18 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
weights = hypernetwork.weights()
for weight in weights:
weight.requires_grad = True
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
steps_without_grad = 0
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
......@@ -458,9 +477,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint. = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{}.pt')
hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
......@@ -521,13 +540,23 @@ Last saved image: {html.escape(last_saved_image)}<br/>
checkpoint = sd_models.select_checkpoint()
hypernetwork.sd_checkpoint = checkpoint.hash
hypernetwork.sd_checkpoint_name = checkpoint.model_name
# Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention). = hypernetwork_name
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{}.pt')
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
return hypernetwork, filename
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
old_hypernetwork_name =
old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
hypernetwork.sd_checkpoint = checkpoint.hash
hypernetwork.sd_checkpoint_name = checkpoint.model_name = hypernetwork_name
hypernetwork.sd_checkpoint = old_sd_checkpoint
hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name = old_hypernetwork_name
......@@ -396,6 +396,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Hypernet": (None if shared.loaded_hypernetwork is None else,
"Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength),
"Batch size": (None if p.batch_size < 2 else p.batch_size),
"Batch pos": (None if p.batch_size < 2 else position_in_batch),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
......@@ -478,7 +479,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.scripts is not None:
infotexts = []
output_images = []
......@@ -501,7 +502,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if (len(prompts) == 0):
if len(prompts) == 0:
with devices.autocast():
......@@ -590,7 +591,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
return Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
res = Processed(p, output_images, p.all_seeds[0], infotext() + "".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], all_prompts=p.all_prompts, all_seeds=p.all_seeds, all_subseeds=p.all_subseeds, index_of_first_image=index_of_first_image, infotexts=infotexts)
if p.scripts is not None:
p.scripts.postprocess(p, res)
return res
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
......@@ -680,15 +687,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
image_conditioning = self.txt2img_image_conditioning(x)
# GC now before running the next img2img to prevent running out of memory
x = None
image_conditioning = self.img2img_image_conditioning(
decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning)
return samples
......@@ -64,7 +64,16 @@ class Script:
def process(self, p, *args):
This function is called before processing begins for AlwaysVisible scripts.
scripts. You can modify the processing object (p) here, inject hooks, etc.
You can modify the processing object (p) here, inject hooks, etc.
args contains all values returned by components from ui()
def postprocess(self, p, processed, *args):
This function is called after processing ends for AlwaysVisible scripts.
args contains all values returned by components from ui()
......@@ -289,13 +298,22 @@ class ScriptRunner:
return processed
def run_alwayson_scripts(self, p):
def process(self, p):
for script in self.alwayson_scripts:
script_args = p.script_args[script.args_from:script.args_to]
script.process(p, *script_args)
except Exception:
print(f"Error running alwayson script: {script.filename}", file=sys.stderr)
print(f"Error running process: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def postprocess(self, p, processed):
for script in self.alwayson_scripts:
script_args = p.script_args[script.args_from:script.args_to]
script.postprocess(p, processed, *script_args)
except Exception:
print(f"Error running postprocess: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
def reload_sources(self, cache):
......@@ -144,9 +144,38 @@ class State:
self.sampling_step = 0
self.current_image_sampling_step = 0
def get_job_timestamp(self):
return"%Y%m%d%H%M%S") # shouldn't this return job_timestamp?
def dict(self):
obj = {
"skipped": self.skipped,
"interrupted": self.skipped,
"job": self.job,
"job_count": self.job_count,
"job_no": self.job_no,
"sampling_step": self.sampling_step,
"sampling_steps": self.sampling_steps,
return obj
def begin(self):
self.sampling_step = 0
self.job_count = -1
self.job_no = 0
self.job_timestamp ="%Y%m%d%H%M%S")
self.current_latent = None
self.current_image = None
self.current_image_sampling_step = 0
self.skipped = False
self.interrupted = False
self.textinfo = None
def end(self):
self.job = ""
self.job_count = 0
state = State()
......@@ -42,6 +42,8 @@ class PersonalizedBase(Dataset):
self.lines = lines
assert data_root, 'dataset directory not specified'
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
cond_model = shared.sd_model.cond_stage_model
......@@ -4,30 +4,37 @@ import tqdm
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
pairs = learn_rate.split(',')
self.rates = [] = 0
self.maxit = 0
for i, pair in enumerate(pairs):
tmp = pair.split(':')
if len(tmp) == 2:
step = int(tmp[1])
if step > cur_step:
self.rates.append((float(tmp[0]), min(step, max_steps)))
self.maxit += 1
if step > max_steps:
for i, pair in enumerate(pairs):
if not pair.strip():
tmp = pair.split(':')
if len(tmp) == 2:
step = int(tmp[1])
if step > cur_step:
self.rates.append((float(tmp[0]), min(step, max_steps)))
self.maxit += 1
if step > max_steps:
elif step == -1:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
elif step == -1:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
assert self.rates
except (ValueError, AssertionError):
raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
def __iter__(self):
return self
......@@ -119,7 +119,7 @@ class EmbeddingDatabase:
vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name)
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('hash', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model)
......@@ -204,9 +204,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
assert isinstance(batch_size, int), "Batch size must be integer"
assert batch_size > 0, "Batch size must be positive"
assert data_root, "Dataset directory is empty"
assert os.path.isdir(data_root), "Dataset directory doesn't exist"
assert os.listdir(data_root), "Dataset directory is empty"
assert template_file, "Prompt template file is empty"
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer"
assert steps > 0 , "Max steps must be positive"
assert isinstance(save_model_every, int), "Save {name} must be integer"
assert save_model_every >= 0 , "Save {name} must be positive or 0"
assert isinstance(create_image_every, int), "Create image must be integer"
assert create_image_every >= 0 , "Create image must be positive or 0"
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
......@@ -232,38 +253,41 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
os.makedirs(images_embeds_dir, exist_ok=True)
images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
cond_model = shared.sd_model.cond_stage_model
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
embedding.vec.requires_grad = True
losses = torch.zeros((32,))
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
checkpoint = sd_models.select_checkpoint()
ititial_step = embedding.step or 0
if ititial_step > steps:
if ititial_step >= steps:
shared.state.textinfo = f"Model has already been trained beyond specified max steps"
return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
clip_grad_mode_value = clip_grad_mode == "value"
clip_grad_mode_norm = clip_grad_mode == "norm"
clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
if clip_grad_enabled:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
losses = torch.zeros((32,))
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, entries in pbar:
embedding.step = i + ititial_step
......@@ -305,9 +329,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if embedding_dir is not None and steps_done % save_embedding_every == 0:
# Before saving, change name to match current checkpoint. = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{}.pt')
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
......@@ -388,14 +412,26 @@ Last saved image: {html.escape(last_saved_image)}<br/>
checkpoint = sd_models.select_checkpoint()
embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None
# Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention). = embedding_name
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{}.pt')
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
return embedding, filename
def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name =
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name
if remove_cached_checksum:
embedding.cached_checksum = None = embedding_name
embedding.sd_checkpoint = old_sd_checkpoint
embedding.sd_checkpoint_name = old_sd_checkpoint_name = old_embedding_name
embedding.cached_checksum = old_cached_checksum
import unittest
class TestExtrasWorking(unittest.TestCase):
def setUp(self):
self.url_img2img = "http://localhost:7860/sdapi/v1/extra-single-image"
self.simple_extras = {
"resize_mode": 0,
"show_extras_results": True,
"gfpgan_visibility": 0,
"codeformer_visibility": 0,
"codeformer_weight": 0,
"upscaling_resize": 2,
"upscaling_resize_w": 512,
"upscaling_resize_h": 512,
"upscaling_crop": True,
"upscaler_1": "None",
"upscaler_2": "None",
"extras_upscaler_2_visibility": 0,
"image": ""
class TestExtrasCorrectness(unittest.TestCase):
if __name__ == "__main__":
import unittest
import requests
from gradio.processing_utils import encode_pil_to_base64
from PIL import Image
class TestImg2ImgWorking(unittest.TestCase):
def setUp(self):
self.url_img2img = "http://localhost:7860/sdapi/v1/img2img"
self.simple_img2img = {
"init_images": [encode_pil_to_base64("test/test_files/img2img_basic.png"))],
"resize_mode": 0,
"denoising_strength": 0.75,
"mask": None,
"mask_blur": 4,
"inpainting_fill": 0,
"inpaint_full_res": False,
"inpaint_full_res_padding": 0,
"inpainting_mask_invert": 0,
"prompt": "example prompt",
"styles": [],
"seed": -1,
"subseed": -1,
"subseed_strength": 0,
"seed_resize_from_h": -1,
"seed_resize_from_w": -1,
"batch_size": 1,
"n_iter": 1,
"steps": 3,
"cfg_scale": 7,
"width": 64,
"height": 64,
"restore_faces": False,
"tiling": False,
"negative_prompt": "",
"eta": 0,
"s_churn": 0,
"s_tmax": 0,
"s_tmin": 0,
"s_noise": 1,
"override_settings": {},
"sampler_index": "Euler a",
"include_init_images": False
def test_img2img_simple_performed(self):
self.assertEqual(, json=self.simple_img2img).status_code, 200)
def test_inpainting_masked_performed(self):
self.simple_img2img["mask"] = encode_pil_to_base64("test/test_files/mask_basic.png"))
self.assertEqual(, json=self.simple_img2img).status_code, 200)
class TestImg2ImgCorrectness(unittest.TestCase):
if __name__ == "__main__":
import unittest
import requests
import time
def run_tests():
timeout_threshold = 240
start_time = time.time()
while time.time()-start_time < timeout_threshold:
except requests.exceptions.ConnectionError:
if time.time()-start_time < timeout_threshold:
suite = unittest.TestLoader().discover('', pattern='*')
result = unittest.TextTestRunner(verbosity=2).run(suite)
print("Launch unsuccessful")
import unittest
import requests
class TestTxt2ImgWorking(unittest.TestCase):
def setUp(self):
self.url_txt2img = "http://localhost:7860/sdapi/v1/txt2img"
self.simple_txt2img = {
"enable_hr": False,
"denoising_strength": 0,
"firstphase_width": 0,
"firstphase_height": 0,
"prompt": "example prompt",
"styles": [],
"seed": -1,
"subseed": -1,
"subseed_strength": 0,
"seed_resize_from_h": -1,
"seed_resize_from_w": -1,
"batch_size": 1,
"n_iter": 1,
"steps": 3,
"cfg_scale": 7,
"width": 64,
"height": 64,
"restore_faces": False,
"tiling": False,
"negative_prompt": "",
"eta": 0,
"s_churn": 0,
"s_tmax": 0,
"s_tmin": 0,
"s_noise": 1,
"sampler_index": "Euler a"
def test_txt2img_simple_performed(self):
self.assertEqual(, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_negative_prompt_performed(self):
self.simple_txt2img["negative_prompt"] = "example negative prompt"
self.assertEqual(, json=self.simple_txt2img).status_code, 200)
def test_txt2img_not_square_image_performed(self):
self.simple_txt2img["height"] = 128
self.assertEqual(, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_hrfix_performed(self):
self.simple_txt2img["enable_hr"] = True
self.assertEqual(, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_restore_faces_performed(self):
self.simple_txt2img["restore_faces"] = True
self.assertEqual(, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_tiling_faces_performed(self):
self.simple_txt2img["tiling"] = True
self.assertEqual(, json=self.simple_txt2img).status_code, 200)
def test_txt2img_with_vanilla_sampler_performed(self):
self.simple_txt2img["sampler_index"] = "PLMS"
self.assertEqual(, json=self.simple_txt2img).status_code, 200)
def test_txt2img_multiple_batches_performed(self):
self.simple_txt2img["n_iter"] = 2
self.assertEqual(, json=self.simple_txt2img).status_code, 200)
class TestTxt2ImgCorrectness(unittest.TestCase):
if __name__ == "__main__":
......@@ -46,26 +46,13 @@ def wrap_queued_call(func):
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
shared.state.sampling_step = 0
shared.state.job_count = -1
shared.state.job_no = 0
shared.state.job_timestamp = shared.state.get_job_timestamp()
shared.state.current_latent = None
shared.state.current_image = None
shared.state.current_image_sampling_step = 0
shared.state.skipped = False
shared.state.interrupted = False
shared.state.textinfo = None
with queue_lock:
res = func(*args, **kwargs)
shared.state.job = ""
shared.state.job_count = 0
return res
......@@ -102,15 +102,14 @@ then
exit 1
printf "\n%s\n" "${delimiter}"
printf "Clone or update stable-diffusion-webui"
printf "\n%s\n" "${delimiter}"
cd "${install_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/, aborting...\e[0m" "${install_dir}"; exit 1; }
if [[ -d "${clone_dir}" ]]
cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
"${GIT}" pull
printf "\n%s\n" "${delimiter}"
printf "Clone stable-diffusion-webui"
printf "\n%s\n" "${delimiter}"
"${GIT}" clone "${clone_dir}"
cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment