Commit 12f5bb49 authored by novelailab's avatar novelailab

Fix clip extend

parent a1a34f99
...@@ -41,6 +41,19 @@ def pil_upscale(image, scale=1): ...@@ -41,6 +41,19 @@ def pil_upscale(image, scale=1):
def fix_batch(tensor, bs): def fix_batch(tensor, bs):
return torch.stack([tensor.squeeze(0)]*bs, dim=0) return torch.stack([tensor.squeeze(0)]*bs, dim=0)
# make uc and prompt shapes match via padding for long prompts
# finetune
null_cond = None
def fix_cond_shapes(model, prompt_condition, uc):
global null_cond
if null_cond is None:
null_cond = model.get_learned_conditioning([""])
while prompt_condition.shape[1] > uc.shape[1]:
uc = torch.cat((uc, null_cond.repeat((uc.shape[0], 1, 1))), axis=1)
while prompt_condition.shape[1] < uc.shape[1]:
prompt_condition = torch.cat((prompt_condition, null_cond.repeat((prompt_condition.shape[0], 1, 1))), axis=1)
return prompt_condition, uc
# mix conditioning vectors for prompts # mix conditioning vectors for prompts
# @aero # @aero
def prompt_mixing(model, prompt_body, batch_size): def prompt_mixing(model, prompt_body, batch_size):
...@@ -61,6 +74,7 @@ def prompt_mixing(model, prompt_body, batch_size): ...@@ -61,6 +74,7 @@ def prompt_mixing(model, prompt_body, batch_size):
if prompt_sum is None: if prompt_sum is None:
prompt_sum = prompt_vector * prompt_power prompt_sum = prompt_vector * prompt_power
else: else:
prompt_sum, prompt_vector = fix_cond_shapes(model, prompt_sum, prompt_vector)
prompt_sum = prompt_sum + (prompt_vector * prompt_power) prompt_sum = prompt_sum + (prompt_vector * prompt_power)
prompt_total_power = prompt_total_power + prompt_power prompt_total_power = prompt_total_power + prompt_power
return fix_batch(prompt_sum / prompt_total_power, batch_size) return fix_batch(prompt_sum / prompt_total_power, batch_size)
...@@ -392,6 +406,7 @@ class StableDiffusionModel(nn.Module): ...@@ -392,6 +406,7 @@ class StableDiffusionModel(nn.Module):
uc = prompt_mixing(self.model, uc[0], request.n_samples) uc = prompt_mixing(self.model, uc[0], request.n_samples)
else: else:
uc = self.model.get_learned_conditioning(request.n_samples * [""]) uc = self.model.get_learned_conditioning(request.n_samples * [""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
shape = [ shape = [
request.latent_channels, request.latent_channels,
...@@ -475,6 +490,7 @@ class StableDiffusionModel(nn.Module): ...@@ -475,6 +490,7 @@ class StableDiffusionModel(nn.Module):
uc = None uc = None
if request.scale != 1.0: if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""]) uc = self.model.get_learned_conditioning(request.n_samples * [""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
shape = [ shape = [
request.latent_channels, request.latent_channels,
...@@ -515,12 +531,13 @@ class StableDiffusionModel(nn.Module): ...@@ -515,12 +531,13 @@ class StableDiffusionModel(nn.Module):
init_latent = init_latent + (torch.randn_like(init_latent) * request.noise) init_latent = init_latent + (torch.randn_like(init_latent) * request.noise)
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
uc = None uc = None
if request.scale != 1.0: if request.scale != 1.0:
uc = self.model.get_learned_conditioning(request.n_samples * [""]) uc = self.model.get_learned_conditioning(request.n_samples * [""])
prompt_condition, uc = fix_cond_shapes(self.model, prompt_condition, uc)
prompt_condition = prompt_mixing(self.model, prompt[0], request.n_samples)
# encode (scaled latent) # encode (scaled latent)
start_code_terped=None start_code_terped=None
z_enc = self.ddim.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped) z_enc = self.ddim.stochastic_encode(init_latent, torch.tensor([t_enc]*request.n_samples).to(self.device), noise=start_code_terped)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment