Commit d74c3810 authored by Jesse Williams's avatar Jesse Williams Committed by AUTOMATIC1111

Confirm that options are valid before starting

When using the 'Sampler' or 'Checkpoint' options, if one of the entered
names has a typo, an error will only be thrown once the `draw_xy_grid`
loop reaches that name. This can waste a lot of time for large grids
with a typo near the end of a list, since the script needs to start over
and re-generate any earlier images to finish making the grid.

Also fixing typo in variable name in `draw_xy_grid`.
parent 6f6798dd
...@@ -145,7 +145,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ...@@ -145,7 +145,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
ver_texts = [[images.GridAnnotation(y)] for y in y_labels] ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
hor_texts = [[images.GridAnnotation(x)] for x in x_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
first_pocessed = None first_processed = None
state.job_count = len(xs) * len(ys) * p.n_iter state.job_count = len(xs) * len(ys) * p.n_iter
...@@ -154,8 +154,8 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ...@@ -154,8 +154,8 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
processed = cell(x, y) processed = cell(x, y)
if first_pocessed is None: if first_processed is None:
first_pocessed = processed first_processed = processed
try: try:
res.append(processed.images[0]) res.append(processed.images[0])
...@@ -166,9 +166,9 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ...@@ -166,9 +166,9 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
if draw_legend: if draw_legend:
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
first_pocessed.images = [grid] first_processed.images = [grid]
return first_pocessed return first_processed
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
...@@ -216,7 +216,6 @@ class Script(scripts.Script): ...@@ -216,7 +216,6 @@ class Script(scripts.Script):
m = re_range.fullmatch(val) m = re_range.fullmatch(val)
mc = re_range_count.fullmatch(val) mc = re_range_count.fullmatch(val)
if m is not None: if m is not None:
start = int(m.group(1)) start = int(m.group(1))
end = int(m.group(2))+1 end = int(m.group(2))+1
step = int(m.group(3)) if m.group(3) is not None else 1 step = int(m.group(3)) if m.group(3) is not None else 1
...@@ -258,6 +257,16 @@ class Script(scripts.Script): ...@@ -258,6 +257,16 @@ class Script(scripts.Script):
valslist = list(permutations(valslist)) valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist] valslist = [opt.type(x) for x in valslist]
# Confirm options are valid before starting
if opt.label == "Sampler":
for sampler_val in valslist:
if sampler_val.lower() not in samplers_dict.keys():
raise RuntimeError(f"Unknown sampler: {sampler_val}")
elif opt.label == "Checkpoint name":
for ckpt_val in valslist:
if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
raise RuntimeError(f"Checkpoint for {ckpt_val} not found")
return valslist return valslist
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment