Commit 8faac8b9 authored by AUTOMATIC's avatar AUTOMATIC

run basic torch calculation at startup in parallel to reduce the performance...

run basic torch calculation at startup in parallel to reduce the performance impact of first generation
parent 1f318292
import sys
import contextlib
from functools import lru_cache
import torch
from modules import errors
......@@ -154,3 +156,19 @@ def test_for_nans(x, where):
message += " Use --disable-nan-check commandline argument to disable this check."
raise NansException(message)
@lru_cache
def first_time_calculation():
"""
just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
spends about 2.7 seconds doing that, at least wih NVidia.
"""
x = torch.zeros((1, 1)).to(device, dtype)
linear = torch.nn.Linear(1, 1).to(device, dtype)
linear(x)
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
conv2d(x)
......@@ -20,7 +20,7 @@ import logging
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
from modules import paths, timer, import_hook, errors # noqa: F401
from modules import paths, timer, import_hook, errors, devices # noqa: F401
startup_timer = timer.Timer()
......@@ -295,6 +295,8 @@ def initialize_rest(*, reload_script_modules=False):
# (when reloading, this does nothing)
Thread(target=lambda: shared.sd_model).start()
Thread(target=devices.first_time_calculation).start()
shared.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment