Commit 9324cdaa authored by MalumaDev's avatar MalumaDev

ui fix, re organization of the code

parent e4f8b5f0
import copy
import itertools import itertools
import os import os
from pathlib import Path from pathlib import Path
...@@ -7,11 +8,12 @@ import gc ...@@ -7,11 +8,12 @@ import gc
import gradio as gr import gradio as gr
import torch import torch
from PIL import Image from PIL import Image
from modules import shared from torch import optim
from modules.shared import device
from transformers import CLIPModel, CLIPProcessor
from tqdm.auto import tqdm from modules import shared
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
from tqdm.auto import tqdm, trange
from modules.shared import opts, device
def get_all_images_in_folder(folder): def get_all_images_in_folder(folder):
...@@ -37,12 +39,39 @@ def iter_to_batched(iterable, n=1): ...@@ -37,12 +39,39 @@ def iter_to_batched(iterable, n=1):
yield chunk yield chunk
def create_ui():
with gr.Group():
with gr.Accordion("Open for Clip Aesthetic!", open=False):
with gr.Row():
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
value=0.9)
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
with gr.Row():
aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
placeholder="Aesthetic learning rate", value="0.0001")
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
label="Aesthetic imgs embedding",
value="None")
with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
placeholder="This text is used to rotate the feature space of the imgs embs",
value="")
aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
value=0.1)
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative
def generate_imgs_embd(name, folder, batch_size): def generate_imgs_embd(name, folder, batch_size):
# clipModel = CLIPModel.from_pretrained( # clipModel = CLIPModel.from_pretrained(
# shared.sd_model.cond_stage_model.clipModel.name_or_path # shared.sd_model.cond_stage_model.clipModel.name_or_path
# ) # )
model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path).to(device) model = shared.clip_model.to(device)
processor = CLIPProcessor.from_pretrained(shared.sd_model.cond_stage_model.clipModel.name_or_path) processor = CLIPProcessor.from_pretrained(model.name_or_path)
with torch.no_grad(): with torch.no_grad():
embs = [] embs = []
...@@ -63,7 +92,6 @@ def generate_imgs_embd(name, folder, batch_size): ...@@ -63,7 +92,6 @@ def generate_imgs_embd(name, folder, batch_size):
torch.save(embs, path) torch.save(embs, path)
model = model.cpu() model = model.cpu()
del model
del processor del processor
del embs del embs
gc.collect() gc.collect()
...@@ -74,4 +102,114 @@ def generate_imgs_embd(name, folder, batch_size): ...@@ -74,4 +102,114 @@ def generate_imgs_embd(name, folder, batch_size):
""" """
shared.update_aesthetic_embeddings() shared.update_aesthetic_embeddings()
return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding", return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
value="None"), res, "" value="None"), \
gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
label="Imgs embedding",
value="None"), res, ""
def slerp(low, high, val):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
class AestheticCLIP:
def __init__(self):
self.skip = False
self.aesthetic_steps = 0
self.aesthetic_weight = 0
self.aesthetic_lr = 0
self.slerp = False
self.aesthetic_text_negative = ""
self.aesthetic_slerp_angle = 0
self.aesthetic_imgs_text = ""
self.image_embs_name = None
self.image_embs = None
self.load_image_embs(None)
def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
aesthetic_slerp=True, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False):
self.aesthetic_imgs_text = aesthetic_imgs_text
self.aesthetic_slerp_angle = aesthetic_slerp_angle
self.aesthetic_text_negative = aesthetic_text_negative
self.slerp = aesthetic_slerp
self.aesthetic_lr = aesthetic_lr
self.aesthetic_weight = aesthetic_weight
self.aesthetic_steps = aesthetic_steps
self.load_image_embs(image_embs_name)
def set_skip(self, skip):
self.skip = skip
def load_image_embs(self, image_embs_name):
if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
image_embs_name = None
self.image_embs_name = None
if image_embs_name is not None and self.image_embs_name != image_embs_name:
self.image_embs_name = image_embs_name
self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
self.image_embs.requires_grad_(False)
def __call__(self, z, remade_batch_tokens):
if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
tokenizer = shared.sd_model.cond_stage_model.tokenizer
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [
[tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens).to(device)
model = copy.deepcopy(shared.clip_model).to(device)
model.requires_grad_(True)
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
text_embs_2 = model.get_text_features(
**tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
if self.aesthetic_text_negative:
text_embs_2 = self.image_embs - text_embs_2
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
else:
img_embs = self.image_embs
with torch.enable_grad():
# We optimize the model to maximize the similarity
optimizer = optim.Adam(
model.text_model.parameters(), lr=self.aesthetic_lr
)
for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
text_embs = model.get_text_features(input_ids=tokens)
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
sim = text_embs @ img_embs.T
loss = -sim
optimizer.zero_grad()
loss.mean().backward()
optimizer.step()
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
zn = model.text_model.final_layer_norm(zn)
else:
zn = zn.last_hidden_state
model.cpu()
del model
gc.collect()
torch.cuda.empty_cache()
zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
if self.slerp:
z = slerp(z, zn, self.aesthetic_weight)
else:
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
return z
...@@ -56,7 +56,14 @@ def process_batch(p, input_dir, output_dir, args): ...@@ -56,7 +56,14 @@ def process_batch(p, input_dir, output_dir, args):
processed_image.save(os.path.join(output_dir, filename)) processed_image.save(os.path.join(output_dir, filename))
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str,
aesthetic_lr=0,
aesthetic_weight=0, aesthetic_steps=0,
aesthetic_imgs=None,
aesthetic_slerp=False,
aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False, *args):
is_inpaint = mode == 1 is_inpaint = mode == 1
is_batch = mode == 2 is_batch = mode == 2
...@@ -109,6 +116,11 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro ...@@ -109,6 +116,11 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
inpainting_mask_invert=inpainting_mask_invert, inpainting_mask_invert=inpainting_mask_invert,
) )
shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative)
if shared.cmd_opts.enable_console_prompts: if shared.cmd_opts.enable_console_prompts:
print(f"\nimg2img: {prompt}", file=shared.progress_print_out) print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
......
...@@ -146,7 +146,8 @@ class Processed: ...@@ -146,7 +146,8 @@ class Processed:
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0] self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1 self.seed = int(self.seed if type(self.seed) != list else self.seed[0]) if self.seed is not None else -1
self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 self.subseed = int(
self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1
self.all_prompts = all_prompts or [self.prompt] self.all_prompts = all_prompts or [self.prompt]
self.all_seeds = all_seeds or [self.seed] self.all_seeds = all_seeds or [self.seed]
...@@ -332,16 +333,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration ...@@ -332,16 +333,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, def process_images(p: StableDiffusionProcessing) -> Processed:
aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False) -> Processed:
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
aesthetic_lr = float(aesthetic_lr)
aesthetic_weight = float(aesthetic_weight)
aesthetic_steps = int(aesthetic_steps)
if type(p.prompt) == list: if type(p.prompt) == list:
assert (len(p.prompt) > 0) assert (len(p.prompt) > 0)
else: else:
...@@ -417,16 +411,10 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh ...@@ -417,16 +411,10 @@ def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weigh
# uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) # uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
# c = p.sd_model.get_learned_conditioning(prompts) # c = p.sd_model.get_learned_conditioning(prompts)
with devices.autocast(): with devices.autocast():
if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): shared.aesthetic_clip.set_skip(True)
shared.sd_model.cond_stage_model.set_aesthetic_params()
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt],
p.steps) p.steps)
if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): shared.aesthetic_clip.set_skip(False)
shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight,
aesthetic_steps, aesthetic_imgs,
aesthetic_slerp, aesthetic_imgs_text,
aesthetic_slerp_angle,
aesthetic_text_negative)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
if len(model_hijack.comments) > 0: if len(model_hijack.comments) > 0:
...@@ -582,7 +570,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -582,7 +570,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f
self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model) self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
...@@ -600,10 +587,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -600,10 +587,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
seed_resize_from_w=self.seed_resize_from_w, p=self) seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] samples = samples[:, :, self.truncate_y // 2:samples.shape[2] - self.truncate_y // 2,
self.truncate_x // 2:samples.shape[3] - self.truncate_x // 2]
if opts.use_scale_latent_for_hires_fix: if opts.use_scale_latent_for_hires_fix:
samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f),
mode="bilinear")
else: else:
decoded_samples = decode_first_stage(self.sd_model, samples) decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
......
...@@ -29,8 +29,8 @@ def apply_optimizations(): ...@@ -29,8 +29,8 @@ def apply_optimizations():
ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.model.nonlinearity = silu
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (
if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)): 6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
print("Applying xformers cross attention optimization.") print("Applying xformers cross attention optimization.")
ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
...@@ -118,33 +118,14 @@ class StableDiffusionModelHijack: ...@@ -118,33 +118,14 @@ class StableDiffusionModelHijack:
return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count) return remade_batch_tokens[0], token_count, get_target_prompt_token_count(token_count)
def slerp(low, high, val):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack): def __init__(self, wrapped, hijack):
super().__init__() super().__init__()
self.wrapped = wrapped self.wrapped = wrapped
self.clipModel = CLIPModel.from_pretrained(
self.wrapped.transformer.name_or_path
)
del self.clipModel.vision_model
self.tokenizer = CLIPTokenizer.from_pretrained(self.wrapped.transformer.name_or_path)
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
# self.vision = CLIPVisionModel.from_pretrained(self.wrapped.transformer.name_or_path).eval()
self.image_embs_name = None
self.image_embs = None
self.load_image_embs(None)
self.token_mults = {} self.token_mults = {}
self.hijack: StableDiffusionModelHijack = hijack
self.tokenizer = wrapped.tokenizer
self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0] self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]
tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if
...@@ -164,28 +145,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -164,28 +145,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
if mult != 1.0: if mult != 1.0:
self.token_mults[ident] = mult self.token_mults[ident] = mult
def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
aesthetic_slerp=True, aesthetic_imgs_text="",
aesthetic_slerp_angle=0.15,
aesthetic_text_negative=False):
self.aesthetic_imgs_text = aesthetic_imgs_text
self.aesthetic_slerp_angle = aesthetic_slerp_angle
self.aesthetic_text_negative = aesthetic_text_negative
self.slerp = aesthetic_slerp
self.aesthetic_lr = aesthetic_lr
self.aesthetic_weight = aesthetic_weight
self.aesthetic_steps = aesthetic_steps
self.load_image_embs(image_embs_name)
def load_image_embs(self, image_embs_name):
if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
image_embs_name = None
if image_embs_name is not None and self.image_embs_name != image_embs_name:
self.image_embs_name = image_embs_name
self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
self.image_embs.requires_grad_(False)
def tokenize_line(self, line, used_custom_terms, hijack_comments): def tokenize_line(self, line, used_custom_terms, hijack_comments):
id_end = self.wrapped.tokenizer.eos_token_id id_end = self.wrapped.tokenizer.eos_token_id
...@@ -391,58 +350,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ...@@ -391,58 +350,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
z1 = self.process_tokens(tokens, multipliers) z1 = self.process_tokens(tokens, multipliers)
z = z1 if z is None else torch.cat((z, z1), axis=-2) z = z1 if z is None else torch.cat((z, z1), axis=-2)
z = shared.aesthetic_clip(z, remade_batch_tokens)
if self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name != None:
if not opts.use_old_emphasis_implementation:
remade_batch_tokens = [
[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in
remade_batch_tokens]
tokens = torch.asarray(remade_batch_tokens).to(device)
model = copy.deepcopy(self.clipModel).to(device)
model.requires_grad_(True)
if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
text_embs_2 = model.get_text_features(
**self.tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
if self.aesthetic_text_negative:
text_embs_2 = self.image_embs - text_embs_2
text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
else:
img_embs = self.image_embs
with torch.enable_grad():
# We optimize the model to maximize the similarity
optimizer = optim.Adam(
model.text_model.parameters(), lr=self.aesthetic_lr
)
for i in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
text_embs = model.get_text_features(input_ids=tokens)
text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
sim = text_embs @ img_embs.T
loss = -sim
optimizer.zero_grad()
loss.mean().backward()
optimizer.step()
zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
if opts.CLIP_stop_at_last_layers > 1:
zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
zn = model.text_model.final_layer_norm(zn)
else:
zn = zn.last_hidden_state
model.cpu()
del model
zn = torch.concat([zn for i in range(z.shape[1] // 77)], 1)
if self.slerp:
z = slerp(z, zn, self.aesthetic_weight)
else:
z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight
remade_batch_tokens = rem_tokens remade_batch_tokens = rem_tokens
batch_multipliers = rem_multipliers batch_multipliers = rem_multipliers
i += 1 i += 1
......
...@@ -20,7 +20,7 @@ checkpoints_loaded = collections.OrderedDict() ...@@ -20,7 +20,7 @@ checkpoints_loaded = collections.OrderedDict()
try: try:
# this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
from transformers import logging from transformers import logging, CLIPModel
logging.set_verbosity_error() logging.set_verbosity_error()
except Exception: except Exception:
...@@ -196,6 +196,9 @@ def load_model(): ...@@ -196,6 +196,9 @@ def load_model():
sd_hijack.model_hijack.hijack(sd_model) sd_hijack.model_hijack.hijack(sd_model)
if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path:
shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path)
sd_model.eval() sd_model.eval()
print(f"Model loaded.") print(f"Model loaded.")
......
...@@ -3,6 +3,7 @@ import datetime ...@@ -3,6 +3,7 @@ import datetime
import json import json
import os import os
import sys import sys
from collections import OrderedDict
import gradio as gr import gradio as gr
import tqdm import tqdm
...@@ -94,15 +95,15 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) ...@@ -94,15 +95,15 @@ os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)
hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir)
loaded_hypernetwork = None loaded_hypernetwork = None
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in aesthetic_embeddings = {}
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
aesthetic_embeddings = aesthetic_embeddings | {"None": None}
def update_aesthetic_embeddings(): def update_aesthetic_embeddings():
global aesthetic_embeddings global aesthetic_embeddings
aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in aesthetic_embeddings = {f.replace(".pt",""): os.path.join(cmd_opts.aesthetic_embeddings_dir, f) for f in
os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")} os.listdir(cmd_opts.aesthetic_embeddings_dir) if f.endswith(".pt")}
aesthetic_embeddings = aesthetic_embeddings | {"None": None} aesthetic_embeddings = OrderedDict(**{"None": None}, **aesthetic_embeddings)
update_aesthetic_embeddings()
def reload_hypernetworks(): def reload_hypernetworks():
global hypernetworks global hypernetworks
...@@ -381,6 +382,11 @@ sd_upscalers = [] ...@@ -381,6 +382,11 @@ sd_upscalers = []
sd_model = None sd_model = None
clip_model = None
from modules.aesthetic_clip import AestheticCLIP
aesthetic_clip = AestheticCLIP()
progress_print_out = sys.stdout progress_print_out = sys.stdout
......
...@@ -49,7 +49,7 @@ class PersonalizedBase(Dataset): ...@@ -49,7 +49,7 @@ class PersonalizedBase(Dataset):
print("Preparing dataset...") print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths): for path in tqdm.tqdm(self.image_paths):
try: try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC) image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception: except Exception:
continue continue
......
import modules.scripts import modules.scripts
from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.processing import StableDiffusionProcessing, Processed, StableDiffusionProcessingTxt2Img, \
StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, cmd_opts from modules.shared import opts, cmd_opts
import modules.shared as shared import modules.shared as shared
import modules.processing as processing import modules.processing as processing
from modules.ui import plaintext_to_html from modules.ui import plaintext_to_html
def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int,aesthetic_lr=0, def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int,
restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int,
subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool,
height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int,
firstphase_height: int, aesthetic_lr=0,
aesthetic_weight=0, aesthetic_steps=0, aesthetic_weight=0, aesthetic_steps=0,
aesthetic_imgs=None, aesthetic_imgs=None,
aesthetic_slerp=False, aesthetic_slerp=False,
...@@ -41,15 +46,17 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -41,15 +46,17 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
firstphase_height=firstphase_height if enable_hr else None, firstphase_height=firstphase_height if enable_hr else None,
) )
shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle,
aesthetic_text_negative)
if cmd_opts.enable_console_prompts: if cmd_opts.enable_console_prompts:
print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) print(f"\ntxt2img: {prompt}", file=shared.progress_print_out)
processed = modules.scripts.scripts_txt2img.run(p, *args) processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is None: if processed is None:
processed = process_images(p, aesthetic_lr, aesthetic_weight, aesthetic_steps, aesthetic_imgs, aesthetic_slerp,aesthetic_imgs_text, processed = process_images(p)
aesthetic_slerp_angle,
aesthetic_text_negative)
shared.total_tqdm.clear() shared.total_tqdm.clear()
...@@ -61,4 +68,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: ...@@ -61,4 +68,3 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2:
processed.images = [] processed.images = []
return processed.images, generation_info_js, plaintext_to_html(processed.info) return processed.images, generation_info_js, plaintext_to_html(processed.info)
...@@ -43,7 +43,7 @@ from modules.images import save_image ...@@ -43,7 +43,7 @@ from modules.images import save_image
import modules.textual_inversion.ui import modules.textual_inversion.ui
import modules.hypernetworks.ui import modules.hypernetworks.ui
import modules.aesthetic_clip import modules.aesthetic_clip as aesthetic_clip
import modules.images_history as img_his import modules.images_history as img_his
...@@ -593,23 +593,25 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -593,23 +593,25 @@ def create_ui(wrap_gradio_gpu_call):
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
with gr.Group(): # with gr.Group():
with gr.Accordion("Open for Clip Aesthetic!",open=False): # with gr.Accordion("Open for Clip Aesthetic!",open=False):
with gr.Row(): # with gr.Row():
aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9) # aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight", value=0.9)
aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5) # aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)
#
with gr.Row(): # with gr.Row():
aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001") # aesthetic_lr = gr.Textbox(label='Aesthetic learning rate', placeholder="Aesthetic learning rate", value="0.0001")
aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False) # aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()), # aesthetic_imgs = gr.Dropdown(sorted(aesthetic_embeddings.keys()),
label="Aesthetic imgs embedding", # label="Aesthetic imgs embedding",
value="None") # value="None")
#
with gr.Row(): # with gr.Row():
aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="") # aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", value="")
aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1) # aesthetic_slerp_angle = gr.Slider(label='Slerp angle',minimum=0, maximum=1, step=0.01, value=0.1)
aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False) # aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)
aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative = aesthetic_clip.create_ui()
with gr.Row(): with gr.Row():
...@@ -840,6 +842,9 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -840,6 +842,9 @@ def create_ui(wrap_gradio_gpu_call):
width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512) width = gr.Slider(minimum=64, maximum=2048, step=64, label="Width", value=512)
height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512) height = gr.Slider(minimum=64, maximum=2048, step=64, label="Height", value=512)
aesthetic_weight_im, aesthetic_steps_im, aesthetic_lr_im, aesthetic_slerp_im, aesthetic_imgs_im, aesthetic_imgs_text_im, aesthetic_slerp_angle_im, aesthetic_text_negative_im = aesthetic_clip.create_ui()
with gr.Row(): with gr.Row():
restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1) restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1)
tiling = gr.Checkbox(label='Tiling', value=False) tiling = gr.Checkbox(label='Tiling', value=False)
...@@ -944,6 +949,14 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -944,6 +949,14 @@ def create_ui(wrap_gradio_gpu_call):
inpainting_mask_invert, inpainting_mask_invert,
img2img_batch_input_dir, img2img_batch_input_dir,
img2img_batch_output_dir, img2img_batch_output_dir,
aesthetic_lr_im,
aesthetic_weight_im,
aesthetic_steps_im,
aesthetic_imgs_im,
aesthetic_slerp_im,
aesthetic_imgs_text_im,
aesthetic_slerp_angle_im,
aesthetic_text_negative_im,
] + custom_inputs, ] + custom_inputs,
outputs=[ outputs=[
img2img_gallery, img2img_gallery,
...@@ -1283,7 +1296,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1283,7 +1296,7 @@ def create_ui(wrap_gradio_gpu_call):
) )
create_embedding_ae.click( create_embedding_ae.click(
fn=modules.aesthetic_clip.generate_imgs_embd, fn=aesthetic_clip.generate_imgs_embd,
inputs=[ inputs=[
new_embedding_name_ae, new_embedding_name_ae,
process_src_ae, process_src_ae,
...@@ -1291,6 +1304,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1291,6 +1304,7 @@ def create_ui(wrap_gradio_gpu_call):
], ],
outputs=[ outputs=[
aesthetic_imgs, aesthetic_imgs,
aesthetic_imgs_im,
ti_output, ti_output,
ti_outcome, ti_outcome,
] ]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment