Commit bdbe0982 authored by AUTOMATIC's avatar AUTOMATIC

changed embedding accepted shape detection to use existing code and support...

changed embedding accepted shape detection to use existing code and support the new alt-diffusion model, and reformatted messages a bit #6149
parent c24a314c
......@@ -80,23 +80,8 @@ class EmbeddingDatabase:
return embedding
def get_expected_shape(self):
expected_shape = -1 # initialize with unknown
idx = torch.tensor(0).to(shared.device)
if expected_shape == -1:
try: # matches sd15 signature
first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx)
expected_shape = first_embedding.shape[0]
except:
pass
if expected_shape == -1:
try: # matches sd20 signature
first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx)
expected_shape = first_embedding.shape[0]
except:
pass
if expected_shape == -1:
print('Could not determine expected embeddings shape from model')
return expected_shape
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
return vec.shape[1]
def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
......@@ -112,8 +97,6 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
data = []
if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
......@@ -150,11 +133,10 @@ class EmbeddingDatabase:
embedding.vectors = vec.shape[0]
embedding.shape = vec.shape[-1]
if (self.expected_shape == -1) or (self.expected_shape == embedding.shape):
if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
else:
self.skipped_embeddings.append(name)
# print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))
for fn in os.listdir(self.embeddings_dir):
try:
......@@ -169,9 +151,9 @@ class EmbeddingDatabase:
print(traceback.format_exc(), file=sys.stderr)
continue
print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys())))
if (len(self.skipped_embeddings) > 0):
print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings)))
print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
if len(self.skipped_embeddings) > 0:
print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment