Commit 88ec0cf5 authored by AUTOMATIC's avatar AUTOMATIC

fix for incorrect embedding token length calculation (will break seeds that...

fix for incorrect embedding token length calculation (will break seeds that use embeddings, you're welcome!)
add option to input initialization text for embeddings
parent 53a3dc60
...@@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -130,7 +130,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
if embedding is None: if embedding is None:
remade_tokens.append(token) remade_tokens.append(token)
...@@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -142,7 +142,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens += [0] * emb_len remade_tokens += [0] * emb_len
multipliers += [weight] * emb_len multipliers += [weight] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum())) used_custom_terms.append((embedding.name, embedding.checksum()))
i += emb_len i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2: if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
...@@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -213,7 +213,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
while i < len(tokens): while i < len(tokens):
token = tokens[i] token = tokens[i]
embedding = self.hijack.embedding_db.find_embedding_at_position(tokens, i) embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, i)
mult_change = self.token_mults.get(token) if opts.enable_emphasis else None mult_change = self.token_mults.get(token) if opts.enable_emphasis else None
if mult_change is not None: if mult_change is not None:
...@@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -229,7 +229,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
remade_tokens += [0] * emb_len remade_tokens += [0] * emb_len
multipliers += [mult] * emb_len multipliers += [mult] * emb_len
used_custom_terms.append((embedding.name, embedding.checksum())) used_custom_terms.append((embedding.name, embedding.checksum()))
i += emb_len i += embedding_length_in_tokens
if len(remade_tokens) > maxlen - 2: if len(remade_tokens) > maxlen - 2:
vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()}
......
...@@ -117,24 +117,21 @@ class EmbeddingDatabase: ...@@ -117,24 +117,21 @@ class EmbeddingDatabase:
possible_matches = self.ids_lookup.get(token, None) possible_matches = self.ids_lookup.get(token, None)
if possible_matches is None: if possible_matches is None:
return None return None, None
for ids, embedding in possible_matches: for ids, embedding in possible_matches:
if tokens[offset:offset + len(ids)] == ids: if tokens[offset:offset + len(ids)] == ids:
return embedding return embedding, len(ids)
return None return None, None
def create_embedding(name, num_vectors_per_token, init_text='*'):
def create_embedding(name, num_vectors_per_token):
init_text = '*'
cond_model = shared.sd_model.cond_stage_model cond_model = shared.sd_model.cond_stage_model
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"] ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
embedded = embedding_layer(ids.to(devices.device)).squeeze(0) embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
for i in range(num_vectors_per_token): for i in range(num_vectors_per_token):
......
...@@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti ...@@ -6,8 +6,8 @@ import modules.textual_inversion.textual_inversion as ti
from modules import sd_hijack, shared from modules import sd_hijack, shared
def create_embedding(name, nvpt): def create_embedding(name, initialization_text, nvpt):
filename = ti.create_embedding(name, nvpt) filename = ti.create_embedding(name, nvpt, init_text=initialization_text)
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
......
...@@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -954,6 +954,7 @@ def create_ui(wrap_gradio_gpu_call):
gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>") gr.HTML(value="<p style='margin-bottom: 0.7em'>Create a new embedding</p>")
new_embedding_name = gr.Textbox(label="Name") new_embedding_name = gr.Textbox(label="Name")
initialization_text = gr.Textbox(label="Initialization text", value="*")
nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1) nvpt = gr.Slider(label="Number of vectors per token", minimum=1, maximum=75, step=1, value=1)
with gr.Row(): with gr.Row():
...@@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -997,6 +998,7 @@ def create_ui(wrap_gradio_gpu_call):
fn=modules.textual_inversion.ui.create_embedding, fn=modules.textual_inversion.ui.create_embedding,
inputs=[ inputs=[
new_embedding_name, new_embedding_name,
initialization_text,
nvpt, nvpt,
], ],
outputs=[ outputs=[
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment