Commit 2adb2497 authored by brkirch's avatar brkirch

Merge branch 'cpu-cmdline-opt' of...

Merge branch 'cpu-cmdline-opt' of https://github.com/brkirch/stable-diffusion-webui into cpu-cmdline-opt
parents eeab7aed 35a00b01
......@@ -47,6 +47,7 @@ titles = {
"Custom code": "Run Python code. Advanced user only. Must run program with --allow-code for this to work",
"Prompt S/R": "Separate a list of words with commas, and the first word will be used as a keyword: script will search for this word in the prompt, and replace it with others",
"Prompt order": "Separate a list of words with commas, and the script will make a variation of prompt with those words for their every possible order",
"Tiling": "Produce an image that can be tiled.",
"Tile overlap": "For SD upscale, how much overlap in pixels should there be between tiles. Tiles overlap so that when they are merged back into one picture, there is no clearly visible seam.",
......
......@@ -9,6 +9,9 @@ from torchvision import transforms
import random
import tqdm
from modules import devices
import re
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
class PersonalizedBase(Dataset):
......@@ -38,8 +41,8 @@ class PersonalizedBase(Dataset):
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0].replace('_', '-').replace(' ', '-').split('-')
filename_tokens = [token for token in filename_tokens if token.isalpha()]
filename_tokens = os.path.splitext(filename)[0]
filename_tokens = re_tag.findall(filename_tokens)
npimage = np.array(image).astype(np.uint8)
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
......
......@@ -26,7 +26,9 @@ def preprocess(process_src, process_dst, process_flip, process_split, process_ca
if process_caption:
caption = "-" + shared.interrogator.generate_caption(image)
else:
caption = ""
caption = filename
caption = os.path.splitext(caption)[0]
caption = os.path.basename(caption)
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
subindex[0] += 1
......
from collections import namedtuple
from copy import copy
from itertools import permutations
import random
from PIL import Image
......@@ -29,6 +30,31 @@ def apply_prompt(p, x, xs):
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
def apply_order(p, x, xs):
token_order = []
# Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
for token in x:
token_order.append((p.prompt.find(token), token))
token_order.sort(key=lambda t: t[0])
prompt_parts = []
# Split the prompt up, taking out the tokens
for _, token in token_order:
n = p.prompt.find(token)
prompt_parts.append(p.prompt[0:n])
p.prompt = p.prompt[n + len(token):]
# Rebuild the prompt with the tokens in the order we want
prompt_tmp = ""
for idx, part in enumerate(prompt_parts):
prompt_tmp += part
prompt_tmp += x[idx]
p.prompt = prompt_tmp + p.prompt
samplers_dict = {}
for i, sampler in enumerate(modules.sd_samplers.samplers):
samplers_dict[sampler.name.lower()] = i
......@@ -60,16 +86,26 @@ def format_value_add_label(p, opt, x):
def format_value(p, opt, x):
if type(x) == float:
x = round(x, 8)
return x
def format_value_join_list(p, opt, x):
return ", ".join(x)
def do_nothing(p, x, xs):
pass
def format_nothing(p, opt, x):
return ""
def str_permutations(x):
"""dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
return x
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
......@@ -82,6 +118,7 @@ axis_options = [
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
......@@ -131,6 +168,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d
re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
class Script(scripts.Script):
def title(self):
return "X/Y plot"
......@@ -206,6 +244,8 @@ class Script(scripts.Script):
valslist_ext.append(val)
valslist = valslist_ext
elif opt.type == str_permutations:
valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment