Commit 0c5fa9a6 authored by AUTOMATIC's avatar AUTOMATIC

do not reload embeddings from disk when doing textual inversion

parent be1596ce
......@@ -53,7 +53,7 @@ def get_correct_sampler(p):
return sd_samplers.samplers_for_img2img
class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None, do_not_reload_embeddings=False):
self.sd_model = sd_model
self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids
......@@ -80,6 +80,7 @@ class StableDiffusionProcessing:
self.extra_generation_params: dict = extra_generation_params or {}
self.overlay_images = overlay_images
self.eta = eta
self.do_not_reload_embeddings = do_not_reload_embeddings
self.paste_to = None
self.color_corrections = None
self.denoising_strength: float = 0
......@@ -364,7 +365,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
def infotext(iteration=0, position_in_batch=0):
return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch)
if os.path.exists(cmd_opts.embeddings_dir):
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
infotexts = []
......
......@@ -296,6 +296,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
sd_model=shared.sd_model,
do_not_save_grid=True,
do_not_save_samples=True,
do_not_reload_embeddings=True,
)
if preview_from_txt2img:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment