Commit fd661997 authored by AUTOMATIC's avatar AUTOMATIC

added preview option

parent db6db585
...@@ -176,6 +176,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -176,6 +176,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}" shared.state.job = f"Batch {n+1} out of {p.n_iter}"
samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc) samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
if state.interrupted:
# if we are interruped, sample returns just noise
# use the image collected previously in sampler loop
samples_ddim = shared.state.current_latent
x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim) x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
......
...@@ -42,6 +42,8 @@ def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs): ...@@ -42,6 +42,8 @@ def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts) img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
state.current_latent = x_dec
return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs) return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
...@@ -141,6 +143,9 @@ class KDiffusionSampler: ...@@ -141,6 +143,9 @@ class KDiffusionSampler:
self.func = getattr(k_diffusion.sampling, self.funcname) self.func = getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
def callback_state(self, d):
state.current_latent = d["denoised"]
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning): def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps) t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
sigmas = self.model_wrap.get_sigmas(p.steps) sigmas = self.model_wrap.get_sigmas(p.steps)
...@@ -157,7 +162,7 @@ class KDiffusionSampler: ...@@ -157,7 +162,7 @@ class KDiffusionSampler:
if hasattr(k_diffusion.sampling, 'trange'): if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs) k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False) return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
def sample(self, p, x, conditioning, unconditional_conditioning): def sample(self, p, x, conditioning, unconditional_conditioning):
sigmas = self.model_wrap.get_sigmas(p.steps) sigmas = self.model_wrap.get_sigmas(p.steps)
...@@ -166,6 +171,6 @@ class KDiffusionSampler: ...@@ -166,6 +171,6 @@ class KDiffusionSampler:
if hasattr(k_diffusion.sampling, 'trange'): if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs) k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False) samples_ddim = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
return samples_ddim return samples_ddim
...@@ -39,6 +39,7 @@ gpu = torch.device("cuda") ...@@ -39,6 +39,7 @@ gpu = torch.device("cuda")
device = gpu if torch.cuda.is_available() else cpu device = gpu if torch.cuda.is_available() else cpu
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
class State: class State:
interrupted = False interrupted = False
job = "" job = ""
...@@ -46,6 +47,8 @@ class State: ...@@ -46,6 +47,8 @@ class State:
job_count = 0 job_count = 0
sampling_step = 0 sampling_step = 0
sampling_steps = 0 sampling_steps = 0
current_latent = None
current_image = None
def interrupt(self): def interrupt(self):
self.interrupted = True self.interrupted = True
...@@ -99,6 +102,7 @@ class Options: ...@@ -99,6 +102,7 @@ class Options:
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
"upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}), "upscale_at_full_resolution_padding": OptionInfo(16, "Inpainting at full resolution: padding, in pixels, for the masked region.", gr.Slider, {"minimum": 0, "maximum": 128, "step": 4}),
"show_progressbar": OptionInfo(True, "Show progressbar"), "show_progressbar": OptionInfo(True, "Show progressbar"),
"show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
} }
def __init__(self): def __init__(self):
......
...@@ -9,6 +9,8 @@ import sys ...@@ -9,6 +9,8 @@ import sys
import time import time
import traceback import traceback
import numpy as np
import torch
from PIL import Image from PIL import Image
import gradio as gr import gradio as gr
...@@ -119,6 +121,9 @@ def wrap_gradio_call(func): ...@@ -119,6 +121,9 @@ def wrap_gradio_call(func):
print("Arguments:", args, kwargs, file=sys.stderr) print("Arguments:", args, kwargs, file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
shared.state.job = ""
shared.state.job_count = 0
res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"] res = [None, '', f"<div class='error'>{plaintext_to_html(type(e).__name__+': '+str(e))}</div>"]
elapsed = time.perf_counter() - t elapsed = time.perf_counter() - t
...@@ -134,11 +139,9 @@ def wrap_gradio_call(func): ...@@ -134,11 +139,9 @@ def wrap_gradio_call(func):
def check_progress_call(): def check_progress_call():
if not opts.show_progressbar:
return ""
if shared.state.job_count == 0: if shared.state.job_count == 0:
return "" return "", gr_show(False), gr_show(False)
progress = 0 progress = 0
...@@ -149,9 +152,29 @@ def check_progress_call(): ...@@ -149,9 +152,29 @@ def check_progress_call():
progress = min(progress, 1) progress = min(progress, 1)
progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>""" progressbar = ""
if opts.show_progressbar:
progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>"""
image = gr_show(False)
preview_visibility = gr_show(False)
if opts.show_progress_every_n_steps > 0:
if (shared.state.sampling_step-1) % opts.show_progress_every_n_steps == 0 and shared.state.current_latent is not None:
x_sample = shared.sd_model.decode_first_stage(shared.state.current_latent[0:1].type(shared.sd_model.dtype))[0]
x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
x_sample = x_sample.astype(np.uint8)
shared.state.current_image = Image.fromarray(x_sample)
return f"<span style='display: none'>{time.time()}</span><p>{progressbar}</p>" image = shared.state.current_image
if image is None or progress >= 1:
image = gr.update(value=None)
else:
preview_visibility = gr_show(True)
return f"<span style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
def roll_artist(prompt): def roll_artist(prompt):
...@@ -204,6 +227,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -204,6 +227,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
with gr.Group(): with gr.Group():
txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery') txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery')
...@@ -251,8 +275,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -251,8 +275,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
check_progress.click( check_progress.click(
fn=check_progress_call, fn=check_progress_call,
show_progress=False,
inputs=[], inputs=[],
outputs=[progressbar], outputs=[progressbar, txt2img_preview, txt2img_preview],
) )
...@@ -337,13 +362,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -337,13 +362,16 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
with gr.Group(): with gr.Group():
img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery') img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery')
with gr.Group(): with gr.Group():
with gr.Row(): with gr.Row():
interrupt = gr.Button('Interrupt')
save = gr.Button('Save') save = gr.Button('Save')
img2img_send_to_img2img = gr.Button('Send to img2img')
img2img_send_to_inpaint = gr.Button('Send to inpaint')
img2img_send_to_extras = gr.Button('Send to extras') img2img_send_to_extras = gr.Button('Send to extras')
interrupt = gr.Button('Interrupt')
progressbar = gr.HTML(elem_id="progressbar") progressbar = gr.HTML(elem_id="progressbar")
...@@ -426,8 +454,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -426,8 +454,9 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
check_progress.click( check_progress.click(
fn=check_progress_call, fn=check_progress_call,
show_progress=False,
inputs=[], inputs=[],
outputs=[progressbar], outputs=[progressbar, img2img_preview, img2img_preview],
) )
interrupt.click( interrupt.click(
...@@ -463,6 +492,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): ...@@ -463,6 +492,20 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
outputs=[init_img_with_mask], outputs=[init_img_with_mask],
) )
img2img_send_to_img2img.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery",
inputs=[img2img_gallery],
outputs=[init_img],
)
img2img_send_to_inpaint.click(
fn=lambda x: image_from_url_text(x),
_js="extract_image_from_gallery",
inputs=[img2img_gallery],
outputs=[init_img_with_mask],
)
with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
......
...@@ -79,6 +79,23 @@ function addTitles(root){ ...@@ -79,6 +79,23 @@ function addTitles(root){
global_progressbar = progressbar global_progressbar = progressbar
var mutationObserver = new MutationObserver(function(m){ var mutationObserver = new MutationObserver(function(m){
txt2img_preview = gradioApp().getElementById('txt2img_preview')
txt2img_gallery = gradioApp().getElementById('txt2img_gallery')
img2img_preview = gradioApp().getElementById('img2img_preview')
img2img_gallery = gradioApp().getElementById('img2img_gallery')
if(txt2img_preview != null && txt2img_gallery != null){
txt2img_preview.style.width = txt2img_gallery.clientWidth + "px"
txt2img_preview.style.height = txt2img_gallery.clientHeight + "px"
}
if(img2img_preview != null && img2img_gallery != null){
img2img_preview.style.width = img2img_gallery.clientWidth + "px"
img2img_preview.style.height = img2img_gallery.clientHeight + "px"
}
window.setTimeout(requestProgress, 500) window.setTimeout(requestProgress, 500)
}); });
mutationObserver.observe( progressbar, { childList:true, subtree:true }) mutationObserver.observe( progressbar, { childList:true, subtree:true })
......
...@@ -31,6 +31,20 @@ button{ ...@@ -31,6 +31,20 @@ button{
max-width: 10em; max-width: 10em;
} }
#txt2img_preview, #img2img_preview{
position: absolute;
width: 320px;
left: 0;
right: 0;
margin-left: auto;
margin-right: auto;
z-index: 100;
}
#txt2img_preview div.left-0.top-0, #img2img_preview div.left-0.top-0{
display: none;
}
fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block span{
position: absolute; position: absolute;
top: -0.6em; top: -0.6em;
...@@ -96,3 +110,4 @@ input[type="range"]{ ...@@ -96,3 +110,4 @@ input[type="range"]{
text-align: right; text-align: right;
border-radius: 8px; border-radius: 8px;
} }
...@@ -125,7 +125,8 @@ def wrap_gradio_gpu_call(func): ...@@ -125,7 +125,8 @@ def wrap_gradio_gpu_call(func):
shared.state.sampling_step = 0 shared.state.sampling_step = 0
shared.state.job_count = -1 shared.state.job_count = -1
shared.state.job_no = 0 shared.state.job_no = 0
shared.state.current_latent = None
shared.state.current_image = None
with queue_lock: with queue_lock:
res = func(*args, **kwargs) res = func(*args, **kwargs)
...@@ -163,7 +164,7 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts")) ...@@ -163,7 +164,7 @@ modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
if __name__ == "__main__": if __name__ == "__main__":
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with singal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
os._exit(0) os._exit(0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment