Commit 085427de authored by AUTOMATIC's avatar AUTOMATIC

make it possible for extensions/scripts to add their own embedding directories

parent a0c87f1f
......@@ -83,10 +83,12 @@ class StableDiffusionModelHijack:
clip = None
optimization_method = None
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
def hijack(self, m):
def __init__(self):
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
......@@ -117,7 +119,6 @@ class StableDiffusionModelHijack:
self.layers = flatten(m)
def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped
......
......@@ -66,17 +66,41 @@ class Embedding:
return self.cached_checksum
class DirWithTextualInversionEmbeddings:
def __init__(self, path):
self.path = path
self.mtime = None
def has_changed(self):
if not os.path.isdir(self.path):
return False
mt = os.path.getmtime(self.path)
if self.mtime is None or mt > self.mtime:
return True
def update(self):
if not os.path.isdir(self.path):
return
self.mtime = os.path.getmtime(self.path)
class EmbeddingDatabase:
def __init__(self, embeddings_dir):
def __init__(self):
self.ids_lookup = {}
self.word_embeddings = {}
self.skipped_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
self.expected_shape = -1
self.embedding_dirs = {}
def register_embedding(self, embedding, model):
def add_embedding_dir(self, path):
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
def clear_embedding_dirs(self):
self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenize([embedding.name])[0]
......@@ -93,18 +117,7 @@ class EmbeddingDatabase:
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
def load_from_file(self, path, filename):
name, ext = os.path.splitext(filename)
ext = ext.upper()
......@@ -155,7 +168,11 @@ class EmbeddingDatabase:
else:
self.skipped_embeddings[name] = embedding
for root, dirs, fns in os.walk(self.embeddings_dir):
def load_from_dir(self, embdir):
if not os.path.isdir(embdir.path):
return
for root, dirs, fns in os.walk(embdir.path):
for fn in fns:
try:
fullfn = os.path.join(root, fn)
......@@ -163,12 +180,32 @@ class EmbeddingDatabase:
if os.stat(fullfn).st_size == 0:
continue
process_file(fullfn, fn)
self.load_from_file(fullfn, fn)
except Exception:
print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
def load_textual_inversion_embeddings(self, force_reload=False):
if not force_reload:
need_reload = False
for path, embdir in self.embedding_dirs.items():
if embdir.has_changed():
need_reload = True
break
if not need_reload:
return
self.ids_lookup.clear()
self.word_embeddings.clear()
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
for path, embdir in self.embedding_dirs.items():
self.load_from_dir(embdir)
embdir.update()
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
......@@ -251,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert os.path.isfile(template_file), "Prompt template file doesn't exist"
assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer"
assert steps > 0 , "Max steps must be positive"
assert steps > 0, "Max steps must be positive"
assert isinstance(save_model_every, int), "Save {name} must be integer"
assert save_model_every >= 0 , "Save {name} must be positive or 0"
assert save_model_every >= 0, "Save {name} must be positive or 0"
assert isinstance(create_image_every, int), "Create image must be integer"
assert create_image_every >= 0 , "Create image must be positive or 0"
assert create_image_every >= 0, "Create image must be positive or 0"
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment