Commit 03694e1f authored by DepFA's avatar DepFA Committed by GitHub

add embedding load and save from b64 json

parent fa0c5eb8
...@@ -7,9 +7,11 @@ import tqdm ...@@ -7,9 +7,11 @@ import tqdm
import html import html
import datetime import datetime
from PIL import Image, PngImagePlugin from PIL import Image,PngImagePlugin
from ..images import captionImge
import numpy as np
import base64 import base64
from io import BytesIO import json
from modules import shared, devices, sd_hijack, processing, sd_models from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset import modules.textual_inversion.dataset
...@@ -87,9 +89,9 @@ class EmbeddingDatabase: ...@@ -87,9 +89,9 @@ class EmbeddingDatabase:
if filename.upper().endswith('.PNG'): if filename.upper().endswith('.PNG'):
embed_image = Image.open(path) embed_image = Image.open(path)
if 'sd-embedding' in embed_image.text: if 'sd-ti-embedding' in embed_image.text:
embeddingData = base64.b64decode(embed_image.text['sd-embedding']) data = embeddingFromB64(embed_image.text['sd-ti-embedding'])
data = torch.load(BytesIO(embeddingData), map_location="cpu") name = data.get('name',name)
else: else:
data = torch.load(path, map_location="cpu") data = torch.load(path, map_location="cpu")
...@@ -258,13 +260,23 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, ...@@ -258,13 +260,23 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps,
if save_image_with_stored_embedding: if save_image_with_stored_embedding:
info = PngImagePlugin.PngInfo() info = PngImagePlugin.PngInfo()
info.add_text("sd-embedding", base64.b64encode(open(last_saved_file,'rb').read())) data = torch.load(last_saved_file)
image.save(last_saved_image, "PNG", pnginfo=info) info.add_text("sd-ti-embedding", embeddingToB64(data))
pre_lines = [((255, 207, 175),"<{}>".format(data.get('name','???')))]
caption_checkpoint_hash = data.get('sd_checkpoint','UNK')
caption_checkpoint_hash = caption_checkpoint_hash.upper() if caption_checkpoint_hash else 'UNK'
caption_stepcount = data.get('step',0)
caption_stepcount = caption_stepcount if caption_stepcount else 0
post_lines = [((240, 223, 175),"Trained against checkpoint [{}] for {} steps".format(caption_checkpoint_hash,
caption_stepcount))]
captioned_image = captionImge(image,prelines=pre_lines,postlines=post_lines)
captioned_image.save(last_saved_image, "PNG", pnginfo=info)
else: else:
image.save(last_saved_image) image.save(last_saved_image)
last_saved_image += f", prompt: {text}" last_saved_image += f", prompt: {text}"
shared.state.job_no = embedding.step shared.state.job_no = embedding.step
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment