Commit 52dcf0f0 authored by AUTOMATIC's avatar AUTOMATIC

record startup time

parent f968270f
...@@ -12,11 +12,22 @@ from packaging import version ...@@ -12,11 +12,22 @@ from packaging import version
import logging import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import import_hook, errors, extra_networks, ui_extra_networks_checkpoints from modules import paths, timer, import_hook, errors
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call startup_timer = timer.Timer()
import torch import torch
startup_timer.record("import torch")
import gradio
startup_timer.record("import gradio")
import ldm.modules.encoders.modules
startup_timer.record("import ldm")
from modules import extra_networks, ui_extra_networks_checkpoints
from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors
if ".dev" in torch.__version__ or "+git" in torch.__version__: if ".dev" in torch.__version__ or "+git" in torch.__version__:
...@@ -30,7 +41,6 @@ import modules.gfpgan_model as gfpgan ...@@ -30,7 +41,6 @@ import modules.gfpgan_model as gfpgan
import modules.img2img import modules.img2img
import modules.lowvram import modules.lowvram
import modules.paths
import modules.scripts import modules.scripts
import modules.sd_hijack import modules.sd_hijack
import modules.sd_models import modules.sd_models
...@@ -45,6 +55,8 @@ from modules import modelloader ...@@ -45,6 +55,8 @@ from modules import modelloader
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
startup_timer.record("other imports")
if cmd_opts.server_name: if cmd_opts.server_name:
server_name = cmd_opts.server_name server_name = cmd_opts.server_name
...@@ -88,6 +100,7 @@ def initialize(): ...@@ -88,6 +100,7 @@ def initialize():
extensions.list_extensions() extensions.list_extensions()
localization.list_localizations(cmd_opts.localizations_dir) localization.list_localizations(cmd_opts.localizations_dir)
startup_timer.record("list extensions")
if cmd_opts.ui_debug_mode: if cmd_opts.ui_debug_mode:
shared.sd_upscalers = upscaler.UpscalerLanczos().scalers shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
...@@ -96,16 +109,28 @@ def initialize(): ...@@ -96,16 +109,28 @@ def initialize():
modelloader.cleanup_models() modelloader.cleanup_models()
modules.sd_models.setup_model() modules.sd_models.setup_model()
startup_timer.record("list SD models")
codeformer.setup_model(cmd_opts.codeformer_models_path) codeformer.setup_model(cmd_opts.codeformer_models_path)
startup_timer.record("setup codeformer")
gfpgan.setup_model(cmd_opts.gfpgan_models_path) gfpgan.setup_model(cmd_opts.gfpgan_models_path)
startup_timer.record("setup gfpgan")
modelloader.list_builtin_upscalers() modelloader.list_builtin_upscalers()
startup_timer.record("list builtin upscalers")
modules.scripts.load_scripts() modules.scripts.load_scripts()
startup_timer.record("load scripts")
modelloader.load_upscalers() modelloader.load_upscalers()
startup_timer.record("load upscalers")
modules.sd_vae.refresh_vae_list() modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
modules.textual_inversion.textual_inversion.list_textual_inversion_templates() modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
try: try:
modules.sd_models.load_model() modules.sd_models.load_model()
...@@ -114,6 +139,7 @@ def initialize(): ...@@ -114,6 +139,7 @@ def initialize():
print("", file=sys.stderr) print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr) print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1) exit(1)
startup_timer.record("load SD checkpoint")
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
...@@ -121,8 +147,10 @@ def initialize(): ...@@ -121,8 +147,10 @@ def initialize():
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
startup_timer.record("opts onchange")
shared.reload_hypernetworks() shared.reload_hypernetworks()
startup_timer.record("reload hypernets")
ui_extra_networks.intialize() ui_extra_networks.intialize()
ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion())
...@@ -131,6 +159,7 @@ def initialize(): ...@@ -131,6 +159,7 @@ def initialize():
extra_networks.initialize() extra_networks.initialize()
extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet())
startup_timer.record("extra networks")
if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
...@@ -144,6 +173,7 @@ def initialize(): ...@@ -144,6 +173,7 @@ def initialize():
print("TLS setup invalid, running webui without TLS") print("TLS setup invalid, running webui without TLS")
else: else:
print("Running with TLS") print("Running with TLS")
startup_timer.record("TLS")
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
...@@ -189,6 +219,7 @@ def api_only(): ...@@ -189,6 +219,7 @@ def api_only():
modules.script_callbacks.app_started_callback(None, app) modules.script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
...@@ -199,10 +230,13 @@ def webui(): ...@@ -199,10 +230,13 @@ def webui():
while 1: while 1:
if shared.opts.clean_temp_dir_at_start: if shared.opts.clean_temp_dir_at_start:
ui_tempdir.cleanup_tmpdr() ui_tempdir.cleanup_tmpdr()
startup_timer.record("cleanup temp dir")
modules.script_callbacks.before_ui_callback() modules.script_callbacks.before_ui_callback()
startup_timer.record("scripts before_ui_callback")
shared.demo = modules.ui.create_ui() shared.demo = modules.ui.create_ui()
startup_timer.record("create ui")
if cmd_opts.gradio_queue: if cmd_opts.gradio_queue:
shared.demo.queue(64) shared.demo.queue(64)
...@@ -229,6 +263,8 @@ def webui(): ...@@ -229,6 +263,8 @@ def webui():
# after initial launch, disable --autolaunch for subsequent restarts # after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False cmd_opts.autolaunch = False
startup_timer.record("gradio launch")
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
# an attacker to trick the user into opening a malicious HTML page, which makes a request to the # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
# running web ui and do whatever the attacker wants, including installing an extension and # running web ui and do whatever the attacker wants, including installing an extension and
...@@ -247,6 +283,9 @@ def webui(): ...@@ -247,6 +283,9 @@ def webui():
ui_extra_networks.add_pages_to_demo(app) ui_extra_networks.add_pages_to_demo(app)
modules.script_callbacks.app_started_callback(shared.demo, app) modules.script_callbacks.app_started_callback(shared.demo, app)
startup_timer.record("scripts app_started_callback")
print(f"Startup time: {startup_timer.summary()}.")
wait_on_server(shared.demo) wait_on_server(shared.demo)
print('Restarting UI...') print('Restarting UI...')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment