Commit 4ec4af6e authored by AUTOMATIC's avatar AUTOMATIC

add checkpoint info to saved embeddings

parent 71fe7fa4
...@@ -7,7 +7,7 @@ import tqdm ...@@ -7,7 +7,7 @@ import tqdm
import html import html
import datetime import datetime
from modules import shared, devices, sd_hijack, processing from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
...@@ -17,6 +17,8 @@ class Embedding: ...@@ -17,6 +17,8 @@ class Embedding:
self.name = name self.name = name
self.step = step self.step = step
self.cached_checksum = None self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
def save(self, filename): def save(self, filename):
embedding_data = { embedding_data = {
...@@ -24,6 +26,8 @@ class Embedding: ...@@ -24,6 +26,8 @@ class Embedding:
"string_to_param": {"*": self.vec}, "string_to_param": {"*": self.vec},
"name": self.name, "name": self.name,
"step": self.step, "step": self.step,
"sd_checkpoint": self.sd_checkpoint,
"sd_checkpoint_name": self.sd_checkpoint_name,
} }
torch.save(embedding_data, filename) torch.save(embedding_data, filename)
...@@ -41,6 +45,7 @@ class Embedding: ...@@ -41,6 +45,7 @@ class Embedding:
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
return self.cached_checksum return self.cached_checksum
class EmbeddingDatabase: class EmbeddingDatabase:
def __init__(self, embeddings_dir): def __init__(self, embeddings_dir):
self.ids_lookup = {} self.ids_lookup = {}
...@@ -96,6 +101,8 @@ class EmbeddingDatabase: ...@@ -96,6 +101,8 @@ class EmbeddingDatabase:
vec = emb.detach().to(devices.device, dtype=torch.float32) vec = emb.detach().to(devices.device, dtype=torch.float32)
embedding = Embedding(vec, name) embedding = Embedding(vec, name)
embedding.step = data.get('step', None) embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('hash', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
self.register_embedding(embedding, shared.sd_model) self.register_embedding(embedding, shared.sd_model)
for fn in os.listdir(self.embeddings_dir): for fn in os.listdir(self.embeddings_dir):
...@@ -249,6 +256,10 @@ Last saved image: {html.escape(last_saved_image)}<br/> ...@@ -249,6 +256,10 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p> </p>
""" """
checkpoint = sd_models.select_checkpoint()
embedding.sd_checkpoint = checkpoint.hash
embedding.sd_checkpoint_name = checkpoint.model_name
embedding.cached_checksum = None embedding.cached_checksum = None
embedding.save(filename) embedding.save(filename)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment