Commit 085427de authored by AUTOMATIC's avatar AUTOMATIC

make it possible for extensions/scripts to add their own embedding directories

parent a0c87f1f
...@@ -83,10 +83,12 @@ class StableDiffusionModelHijack: ...@@ -83,10 +83,12 @@ class StableDiffusionModelHijack:
clip = None clip = None
optimization_method = None optimization_method = None
embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir) embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()
def hijack(self, m): def __init__(self):
self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)
def hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings = m.cond_stage_model.roberta.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
...@@ -117,7 +119,6 @@ class StableDiffusionModelHijack: ...@@ -117,7 +119,6 @@ class StableDiffusionModelHijack:
self.layers = flatten(m) self.layers = flatten(m)
def undo_hijack(self, m): def undo_hijack(self, m):
if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
m.cond_stage_model = m.cond_stage_model.wrapped m.cond_stage_model = m.cond_stage_model.wrapped
......
...@@ -66,17 +66,41 @@ class Embedding: ...@@ -66,17 +66,41 @@ class Embedding:
return self.cached_checksum return self.cached_checksum
class DirWithTextualInversionEmbeddings:
def __init__(self, path):
self.path = path
self.mtime = None
def has_changed(self):
if not os.path.isdir(self.path):
return False
mt = os.path.getmtime(self.path)
if self.mtime is None or mt > self.mtime:
return True
def update(self):
if not os.path.isdir(self.path):
return
self.mtime = os.path.getmtime(self.path)
class EmbeddingDatabase: class EmbeddingDatabase:
def __init__(self, embeddings_dir): def __init__(self):
self.ids_lookup = {} self.ids_lookup = {}
self.word_embeddings = {} self.word_embeddings = {}
self.skipped_embeddings = {} self.skipped_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
self.expected_shape = -1 self.expected_shape = -1
self.embedding_dirs = {}
def register_embedding(self, embedding, model): def add_embedding_dir(self, path):
self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
def clear_embedding_dirs(self):
self.embedding_dirs.clear()
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding self.word_embeddings[embedding.name] = embedding
ids = model.cond_stage_model.tokenize([embedding.name])[0] ids = model.cond_stage_model.tokenize([embedding.name])[0]
...@@ -93,18 +117,7 @@ class EmbeddingDatabase: ...@@ -93,18 +117,7 @@ class EmbeddingDatabase:
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1] return vec.shape[1]
def load_textual_inversion_embeddings(self, force_reload = False): def load_from_file(self, path, filename):
mt = os.path.getmtime(self.embeddings_dir)
if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
name, ext = os.path.splitext(filename) name, ext = os.path.splitext(filename)
ext = ext.upper() ext = ext.upper()
...@@ -155,7 +168,11 @@ class EmbeddingDatabase: ...@@ -155,7 +168,11 @@ class EmbeddingDatabase:
else: else:
self.skipped_embeddings[name] = embedding self.skipped_embeddings[name] = embedding
for root, dirs, fns in os.walk(self.embeddings_dir): def load_from_dir(self, embdir):
if not os.path.isdir(embdir.path):
return
for root, dirs, fns in os.walk(embdir.path):
for fn in fns: for fn in fns:
try: try:
fullfn = os.path.join(root, fn) fullfn = os.path.join(root, fn)
...@@ -163,12 +180,32 @@ class EmbeddingDatabase: ...@@ -163,12 +180,32 @@ class EmbeddingDatabase:
if os.stat(fullfn).st_size == 0: if os.stat(fullfn).st_size == 0:
continue continue
process_file(fullfn, fn) self.load_from_file(fullfn, fn)
except Exception: except Exception:
print(f"Error loading embedding {fn}:", file=sys.stderr) print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
continue continue
def load_textual_inversion_embeddings(self, force_reload=False):
if not force_reload:
need_reload = False
for path, embdir in self.embedding_dirs.items():
if embdir.has_changed():
need_reload = True
break
if not need_reload:
return
self.ids_lookup.clear()
self.word_embeddings.clear()
self.skipped_embeddings.clear()
self.expected_shape = self.get_expected_shape()
for path, embdir in self.embedding_dirs.items():
self.load_from_dir(embdir)
embdir.update()
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0: if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
...@@ -251,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat ...@@ -251,14 +288,15 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
assert os.path.isfile(template_file), "Prompt template file doesn't exist" assert os.path.isfile(template_file), "Prompt template file doesn't exist"
assert steps, "Max steps is empty or 0" assert steps, "Max steps is empty or 0"
assert isinstance(steps, int), "Max steps must be integer" assert isinstance(steps, int), "Max steps must be integer"
assert steps > 0 , "Max steps must be positive" assert steps > 0, "Max steps must be positive"
assert isinstance(save_model_every, int), "Save {name} must be integer" assert isinstance(save_model_every, int), "Save {name} must be integer"
assert save_model_every >= 0 , "Save {name} must be positive or 0" assert save_model_every >= 0, "Save {name} must be positive or 0"
assert isinstance(create_image_every, int), "Create image must be integer" assert isinstance(create_image_every, int), "Create image must be integer"
assert create_image_every >= 0 , "Create image must be positive or 0" assert create_image_every >= 0, "Create image must be positive or 0"
if save_model_every or create_image_every: if save_model_every or create_image_every:
assert log_directory, "Log directory is empty" assert log_directory, "Log directory is empty"
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0 save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0 create_image_every = create_image_every or 0
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment