Commit 365d4b16 authored by discus0434's avatar discus0434 Committed by GitHub

Merge branch 'AUTOMATIC1111:master' into master

parents 3770b8d2 f510a227
...@@ -275,7 +275,7 @@ re_attention = re.compile(r""" ...@@ -275,7 +275,7 @@ re_attention = re.compile(r"""
def parse_prompt_attention(text): def parse_prompt_attention(text):
""" """
Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight. Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
Accepted tokens are: Accepted tokens are:
(abc) - increases attention to abc by a multiplier of 1.1 (abc) - increases attention to abc by a multiplier of 1.1
(abc:3.12) - increases attention to abc by a multiplier of 3.12 (abc:3.12) - increases attention to abc by a multiplier of 3.12
......
...@@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v): ...@@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v):
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation # Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def einsum_op(q, k, v): def einsum_op(q, k, v):
if q.device.type == 'cuda': if q.device.type == 'cuda':
......
...@@ -148,7 +148,10 @@ def get_state_dict_from_checkpoint(pl_sd): ...@@ -148,7 +148,10 @@ def get_state_dict_from_checkpoint(pl_sd):
if new_key is not None: if new_key is not None:
sd[new_key] = v sd[new_key] = v
return sd pl_sd.clear()
pl_sd.update(sd)
return pl_sd
def load_model_weights(model, checkpoint_info): def load_model_weights(model, checkpoint_info):
......
...@@ -12,7 +12,7 @@ import time ...@@ -12,7 +12,7 @@ import time
import traceback import traceback
import platform import platform
import subprocess as sp import subprocess as sp
from functools import reduce from functools import partial, reduce
import numpy as np import numpy as np
import torch import torch
...@@ -261,6 +261,19 @@ def wrap_gradio_call(func, extra_outputs=None): ...@@ -261,6 +261,19 @@ def wrap_gradio_call(func, extra_outputs=None):
return f return f
def calc_time_left(progress, threshold, label, force_display):
if progress == 0:
return ""
else:
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
if (eta_relative > threshold and progress > 0.02) or force_display:
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
else:
return ""
def check_progress_call(id_part): def check_progress_call(id_part):
if shared.state.job_count == 0: if shared.state.job_count == 0:
return "", gr_show(False), gr_show(False), gr_show(False) return "", gr_show(False), gr_show(False), gr_show(False)
...@@ -272,11 +285,15 @@ def check_progress_call(id_part): ...@@ -272,11 +285,15 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0: if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display )
if time_left != "":
shared.state.time_left_force_display = True
progress = min(progress, 1) progress = min(progress, 1)
progressbar = "" progressbar = ""
if opts.show_progressbar: if opts.show_progressbar:
progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>""" progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:hidden;width:{progress * 100}%">{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}</div></div>"""
image = gr_show(False) image = gr_show(False)
preview_visibility = gr_show(False) preview_visibility = gr_show(False)
...@@ -308,6 +325,8 @@ def check_progress_call_initial(id_part): ...@@ -308,6 +325,8 @@ def check_progress_call_initial(id_part):
shared.state.current_latent = None shared.state.current_latent = None
shared.state.current_image = None shared.state.current_image = None
shared.state.textinfo = None shared.state.textinfo = None
shared.state.time_start = time.time()
shared.state.time_left_force_display = False
return check_progress_call(id_part) return check_progress_call(id_part)
...@@ -1543,6 +1562,7 @@ Requested path was: {f} ...@@ -1543,6 +1562,7 @@ Requested path was: {f}
def reload_scripts(): def reload_scripts():
modules.scripts.reload_script_body_only() modules.scripts.reload_script_body_only()
reload_javascript() # need to refresh the html page
reload_script_bodies.click( reload_script_bodies.click(
fn=reload_scripts, fn=reload_scripts,
...@@ -1801,26 +1821,30 @@ Requested path was: {f} ...@@ -1801,26 +1821,30 @@ Requested path was: {f}
return demo return demo
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: def load_javascript(raw_response):
javascript = f'<script>{jsfile.read()}</script>' with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>'
jsdir = os.path.join(script_path, "javascript") jsdir = os.path.join(script_path, "javascript")
for filename in sorted(os.listdir(jsdir)): for filename in sorted(os.listdir(jsdir)):
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
javascript += f"\n<script>{jsfile.read()}</script>" javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
if cmd_opts.theme is not None: if cmd_opts.theme is not None:
javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n" javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>" javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
if 'gradio_routes_templates_response' not in globals():
def template_response(*args, **kwargs): def template_response(*args, **kwargs):
res = gradio_routes_templates_response(*args, **kwargs) res = raw_response(*args, **kwargs)
res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8")) res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers() res.init_headers()
return res return res
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
gradio.routes.templates.TemplateResponse = template_response gradio.routes.templates.TemplateResponse = template_response
reload_javascript = partial(load_javascript,
gradio.routes.templates.TemplateResponse)
reload_javascript()
...@@ -34,9 +34,10 @@ ...@@ -34,9 +34,10 @@
.performance { .performance {
font-size: 0.85em; font-size: 0.85em;
color: #444; color: #444;
display: flex; }
justify-content: space-between;
white-space: nowrap; .performance p{
display: inline-block;
} }
.performance .time { .performance .time {
...@@ -44,8 +45,6 @@ ...@@ -44,8 +45,6 @@
} }
.performance .vram { .performance .vram {
margin-left: 0;
text-align: right;
} }
#txt2img_generate, #img2img_generate { #txt2img_generate, #img2img_generate {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment