Commit d48f3470 authored by yfszzx's avatar yfszzx
parents 4a37c7ee 7c890336
...@@ -17,6 +17,7 @@ __pycache__ ...@@ -17,6 +17,7 @@ __pycache__
/webui.settings.bat /webui.settings.bat
/embeddings /embeddings
/styles.csv /styles.csv
/params.txt
/styles.csv.bak /styles.csv.bak
/webui-user.bat /webui-user.bat
/webui-user.sh /webui-user.sh
......
...@@ -94,7 +94,7 @@ contextMenuInit = function(){ ...@@ -94,7 +94,7 @@ contextMenuInit = function(){
} }
gradioApp().addEventListener("click", function(e) { gradioApp().addEventListener("click", function(e) {
let source = e.composedPath()[0] let source = e.composedPath()[0]
if(source.id && source.indexOf('check_progress')>-1){ if(source.id && source.id.indexOf('check_progress')>-1){
return return
} }
......
...@@ -14,7 +14,7 @@ titles = { ...@@ -14,7 +14,7 @@ titles = {
"\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time", "\u{1f3b2}\ufe0f": "Set seed to -1, which will cause a new random number to be used every time",
"\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed", "\u267b\ufe0f": "Reuse seed from last generation, mostly useful if it was randomed",
"\u{1f3a8}": "Add a random artist to the prompt.", "\u{1f3a8}": "Add a random artist to the prompt.",
"\u2199\ufe0f": "Read generation parameters from prompt into user interface.", "\u2199\ufe0f": "Read generation parameters from prompt or last generation if prompt is empty into user interface.",
"\u{1f4c2}": "Open images output directory", "\u{1f4c2}": "Open images output directory",
"Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
...@@ -84,6 +84,8 @@ titles = { ...@@ -84,6 +84,8 @@ titles = {
"Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.", "Filename word regex": "This regular expression will be used extract words from filename, and they will be joined using the option below into label text used for training. Leave empty to keep filename text as it is.",
"Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.", "Filename join string": "This string will be used to hoin split words into a single line if the option above is enabled.",
"Quicksettings list": "List of setting names, separated by commas, for settings that should go to the quick access bar at the top, rather than the usual setting tab. See modules/shared.py for setting names. Requires restarting to apply."
} }
......
...@@ -36,7 +36,7 @@ onUiUpdate(function(){ ...@@ -36,7 +36,7 @@ onUiUpdate(function(){
const notification = new Notification( const notification = new Notification(
'Stable Diffusion', 'Stable Diffusion',
{ {
body: `Generated ${imgs.size > 1 ? imgs.size - 1 : 1} image${imgs.size > 1 ? 's' : ''}`, body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`,
icon: headImg, icon: headImg,
image: headImg, image: headImg,
} }
......
...@@ -33,27 +33,27 @@ function args_to_array(args){ ...@@ -33,27 +33,27 @@ function args_to_array(args){
} }
function switch_to_txt2img(){ function switch_to_txt2img(){
gradioApp().querySelectorAll('button')[0].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[0].click();
return args_to_array(arguments); return args_to_array(arguments);
} }
function switch_to_img2img_img2img(){ function switch_to_img2img_img2img(){
gradioApp().querySelectorAll('button')[1].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click(); gradioApp().getElementById('mode_img2img').querySelectorAll('button')[0].click();
return args_to_array(arguments); return args_to_array(arguments);
} }
function switch_to_img2img_inpaint(){ function switch_to_img2img_inpaint(){
gradioApp().querySelectorAll('button')[1].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[1].click();
gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click(); gradioApp().getElementById('mode_img2img').querySelectorAll('button')[1].click();
return args_to_array(arguments); return args_to_array(arguments);
} }
function switch_to_extras(){ function switch_to_extras(){
gradioApp().querySelectorAll('button')[2].click(); gradioApp().querySelector('#tabs').querySelectorAll('button')[2].click();
return args_to_array(arguments); return args_to_array(arguments);
} }
......
...@@ -19,6 +19,7 @@ def get_deepbooru_tags(pil_image): ...@@ -19,6 +19,7 @@ def get_deepbooru_tags(pil_image):
release_process() release_process()
OPT_INCLUDE_RANKS = "include_ranks"
def create_deepbooru_opts(): def create_deepbooru_opts():
from modules import shared from modules import shared
...@@ -26,6 +27,7 @@ def create_deepbooru_opts(): ...@@ -26,6 +27,7 @@ def create_deepbooru_opts():
"use_spaces": shared.opts.deepbooru_use_spaces, "use_spaces": shared.opts.deepbooru_use_spaces,
"use_escape": shared.opts.deepbooru_escape, "use_escape": shared.opts.deepbooru_escape,
"alpha_sort": shared.opts.deepbooru_sort_alpha, "alpha_sort": shared.opts.deepbooru_sort_alpha,
OPT_INCLUDE_RANKS: shared.opts.interrogate_return_ranks,
} }
...@@ -113,6 +115,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o ...@@ -113,6 +115,7 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o
alpha_sort = deepbooru_opts['alpha_sort'] alpha_sort = deepbooru_opts['alpha_sort']
use_spaces = deepbooru_opts['use_spaces'] use_spaces = deepbooru_opts['use_spaces']
use_escape = deepbooru_opts['use_escape'] use_escape = deepbooru_opts['use_escape']
include_ranks = deepbooru_opts['include_ranks']
width = model.input_shape[2] width = model.input_shape[2]
height = model.input_shape[1] height = model.input_shape[1]
...@@ -151,19 +154,20 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o ...@@ -151,19 +154,20 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, deepbooru_o
if alpha_sort: if alpha_sort:
sort_ndx = 1 sort_ndx = 1
# sort by reverse by likelihood and normal for alpha # sort by reverse by likelihood and normal for alpha, and format tag text as requested
unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort))
for weight, tag in unsorted_tags_in_theshold: for weight, tag in unsorted_tags_in_theshold:
result_tags_out.append(tag) # note: tag_outformat will still have a colon if include_ranks is True
tag_outformat = tag.replace(':', ' ')
if use_spaces:
tag_outformat = tag_outformat.replace('_', ' ')
if use_escape:
tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
if include_ranks:
tag_outformat = f"({tag_outformat}:{weight:.3f})"
print('\n'.join(sorted(result_tags_print, reverse=True))) result_tags_out.append(tag_outformat)
tags_text = ', '.join(result_tags_out)
if use_spaces: print('\n'.join(sorted(result_tags_print, reverse=True)))
tags_text = tags_text.replace('_', ' ')
if use_escape:
tags_text = re.sub(re_special, r'\\\1', tags_text)
return tags_text.replace(':', ' ') return ', '.join(result_tags_out)
import os
import re import re
import gradio as gr import gradio as gr
from modules.shared import script_path
from modules import shared
re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
re_param = re.compile(re_param_code) re_param = re.compile(re_param_code)
...@@ -61,6 +64,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model ...@@ -61,6 +64,12 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
def connect_paste(button, paste_fields, input_comp, js=None): def connect_paste(button, paste_fields, input_comp, js=None):
def paste_func(prompt): def paste_func(prompt):
if not prompt and not shared.cmd_opts.hide_ui_dir_config:
filename = os.path.join(script_path, "params.txt")
if os.path.exists(filename):
with open(filename, "r", encoding="utf8") as file:
prompt = file.read()
params = parse_generation_parameters(prompt) params = parse_generation_parameters(prompt)
res = [] res = []
......
...@@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler ...@@ -18,6 +18,8 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
class HypernetworkModule(torch.nn.Module): class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
def __init__(self, dim, state_dict=None): def __init__(self, dim, state_dict=None):
super().__init__() super().__init__()
...@@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module): ...@@ -36,7 +38,11 @@ class HypernetworkModule(torch.nn.Module):
self.to(devices.device) self.to(devices.device)
def forward(self, x): def forward(self, x):
return x + (self.linear2(self.linear1(x))) return x + (self.linear2(self.linear1(x))) * self.multiplier
def apply_strength(value=None):
HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
class Hypernetwork: class Hypernetwork:
......
...@@ -123,7 +123,7 @@ class InterrogateModels: ...@@ -123,7 +123,7 @@ class InterrogateModels:
return caption[0] return caption[0]
def interrogate(self, pil_image): def interrogate(self, pil_image, include_ranks=False):
res = None res = None
try: try:
...@@ -156,7 +156,10 @@ class InterrogateModels: ...@@ -156,7 +156,10 @@ class InterrogateModels:
for name, topn, items in self.categories: for name, topn, items in self.categories:
matches = self.rank(image_features, items, top_count=topn) matches = self.rank(image_features, items, top_count=topn)
for match, score in matches: for match, score in matches:
res += ", " + match if include_ranks:
res += ", " + match
else:
res += f", ({match}:{score})"
except Exception: except Exception:
print(f"Error interrogating", file=sys.stderr) print(f"Error interrogating", file=sys.stderr)
......
...@@ -324,6 +324,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -324,6 +324,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
else: else:
assert p.prompt is not None assert p.prompt is not None
with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
devices.torch_gc() devices.torch_gc()
seed = get_fixed_seed(p.seed) seed = get_fixed_seed(p.seed)
......
...@@ -13,7 +13,7 @@ import modules.memmon ...@@ -13,7 +13,7 @@ import modules.memmon
import modules.sd_models import modules.sd_models
import modules.styles import modules.styles
import modules.devices as devices import modules.devices as devices
from modules import sd_samplers from modules import sd_samplers, sd_models
from modules.hypernetworks import hypernetwork from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path from modules.paths import models_path, script_path, sd_path
...@@ -145,14 +145,14 @@ def realesrgan_models_names(): ...@@ -145,14 +145,14 @@ def realesrgan_models_names():
class OptionInfo: class OptionInfo:
def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False): def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, show_on_main_page=False, refresh=None):
self.default = default self.default = default
self.label = label self.label = label
self.component = component self.component = component
self.component_args = component_args self.component_args = component_args
self.onchange = onchange self.onchange = onchange
self.section = None self.section = None
self.show_on_main_page = show_on_main_page self.refresh = refresh
def options_section(section_identifier, options_dict): def options_section(section_identifier, options_dict):
...@@ -237,8 +237,9 @@ options_templates.update(options_section(('training', "Training"), { ...@@ -237,8 +237,9 @@ options_templates.update(options_section(('training', "Training"), {
})) }))
options_templates.update(options_section(('sd', "Stable Diffusion"), { options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, show_on_main_page=True), "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models),
"sd_hypernetwork": OptionInfo("None", "Stable Diffusion finetune hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
"img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
"save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),
"img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
...@@ -250,14 +251,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { ...@@ -250,14 +251,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
"filter_nsfw": OptionInfo(False, "Filter NSFW content"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), 'CLIP_stop_at_last_layers': OptionInfo(1, "Stop At last layers of CLIP model", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
"random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
'quicksettings': OptionInfo("sd_model_checkpoint", "Quicksettings list"),
})) }))
options_templates.update(options_section(('interrogate', "Interrogate Options"), { options_templates.update(options_section(('interrogate', "Interrogate Options"), {
"interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"), "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
"interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"), "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
"interrogate_return_ranks": OptionInfo(False, "Interrogate: include ranks of model tags matches in results (Has no effect on caption-based interrogators)."),
"interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
"interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
"interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
"interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
"interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
"deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"), "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
"deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"), "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),
...@@ -345,6 +349,8 @@ class Options: ...@@ -345,6 +349,8 @@ class Options:
item = self.data_labels.get(key) item = self.data_labels.get(key)
item.onchange = func item.onchange = func
func()
def dumpjson(self): def dumpjson(self):
d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()} d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
return json.dumps(d) return json.dumps(d)
......
...@@ -17,7 +17,9 @@ def preprocess(process_src, process_dst, process_width, process_height, process_ ...@@ -17,7 +17,9 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
shared.interrogator.load() shared.interrogator.load()
if process_caption_deepbooru: if process_caption_deepbooru:
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, deepbooru.create_deepbooru_opts()) db_opts = deepbooru.create_deepbooru_opts()
db_opts[deepbooru.OPT_INCLUDE_RANKS] = False
deepbooru.create_deepbooru_process(opts.interrogate_deepbooru_score_threshold, db_opts)
preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru) preprocess_work(process_src, process_dst, process_width, process_height, process_flip, process_split, process_caption, process_caption_deepbooru)
......
...@@ -79,6 +79,8 @@ reuse_symbol = '\u267b\ufe0f' # ♻️ ...@@ -79,6 +79,8 @@ reuse_symbol = '\u267b\ufe0f' # ♻️
art_symbol = '\U0001f3a8' # 🎨 art_symbol = '\U0001f3a8' # 🎨
paste_symbol = '\u2199\ufe0f' # ↙ paste_symbol = '\u2199\ufe0f' # ↙
folder_symbol = '\U0001f4c2' # 📂 folder_symbol = '\U0001f4c2' # 📂
refresh_symbol = '\U0001f504' # 🔄
def plaintext_to_html(text): def plaintext_to_html(text):
text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>" text = "<p>" + "<br>\n".join([f"{html.escape(x)}" for x in text.split('\n')]) + "</p>"
...@@ -1218,8 +1220,7 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1218,8 +1220,7 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[], outputs=[],
) )
def create_setting_component(key, is_quicksettings=False):
def create_setting_component(key):
def fun(): def fun():
return opts.data[key] if key in opts.data else opts.data_labels[key].default return opts.data[key] if key in opts.data else opts.data_labels[key].default
...@@ -1239,7 +1240,34 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1239,7 +1240,34 @@ def create_ui(wrap_gradio_gpu_call):
else: else:
raise Exception(f'bad options item type: {str(t)} for key {key}') raise Exception(f'bad options item type: {str(t)} for key {key}')
return comp(label=info.label, value=fun, **(args or {})) if info.refresh is not None:
if is_quicksettings:
res = comp(label=info.label, value=fun, **(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_"+key)
else:
with gr.Row(variant="compact"):
res = comp(label=info.label, value=fun, **(args or {}))
refresh_button = gr.Button(value=refresh_symbol, elem_id="refresh_" + key)
def refresh():
info.refresh()
refreshed_args = info.component_args() if callable(info.component_args) else info.component_args
for k, v in refreshed_args.items():
setattr(res, k, v)
return gr.update(**(refreshed_args or {}))
refresh_button.click(
fn=refresh,
inputs=[],
outputs=[res],
)
else:
res = comp(label=info.label, value=fun, **(args or {}))
return res
components = [] components = []
component_dict = {} component_dict = {}
...@@ -1313,6 +1341,9 @@ Requested path was: {f} ...@@ -1313,6 +1341,9 @@ Requested path was: {f}
settings_cols = 3 settings_cols = 3
items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols) items_per_col = int(len(opts.data_labels) * 0.9 / settings_cols)
quicksettings_names = [x.strip() for x in opts.quicksettings.split(",")]
quicksettings_names = set(x for x in quicksettings_names if x != 'quicksettings')
quicksettings_list = [] quicksettings_list = []
cols_displayed = 0 cols_displayed = 0
...@@ -1337,7 +1368,7 @@ Requested path was: {f} ...@@ -1337,7 +1368,7 @@ Requested path was: {f}
gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1])) gr.HTML(elem_id="settings_header_text_{}".format(item.section[0]), value='<h1 class="gr-button-lg">{}</h1>'.format(item.section[1]))
if item.show_on_main_page: if k in quicksettings_names:
quicksettings_list.append((i, k, item)) quicksettings_list.append((i, k, item))
components.append(dummy_component) components.append(dummy_component)
else: else:
...@@ -1346,7 +1377,11 @@ Requested path was: {f} ...@@ -1346,7 +1377,11 @@ Requested path was: {f}
components.append(component) components.append(component)
items_displayed += 1 items_displayed += 1
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") with gr.Row():
request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
request_notifications.click( request_notifications.click(
fn=lambda: None, fn=lambda: None,
inputs=[], inputs=[],
...@@ -1354,10 +1389,6 @@ Requested path was: {f} ...@@ -1354,10 +1389,6 @@ Requested path was: {f}
_js='function(){}' _js='function(){}'
) )
with gr.Row():
reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary')
restart_gradio = gr.Button(value='Restart Gradio and Refresh components (Custom Scripts, ui.py, js and css only)', variant='primary')
def reload_scripts(): def reload_scripts():
modules.scripts.reload_script_body_only() modules.scripts.reload_script_body_only()
...@@ -1372,7 +1403,6 @@ Requested path was: {f} ...@@ -1372,7 +1403,6 @@ Requested path was: {f}
shared.state.interrupt() shared.state.interrupt()
settings_interface.gradio_ref.do_restart = True settings_interface.gradio_ref.do_restart = True
restart_gradio.click( restart_gradio.click(
fn=request_restart, fn=request_restart,
inputs=[], inputs=[],
...@@ -1408,12 +1438,12 @@ Requested path was: {f} ...@@ -1408,12 +1438,12 @@ Requested path was: {f}
with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo:
with gr.Row(elem_id="quicksettings"): with gr.Row(elem_id="quicksettings"):
for i, k, item in quicksettings_list: for i, k, item in quicksettings_list:
component = create_setting_component(k) component = create_setting_component(k, is_quicksettings=True)
component_dict[k] = component component_dict[k] = component
settings_interface.gradio_ref = demo settings_interface.gradio_ref = demo
with gr.Tabs() as tabs: with gr.Tabs(elem_id="tabs") as tabs:
for interface, label, ifid in interfaces: for interface, label, ifid in interfaces:
with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid): with gr.TabItem(label, id=ifid, elem_id='tab_' + ifid):
interface.render() interface.render()
......
...@@ -120,15 +120,45 @@ class Script(scripts.Script): ...@@ -120,15 +120,45 @@ class Script(scripts.Script):
return is_img2img return is_img2img
def ui(self, is_img2img): def ui(self, is_img2img):
info = gr.Markdown('''
* `CFG Scale` should be 2 or lower.
''')
override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True)
override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True)
original_prompt = gr.Textbox(label="Original prompt", lines=1) original_prompt = gr.Textbox(label="Original prompt", lines=1)
original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1) original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1)
cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True)
st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50) st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50)
override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True)
cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0) randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False) sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False)
return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment]
def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment): return [
info,
override_sampler,
override_prompt, original_prompt, original_negative_prompt,
override_steps, st,
override_strength,
cfg, randomness, sigma_adjustment,
]
def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
# Override
if override_sampler:
p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler")
if override_prompt:
p.prompt = original_prompt
p.negative_prompt = original_negative_prompt
if override_steps:
p.steps = st
if override_strength:
p.denoising_strength = 1.0
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
......
...@@ -107,6 +107,10 @@ def apply_hypernetwork(p, x, xs): ...@@ -107,6 +107,10 @@ def apply_hypernetwork(p, x, xs):
hypernetwork.load_hypernetwork(name) hypernetwork.load_hypernetwork(name)
def apply_hypernetwork_strength(p, x, xs):
hypernetwork.apply_strength(x)
def confirm_hypernetworks(p, xs): def confirm_hypernetworks(p, xs):
for x in xs: for x in xs:
if x.lower() in ["", "none"]: if x.lower() in ["", "none"]:
...@@ -165,6 +169,7 @@ axis_options = [ ...@@ -165,6 +169,7 @@ axis_options = [
AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers), AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints), AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks), AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None), AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None), AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None), AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
...@@ -175,13 +180,17 @@ axis_options = [ ...@@ -175,13 +180,17 @@ axis_options = [
] ]
def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images):
res = []
ver_texts = [[images.GridAnnotation(y)] for y in y_labels] ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
hor_texts = [[images.GridAnnotation(x)] for x in x_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
first_processed = None # Temporary list of all the images that are generated to be populated into the grid.
# Will be filled with empty images for any individual step that fails to process properly
image_cache = []
processed_result = None
cell_mode = "P"
cell_size = (1,1)
state.job_count = len(xs) * len(ys) * p.n_iter state.job_count = len(xs) * len(ys) * p.n_iter
...@@ -189,22 +198,39 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): ...@@ -189,22 +198,39 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
for ix, x in enumerate(xs): for ix, x in enumerate(xs):
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
processed = cell(x, y) processed:Processed = cell(x, y)
if first_processed is None:
first_processed = processed
try: try:
res.append(processed.images[0]) # this dereference will throw an exception if the image was not processed
# (this happens in cases such as if the user stops the process from the UI)
processed_image = processed.images[0]
if processed_result is None:
# Use our first valid processed result as a template container to hold our full results
processed_result = copy(processed)
cell_mode = processed_image.mode
cell_size = processed_image.size
processed_result.images = [Image.new(cell_mode, cell_size)]
image_cache.append(processed_image)
if include_lone_images:
processed_result.images.append(processed_image)
processed_result.all_prompts.append(processed.prompt)
processed_result.all_seeds.append(processed.seed)
processed_result.infotexts.append(processed.infotexts[0])
except: except:
res.append(Image.new(res[0].mode, res[0].size)) image_cache.append(Image.new(cell_mode, cell_size))
if not processed_result:
print("Unexpected error: draw_xy_grid failed to return even a single processed image")
return Processed()
grid = images.image_grid(res, rows=len(ys)) grid = images.image_grid(image_cache, rows=len(ys))
if draw_legend: if draw_legend:
grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts) grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
first_processed.images = [grid] processed_result.images[0] = grid
return first_processed return processed_result
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
...@@ -229,11 +255,12 @@ class Script(scripts.Script): ...@@ -229,11 +255,12 @@ class Script(scripts.Script):
y_values = gr.Textbox(label="Y values", visible=False, lines=1) y_values = gr.Textbox(label="Y values", visible=False, lines=1)
draw_legend = gr.Checkbox(label='Draw legend', value=True) draw_legend = gr.Checkbox(label='Draw legend', value=True)
include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False)
return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds] return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds): def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
if not no_fixed_seeds: if not no_fixed_seeds:
modules.processing.fix_seed(p) modules.processing.fix_seed(p)
...@@ -344,7 +371,8 @@ class Script(scripts.Script): ...@@ -344,7 +371,8 @@ class Script(scripts.Script):
x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
cell=cell, cell=cell,
draw_legend=draw_legend draw_legend=draw_legend,
include_lone_images=include_lone_images
) )
if opts.grid_save: if opts.grid_save:
...@@ -354,6 +382,8 @@ class Script(scripts.Script): ...@@ -354,6 +382,8 @@ class Script(scripts.Script):
modules.sd_models.reload_model_weights(shared.sd_model) modules.sd_models.reload_model_weights(shared.sd_model)
hypernetwork.load_hypernetwork(opts.sd_hypernetwork) hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
hypernetwork.apply_strength()
opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers opts.data["CLIP_stop_at_last_layers"] = CLIP_stop_at_last_layers
......
...@@ -228,6 +228,8 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s ...@@ -228,6 +228,8 @@ fieldset span.text-gray-500, .gr-block.gr-box span.text-gray-500, label.block s
border-top: 1px solid #eee; border-top: 1px solid #eee;
border-left: 1px solid #eee; border-left: 1px solid #eee;
border-right: 1px solid #eee; border-right: 1px solid #eee;
z-index: 300;
} }
.dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{ .dark fieldset span.text-gray-500, .dark .gr-block.gr-box span.text-gray-500, .dark label.block span{
...@@ -480,16 +482,30 @@ input[type="range"]{ ...@@ -480,16 +482,30 @@ input[type="range"]{
background: #a55000; background: #a55000;
} }
#quicksettings {
gap: 0.4em;
}
#quicksettings > div{ #quicksettings > div{
border: none; border: none;
background: none; background: none;
flex: unset;
gap: 0.5em;
} }
#quicksettings > div > div{ #quicksettings > div > div{
max-width: 32em; max-width: 32em;
min-width: 24em;
padding: 0; padding: 0;
} }
#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork{
max-width: 2.5em;
min-width: 2.5em;
height: 2.4em;
}
canvas[key="mask"] { canvas[key="mask"] {
z-index: 12 !important; z-index: 12 !important;
filter: invert(); filter: invert();
...@@ -506,3 +522,10 @@ canvas[key="mask"] { ...@@ -506,3 +522,10 @@ canvas[key="mask"] {
z-index: 200; z-index: 200;
width: 8em; width: 8em;
} }
#quicksettings .gr-box > div > div > input.gr-text-input {
top: -1.12em;
}
.row.gr-compact{
overflow: visible;
}
...@@ -72,7 +72,6 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): ...@@ -72,7 +72,6 @@ def wrap_gradio_gpu_call(func, extra_outputs=None):
return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs)
def initialize(): def initialize():
modelloader.cleanup_models() modelloader.cleanup_models()
modules.sd_models.setup_model() modules.sd_models.setup_model()
...@@ -86,6 +85,7 @@ def initialize(): ...@@ -86,6 +85,7 @@ def initialize():
shared.sd_model = modules.sd_models.load_model() shared.sd_model = modules.sd_models.load_model()
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
def webui(): def webui():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment