Commit ff8dc190 authored by Liam's avatar Liam

fixed token counter for prompt editing

parent abdbf1de
...@@ -11,6 +11,7 @@ import time ...@@ -11,6 +11,7 @@ import time
import traceback import traceback
import platform import platform
import subprocess as sp import subprocess as sp
from functools import reduce
import numpy as np import numpy as np
import torch import torch
...@@ -32,6 +33,7 @@ import modules.gfpgan_model ...@@ -32,6 +33,7 @@ import modules.gfpgan_model
import modules.codeformer_model import modules.codeformer_model
import modules.styles import modules.styles
import modules.generation_parameters_copypaste import modules.generation_parameters_copypaste
from modules.prompt_parser import get_learned_conditioning_prompt_schedules
# this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the bowser will not show any UI
mimetypes.init() mimetypes.init()
...@@ -345,8 +347,11 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: ...@@ -345,8 +347,11 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
outputs=[seed, dummy_component] outputs=[seed, dummy_component]
) )
def update_token_counter(text): def update_token_counter(text, steps):
tokens, token_count, max_length = model_hijack.tokenize(text) prompt_schedules = get_learned_conditioning_prompt_schedules([text], steps)
flat_prompts = reduce(lambda list1, list2: list1+list2, prompt_schedules)
prompts = [prompt_text for step,prompt_text in flat_prompts]
tokens, token_count, max_length = max([model_hijack.tokenize(prompt) for prompt in prompts], key=lambda args: args[1])
style_class = ' class="red"' if (token_count > max_length) else "" style_class = ' class="red"' if (token_count > max_length) else ""
return f"<span {style_class}>{token_count}/{max_length}</span>" return f"<span {style_class}>{token_count}/{max_length}</span>"
...@@ -364,8 +369,7 @@ def create_toprow(is_img2img): ...@@ -364,8 +369,7 @@ def create_toprow(is_img2img):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste") paste = gr.Button(value=paste_symbol, elem_id="paste")
token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter") token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])
with gr.Column(scale=10, elem_id="style_pos_col"): with gr.Column(scale=10, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
...@@ -396,7 +400,7 @@ def create_toprow(is_img2img): ...@@ -396,7 +400,7 @@ def create_toprow(is_img2img):
prompt_style_apply = gr.Button('Apply style', elem_id="style_apply") prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
save_style = gr.Button('Create style', elem_id="style_create") save_style = gr.Button('Create style', elem_id="style_create")
return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, paste, token_counter, token_button
def setup_progressbar(progressbar, preview, id_part): def setup_progressbar(progressbar, preview, id_part):
...@@ -419,7 +423,7 @@ def setup_progressbar(progressbar, preview, id_part): ...@@ -419,7 +423,7 @@ def setup_progressbar(progressbar, preview, id_part):
def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
with gr.Blocks(analytics_enabled=False) as txt2img_interface: with gr.Blocks(analytics_enabled=False) as txt2img_interface:
txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste = create_toprow(is_img2img=False) txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=False)
dummy_component = gr.Label(visible=False) dummy_component = gr.Label(visible=False)
with gr.Row(elem_id='txt2img_progress_row'): with gr.Row(elem_id='txt2img_progress_row'):
...@@ -568,9 +572,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -568,9 +572,10 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
] ]
modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt) modules.generation_parameters_copypaste.connect_paste(paste, txt2img_paste_fields, txt2img_prompt)
token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as img2img_interface: with gr.Blocks(analytics_enabled=False) as img2img_interface:
img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste = create_toprow(is_img2img=True) img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, paste, token_counter, token_button = create_toprow(is_img2img=True)
with gr.Row(elem_id='img2img_progress_row'): with gr.Row(elem_id='img2img_progress_row'):
with gr.Column(scale=1): with gr.Column(scale=1):
...@@ -793,6 +798,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger): ...@@ -793,6 +798,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
(denoising_strength, "Denoising strength"), (denoising_strength, "Denoising strength"),
] ]
modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt) modules.generation_parameters_copypaste.connect_paste(paste, img2img_paste_fields, img2img_prompt)
token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Blocks(analytics_enabled=False) as extras_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment