Commit fb3cefb3 authored by Greg Fuller's avatar Greg Fuller

Merge remote-tracking branch 'upstream/master' into interrogate_include_ranks_in_output

parents d717eb07 698d303b
......@@ -1045,7 +1045,6 @@ Bakemono Zukushi,0.67051035,anime
Lucy Madox Brown,0.67032814,fineart
Paul Wonner,0.6700563,scribbles
Guido Borelli Da Caluso,0.66966087,digipa-high-impact
Guido Borelli da Caluso,0.66966087,digipa-high-impact
Emil Alzamora,0.5844039,nudity
Heinrich Brocksieper,0.64469147,fineart
Dan Smith,0.669563,digipa-high-impact
......
......@@ -3,9 +3,9 @@ channels:
- pytorch
- defaults
dependencies:
- python=3.8.5
- pip=20.3
- python=3.10
- pip=22.2.2
- cudatoolkit=11.3
- pytorch=1.11.0
- torchvision=0.12.0
- numpy=1.19.2
- pytorch=1.12.1
- torchvision=0.13.1
- numpy=1.23.1
\ No newline at end of file
......@@ -25,6 +25,7 @@ addEventListener('keydown', (event) => {
} else {
end = target.value.slice(selectionEnd + 1).indexOf(")") + 1;
weight = parseFloat(target.value.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return;
if (event.key == minus) weight -= 0.1;
if (event.key == plus) weight += 0.1;
......
......@@ -80,7 +80,10 @@ titles = {
"Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.",
"Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be bevaing in an unethical manner.",
"Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.",
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
"Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.",
}
......
......@@ -101,7 +101,8 @@ function create_tab_index_args(tabId, args){
}
function get_extras_tab_index(){
return create_tab_index_args('mode_extras', arguments)
const [,,...args] = [...arguments]
return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
}
function create_submit_args(args){
......
import os.path
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import get_context
import multiprocessing
import time
import re
re_special = re.compile(r'([\\()])')
def _load_tf_and_return_tags(pil_image, threshold, include_ranks):
def get_deepbooru_tags(pil_image):
"""
This method is for running only one image at a time for simple use. Used to the img2img interrogate.
"""
from modules import shared # prevents circular reference
try:
create_deepbooru_process(shared.opts.interrogate_deepbooru_score_threshold, create_deepbooru_opts())
return get_tags_from_process(pil_image)
finally:
release_process()
def create_deepbooru_opts():
from modules import shared
return {
"use_spaces": shared.opts.deepbooru_use_spaces,
"use_escape": shared.opts.deepbooru_escape,
"alpha_sort": shared.opts.deepbooru_sort_alpha,
}
def deepbooru_process(queue, deepbooru_process_return, threshold, deepbooru_opts):
model, tags = get_deepbooru_tags_model()
while True: # while process is running, keep monitoring queue for new image
pil_image = queue.get()
if pil_image == "QUIT":
break
else:
deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts)
def create_deepbooru_process(threshold, deepbooru_opts):
"""
Creates deepbooru process. A queue is created to send images into the process. This enables multiple images
to be processed in a row without reloading the model or creating a new process. To return the data, a shared
dictionary is created to hold the tags created. To wait for tags to be returned, a value of -1 is assigned
to the dictionary and the method adding the image to the queue should wait for this value to be updated with
the tags.
"""
from modules import shared # prevents circular reference
shared.deepbooru_process_manager = multiprocessing.Manager()
shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue()
shared.deepbooru_process_return = shared.deepbooru_process_manager.dict()
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, deepbooru_opts))
shared.deepbooru_process.start()
def get_tags_from_process(image):
from modules import shared
shared.deepbooru_process_return["value"] = -1
shared.deepbooru_process_queue.put(image)
while shared.deepbooru_process_return["value"] == -1:
time.sleep(0.2)
caption = shared.deepbooru_process_return["value"]
shared.deepbooru_process_return["value"] = -1
return caption
def release_process():
"""
Stops the deepbooru process to return used memory
"""
from modules import shared # prevents circular reference
shared.deepbooru_process_queue.put("QUIT")
shared.deepbooru_process.join()
shared.deepbooru_process_queue = None
shared.deepbooru_process = None
shared.deepbooru_process_return = None
shared.deepbooru_process_manager = None
def get_deepbooru_tags_model():
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
this_folder = os.path.dirname(__file__)
model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
if not os.path.exists(os.path.join(model_path, 'project.json')):
# there is no point importing these every time
import zipfile
from basicsr.utils.download_util import load_file_from_url
load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
model_path)
load_file_from_url(
r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
model_path)
with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
zip_ref.extractall(model_path)
os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
......@@ -24,6 +102,17 @@ def _load_tf_and_return_tags(pil_image, threshold, include_ranks):
model = dd.project.load_model_from_project(
model_path, compile_model=True
)
return model, tags
def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_opts):
import deepdanbooru as dd
import tensorflow as tf
import numpy as np
alpha_sort = deepbooru_opts['alpha_sort']
use_spaces = deepbooru_opts['use_spaces']
use_escape = deepbooru_opts['use_escape']
width = model.input_shape[2]
height = model.input_shape[1]
......@@ -46,32 +135,35 @@ def _load_tf_and_return_tags(pil_image, threshold, include_ranks):
for i, tag in enumerate(tags):
result_dict[tag] = y[i]
result_tags_out = []
unsorted_tags_in_theshold = []
result_tags_print = []
for tag in tags:
if result_dict[tag] >= threshold:
if tag.startswith("rating:"):
continue
tag_formatted = tag.replace('_', ' ').replace(':', ' ')
if include_ranks:
result_tags_out.append(f'({tag_formatted}:{result_dict[tag]})')
else:
result_tags_out.append(tag_formatted)
unsorted_tags_in_theshold.append((result_dict[tag], tag))
result_tags_print.append(f'{result_dict[tag]} {tag}')
print('\n'.join(sorted(result_tags_print, reverse=True)))
# sort tags
result_tags_out = []
sort_ndx = 0
if alpha_sort:
sort_ndx = 1
# sort by reverse by likelihood and normal for alpha
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold:
result_tags_out.append(tag)
return ', '.join(result_tags_out)
print('\n'.join(sorted(result_tags_print, reverse=True)))
tags_text = ', '.join(result_tags_out)
def subprocess_init_no_cuda():
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
if use_spaces:
tags_text = tags_text.replace('_', ' ')
if use_escape:
tags_text = re.sub(re_special, r'\\\1', tags_text)
def get_deepbooru_tags(pil_image, threshold=0.5, include_ranks=False):
context = get_context('spawn')
with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, include_ranks)
ret = f.result() # will rethrow any exceptions
return ret
\ No newline at end of file
return tags_text.replace(':', ' ')
import math
import os
import numpy as np
......@@ -19,7 +20,7 @@ import gradio as gr
cached_images = {}
def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
def run_extras(extras_mode, resize_mode, image, image_folder, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
devices.torch_gc()
imageArr = []
......@@ -67,8 +68,13 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
image = res
if resize_mode == 1:
upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
crop_info = " (crop)" if upscaling_crop else ""
info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
if upscaling_resize != 1.0:
def upscale(image, scaler_index, resize):
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
pixels = tuple(np.array(small).flatten().tolist())
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight) + pixels
......@@ -77,15 +83,19 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
if c is None:
upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
c = cropped
cached_images[key] = c
return c
info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
res = upscale(image, extras_upscaler_1, upscaling_resize)
res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
res2 = upscale(image, extras_upscaler_2, upscaling_resize)
res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
res = Image.blend(res, res2, extras_upscaler_2_visibility)
......@@ -190,7 +200,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
theta_0[key] = theta_func(theta_0[key], theta_1[key], (float(1.0) - interp_amount)) # Need to reverse the interp_amount to match the desired mix ration in the merged checkpoint
if save_as_half:
theta_0[key] = theta_0[key].half()
for key in theta_1.keys():
if 'model' in key and key not in theta_0:
theta_0[key] = theta_1[key]
......
......@@ -14,7 +14,7 @@ import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnSchedule
from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module):
......@@ -120,6 +120,17 @@ def load_hypernetwork(filename):
shared.loaded_hypernetwork = None
def find_closest_hypernetwork_name(search: str):
if not search:
return None
search = search.lower()
applicable = [name for name in shared.hypernetworks if search in name.lower()]
if not applicable:
return None
applicable = sorted(applicable, key=lambda name: len(name))
return applicable[0]
def apply_hypernetwork(hypernetwork, context, layer=None):
hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
......@@ -164,7 +175,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None):
def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_image_prompt):
assert hypernetwork_name, 'embedding not selected'
assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None)
shared.loaded_hypernetwork = Hypernetwork()
......@@ -212,31 +223,23 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if ititial_step > steps:
return hypernetwork, filename
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
(learn_rate, end_step) = next(schedules)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW(weights, lr=learn_rate)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, (x, text, cond) in pbar:
for i, entry in pbar:
hypernetwork.step = i + ititial_step
if hypernetwork.step > end_step:
try:
(learn_rate, end_step) = next(schedules)
except Exception:
break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
scheduler.apply(optimizer, hypernetwork.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
cond = cond.to(devices.device)
x = x.to(devices.device)
cond = entry.cond.to(devices.device)
x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), cond)[0]
del x
del cond
......@@ -256,7 +259,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
preview_text = text if preview_image_prompt == "" else preview_image_prompt
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
optimizer.zero_grad()
shared.sd_model.cond_stage_model.to(devices.device)
......@@ -271,16 +274,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
)
processed = processing.process_images(p)
image = processed.images[0]
image = processed.images[0] if len(processed.images)>0 else None
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"
if image is not None:
shared.state.current_image = image
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"
shared.state.job_no = hypernetwork.step
......@@ -288,7 +291,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory,
<p>
Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(text)}<br/>
Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
......
......@@ -321,7 +321,17 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
fixes.append(fix[1])
self.hijack.fixes.append(fixes)
z1 = self.process_tokens([x[:75] for x in remade_batch_tokens], [x[:75] for x in batch_multipliers])
tokens = []
multipliers = []
for j in range(len(remade_batch_tokens)):
if len(remade_batch_tokens[j]) > 0:
tokens.append(remade_batch_tokens[j][:75])
multipliers.append(batch_multipliers[j][:75])
else:
tokens.append([self.wrapped.tokenizer.eos_token_id] * 75)
multipliers.append([1.0] * 75)
z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2)
remade_batch_tokens = rem_tokens
......
......@@ -86,6 +86,7 @@ parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
xformers_available = False
config_filename = cmd_opts.ui_settings_file
os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None
......@@ -229,7 +230,10 @@ options_templates.update(options_section(('system', "System"), {
}))
options_templates.update(options_section(('training', "Training"), {
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP form VRAM when training"),
"unload_models_when_training": OptionInfo(False, "Unload VAE and CLIP from VRAM when training"),
"dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
"dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
"training_image_repeats_per_epoch": OptionInfo(100, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
}))
options_templates.update(options_section(('sd', "Stable Diffusion"), {
......@@ -255,8 +259,10 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
"deepbooru_escape": OptionInfo(True, "escape (\\) brackets in deepbooru (so they are used as literal brackets and not for emphasis)"),
}))
options_templates.update(options_section(('ui', "User interface"), {
......
......@@ -11,11 +11,21 @@ import tqdm
from modules import devices, shared
import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
re_numbers_at_start = re.compile(r"^[-\d]+\s*")
class DatasetEntry:
def __init__(self, filename=None, latent=None, filename_text=None):
self.filename = filename
self.latent = latent
self.filename_text = filename_text
self.cond = None
self.cond_text = None
class PersonalizedBase(Dataset):
def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex)>0 else None
self.placeholder_token = placeholder_token
......@@ -42,9 +52,18 @@ class PersonalizedBase(Dataset):
except Exception:
continue
text_filename = os.path.splitext(path)[0] + ".txt"
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
filename_tokens = re_tag.findall(filename_tokens)
if os.path.exists(text_filename):
with open(text_filename, "r", encoding="utf8") as file:
filename_text = file.read()
else:
filename_text = os.path.splitext(filename)[0]
filename_text = re.sub(re_numbers_at_start, '', filename_text)
if re_word:
tokens = re_word.findall(filename_text)
filename_text = (shared.opts.dataset_filename_join_string or "").join(tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
......@@ -55,13 +74,13 @@ class PersonalizedBase(Dataset):
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
init_latent = init_latent.to(devices.cpu)
entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)
if include_cond:
text = self.create_text(filename_tokens)
cond = cond_model([text]).to(devices.cpu)
else:
cond = None
entry.cond_text = self.create_text(filename_text)
entry.cond = cond_model([entry.cond_text]).to(devices.cpu)
self.dataset.append((init_latent, filename_tokens, cond))
self.dataset.append(entry)
self.length = len(self.dataset) * repeats
......@@ -72,10 +91,10 @@ class PersonalizedBase(Dataset):
def shuffle(self):
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
def create_text(self, filename_tokens):
def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
text = text.replace("[filewords]", ' '.join(filename_tokens))
text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
......@@ -86,7 +105,9 @@ class PersonalizedBase(Dataset):
self.shuffle()
index = self.indexes[i % len(self.indexes)]
x, filename_tokens, cond = self.dataset[index]
entry = self.dataset[index]
if entry.cond is None:
entry.cond_text = self.create_text(entry.filename_text)
text = self.create_text(filename_tokens)
return x, text, cond
return entry
import base64
import json
import numpy as np
import zlib
from PIL import Image, PngImagePlugin, ImageDraw, ImageFont
from fonts.ttf import Roboto
import torch
class EmbeddingEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, torch.Tensor):
return {'TORCHTENSOR': obj.cpu().detach().numpy().tolist()}
return json.JSONEncoder.default(self, obj)
class EmbeddingDecoder(json.JSONDecoder):
def __init__(self, *args, **kwargs):
json.JSONDecoder.__init__(self, object_hook=self.object_hook, *args, **kwargs)
def object_hook(self, d):
if 'TORCHTENSOR' in d:
return torch.from_numpy(np.array(d['TORCHTENSOR']))
return d
def embedding_to_b64(data):
d = json.dumps(data, cls=EmbeddingEncoder)
return base64.b64encode(d.encode())
def embedding_from_b64(data):
d = base64.b64decode(data)
return json.loads(d, cls=EmbeddingDecoder)
def lcg(m=2**32, a=1664525, c=1013904223, seed=0):
while True:
seed = (a * seed + c) % m
yield seed % 255
def xor_block(block):
g = lcg()
randblock = np.array([next(g) for _ in range(np.product(block.shape))]).astype(np.uint8).reshape(block.shape)
return np.bitwise_xor(block.astype(np.uint8), randblock & 0x0F)
def style_block(block, sequence):
im = Image.new('RGB', (block.shape[1], block.shape[0]))
draw = ImageDraw.Draw(im)
i = 0
for x in range(-6, im.size[0], 8):
for yi, y in enumerate(range(-6, im.size[1], 8)):
offset = 0
if yi % 2 == 0:
offset = 4
shade = sequence[i % len(sequence)]
i += 1
draw.ellipse((x+offset, y, x+6+offset, y+6), fill=(shade, shade, shade))
fg = np.array(im).astype(np.uint8) & 0xF0
return block ^ fg
def insert_image_data_embed(image, data):
d = 3
data_compressed = zlib.compress(json.dumps(data, cls=EmbeddingEncoder).encode(), level=9)
data_np_ = np.frombuffer(data_compressed, np.uint8).copy()
data_np_high = data_np_ >> 4
data_np_low = data_np_ & 0x0F
h = image.size[1]
next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h))
next_size = next_size + ((h*d)-(next_size % (h*d)))
data_np_low.resize(next_size)
data_np_low = data_np_low.reshape((h, -1, d))
data_np_high.resize(next_size)
data_np_high = data_np_high.reshape((h, -1, d))
edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024]
edge_style = (np.abs(edge_style)/np.max(np.abs(edge_style))*255).astype(np.uint8)
data_np_low = style_block(data_np_low, sequence=edge_style)
data_np_low = xor_block(data_np_low)
data_np_high = style_block(data_np_high, sequence=edge_style[::-1])
data_np_high = xor_block(data_np_high)
im_low = Image.fromarray(data_np_low, mode='RGB')
im_high = Image.fromarray(data_np_high, mode='RGB')
background = Image.new('RGB', (image.size[0]+im_low.size[0]+im_high.size[0]+2, image.size[1]), (0, 0, 0))
background.paste(im_low, (0, 0))
background.paste(image, (im_low.size[0]+1, 0))
background.paste(im_high, (im_low.size[0]+1+image.size[0]+1, 0))
return background
def crop_black(img, tol=0):
mask = (img > tol).all(2)
mask0, mask1 = mask.any(0), mask.any(1)
col_start, col_end = mask0.argmax(), mask.shape[1]-mask0[::-1].argmax()
row_start, row_end = mask1.argmax(), mask.shape[0]-mask1[::-1].argmax()
return img[row_start:row_end, col_start:col_end]
def extract_image_data_embed(image):
d = 3
outarr = crop_black(np.array(image.convert('RGB').getdata()).reshape(image.size[1], image.size[0], d).astype(np.uint8)) & 0x0F
black_cols = np.where(np.sum(outarr, axis=(0, 2)) == 0)
if black_cols[0].shape[0] < 2:
print('No Image data blocks found.')
return None
data_block_lower = outarr[:, :black_cols[0].min(), :].astype(np.uint8)
data_block_upper = outarr[:, black_cols[0].max()+1:, :].astype(np.uint8)
data_block_lower = xor_block(data_block_lower)
data_block_upper = xor_block(data_block_upper)
data_block = (data_block_upper << 4) | (data_block_lower)
data_block = data_block.flatten().tobytes()
data = zlib.decompress(data_block)
return json.loads(data, cls=EmbeddingDecoder)
def caption_image_overlay(srcimage, title, footerLeft, footerMid, footerRight, textfont=None):
from math import cos
image = srcimage.copy()
if textfont is None:
try:
textfont = ImageFont.truetype(opts.font or Roboto, fontsize)
textfont = opts.font or Roboto
except Exception:
textfont = Roboto
factor = 1.5
gradient = Image.new('RGBA', (1, image.size[1]), color=(0, 0, 0, 0))
for y in range(image.size[1]):
mag = 1-cos(y/image.size[1]*factor)
mag = max(mag, 1-cos((image.size[1]-y)/image.size[1]*factor*1.1))
gradient.putpixel((0, y), (0, 0, 0, int(mag*255)))
image = Image.alpha_composite(image.convert('RGBA'), gradient.resize(image.size))
draw = ImageDraw.Draw(image)
fontsize = 32
font = ImageFont.truetype(textfont, fontsize)
padding = 10
_, _, w, h = draw.textbbox((0, 0), title, font=font)
fontsize = min(int(fontsize * (((image.size[0]*0.75)-(padding*4))/w)), 72)
font = ImageFont.truetype(textfont, fontsize)
_, _, w, h = draw.textbbox((0, 0), title, font=font)
draw.text((padding, padding), title, anchor='lt', font=font, fill=(255, 255, 255, 230))
_, _, w, h = draw.textbbox((0, 0), footerLeft, font=font)
fontsize_left = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
_, _, w, h = draw.textbbox((0, 0), footerMid, font=font)
fontsize_mid = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
_, _, w, h = draw.textbbox((0, 0), footerRight, font=font)
fontsize_right = min(int(fontsize * (((image.size[0]/3)-(padding))/w)), 72)
font = ImageFont.truetype(textfont, min(fontsize_left, fontsize_mid, fontsize_right))
draw.text((padding, image.size[1]-padding), footerLeft, anchor='ls', font=font, fill=(255, 255, 255, 230))
draw.text((image.size[0]/2, image.size[1]-padding), footerMid, anchor='ms', font=font, fill=(255, 255, 255, 230))
draw.text((image.size[0]-padding, image.size[1]-padding), footerRight, anchor='rs', font=font, fill=(255, 255, 255, 230))
return image
if __name__ == '__main__':
testEmbed = Image.open('test_embedding.png')
data = extract_image_data_embed(testEmbed)
assert data is not None
data = embedding_from_b64(testEmbed.text['sd-ti-embedding'])
assert data is not None
image = Image.new('RGBA', (512, 512), (255, 255, 200, 255))
cap_image = caption_image_overlay(image, 'title', 'footerLeft', 'footerMid', 'footerRight')
test_embed = {'string_to_param': {'*': torch.from_numpy(np.random.random((2, 4096)))}}
embedded_image = insert_image_data_embed(cap_image, test_embed)
retrived_embed = extract_image_data_embed(embedded_image)
assert str(retrived_embed) == str(test_embed)
embedded_image2 = insert_image_data_embed(cap_image, retrived_embed)
assert embedded_image == embedded_image2
g = lcg()
shared_random = np.array([next(g) for _ in range(100)]).astype(np.uint8).tolist()
reference_random = [253, 242, 127, 44, 157, 27, 239, 133, 38, 79, 167, 4, 177,
95, 130, 79, 78, 14, 52, 215, 220, 194, 126, 28, 240, 179,
160, 153, 149, 50, 105, 14, 21, 218, 199, 18, 54, 198, 193,
38, 128, 19, 53, 195, 124, 75, 205, 12, 6, 145, 0, 28,
30, 148, 8, 45, 218, 171, 55, 249, 97, 166, 12, 35, 0,
41, 221, 122, 215, 170, 31, 113, 186, 97, 119, 31, 23, 185,
66, 140, 30, 41, 37, 63, 137, 109, 216, 55, 159, 145, 82,
204, 86, 73, 222, 44, 198, 118, 240, 97]
assert shared_random == reference_random
hunna_kay_random_sum = sum(np.array([next(g) for _ in range(100000)]).astype(np.uint8).tolist())
assert 12731374 == hunna_kay_random_sum
import tqdm
class LearnSchedule:
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
"""
specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
"""
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
......@@ -32,3 +38,32 @@ class LearnSchedule:
return self.rates[self.it - 1]
else:
raise StopIteration
class LearnRateScheduler:
def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
(self.learn_rate, self.end_step) = next(self.schedules)
self.verbose = verbose
if self.verbose:
print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
self.finished = False
def apply(self, optimizer, step_number):
if step_number <= self.end_step:
return
try:
(self.learn_rate, self.end_step) = next(self.schedules)
except Exception:
self.finished = True
return
if self.verbose:
tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
for pg in optimizer.param_groups:
pg['lr'] = self.learn_rate
......@@ -3,11 +3,35 @@ from PIL import Image, ImageOps
import platform
import sys
import tqdm
import time
from modules import shared, images
from modules.shared import opts, cmd_opts
if cmd_opts.deepdanbooru:
import modules.deepbooru as deepbooru
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption):
def preprocess(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
try:
if process_caption:
shared.interrogator.load()
if process_caption_deepbooru:
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts())
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
finally:
if process_caption:
shared.interrogator.send_blip_to_ram()
if process_caption_deepbooru:
deepbooru.release_process()
def preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru=False):
width = process_width
height = process_height
src = os.path.abspath(process_src)
......@@ -22,19 +46,28 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
if process_caption:
shared.interrogator.load()
def save_pic_with_caption(image, index):
caption = ""
if process_caption:
caption = "-" + shared.interrogator.generate_caption(image)
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
else:
caption = filename
caption = os.path.splitext(caption)[0]
caption = os.path.basename(caption)
caption += shared.interrogator.generate_caption(image)
if process_caption_deepbooru:
if len(caption) > 0:
caption += ", "
caption += deepbooru.get_tags_from_process(image)
filename_part = filename
filename_part = os.path.splitext(filename_part)[0]
filename_part = os.path.basename(filename_part)
basename = f"{index:05}-{subindex[0]}-{filename_part}"
image.save(os.path.join(dst, f"{basename}.png"))
if len(caption) > 0:
with open(os.path.join(dst, f"{basename}.txt"), "w", encoding="utf8") as file:
file.write(caption)
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1
def save_pic(image, index):
......@@ -79,30 +112,3 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
save_pic(img, index)
shared.state.nextjob()
if process_caption:
shared.interrogator.send_blip_to_ram()
def sanitize_caption(base_path, original_caption, suffix):
operating_system = platform.system().lower()
if (operating_system == "windows"):
invalid_path_characters = "\\/:*?\"<>|"
max_path_length = 259
else:
invalid_path_characters = "/" #linux/macos
max_path_length = 1023
caption = original_caption
for invalid_character in invalid_path_characters:
caption = caption.replace(invalid_character, "")
fixed_path_length = len(base_path) + len(suffix)
if fixed_path_length + len(caption) <= max_path_length:
return caption
caption_tokens = caption.split()
new_caption = ""
for token in caption_tokens:
last_caption = new_caption
new_caption = new_caption + token + " "
if (len(new_caption) + fixed_path_length - 1 > max_path_length):
break
print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
return last_caption.strip()
......@@ -7,11 +7,15 @@ import tqdm
import html
import datetime
from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnSchedule
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed,
caption_image_overlay)
class Embedding:
def __init__(self, vec, name, step=None):
......@@ -81,7 +85,18 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = torch.load(path, map_location="cpu")
data = []
if filename.upper().endswith('.PNG'):
embed_image = Image.open(path)
if 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name', name)
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
else:
data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
......@@ -157,7 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
......@@ -179,11 +194,17 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
else:
images_dir = None
if create_image_every > 0 and save_image_with_stored_embedding:
images_embeds_dir = os.path.join(log_directory, "image_embeddings")
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
......@@ -199,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
(learn_rate, end_step) = next(schedules)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text, _) in pbar:
for i, entry in pbar:
embedding.step = i + ititial_step
if embedding.step > end_step:
try:
(learn_rate, end_step) = next(schedules)
except:
break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
for pg in optimizer.param_groups:
pg['lr'] = learn_rate
scheduler.apply(optimizer, embedding.step)
if scheduler.finished:
break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
c = cond_model([text])
c = cond_model([entry.cond_text])
x = x.to(devices.device)
x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x
......@@ -246,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
preview_text = text if preview_image_prompt == "" else preview_image_prompt
preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
......@@ -262,6 +275,26 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
image = processed.images[0]
shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file):
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
info = PngImagePlugin.PngInfo()
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name', '???'))
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
footer_right = '{}'.format(embedding.step)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"
......@@ -272,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
Last prompt: {html.escape(text)}<br/>
Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
......
This diff is collapsed.
......@@ -129,8 +129,6 @@ class Script(scripts.Script):
return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment]
def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment):
p.batch_size = 1
p.batch_count = 1
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
......@@ -154,7 +152,7 @@ class Script(scripts.Script):
rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], [p.seed + x + 1 for x in range(p.init_latent.shape[0])])
rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
......
......@@ -77,14 +77,42 @@ def apply_sampler(p, x, xs):
p.sampler_index = sampler_index
def confirm_samplers(p, xs):
samplers_dict = build_samplers_dict(p)
for x in xs:
if x.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {x}")
def apply_checkpoint(p, x, xs):
info = modules.sd_models.get_closet_checkpoint_match(x)
assert info is not None, f'Checkpoint for {x} not found'
if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info)
def confirm_checkpoints(p, xs):
for x in xs:
if modules.sd_models.get_closet_checkpoint_match(x) is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
def apply_hypernetwork(p, x, xs):
hypernetwork.load_hypernetwork(x)
if x.lower() in ["", "none"]:
name = None
else:
name = hypernetwork.find_closest_hypernetwork_name(x)
if not name:
raise RuntimeError(f"Unknown hypernetwork: {x}")
hypernetwork.load_hypernetwork(name)
def confirm_hypernetworks(p, xs):
for x in xs:
if x.lower() in ["", "none"]:
continue
if not hypernetwork.find_closest_hypernetwork_name(x):
raise RuntimeError(f"Unknown hypernetwork: {x}")
def apply_clip_skip(p, x, xs):
......@@ -121,29 +149,29 @@ def str_permutations(x):
return x
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
axis_options = [
AxisOption("Nothing", str, do_nothing, format_nothing),
AxisOption("Seed", int, apply_field("seed"), format_value_add_label),
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label),
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label),
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
AxisOption("Nothing", str, do_nothing, format_nothing, None),
AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
]
......@@ -269,17 +297,10 @@ class Script(scripts.Script):
valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist]
# Confirm options are valid before starting
if opt.label == "Sampler":
samplers_dict = build_samplers_dict(p)
for sampler_val in valslist:
if sampler_val.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {sampler_val}")
elif opt.label == "Checkpoint name":
for ckpt_val in valslist:
if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
raise RuntimeError(f"Checkpoint for {ckpt_val} not found")
if opt.confirm:
opt.confirm(p, valslist)
return valslist
......
......@@ -31,12 +31,7 @@ from modules.paths import script_path
from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork
modelloader.cleanup_models()
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers()
queue_lock = threading.Lock()
......@@ -78,15 +73,24 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
def initialize():
modelloader.cleanup_models()
modules.sd_models.setup_model()
codeformer.setup_model(cmd_opts.codeformer_models_path)
gfpgan.setup_model(cmd_opts.gfpgan_models_path)
shared.face_restorers.append(modules.face_restoration.FaceRestoration())
modelloader.load_upscalers()
shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
def webui():
initialize()
# make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}')
......@@ -98,7 +102,7 @@ def webui():
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
app,local_url,share_url = demo.launch(
app, local_url, share_url = demo.launch(
share=cmd_opts.share,
server_name="0.0.0.0" if cmd_opts.listen else None,
server_port=cmd_opts.port,
......@@ -129,6 +133,5 @@ def webui():
print('Restarting Gradio')
if __name__ == "__main__":
webui()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment