Commit 12f5bb49 authored by novelailab's avatar novelailab

Fix clip extend

parent a1a34f99
......@@ -41,6 +41,19 @@ def pil_upscale(image, scale=1):
def fix_batch(tensor, bs):
return torch.stack([tensor.squeeze(0)]*bs, dim=0)
# make uc and prompt shapes match via padding for long prompts
# finetune
null_cond = None
def fix_cond_shapes(model, prompt_condition, uc):
global null_cond
if null_cond is None:
null_cond = model.get_learned_conditioning([""])
while prompt_condition.shape[1] > uc.shape[1]:
uc = torch.cat((uc, null_cond.repeat((uc.shape[0], 1, 1))), axis=1)
while prompt_condition.shape[1] < uc.shape[1]:
prompt_condition = torch.cat((prompt_condition, null_cond.repeat((prompt_condition.shape[0], 1, 1))), axis=1)
return prompt_condition, uc
# mix conditioning vectors for prompts
# @aero
def prompt_mixing(model, prompt_body, batch_size):
......@@ -61,6 +74,7 @@ def prompt_mixing(model, prompt_body, batch_size):
if prompt_sum is None:
prompt_sum = prompt_vector * prompt_power
else:
prompt_sum, prompt_vector = fix_cond_shapes(model, prompt_sum, prompt_vector)
prompt_sum = prompt_sum + (prompt_vector * prompt_power)
prompt_total_power = prompt_total_power + prompt_power
return fix_batch(prompt_sum / prompt_total_power, batch_size)
......@@ -392,6 +406,7 @@ class StableDiffusionModel(nn.Module):
uc = prompt_mixing(self.model, uc[0], request.n_samples)
else:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
shape = [
request.latent_channels,
......@@ -475,6 +490,7 @@ class StableDiffusionModel(nn.Module):
uc = None
if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
shape = [
request.latent_channels,
......@@ -515,12 +531,13 @@ class StableDiffusionModel(nn.Module):
init_latent = init_latent + (torch.randn_like(init_latent) * request.noise)
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
uc = None
if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
# encode (scaled latent)
start_code_terped=None
z_enc = self.ddim.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment