Commit 73a0b4bb authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #13944 from wfjsw/dag

implementing script metadata and DAG sorting mechanism
parents 411da7c2 bde439ef
import configparser
import functools
import os import os
import threading import threading
import re
from modules import shared, errors, cache, scripts from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo from modules.gitpython_hack import Repo
...@@ -23,8 +26,9 @@ class Extension: ...@@ -23,8 +26,9 @@ class Extension:
lock = threading.Lock() lock = threading.Lock()
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
def __init__(self, name, path, enabled=True, is_builtin=False): def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None):
self.name = name self.name = name
self.canonical_name = canonical_name or name.lower()
self.path = path self.path = path
self.enabled = enabled self.enabled = enabled
self.status = '' self.status = ''
...@@ -37,6 +41,18 @@ class Extension: ...@@ -37,6 +41,18 @@ class Extension:
self.remote = None self.remote = None
self.have_info_from_repo = False self.have_info_from_repo = False
@functools.cached_property
def metadata(self):
if os.path.isfile(os.path.join(self.path, "metadata.ini")):
try:
config = configparser.ConfigParser()
config.read(os.path.join(self.path, "metadata.ini"))
return config
except Exception:
errors.report(f"Error reading metadata.ini for extension {self.canonical_name}.",
exc_info=True)
return None
def to_dict(self): def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields} return {x: getattr(self, x) for x in self.cached_fields}
...@@ -56,6 +72,7 @@ class Extension: ...@@ -56,6 +72,7 @@ class Extension:
self.do_read_info_from_repo() self.do_read_info_from_repo()
return self.to_dict() return self.to_dict()
try: try:
d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
self.from_dict(d) self.from_dict(d)
...@@ -136,9 +153,6 @@ class Extension: ...@@ -136,9 +153,6 @@ class Extension:
def list_extensions(): def list_extensions():
extensions.clear() extensions.clear()
if not os.path.isdir(extensions_dir):
return
if shared.cmd_opts.disable_all_extensions: if shared.cmd_opts.disable_all_extensions:
print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
elif shared.opts.disable_all_extensions == "all": elif shared.opts.disable_all_extensions == "all":
...@@ -148,18 +162,68 @@ def list_extensions(): ...@@ -148,18 +162,68 @@ def list_extensions():
elif shared.opts.disable_all_extensions == "extra": elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
extension_paths = [] extension_dependency_map = {}
for dirname in [extensions_dir, extensions_builtin_dir]:
# scan through extensions directory and load metadata
for dirname in [extensions_builtin_dir, extensions_dir]:
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
return continue
for extension_dirname in sorted(os.listdir(dirname)): for extension_dirname in sorted(os.listdir(dirname)):
path = os.path.join(dirname, extension_dirname) path = os.path.join(dirname, extension_dirname)
if not os.path.isdir(path): if not os.path.isdir(path):
continue continue
extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) canonical_name = extension_dirname
requires = None
if os.path.isfile(os.path.join(path, "metadata.ini")):
try:
config = configparser.ConfigParser()
config.read(os.path.join(path, "metadata.ini"))
canonical_name = config.get("Extension", "Name", fallback=canonical_name)
requires = config.get("Extension", "Requires", fallback=None)
except Exception:
errors.report(f"Error reading metadata.ini for extension {extension_dirname}. "
f"Will load regardless.", exc_info=True)
canonical_name = canonical_name.lower().strip()
# check for duplicated canonical names
if canonical_name in extension_dependency_map:
errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions "
f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". "
f"The current loading extension will be discarded.", exc_info=False)
continue
for dirname, path, is_builtin in extension_paths: # both "," and " " are accepted as separator
extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []
extension_dependency_map[canonical_name] = {
"dirname": extension_dirname,
"path": path,
"requires": requires,
}
# check for requirements
for (_, extension_data) in extension_dependency_map.items():
dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires']
requirement_met = True
for req in requires:
if req not in extension_dependency_map:
errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. "
f"The current loading extension will be discarded.", exc_info=False)
requirement_met = False
break
dep_dirname = extension_dependency_map[req]['dirname']
if dep_dirname in shared.opts.disabled_extensions:
errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. "
f"The current loading extension will be discarded.", exc_info=False)
requirement_met = False
break
is_builtin = dirname == extensions_builtin_dir
extension = Extension(name=dirname, path=path,
enabled=dirname not in shared.opts.disabled_extensions and requirement_met,
is_builtin=is_builtin)
extensions.append(extension) extensions.append(extension)
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import re import re
import sys import sys
import inspect import inspect
from graphlib import TopologicalSorter, CycleError
from collections import namedtuple from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
...@@ -314,15 +315,120 @@ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedi ...@@ -314,15 +315,120 @@ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedi
def list_scripts(scriptdirname, extension, *, include_extensions=True): def list_scripts(scriptdirname, extension, *, include_extensions=True):
scripts_list = [] scripts_list = []
script_dependency_map = {}
basedir = os.path.join(paths.script_path, scriptdirname) # build script dependency map
if os.path.exists(basedir):
for filename in sorted(os.listdir(basedir)): root_script_basedir = os.path.join(paths.script_path, scriptdirname)
scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) if os.path.exists(root_script_basedir):
for filename in sorted(os.listdir(root_script_basedir)):
if not os.path.isfile(os.path.join(root_script_basedir, filename)):
continue
script_dependency_map[filename] = {
"extension": None,
"extension_dirname": None,
"script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)),
"requires": [],
"load_before": [],
"load_after": [],
}
if include_extensions: if include_extensions:
for ext in extensions.active(): for ext in extensions.active():
scripts_list += ext.list_files(scriptdirname, extension) extension_scripts_list = ext.list_files(scriptdirname, extension)
for extension_script in extension_scripts_list:
if not os.path.isfile(extension_script.path):
continue
script_canonical_name = ext.canonical_name + "/" + extension_script.filename
if ext.is_builtin:
script_canonical_name = "builtin/" + script_canonical_name
relative_path = scriptdirname + "/" + extension_script.filename
requires = ''
load_before = ''
load_after = ''
if ext.metadata is not None:
requires = ext.metadata.get(relative_path, "Requires", fallback='')
load_before = ext.metadata.get(relative_path, "Before", fallback='')
load_after = ext.metadata.get(relative_path, "After", fallback='')
# propagate directory level metadata
requires = requires + ',' + ext.metadata.get(scriptdirname, "Requires", fallback='')
load_before = load_before + ',' + ext.metadata.get(scriptdirname, "Before", fallback='')
load_after = load_after + ',' + ext.metadata.get(scriptdirname, "After", fallback='')
requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []
load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else []
load_before = list(filter(None, re.split(r"[,\s]+", load_before.lower()))) if load_before else []
script_dependency_map[script_canonical_name] = {
"extension": ext.canonical_name,
"extension_dirname": ext.name,
"script_file": extension_script,
"requires": requires,
"load_before": load_before,
"load_after": load_after,
}
# resolve dependencies
loaded_extensions = set()
for ext in extensions.active():
loaded_extensions.add(ext.canonical_name)
for script_canonical_name, script_data in script_dependency_map.items():
# load before requires inverse dependency
# in this case, append the script name into the load_after list of the specified script
for load_before_script in script_data['load_before']:
# if this requires an individual script to be loaded before
if load_before_script in script_dependency_map:
script_dependency_map[load_before_script]['load_after'].append(script_canonical_name)
elif load_before_script in loaded_extensions:
for _, script_data2 in script_dependency_map.items():
if script_data2['extension'] == load_before_script:
script_data2['load_after'].append(script_canonical_name)
break
# resolve extension name in load_after lists
for load_after_script in list(script_data['load_after']):
if load_after_script not in script_dependency_map and load_after_script in loaded_extensions:
script_data['load_after'].remove(load_after_script)
for script_canonical_name2, script_data2 in script_dependency_map.items():
if script_data2['extension'] == load_after_script:
script_data['load_after'].append(script_canonical_name2)
break
# build the DAG
sorter = TopologicalSorter()
for script_canonical_name, script_data in script_dependency_map.items():
requirement_met = True
for required_script in script_data['requires']:
# if this requires an individual script to be loaded
if required_script not in script_dependency_map and required_script not in loaded_extensions:
errors.report(f"Script \"{script_canonical_name}\" "
f"requires \"{required_script}\" to "
f"be loaded, but it is not. Skipping.",
exc_info=False)
requirement_met = False
break
if not requirement_met:
continue
sorter.add(script_canonical_name, *script_data['load_after'])
# sort the scripts
try:
ordered_script = sorter.static_order()
except CycleError:
errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True)
ordered_script = script_dependency_map.keys()
for script_canonical_name in ordered_script:
script_data = script_dependency_map[script_canonical_name]
scripts_list.append(script_data['script_file'])
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
...@@ -365,15 +471,9 @@ def load_scripts(): ...@@ -365,15 +471,9 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing): elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module)) postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
def orderby(basedir): # here the scripts_list is already ordered
# 1st webui, 2nd extensions-builtin, 3rd extensions # processing_script is not considered though
priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0} for scriptfile in scripts_list:
for key in priority:
if basedir.startswith(key):
return priority[key]
return 9999
for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
try: try:
if scriptfile.basedir != paths.script_path: if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path sys.path = [scriptfile.basedir] + sys.path
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment