Commit d5266c07 authored by AUTOMATIC's avatar AUTOMATIC

split draw_prompt_matrix into two: generalized grid annotation and actual prompt matrix

disabled saving samples for SD upscale
parent 4ed435dd
......@@ -113,6 +113,7 @@ sd_upscalers = {
"None": lambda img: img
}
class Options:
class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None):
......@@ -332,25 +333,32 @@ def combine_grid(grid):
return combined_image
def draw_grid_annotations(im, width, height, hor_texts, ver_texts, hor_crossed_texts, ver_crossed_texts):
def wrap(text, font, line_length):
class GridAnnotation:
def __init__(self, text='', is_active=True):
self.text = text
self.is_active = is_active
self.size = None
def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
def wrap(drawing, text, font, line_length):
lines = ['']
for word in text.split():
line = f'{lines[-1]} {word}'.strip()
if d.textlength(line, font=font) <= line_length:
if drawing.textlength(line, font=font) <= line_length:
lines[-1] = line
else:
lines.append(word)
return '\n'.join(lines)
return lines
def draw_texts(pos, draw_x, draw_y, texts, sizes, active):
for i, (text, size) in enumerate(zip(texts, sizes)):
if not active:
text = '\u0336'.join(text) + '\u0336'
def draw_texts(drawing, draw_x, draw_y, lines):
for i, line in enumerate(lines):
drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
d.multiline_text((draw_x, draw_y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
if not line.is_active:
drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)
draw_y += size[1] + line_spacing
draw_y += line.size[1] + line_spacing
fontsize = (width + height) // 25
line_spacing = fontsize // 2
......@@ -358,103 +366,65 @@ def draw_grid_annotations(im, width, height, hor_texts, ver_texts, hor_crossed_t
color_active = (0, 0, 0)
color_inactive = (153, 153, 153)
pad_top = height // 4
pad_left = width * 3 // 4 if len(hor_texts) > 1 else 0
cols = im.width // width
rows = im.height // height
assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
calc_img = Image.new("RGB", (1, 1), "white")
calc_d = ImageDraw.Draw(calc_img)
for texts in hor_texts + ver_texts:
items = [] + texts
texts.clear()
for line in items:
wrapped = wrap(calc_d, line.text, fnt, width)
texts += [GridAnnotation(x, line.is_active) for x in wrapped]
for line in texts:
bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
pad_top = max(hor_text_heights) + line_spacing * 2
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
result.paste(im, (pad_left, pad_top))
d = ImageDraw.Draw(result)
prompts_horiz = [wrap(x, fnt, width) for x in hor_texts]
prompts_vert = [wrap(x, fnt, pad_left) for x in ver_texts]
sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
hor_text_height = sum([x[1] + line_spacing for x in sizes_hor]) - line_spacing
ver_text_height = sum([x[1] + line_spacing for x in sizes_ver]) - line_spacing
for col in range(cols):
x = pad_left + width * col + width / 2
y = pad_top / 2 - hor_text_height / 2
y = pad_top / 2 - hor_text_heights[col] / 2
draw_texts(col, x, y, prompts_horiz, sizes_hor)
draw_texts(d, x, y, hor_texts[col])
for row in range(rows):
x = pad_left / 2
y = pad_top + height * row + height / 2 - ver_text_height / 2
y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
draw_texts(row, x, y, prompts_vert, sizes_ver)
draw_texts(d, x, y, ver_texts[row])
return result
def draw_prompt_matrix(im, width, height, all_prompts):
def wrap(text, font, line_length):
lines = ['']
for word in text.split():
line = f'{lines[-1]} {word}'.strip()
if d.textlength(line, font=font) <= line_length:
lines[-1] = line
else:
lines.append(word)
return '\n'.join(lines)
def draw_texts(pos, draw_x, draw_y, texts, sizes):
for i, (text, size) in enumerate(zip(texts, sizes)):
active = pos & (1 << i) != 0
if not active:
text = '\u0336'.join(text) + '\u0336'
d.multiline_text((draw_x, draw_y + size[1] / 2), text, font=fnt, fill=color_active if active else color_inactive, anchor="mm", align="center")
draw_y += size[1] + line_spacing
fontsize = (width + height) // 25
line_spacing = fontsize // 2
fnt = ImageFont.truetype("arial.ttf", fontsize)
color_active = (0, 0, 0)
color_inactive = (153, 153, 153)
pad_top = height // 4
pad_left = width * 3 // 4 if len(all_prompts) > 2 else 0
cols = im.width // width
rows = im.height // height
prompts = all_prompts[1:]
result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
result.paste(im, (pad_left, pad_top))
d = ImageDraw.Draw(result)
boundary = math.ceil(len(prompts) / 2)
prompts_horiz = [wrap(x, fnt, width) for x in prompts[:boundary]]
prompts_vert = [wrap(x, fnt, pad_left) for x in prompts[boundary:]]
sizes_hor = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_horiz]]
sizes_ver = [(x[2] - x[0], x[3] - x[1]) for x in [d.multiline_textbbox((0, 0), x, font=fnt) for x in prompts_vert]]
hor_text_height = sum([x[1] + line_spacing for x in sizes_hor]) - line_spacing
ver_text_height = sum([x[1] + line_spacing for x in sizes_ver]) - line_spacing
for col in range(cols):
x = pad_left + width * col + width / 2
y = pad_top / 2 - hor_text_height / 2
draw_texts(col, x, y, prompts_horiz, sizes_hor)
for row in range(rows):
x = pad_left / 2
y = pad_top + height * row + height / 2 - ver_text_height / 2
prompts_horiz = prompts[:boundary]
prompts_vert = prompts[boundary:]
draw_texts(row, x, y, prompts_vert, sizes_ver)
hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
return result
return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
def resize_image(resize_mode, im, width, height):
......@@ -711,7 +681,7 @@ class EmbeddingsWithFixes(nn.Module):
class StableDiffusionProcessing:
def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_grid=False, extra_generation_params=None):
def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None):
self.outpath: str = outpath
self.prompt: str = prompt
self.seed: int = seed
......@@ -724,6 +694,7 @@ class StableDiffusionProcessing:
self.height: int = height
self.prompt_matrix: bool = prompt_matrix
self.use_GFPGAN: bool = use_GFPGAN
self.do_not_save_samples: bool = do_not_save_samples
self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params
......@@ -866,7 +837,9 @@ def process_images(p: StableDiffusionProcessing):
x_sample = restored_img
image = Image.fromarray(x_sample)
save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=infotext())
if not p.do_not_save_samples:
save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=infotext())
output_images.append(image)
base_count += 1
......@@ -1106,6 +1079,7 @@ def img2img(prompt: str, init_img, ddim_steps: int, sampler_index: int, use_GFPG
p.n_iter = 1
p.do_not_save_grid = True
p.do_not_save_samples = True
work = []
work_results = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment