Commit 7350c712 authored by AUTOMATIC's avatar AUTOMATIC

added poor man's inpainting script

parent af133859
......@@ -39,23 +39,26 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
w = image.width
h = image.height
now = tile_w - overlap # non-overlap width
noh = tile_h - overlap
non_overlap_width = tile_w - overlap
non_overlap_height = tile_h - overlap
cols = math.ceil((w - overlap) / now)
rows = math.ceil((h - overlap) / noh)
cols = math.ceil((w - overlap) / non_overlap_width)
rows = math.ceil((h - overlap) / non_overlap_height)
dx = (w - tile_w) // (cols-1) if cols > 1 else 0
dy = (h - tile_h) // (rows-1) if rows > 1 else 0
grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows):
row_images = []
y = row * noh
y = row * dy
if y + tile_h >= h:
y = h - tile_h
for col in range(cols):
x = col * now
x = col * dx
if x+tile_w >= w:
x = w - tile_w
......
......@@ -130,7 +130,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
else:
processed = modules.scripts.run(p, *args)
processed = modules.scripts.scripts_img2img.run(p, *args)
if processed is None:
processed = process_images(p)
......
......@@ -271,7 +271,7 @@ def fill(image, mask):
image_masked = image_masked.convert('RGBa')
for radius, repeats in [(64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
for _ in range(repeats):
image_mod.alpha_composite(blurred)
......@@ -290,6 +290,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.denoising_strength: float = denoising_strength
self.init_latent = None
self.image_mask = mask
#self.image_unblurred_mask = None
self.latent_mask = None
self.mask_for_overlay = None
self.mask_blur = mask_blur
self.inpainting_fill = inpainting_fill
......@@ -308,6 +310,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_mask_invert:
self.image_mask = ImageOps.invert(self.image_mask)
#self.image_unblurred_mask = self.image_mask
if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
......@@ -368,7 +372,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
if self.image_mask is not None:
latmask = self.image_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
init_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
latmask = latmask[0]
latmask = np.tile(latmask[None], (4, 1, 1))
......
......@@ -18,6 +18,9 @@ class Script:
def ui(self, is_img2img):
pass
def show(self, is_img2img):
return True
def run(self, *args):
raise NotImplementedError()
......@@ -25,7 +28,7 @@ class Script:
return ""
scripts = []
scripts_data = []
def load_scripts(basedir):
......@@ -49,10 +52,8 @@ def load_scripts(basedir):
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
obj = script_class()
obj.filename = path
scripts_data.append((script_class, path))
scripts.append(obj)
except Exception:
print(f"Error loading script: {filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
......@@ -69,52 +70,75 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
return default
def setup_ui(is_img2img):
titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in scripts]
class ScriptRunner:
def __init__(self):
self.scripts = []
def setup_ui(self, is_img2img):
for script_class, path in scripts_data:
script = script_class()
script.filename = path
if not script.show(is_img2img):
continue
self.scripts.append(script)
titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
inputs = [dropdown]
for script in self.scripts:
script.args_from = len(inputs)
controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
if controls is None:
continue
dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
for control in controls:
control.visible = False
inputs = [dropdown]
inputs += controls
script.args_to = len(inputs)
for script in scripts:
script.args_from = len(inputs)
controls = script.ui(is_img2img)
def select_script(script_index):
if 0 < script_index <= len(self.scripts):
script = self.scripts[script_index-1]
args_from = script.args_from
args_to = script.args_to
else:
args_from = 0
args_to = 0
for control in controls:
control.visible = False
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
inputs += controls
script.args_to = len(inputs)
dropdown.change(
fn=select_script,
inputs=[dropdown],
outputs=inputs
)
def select_script(index):
if index > 0:
script = scripts[index-1]
args_from = script.args_from
args_to = script.args_to
else:
args_from = 0
args_to = 0
return inputs
return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
dropdown.change(
fn=select_script,
inputs=[dropdown],
outputs=inputs
)
def run(self, p: StableDiffusionProcessing, *args):
script_index = args[0]
return inputs
if script_index == 0:
return None
script = self.scripts[script_index-1]
def run(p: StableDiffusionProcessing, *args):
script_index = args[0] - 1
if script is None:
return None
if script_index < 0 or script_index >= len(scripts):
return None
script_args = args[script.args_from:script.args_to]
processed = script.run(p, *script_args)
script = scripts[script_index]
return processed
script_args = args[script.args_from:script.args_to]
processed = script.run(p, *script_args)
return processed
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
......@@ -24,7 +24,7 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u
use_GFPGAN=use_GFPGAN
)
processed = modules.scripts.run(p, *args)
processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is not None:
pass
......
......@@ -162,7 +162,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
seed = gr.Number(label='Seed', value=-1)
with gr.Group():
custom_inputs = modules.scripts.setup_ui(is_img2img=False)
custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
with gr.Column(variant='panel'):
with gr.Group():
......@@ -244,7 +244,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False)
with gr.Row():
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=True, visible=False)
inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False, visible=False)
inpainting_mask_invert = gr.Radio(label='Masking mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", visible=False)
with gr.Row():
......@@ -269,7 +269,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
seed = gr.Number(label='Seed', value=-1)
with gr.Group():
custom_inputs = modules.scripts.setup_ui(is_img2img=True)
custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
with gr.Column(variant='panel'):
......
import math
import modules.scripts as scripts
import gradio as gr
from PIL import Image, ImageDraw
from modules import images, processing
from modules.processing import Processed, process_images
from modules.shared import opts, cmd_opts, state
class Script(scripts.Script):
def title(self):
return "Poor man's outpainting"
def show(self, is_img2img):
return is_img2img
def ui(self, is_img2img):
if not is_img2img:
return None
pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=128, step=8)
mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False)
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False)
return [pixels, mask_blur, inpainting_fill]
def run(self, p, pixels, mask_blur, inpainting_fill):
initial_seed = None
initial_info = None
p.mask_blur = mask_blur
p.inpainting_fill = inpainting_fill
p.inpaint_full_res = False
init_img = p.init_images[0]
target_w = math.ceil((init_img.width + pixels * 2) / 64) * 64
target_h = math.ceil((init_img.height + pixels * 2) / 64) * 64
border_x = (target_w - init_img.width)//2
border_y = (target_h - init_img.height)//2
img = Image.new("RGB", (target_w, target_h))
img.paste(init_img, (border_x, border_y))
mask = Image.new("L", (img.width, img.height), "white")
draw = ImageDraw.Draw(mask)
draw.rectangle((border_x + mask_blur * 2, border_y + mask_blur * 2, mask.width - border_x - mask_blur * 2, mask.height - border_y - mask_blur * 2), fill="black")
latent_mask = Image.new("L", (img.width, img.height), "white")
latent_draw = ImageDraw.Draw(latent_mask)
latent_draw.rectangle((border_x + 1, border_y + 1, mask.width - border_x - 1, mask.height - border_y - 1), fill="black")
processing.torch_gc()
grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
grid_latent_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
p.n_iter = 1
p.batch_size = 1
p.do_not_save_grid = True
p.do_not_save_samples = True
work = []
work_mask = []
work_latent_mask = []
work_results = []
for (_, _, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles):
for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask):
work.append(tiledata[2])
work_mask.append(tiledata_mask[2])
work_latent_mask.append(tiledata_latent_mask[2])
batch_count = len(work)
print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.")
for i in range(batch_count):
p.init_images = [work[i]]
p.image_mask = work_mask[i]
p.latent_mask = work_latent_mask[i]
state.job = f"Batch {i + 1} out of {batch_count}"
processed = process_images(p)
if initial_seed is None:
initial_seed = processed.seed
initial_info = processed.info
p.seed = processed.seed + 1
work_results += processed.images
image_index = 0
for y, h, row in grid.tiles:
for tiledata in row:
tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
image_index += 1
combined_image = images.combine_grid(grid)
if opts.samples_save:
images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info)
processed = Processed(p, [combined_image], initial_seed, initial_info)
return processed
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment