Commit bdbe0982 authored by AUTOMATIC's avatar AUTOMATIC

changed embedding accepted shape detection to use existing code and support...

changed embedding accepted shape detection to use existing code and support the new alt-diffusion model, and reformatted messages a bit #6149
parent c24a314c
...@@ -80,23 +80,8 @@ class EmbeddingDatabase: ...@@ -80,23 +80,8 @@ class EmbeddingDatabase:
return embedding return embedding
def get_expected_shape(self): def get_expected_shape(self):
expected_shape = -1 # initialize with unknown vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
idx = torch.tensor(0).to(shared.device) return vec.shape[1]
if expected_shape == -1:
try: # matches sd15 signature
first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx)
expected_shape = first_embedding.shape[0]
except:
pass
if expected_shape == -1:
try: # matches sd20 signature
first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx)
expected_shape = first_embedding.shape[0]
except:
pass
if expected_shape == -1:
print('Could not determine expected embeddings shape from model')
return expected_shape
def load_textual_inversion_embeddings(self, force_reload = False): def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir) mt = os.path.getmtime(self.embeddings_dir)
...@@ -112,8 +97,6 @@ class EmbeddingDatabase: ...@@ -112,8 +97,6 @@ class EmbeddingDatabase:
def process_file(path, filename): def process_file(path, filename):
name = os.path.splitext(filename)[0] name = os.path.splitext(filename)[0]
data = []
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']: if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path) embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
...@@ -150,11 +133,10 @@ class EmbeddingDatabase: ...@@ -150,11 +133,10 @@ class EmbeddingDatabase:
embedding.vectors = vec.shape[0] embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1] embedding.shape = vec.shape[-1]
if (self.expected_shape == -1) or (self.expected_shape == embedding.shape): if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model) self.register_embedding(embedding, shared.sd_model)
else: else:
self.skipped_embeddings.append(name) self.skipped_embeddings.append(name)
# print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
for fn in os.listdir(self.embeddings_dir): for fn in os.listdir(self.embeddings_dir):
try: try:
...@@ -169,9 +151,9 @@ class EmbeddingDatabase: ...@@ -169,9 +151,9 @@ class EmbeddingDatabase:
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
continue continue
print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys()))) print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if (len(self.skipped_embeddings) > 0): if len(self.skipped_embeddings) > 0:
print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings))) print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
def find_embedding_at_position(self, tokens, offset): def find_embedding_at_position(self, tokens, offset):
token = tokens[offset] token = tokens[offset]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment