Commit 0027ce1f authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #12457 from rubberbaron/shared-hires-prompt-test

prompt editing timeline has separate range for first pass and hires-fix pass
parents 06f18186 99ab3d43
...@@ -407,12 +407,14 @@ class StableDiffusionProcessing: ...@@ -407,12 +407,14 @@ class StableDiffusionProcessing:
self.main_prompt = self.all_prompts[0] self.main_prompt = self.all_prompts[0]
self.main_negative_prompt = self.all_negative_prompts[0] self.main_negative_prompt = self.all_negative_prompts[0]
def cached_params(self, required_prompts, steps, extra_network_data): def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False):
"""Returns parameters that invalidate the cond cache if changed""" """Returns parameters that invalidate the cond cache if changed"""
return ( return (
required_prompts, required_prompts,
steps, steps,
hires_steps,
use_old_scheduling,
opts.CLIP_stop_at_last_layers, opts.CLIP_stop_at_last_layers,
shared.sd_model.sd_checkpoint_info, shared.sd_model.sd_checkpoint_info,
extra_network_data, extra_network_data,
...@@ -422,7 +424,7 @@ class StableDiffusionProcessing: ...@@ -422,7 +424,7 @@ class StableDiffusionProcessing:
self.height, self.height,
) )
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data): def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data, hires_steps=None):
""" """
Returns the result of calling function(shared.sd_model, required_prompts, steps) Returns the result of calling function(shared.sd_model, required_prompts, steps)
using a cache to store the result if the same arguments have been used before. using a cache to store the result if the same arguments have been used before.
...@@ -435,7 +437,7 @@ class StableDiffusionProcessing: ...@@ -435,7 +437,7 @@ class StableDiffusionProcessing:
caches is a list with items described above. caches is a list with items described above.
""" """
cached_params = self.cached_params(required_prompts, steps, extra_network_data) cached_params = self.cached_params(required_prompts, steps, extra_network_data, hires_steps, shared.opts.use_old_scheduling)
for cache in caches: for cache in caches:
if cache[0] is not None and cached_params == cache[0]: if cache[0] is not None and cached_params == cache[0]:
...@@ -444,7 +446,7 @@ class StableDiffusionProcessing: ...@@ -444,7 +446,7 @@ class StableDiffusionProcessing:
cache = caches[0] cache = caches[0]
with devices.autocast(): with devices.autocast():
cache[1] = function(shared.sd_model, required_prompts, steps) cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling)
cache[0] = cached_params cache[0] = cached_params
return cache[1] return cache[1]
...@@ -456,6 +458,8 @@ class StableDiffusionProcessing: ...@@ -456,6 +458,8 @@ class StableDiffusionProcessing:
sampler_config = sd_samplers.find_sampler_config(self.sampler_name) sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps
self.step_multiplier = total_steps // self.steps self.step_multiplier = total_steps // self.steps
self.firstpass_steps = total_steps
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data) self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data)
...@@ -1292,8 +1296,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): ...@@ -1292,8 +1296,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
steps = self.hr_second_pass_steps or self.steps steps = self.hr_second_pass_steps or self.steps
total_steps = sampler_config.total_steps(steps) if sampler_config else steps total_steps = sampler_config.total_steps(steps) if sampler_config else steps
self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, total_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps)
self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps)
def setup_conds(self): def setup_conds(self):
if self.is_hr_pass: if self.is_hr_pass:
......
...@@ -26,7 +26,7 @@ plain: /([^\\\[\]():|]|\\.)+/ ...@@ -26,7 +26,7 @@ plain: /([^\\\[\]():|]|\\.)+/
%import common.SIGNED_NUMBER -> NUMBER %import common.SIGNED_NUMBER -> NUMBER
""") """)
def get_learned_conditioning_prompt_schedules(prompts, steps): def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False):
""" """
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0] >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0]
>>> g("test") >>> g("test")
...@@ -57,18 +57,39 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ...@@ -57,18 +57,39 @@ def get_learned_conditioning_prompt_schedules(prompts, steps):
[[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']] [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]
>>> g("[fe|||]male") >>> g("[fe|||]male")
[[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']] [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]
>>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0]
>>> g("a [b:.5] c")
[[10, 'a b c']]
>>> g("a [b:1.5] c")
[[5, 'a c'], [10, 'a b c']]
""" """
if hires_steps is None or use_old_scheduling:
int_offset = 0
flt_offset = 0
steps = base_steps
else:
int_offset = base_steps
flt_offset = 1.0
steps = hires_steps
def collect_steps(steps, tree): def collect_steps(steps, tree):
res = [steps] res = [steps]
class CollectSteps(lark.Visitor): class CollectSteps(lark.Visitor):
def scheduled(self, tree): def scheduled(self, tree):
tree.children[-2] = float(tree.children[-2]) s = tree.children[-2]
if tree.children[-2] < 1: v = float(s)
tree.children[-2] *= steps if use_old_scheduling:
tree.children[-2] = min(steps, int(tree.children[-2])) v = v*steps if v<1 else v
res.append(tree.children[-2]) else:
if "." in s:
v = (v - flt_offset) * steps
else:
v = (v - int_offset)
tree.children[-2] = min(steps, int(v))
if tree.children[-2] >= 1:
res.append(tree.children[-2])
def alternate(self, tree): def alternate(self, tree):
res.extend(range(1, steps+1)) res.extend(range(1, steps+1))
...@@ -134,7 +155,7 @@ class SdConditioning(list): ...@@ -134,7 +155,7 @@ class SdConditioning(list):
def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False):
"""converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),
and the sampling step at which this condition is to be replaced by the next one. and the sampling step at which this condition is to be replaced by the next one.
...@@ -154,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): ...@@ -154,7 +175,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
""" """
res = [] res = []
prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling)
cache = {} cache = {}
for prompt, prompt_schedule in zip(prompts, prompt_schedules): for prompt, prompt_schedule in zip(prompts, prompt_schedules):
...@@ -229,7 +250,7 @@ class MulticondLearnedConditioning: ...@@ -229,7 +250,7 @@ class MulticondLearnedConditioning:
self.batch: List[List[ComposableScheduledPromptConditioning]] = batch self.batch: List[List[ComposableScheduledPromptConditioning]] = batch
def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: def get_multicond_learned_conditioning(model, prompts, steps, hires_steps=None, use_old_scheduling=False) -> MulticondLearnedConditioning:
"""same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.
For each prompt, the list is obtained by splitting the prompt using the AND separator. For each prompt, the list is obtained by splitting the prompt using the AND separator.
...@@ -238,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne ...@@ -238,7 +259,7 @@ def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearne
res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts) res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts)
learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps) learned_conditioning = get_learned_conditioning(model, prompt_flat_list, steps, hires_steps, use_old_scheduling)
res = [] res = []
for indexes in res_indexes: for indexes in res_indexes:
......
...@@ -203,6 +203,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { ...@@ -203,6 +203,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), {
"use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."),
"dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
"hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),
"use_old_scheduling": OptionInfo(False, "Use old prompt where first pass and hires both used the same timeline, and < 1 meant relative and >= 1 meant absolute"),
})) }))
options_templates.update(options_section(('interrogate', "Interrogate"), { options_templates.update(options_section(('interrogate', "Interrogate"), {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment