Commit b08500ce authored by AUTOMATIC's avatar AUTOMATIC

Merge branch 'release_candidate'

parents 5ab7f213 231562ea
## 1.2.0
### Features:
* do not wait for stable diffusion model to load at startup
* add filename patterns: [denoising]
* directory hiding for extra networks: dirs starting with . will hide their cards on extra network tabs unless specifically searched for
* Lora: for the `<...>` text in prompt, use name of Lora that is in the metdata of the file, if present, instead of filename (both can be used to activate lora)
* Lora: read infotext params from kohya-ss's extension parameters if they are present and if his extension is not active
* Lora: Fix some Loras not working (ones that have 3x3 convolution layer)
* Lora: add an option to use old method of applying loras (producing same results as with kohya-ss)
* add version to infotext, footer and console output when starting
* add links to wiki for filename pattern settings
* add extended info for quicksettings setting and use multiselect input instead of a text field
### Minor:
* gradio bumped to 3.29.0
* torch bumped to 2.0.1
* --subpath option for gradio for use with reverse proxy
* linux/OSX: use existing virtualenv if already active (the VIRTUAL_ENV environment variable)
* possible frontend optimization: do not apply localizations if there are none
* Add extra `None` option for VAE in XYZ plot
* print error to console when batch processing in img2img fails
* create HTML for extra network pages only on demand
* allow directories starting with . to still list their models for lora, checkpoints, etc
* put infotext options into their own category in settings tab
* do not show licenses page when user selects Show all pages in settings
### Extensions:
* Tooltip localization support
* Add api method to get LoRA models with prompt
### Bug Fixes:
* re-add /docs endpoint
* fix gamepad navigation
* make the lightbox fullscreen image function properly
* fix squished thumbnails in extras tab
* keep "search" filter for extra networks when user refreshes the tab (previously it showed everthing after you refreshed)
* fix webui showing the same image if you configure the generation to always save results into same file
* fix bug with upscalers not working properly
* Fix MPS on PyTorch 2.0.1, Intel Macs
* make it so that custom context menu from contextMenu.js only disappears after user's click, ignoring non-user click events
* prevent Reload UI button/link from reloading the page when it's not yet ready
* fix prompts from file script failing to read contents from a drag/drop file
## 1.1.1
### Bug Fixes:
* fix an error that prevents running webui on torch<2.0 without --disable-safe-unpickle
......
from modules import extra_networks, shared
import lora
class ExtraNetworkLora(extra_networks.ExtraNetwork):
def __init__(self):
super().__init__('lora')
......
......@@ -4,7 +4,7 @@ import re
import torch
from typing import Union
from modules import shared, devices, sd_models, errors
from modules import shared, devices, sd_models, errors, scripts
metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
......@@ -93,6 +93,7 @@ class LoraOnDisk:
self.metadata = m
self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text
self.alias = self.metadata.get('ss_output_name', self.name)
class LoraModule:
......@@ -165,8 +166,10 @@ def load_lora(name, filename):
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d:
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
......@@ -199,11 +202,11 @@ def load_loras(names, multipliers=None):
loaded_loras.clear()
loras_on_disk = [available_loras.get(name, None) for name in names]
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
if any([x is None for x in loras_on_disk]):
list_available_loras()
loras_on_disk = [available_loras.get(name, None) for name in names]
loras_on_disk = [available_lora_aliases.get(name, None) for name in names]
for i, name in enumerate(names):
lora = already_loaded.get(name, None)
......@@ -232,6 +235,8 @@ def lora_calc_updown(lora, module, target):
if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
......@@ -240,6 +245,19 @@ def lora_calc_updown(lora, module, target):
return updown
def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "lora_weights_backup", None)
if weights_backup is None:
return
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of Loras to the weights of torch layer self.
......@@ -264,12 +282,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
self.lora_weights_backup = weights_backup
if current_names != wanted_names:
if weights_backup is not None:
if isinstance(self, torch.nn.MultiheadAttention):
self.in_proj_weight.copy_(weights_backup[0])
self.out_proj.weight.copy_(weights_backup[1])
else:
self.weight.copy_(weights_backup)
lora_restore_weights_from_backup(self)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
......@@ -300,12 +313,45 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
setattr(self, "lora_current_names", wanted_names)
def lora_forward(module, input, original_forward):
"""
Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation.
"""
if len(loaded_loras) == 0:
return original_forward(module, input)
input = devices.cond_cast_unet(input)
lora_restore_weights_from_backup(module)
lora_reset_cached_weight(module)
res = original_forward(module, input)
lora_layer_name = getattr(module, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is None:
continue
module.up.to(device=devices.device)
module.down.to(device=devices.device)
res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)
def lora_Linear_forward(self, input):
if shared.opts.lora_functional:
return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input)
......@@ -318,6 +364,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
def lora_Conv2d_forward(self, input):
if shared.opts.lora_functional:
return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
lora_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input)
......@@ -343,24 +392,59 @@ def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
def list_available_loras():
available_loras.clear()
available_lora_aliases.clear()
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)
candidates = \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.pt'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \
glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True)
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))
for filename in sorted(candidates, key=str.lower):
if os.path.isdir(filename):
continue
name = os.path.splitext(os.path.basename(filename))[0]
entry = LoraOnDisk(name, filename)
available_loras[name] = entry
available_lora_aliases[name] = entry
available_lora_aliases[entry.alias] = entry
re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")
def infotext_pasted(infotext, params):
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:
return # if the other extension is active, it will handle those fields, no need to do anything
added = []
for k, v in params.items():
if not k.startswith("AddNet Model "):
continue
num = k[13:]
if params.get("AddNet Module " + num) != "LoRA":
continue
name = params.get("AddNet Model " + num)
if name is None:
continue
m = re_lora_name.match(name)
if m:
name = m.group(1)
multiplier = params.get("AddNet Weight A " + num, "1.0")
available_loras[name] = LoraOnDisk(name, filename)
added.append(f"<lora:{name}:{multiplier}>")
if added:
params["Prompt"] += "\n" + "".join(added)
available_loras = {}
available_lora_aliases = {}
loaded_loras = []
list_available_loras()
import torch
import gradio as gr
from fastapi import FastAPI
import lora
import extra_networks_lora
import ui_extra_networks_lora
from modules import script_callbacks, ui_extra_networks, extra_networks, shared
def unload():
torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora
torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
......@@ -49,8 +49,33 @@ torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
script_callbacks.on_before_ui(before_ui)
script_callbacks.on_infotext_pasted(lora.infotext_pasted)
shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {
"sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in lora.available_loras]}, refresh=lora.list_available_loras),
}))
shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {
"lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),
}))
def create_lora_json(obj: lora.LoraOnDisk):
return {
"name": obj.name,
"alias": obj.alias,
"path": obj.filename,
"metadata": obj.metadata,
}
def api_loras(_: gr.Blocks, app: FastAPI):
@app.get("/sdapi/v1/loras")
async def get_loras():
return [create_lora_json(obj) for obj in lora.available_loras.values()]
script_callbacks.on_app_started(api_loras)
......@@ -21,7 +21,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
"preview": self.find_preview(path),
"description": self.find_description(path),
"search_term": self.search_terms_from_path(lora_on_disk.filename),
"prompt": json.dumps(f"<lora:{name}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"prompt": json.dumps(f"<lora:{lora_on_disk.alias}:") + " + opts.extra_networks_default_multiplier + " + json.dumps(">"),
"local_preview": f"{path}.{shared.opts.samples_format}",
"metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None,
}
......
......@@ -6,7 +6,7 @@
<ul>
<a href="#" title="replace preview image with currently selected in gallery" onclick={save_card_preview}>replace preview</a>
</ul>
<span style="display:none" class='search_term'>{search_term}</span>
<span style="display:none" class='search_term{serach_only}'>{search_term}</span>
</div>
<span class='name'>{name}</span>
<span class='description'>{description}</span>
......
......@@ -45,29 +45,24 @@ function dimensionChange(e, is_width, is_height){
var viewportOffset = targetElement.getBoundingClientRect();
viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
var viewportscale = Math.min( targetElement.clientWidth/targetElement.naturalWidth, targetElement.clientHeight/targetElement.naturalHeight )
scaledx = targetElement.naturalWidth*viewportscale
scaledy = targetElement.naturalHeight*viewportscale
var scaledx = targetElement.naturalWidth*viewportscale
var scaledy = targetElement.naturalHeight*viewportscale
cleintRectTop = (viewportOffset.top+window.scrollY)
cleintRectLeft = (viewportOffset.left+window.scrollX)
cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
var cleintRectTop = (viewportOffset.top+window.scrollY)
var cleintRectLeft = (viewportOffset.left+window.scrollX)
var cleintRectCentreY = cleintRectTop + (targetElement.clientHeight/2)
var cleintRectCentreX = cleintRectLeft + (targetElement.clientWidth/2)
viewRectTop = cleintRectCentreY-(scaledy/2)
viewRectLeft = cleintRectCentreX-(scaledx/2)
arRectWidth = scaledx
arRectHeight = scaledy
var arscale = Math.min( scaledx/currentWidth, scaledy/currentHeight )
var arscaledx = currentWidth*arscale
var arscaledy = currentHeight*arscale
arscale = Math.min( arRectWidth/currentWidth, arRectHeight/currentHeight )
arscaledx = currentWidth*arscale
arscaledy = currentHeight*arscale
arRectTop = cleintRectCentreY-(arscaledy/2)
arRectLeft = cleintRectCentreX-(arscaledx/2)
arRectWidth = arscaledx
arRectHeight = arscaledy
var arRectTop = cleintRectCentreY-(arscaledy/2)
var arRectLeft = cleintRectCentreX-(arscaledx/2)
var arRectWidth = arscaledx
var arRectHeight = arscaledy
arPreviewRect.style.top = arRectTop+'px';
arPreviewRect.style.left = arRectLeft+'px';
......
......@@ -4,7 +4,7 @@ contextMenuInit = function(){
let menuSpecs = new Map();
const uid = function(){
return Date.now().toString(36) + Math.random().toString(36).substr(2);
return Date.now().toString(36) + Math.random().toString(36).substring(2);
}
function showContextMenu(event,element,menuEntries){
......@@ -16,8 +16,7 @@ contextMenuInit = function(){
oldMenu.remove()
}
let tabButton = uiCurrentTab
let baseStyle = window.getComputedStyle(tabButton)
let baseStyle = window.getComputedStyle(uiCurrentTab)
const contextMenu = document.createElement('nav')
contextMenu.id = "context-menu"
......@@ -36,7 +35,7 @@ contextMenuInit = function(){
menuEntries.forEach(function(entry){
let contextMenuEntry = document.createElement('a')
contextMenuEntry.innerHTML = entry['name']
contextMenuEntry.addEventListener("click", function(e) {
contextMenuEntry.addEventListener("click", function() {
entry['func']();
})
contextMenuList.append(contextMenuEntry);
......@@ -63,7 +62,7 @@ contextMenuInit = function(){
function appendContextMenuOption(targetElementSelector,entryName,entryFunction){
currentItems = menuSpecs.get(targetElementSelector)
var currentItems = menuSpecs.get(targetElementSelector)
if(!currentItems){
currentItems = []
......@@ -79,7 +78,7 @@ contextMenuInit = function(){
}
function removeContextMenuOption(uid){
menuSpecs.forEach(function(v,k) {
menuSpecs.forEach(function(v) {
let index = -1
v.forEach(function(e,ei){if(e['id']==uid){index=ei}})
if(index>=0){
......@@ -93,8 +92,7 @@ contextMenuInit = function(){
return;
}
gradioApp().addEventListener("click", function(e) {
let source = e.composedPath()[0]
if(source.id && source.id.indexOf('check_progress')>-1){
if(! e.isTrusted){
return
}
......@@ -112,7 +110,6 @@ contextMenuInit = function(){
if(e.composedPath()[0].matches(k)){
showContextMenu(e,e.composedPath()[0],v)
e.preventDefault()
return
}
})
});
......
......@@ -69,8 +69,8 @@ function keyupEditAttention(event){
event.preventDefault();
closeCharacter = ')'
delta = opts.keyedit_precision_attention
var closeCharacter = ')'
var delta = opts.keyedit_precision_attention
if (selectionStart > 0 && text[selectionStart - 1] == '<'){
closeCharacter = '>'
......@@ -91,8 +91,8 @@ function keyupEditAttention(event){
selectionEnd += 1;
}
end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
var end = text.slice(selectionEnd + 1).indexOf(closeCharacter) + 1;
var weight = parseFloat(text.slice(selectionEnd + 1, selectionEnd + 1 + end));
if (isNaN(weight)) return;
weight += isPlus ? delta : -delta;
......
function extensions_apply(_, _, disable_all){
function extensions_apply(_disabled_list, _update_list, disable_all){
var disable = []
var update = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7))
disable.push(x.name.substring(7))
if(x.name.startsWith("update_") && x.checked)
update.push(x.name.substr(7))
update.push(x.name.substring(7))
})
restart_reload()
......@@ -16,12 +16,12 @@ function extensions_apply(_, _, disable_all){
return [JSON.stringify(disable), JSON.stringify(update), disable_all]
}
function extensions_check(_, _){
function extensions_check(){
var disable = []
gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){
if(x.name.startsWith("enable_") && ! x.checked)
disable.push(x.name.substr(7))
disable.push(x.name.substring(7))
})
gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){
......@@ -41,7 +41,7 @@ function install_extension_from_index(button, url){
button.disabled = "disabled"
button.value = "Installing..."
textarea = gradioApp().querySelector('#extension_to_install textarea')
var textarea = gradioApp().querySelector('#extension_to_install textarea')
textarea.value = url
updateInput(textarea)
......
function setupExtraNetworksForTab(tabname){
gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks')
......@@ -10,16 +9,34 @@ function setupExtraNetworksForTab(tabname){
tabs.appendChild(search)
tabs.appendChild(refresh)
search.addEventListener("input", function(evt){
searchTerm = search.value.toLowerCase()
var applyFilter = function(){
var searchTerm = search.value.toLowerCase()
gradioApp().querySelectorAll('#'+tabname+'_extra_tabs div.card').forEach(function(elem){
text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
elem.style.display = text.indexOf(searchTerm) == -1 ? "none" : ""
var searchOnly = elem.querySelector('.search_only')
var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase()
var visible = text.indexOf(searchTerm) != -1
if(searchOnly && searchTerm.length < 4){
visible = false
}
elem.style.display = visible ? "" : "none"
})
});
}
search.addEventListener("input", applyFilter);
applyFilter();
extraNetworksApplyFilter[tabname] = applyFilter;
}
function applyExtraNetworkFilter(tabname){
setTimeout(extraNetworksApplyFilter[tabname], 1);
}
var extraNetworksApplyFilter = {}
var activePromptTextarea = {};
function setupExtraNetworks(){
......@@ -55,7 +72,7 @@ function tryToRemoveExtraNetworkFromPrompt(textarea, text){
var partToSearch = m[1]
var replaced = false
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found, index){
var newTextareaText = textarea.value.replaceAll(re_extranet_g, function(found){
m = found.match(re_extranet);
if(m[1] == partToSearch){
replaced = true;
......@@ -96,9 +113,9 @@ function saveCardPreview(event, tabname, filename){
}
function extraNetworksSearchButton(tabs_id, event){
searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
button = event.target
text = button.classList.contains("search-all") ? "" : button.textContent.trim()
var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea')
var button = event.target
var text = button.classList.contains("search-all") ? "" : button.textContent.trim()
searchTextarea.value = text
updateInput(searchTextarea)
......@@ -133,7 +150,7 @@ function popup(contents){
}
function extraNetworksShowMetadata(text){
elem = document.createElement('pre')
var elem = document.createElement('pre')
elem.classList.add('popup-metadata');
elem.textContent = text;
......@@ -165,7 +182,7 @@ function requestGet(url, data, handler, errorHandler){
}
function extraNetworksRequestMetadata(event, extraPage, cardName){
showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
var showError = function(){ extraNetworksShowMetadata("there was an error getting metadata"); }
requestGet("./sd_extra_networks/metadata", {"page": extraPage, "item": cardName}, function(data){
if(data && data.metadata){
......
......@@ -23,7 +23,7 @@ let modalObserver = new MutationObserver(function(mutations) {
});
function attachGalleryListeners(tab_name) {
gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
var gallery = gradioApp().querySelector('#'+tab_name+'_gallery')
gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name+"_generation_info_button").click());
gallery?.addEventListener('keydown', (e) => {
if (e.keyCode == 37 || e.keyCode == 39) // left or right arrow
......
......@@ -66,8 +66,8 @@ titles = {
"Interrogate": "Reconstruct prompt from existing image and put it into the prompt field.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
"Images filename pattern": "Use following tags to define how filenames for images are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
"Directory name pattern": "Use following tags to define how subdirectories for images and grids are chosen: [steps], [cfg], [denoising], [clip_skip], [batch_number], [generation_number], [prompt_hash], [prompt], [prompt_no_styles], [prompt_spaces], [width], [height], [styles], [sampler], [seed], [model_hash], [model_name], [prompt_words], [date], [datetime], [datetime<Format>], [datetime<Format><Time Zone>], [job_timestamp], [hasprompt<prompt1|default><prompt2>..]; leave empty for default.",
"Max prompt words": "Set the maximum number of words to be used in the [prompt_words] option; ATTENTION: If the words are too long, they may exceed the maximum length of the file path that the system can handle",
"Loopback": "Performs img2img processing multiple times. Output images are used as input for the next loop.",
......@@ -118,16 +118,18 @@ titles = {
onUiUpdate(function(){
gradioApp().querySelectorAll('span, button, select, p').forEach(function(span){
tooltip = titles[span.textContent];
if (span.title) return; // already has a title
if(!tooltip){
tooltip = titles[span.value];
let tooltip = localization[titles[span.textContent]] || titles[span.textContent];
if(!tooltip){
tooltip = localization[titles[span.value]] || titles[span.value];
}
if(!tooltip){
for (const c of span.classList) {
if (c in titles) {
tooltip = titles[c];
tooltip = localization[titles[c]] || titles[c];
break;
}
}
......@@ -142,7 +144,7 @@ onUiUpdate(function(){
if (select.onchange != null) return;
select.onchange = function(){
select.title = titles[select.value] || "";
select.title = localization[titles[select.value]] || titles[select.value] || "";
}
})
})
function setInactive(elem, inactive){
if(inactive){
elem.classList.add('inactive')
} else{
elem.classList.remove('inactive')
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
function setInactive(elem, inactive){
elem.classList.toggle('inactive', !!inactive)
}
}
function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y){
hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale')
var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x')
var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y')
gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""
......
......@@ -2,11 +2,10 @@
* temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668
* @see https://github.com/gradio-app/gradio/issues/1721
*/
window.addEventListener( 'resize', () => imageMaskResize());
function imageMaskResize() {
const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas');
if ( ! canvases.length ) {
canvases_fixed = false;
canvases_fixed = false; // TODO: this is unused..?
window.removeEventListener( 'resize', imageMaskResize );
return;
}
......@@ -15,7 +14,7 @@ function imageMaskResize() {
const previewImage = wrapper.previousElementSibling;
if ( ! previewImage.complete ) {
previewImage.addEventListener( 'load', () => imageMaskResize());
previewImage.addEventListener( 'load', imageMaskResize);
return;
}
......@@ -24,7 +23,6 @@ function imageMaskResize() {
const nw = previewImage.naturalWidth;
const nh = previewImage.naturalHeight;
const portrait = nh > nw;
const factor = portrait;
const wW = Math.min(w, portrait ? h/nh*nw : w/nw*nw);
const wH = Math.min(h, portrait ? h/nh*nh : w/nw*nh);
......@@ -40,6 +38,7 @@ function imageMaskResize() {
c.style.maxHeight = '100%';
c.style.objectFit = 'contain';
});
}
onUiUpdate(() => imageMaskResize());
}
onUiUpdate(imageMaskResize);
window.addEventListener( 'resize', imageMaskResize);
window.onload = (function(){
window.addEventListener('drop', e => {
const target = e.composedPath()[0];
const idx = selected_gallery_index();
if (target.placeholder.indexOf("Prompt") == -1) return;
let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image";
......
......@@ -57,7 +57,7 @@ function modalImageSwitch(offset) {
})
if (result != -1) {
nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
var nextButton = galleryButtons[negmod((result + offset), galleryButtons.length)]
nextButton.click()
const modalImage = gradioApp().getElementById("modalImage");
const modal = gradioApp().getElementById("lightboxModal");
......@@ -144,15 +144,11 @@ function setupImageForLightbox(e) {
}
function modalZoomSet(modalImage, enable) {
if (enable) {
modalImage.classList.add('modalImageFullscreen');
} else {
modalImage.classList.remove('modalImageFullscreen');
}
if(modalImage) modalImage.classList.toggle('modalImageFullscreen', !!enable);
}
function modalZoomToggle(event) {
modalImage = gradioApp().getElementById("modalImage");
var modalImage = gradioApp().getElementById("modalImage");
modalZoomSet(modalImage, !modalImage.classList.contains('modalImageFullscreen'))
event.stopPropagation()
}
......@@ -179,7 +175,7 @@ function galleryImageHandler(e) {
}
onUiUpdate(function() {
fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img')
if (fullImg_preview != null) {
fullImg_preview.forEach(setupImageForLightbox);
}
......
let delay = 350//ms
window.addEventListener('gamepadconnected', (e) => {
console.log("Gamepad connected!")
const gamepad = e.gamepad;
setInterval(() => {
const xValue = gamepad.axes[0].toFixed(2);
if (xValue < -0.3) {
modalPrevImage(e);
} else if (xValue > 0.3) {
modalNextImage(e);
}
}, delay);
});
/*
Primarily for vr controller type pointer devices.
I use the wheel event because there's currently no way to do it properly with web xr.
*/
let isScrolling = false;
window.addEventListener('wheel', (e) => {
if (isScrolling) return;
isScrolling = true;
if (e.deltaX <= -0.6) {
window.addEventListener('gamepadconnected', (e) => {
const index = e.gamepad.index;
let isWaiting = false;
setInterval(async () => {
if (!opts.js_modal_lightbox_gamepad || isWaiting) return;
const gamepad = navigator.getGamepads()[index];
const xValue = gamepad.axes[0];
if (xValue <= -0.3) {
modalPrevImage(e);
} else if (e.deltaX >= 0.6) {
isWaiting = true;
} else if (xValue >= 0.3) {
modalNextImage(e);
isWaiting = true;
}
if (isWaiting) {
await sleepUntil(() => {
const xValue = navigator.getGamepads()[index].axes[0]
if (xValue < 0.3 && xValue > -0.3) {
return true;
}
}, opts.js_modal_lightbox_gamepad_repeat);
isWaiting = false;
}
}, 10);
});
/*
Primarily for vr controller type pointer devices.
I use the wheel event because there's currently no way to do it properly with web xr.
*/
let isScrolling = false;
window.addEventListener('wheel', (e) => {
if (!opts.js_modal_lightbox_gamepad || isScrolling) return;
isScrolling = true;
if (e.deltaX <= -0.6) {
modalPrevImage(e);
} else if (e.deltaX >= 0.6) {
modalNextImage(e);
}
setTimeout(() => {
isScrolling = false;
}, delay);
});
\ No newline at end of file
setTimeout(() => {
isScrolling = false;
}, opts.js_modal_lightbox_gamepad_repeat);
});
function sleepUntil(f, timeout) {
return new Promise((resolve) => {
const timeStart = new Date();
const wait = setInterval(function() {
if (f() || new Date() - timeStart > timeout) {
clearInterval(wait);
resolve();
}
}, 20);
});
}
......@@ -25,6 +25,10 @@ re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u
original_lines = {}
translated_lines = {}
function hasLocalization() {
return window.localization && Object.keys(window.localization).length > 0;
}
function textNodesUnder(el){
var n, a=[], walk=document.createTreeWalker(el,NodeFilter.SHOW_TEXT,null,false);
while(n=walk.nextNode()) a.push(n);
......@@ -35,11 +39,11 @@ function canBeTranslated(node, text){
if(! text) return false;
if(! node.parentElement) return false;
parentType = node.parentElement.nodeName
var parentType = node.parentElement.nodeName
if(parentType=='SCRIPT' || parentType=='STYLE' || parentType=='TEXTAREA') return false;
if (parentType=='OPTION' || parentType=='SPAN'){
pnode = node
var pnode = node
for(var level=0; level<4; level++){
pnode = pnode.parentElement
if(! pnode) break;
......@@ -69,7 +73,7 @@ function getTranslation(text){
}
function processTextNode(node){
text = node.textContent.trim()
var text = node.textContent.trim()
if(! canBeTranslated(node, text)) return
......@@ -105,7 +109,7 @@ function processNode(node){
}
function dumpTranslations(){
dumped = {}
var dumped = {}
if (localization.rtl) {
dumped.rtl = true
}
......@@ -119,39 +123,8 @@ function dumpTranslations(){
return dumped
}
onUiUpdate(function(m){
m.forEach(function(mutation){
mutation.addedNodes.forEach(function(node){
processNode(node)
})
});
})
document.addEventListener("DOMContentLoaded", function() {
processNode(gradioApp())
if (localization.rtl) { // if the language is from right to left,
(new MutationObserver((mutations, observer) => { // wait for the style to load
mutations.forEach(mutation => {
mutation.addedNodes.forEach(node => {
if (node.tagName === 'STYLE') {
observer.disconnect();
for (const x of node.sheet.rules) { // find all rtl media rules
if (Array.from(x.media || []).includes('rtl')) {
x.media.appendMedium('all'); // enable them
}
}
}
})
});
})).observe(gradioApp(), { childList: true });
}
})
function download_localization() {
text = JSON.stringify(dumpTranslations(), null, 4)
var text = JSON.stringify(dumpTranslations(), null, 4)
var element = document.createElement('a');
element.setAttribute('href', 'data:text/plain;charset=utf-8,' + encodeURIComponent(text));
......@@ -163,3 +136,36 @@ function download_localization() {
document.body.removeChild(element);
}
if(hasLocalization()) {
onUiUpdate(function (m) {
m.forEach(function (mutation) {
mutation.addedNodes.forEach(function (node) {
processNode(node)
})
});
})
document.addEventListener("DOMContentLoaded", function () {
processNode(gradioApp())
if (localization.rtl) { // if the language is from right to left,
(new MutationObserver((mutations, observer) => { // wait for the style to load
mutations.forEach(mutation => {
mutation.addedNodes.forEach(node => {
if (node.tagName === 'STYLE') {
observer.disconnect();
for (const x of node.sheet.rules) { // find all rtl media rules
if (Array.from(x.media || []).includes('rtl')) {
x.media.appendMedium('all'); // enable them
}
}
}
})
});
})).observe(gradioApp(), { childList: true });
}
})
}
......@@ -2,15 +2,15 @@
let lastHeadImg = null;
notificationButton = null
let notificationButton = null;
onUiUpdate(function(){
if(notificationButton == null){
notificationButton = gradioApp().getElementById('request_notifications')
if(notificationButton != null){
notificationButton.addEventListener('click', function (evt) {
Notification.requestPermission();
notificationButton.addEventListener('click', () => {
void Notification.requestPermission();
},true);
}
}
......
// code related to showing and updating progressbar shown as the image is being made
function rememberGallerySelection(id_gallery){
function rememberGallerySelection(){
}
function getGallerySelectedIndex(id_gallery){
function getGallerySelectedIndex(){
}
function request(url, data, handler, errorHandler){
var xhr = new XMLHttpRequest();
var url = url;
xhr.open("POST", url, true);
xhr.setRequestHeader("Content-Type", "application/json");
xhr.onreadystatechange = function () {
......@@ -107,7 +106,7 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgre
divProgress.style.width = rect.width + "px";
}
progressText = ""
let progressText = ""
divInner.style.width = ((res.progress || 0) * 100.0) + '%'
divInner.style.background = res.progress ? "" : "transparent"
......
// various functions for interaction with ui.py not large enough to warrant putting them in separate files
function set_theme(theme){
gradioURL = window.location.href
var gradioURL = window.location.href
if (!gradioURL.includes('?__theme=')) {
window.location.replace(gradioURL + '?__theme=' + theme);
}
......@@ -47,7 +47,7 @@ function extract_image_from_gallery(gallery){
return [gallery[0]];
}
index = selected_gallery_index()
var index = selected_gallery_index()
if (index < 0 || index >= gallery.length){
// Use the first image in the gallery as the default
......@@ -58,7 +58,7 @@ function extract_image_from_gallery(gallery){
}
function args_to_array(args){
res = []
var res = []
for(var i=0;i<args.length;i++){
res.push(args[i])
}
......@@ -138,7 +138,7 @@ function get_img2img_tab_index() {
}
function create_submit_args(args){
res = []
var res = []
for(var i=0;i<args.length;i++){
res.push(args[i])
}
......@@ -160,7 +160,7 @@ function showSubmitButtons(tabname, show){
}
function showRestoreProgressButton(tabname, show){
button = gradioApp().getElementById(tabname + "_restore_progress")
var button = gradioApp().getElementById(tabname + "_restore_progress")
if(! button) return
button.style.display = show ? "flex" : "none"
......@@ -207,8 +207,9 @@ function submit_img2img(){
return res
}
function restoreProgressTxt2img(x){
function restoreProgressTxt2img(){
showRestoreProgressButton("txt2img", false)
var id = localStorage.getItem("txt2img_task_id")
id = localStorage.getItem("txt2img_task_id")
......@@ -220,10 +221,11 @@ function restoreProgressTxt2img(x){
return id
}
function restoreProgressImg2img(x){
showRestoreProgressButton("img2img", false)
id = localStorage.getItem("img2img_task_id")
function restoreProgressImg2img(){
showRestoreProgressButton("img2img", false)
var id = localStorage.getItem("img2img_task_id")
if(id) {
requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){
......@@ -252,7 +254,7 @@ function modelmerger(){
function ask_for_style_name(_, prompt_text, negative_prompt_text) {
name_ = prompt('Style name:')
var name_ = prompt('Style name:')
return [name_, prompt_text, negative_prompt_text]
}
......@@ -287,11 +289,11 @@ function recalculate_prompts_img2img(){
}
opts = {}
var opts = {}
onUiUpdate(function(){
if(Object.keys(opts).length != 0) return;
json_elem = gradioApp().getElementById('settings_json')
var json_elem = gradioApp().getElementById('settings_json')
if(json_elem == null) return;
var textarea = json_elem.querySelector('textarea')
......@@ -340,12 +342,15 @@ onUiUpdate(function(){
registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button')
registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button')
show_all_pages = gradioApp().getElementById('settings_show_all_pages')
settings_tabs = gradioApp().querySelector('#settings div')
var show_all_pages = gradioApp().getElementById('settings_show_all_pages')
var settings_tabs = gradioApp().querySelector('#settings div')
if(show_all_pages && settings_tabs){
settings_tabs.appendChild(show_all_pages)
show_all_pages.onclick = function(){
gradioApp().querySelectorAll('#settings > div').forEach(function(elem){
if(elem.id == "settings_tab_licenses")
return;
elem.style.display = "block";
})
}
......@@ -353,9 +358,9 @@ onUiUpdate(function(){
})
onOptionsChanged(function(){
elem = gradioApp().getElementById('sd_checkpoint_hash')
sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
shorthash = sd_checkpoint_hash.substr(0,10)
var elem = gradioApp().getElementById('sd_checkpoint_hash')
var sd_checkpoint_hash = opts.sd_checkpoint_hash || ""
var shorthash = sd_checkpoint_hash.substring(0,10)
if(elem && elem.textContent != shorthash){
elem.textContent = shorthash
......@@ -390,7 +395,16 @@ function update_token_counter(button_id) {
function restart_reload(){
document.body.innerHTML='<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>';
setTimeout(function(){location.reload()},2000)
var requestPing = function(){
requestGet("./internal/ping", {}, function(data){
location.reload();
}, function(){
setTimeout(requestPing, 500);
})
}
setTimeout(requestPing, 2000);
return []
}
......
// various hints and extra info for the settings tab
onUiLoaded(function(){
createLink = function(elem_id, text, href){
var a = document.createElement('A')
a.textContent = text
a.target = '_blank';
elem = gradioApp().querySelector('#'+elem_id)
elem.insertBefore(a, elem.querySelector('label'))
return a
}
createLink("setting_samples_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
createLink("setting_directories_filename_pattern", "[wiki] ").href = "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"
createLink("setting_quicksettings_list", "[info] ").addEventListener("click", function(event){
requestGet("./internal/quicksettings-hint", {}, function(data){
var table = document.createElement('table')
table.className = 'settings-value-table'
data.forEach(function(obj){
var tr = document.createElement('tr')
var td = document.createElement('td')
td.textContent = obj.name
tr.appendChild(td)
var td = document.createElement('td')
td.textContent = obj.label
tr.appendChild(td)
table.appendChild(tr)
})
popup(table);
})
});
})
......@@ -19,6 +19,7 @@ python = sys.executable
git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
stored_git_tag = None
dir_repos = "repositories"
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
......@@ -70,6 +71,20 @@ def commit_hash():
return stored_commit_hash
def git_tag():
global stored_git_tag
if stored_git_tag is not None:
return stored_git_tag
try:
stored_git_tag = run(f"{git} describe --tags").strip()
except Exception:
stored_git_tag = "<none>"
return stored_git_tag
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
if desc is not None:
print(desc)
......@@ -222,7 +237,7 @@ def run_extensions_installers(settings_file):
def prepare_environment():
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --extra-index-url https://download.pytorch.org/whl/cu118")
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118")
requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17')
......@@ -246,8 +261,10 @@ def prepare_environment():
check_python_version()
commit = commit_hash()
tag = git_tag()
print(f"Python {sys.version}")
print(f"Version: {tag}")
print(f"Commit hash: {commit}")
if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
......
......@@ -570,20 +570,20 @@ class Api:
filename = create_embedding(**args) # create empty embedding
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() # reload embeddings so new one can be immediately used
shared.state.end()
return CreateResponse(info = "create embedding filename: {filename}".format(filename = filename))
return CreateResponse(info=f"create embedding filename: {filename}")
except AssertionError as e:
shared.state.end()
return TrainResponse(info = "create embedding error: {error}".format(error = e))
return TrainResponse(info=f"create embedding error: {e}")
def create_hypernetwork(self, args: dict):
try:
shared.state.begin()
filename = create_hypernetwork(**args) # create empty embedding
shared.state.end()
return CreateResponse(info = "create hypernetwork filename: {filename}".format(filename = filename))
return CreateResponse(info=f"create hypernetwork filename: {filename}")
except AssertionError as e:
shared.state.end()
return TrainResponse(info = "create hypernetwork error: {error}".format(error = e))
return TrainResponse(info=f"create hypernetwork error: {e}")
def preprocess(self, args: dict):
try:
......@@ -593,13 +593,13 @@ class Api:
return PreprocessResponse(info = 'preprocess complete')
except KeyError as e:
shared.state.end()
return PreprocessResponse(info = "preprocess error: invalid token: {error}".format(error = e))
return PreprocessResponse(info=f"preprocess error: invalid token: {e}")
except AssertionError as e:
shared.state.end()
return PreprocessResponse(info = "preprocess error: {error}".format(error = e))
return PreprocessResponse(info=f"preprocess error: {e}")
except FileNotFoundError as e:
shared.state.end()
return PreprocessResponse(info = 'preprocess error: {error}'.format(error = e))
return PreprocessResponse(info=f'preprocess error: {e}')
def train_embedding(self, args: dict):
try:
......@@ -617,10 +617,10 @@ class Api:
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error))
return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError as msg:
shared.state.end()
return TrainResponse(info = "train embedding error: {msg}".format(msg = msg))
return TrainResponse(info=f"train embedding error: {msg}")
def train_hypernetwork(self, args: dict):
try:
......@@ -641,10 +641,10 @@ class Api:
if not apply_optimizations:
sd_hijack.apply_optimizations()
shared.state.end()
return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error))
return TrainResponse(info=f"train embedding complete: filename: {filename} error: {error}")
except AssertionError as msg:
shared.state.end()
return TrainResponse(info="train embedding error: {error}".format(error=error))
return TrainResponse(info=f"train embedding error: {error}")
def get_memory(self):
try:
......
......@@ -60,7 +60,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
max_debug_str_len = 131072 # (1024*1024)/8
print("Error completing request", file=sys.stderr)
argStr = f"Arguments: {str(args)} {str(kwargs)}"
argStr = f"Arguments: {args} {kwargs}"
print(argStr[:max_debug_str_len], file=sys.stderr)
if len(argStr) > max_debug_str_len:
print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr)
......@@ -73,7 +73,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
if extra_outputs_array is None:
extra_outputs_array = [None, '']
res = extra_outputs_array + [f"<div class='error'>{html.escape(type(e).__name__+': '+str(e))}</div>"]
error_message = f'{type(e).__name__}: {e}'
res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]
shared.state.skipped = False
shared.state.interrupted = False
......
......@@ -102,3 +102,4 @@ parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gra
parser.add_argument("--skip-version-check", action='store_true', help="Do not check versions of torch and xformers")
parser.add_argument("--no-hashing", action='store_true', help="disable sha256 hashing of checkpoints to help loading performance", default=False)
parser.add_argument("--no-download-sd-model", action='store_true', help="don't download SD1.5 model even if no model is found in --ckpt-dir", default=False)
parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
\ No newline at end of file
......@@ -156,13 +156,16 @@ class UpscalerESRGAN(Upscaler):
def load_model(self, path: str):
if "http" in path:
filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
file_name="%s.pth" % self.model_name,
progress=True)
filename = load_file_from_url(
url=self.model_url,
model_dir=self.model_path,
file_name=f"{self.model_name}.pth",
progress=True,
)
else:
filename = path
if not os.path.exists(filename) or filename is None:
print("Unable to load %s from %s" % (self.model_path, filename))
print(f"Unable to load {self.model_path} from {filename}")
return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
......
......@@ -38,7 +38,7 @@ class RRDBNet(nn.Module):
elif upsample_mode == 'pixelshuffle':
upsample_block = pixelshuffle_block
else:
raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode))
raise NotImplementedError(f'upsample mode [{upsample_mode}] is not found')
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type, convtype=convtype)
else:
......@@ -261,10 +261,10 @@ class Upsample(nn.Module):
def extra_repr(self):
if self.scale_factor is not None:
info = 'scale_factor=' + str(self.scale_factor)
info = f'scale_factor={self.scale_factor}'
else:
info = 'size=' + str(self.size)
info += ', mode=' + self.mode
info = f'size={self.size}'
info += f', mode={self.mode}'
return info
......@@ -350,7 +350,7 @@ def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1, beta=1.0):
elif act_type == 'sigmoid': # [0, 1] range output
layer = nn.Sigmoid()
else:
raise NotImplementedError('activation layer [{:s}] is not found'.format(act_type))
raise NotImplementedError(f'activation layer [{act_type}] is not found')
return layer
......@@ -372,7 +372,7 @@ def norm(norm_type, nc):
elif norm_type == 'none':
def norm_layer(x): return Identity()
else:
raise NotImplementedError('normalization layer [{:s}] is not found'.format(norm_type))
raise NotImplementedError(f'normalization layer [{norm_type}] is not found')
return layer
......@@ -388,7 +388,7 @@ def pad(pad_type, padding):
elif pad_type == 'zero':
layer = nn.ZeroPad2d(padding)
else:
raise NotImplementedError('padding layer [{:s}] is not implemented'.format(pad_type))
raise NotImplementedError(f'padding layer [{pad_type}] is not implemented')
return layer
......@@ -432,7 +432,7 @@ def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=
pad_type='zero', norm_type=None, act_type='relu', mode='CNA', convtype='Conv2D',
spectral_norm=False):
""" Conv layer with padding, normalization, activation """
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wrong conv mode [{:s}]'.format(mode)
assert mode in ['CNA', 'NAC', 'CNAC'], f'Wrong conv mode [{mode}]'
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
padding = padding if pad_type == 'zero' else 0
......
......@@ -10,7 +10,8 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork):
additional = shared.opts.sd_hypernetwork
if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0:
p.all_prompts = [x + f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
hypernet_prompt_text = f"<hypernet:{additional}:{shared.opts.extra_networks_default_multiplier}>"
p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts]
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
......
......@@ -59,6 +59,7 @@ def image_from_url_text(filedata):
is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
assert is_in_right_dir, 'trying to open image file outside of allowed directories'
filename = filename.rsplit('?', 1)[0]
return Image.open(filename)
if type(filedata) == list:
......@@ -129,6 +130,7 @@ def connect_paste_params_buttons():
_js=jsfunc,
inputs=[binding.source_image_component],
outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
show_progress=False,
)
if binding.source_text_component is not None and fields is not None:
......@@ -140,6 +142,7 @@ def connect_paste_params_buttons():
fn=lambda *x: x,
inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
outputs=[field for field, name in fields if name in paste_field_names],
show_progress=False,
)
binding.paste_button.click(
......@@ -147,6 +150,7 @@ def connect_paste_params_buttons():
_js=f"switch_to_{binding.tabname}",
inputs=None,
outputs=None,
show_progress=False,
)
......@@ -265,8 +269,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v
m = re_imagesize.match(v)
if m is not None:
res[k+"-1"] = m.group(1)
res[k+"-2"] = m.group(2)
res[f"{k}-1"] = m.group(1)
res[f"{k}-2"] = m.group(2)
else:
res[k] = v
......@@ -409,12 +413,14 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component,
fn=paste_func,
inputs=[input_comp],
outputs=[x[0] for x in paste_fields],
show_progress=False,
)
button.click(
fn=None,
_js=f"recalculate_prompts_{tabname}",
inputs=[],
outputs=[],
show_progress=False,
)
......@@ -13,7 +13,7 @@ cache_data = None
def dump_cache():
with filelock.FileLock(cache_filename+".lock"):
with filelock.FileLock(f"{cache_filename}.lock"):
with open(cache_filename, "w", encoding="utf8") as file:
json.dump(cache_data, file, indent=4)
......@@ -22,7 +22,7 @@ def cache(subsection):
global cache_data
if cache_data is None:
with filelock.FileLock(cache_filename+".lock"):
with filelock.FileLock(f"{cache_filename}.lock"):
if not os.path.isfile(cache_filename):
cache_data = {}
else:
......
......@@ -357,6 +357,7 @@ class FilenameGenerator:
'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1,
'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt<prompt1|default><prompt2>..]
'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"],
'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT,
}
default_time_format = '%Y%m%d%H%M%S'
......@@ -466,7 +467,7 @@ def get_next_sequence_number(path, basename):
"""
result = -1
if basename != '':
basename = basename + "-"
basename = f"{basename}-"
prefix_length = len(basename)
for p in os.listdir(path):
......@@ -535,7 +536,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
add_number = opts.save_images_add_number or file_decoration == ''
if file_decoration != "" and add_number:
file_decoration = "-" + file_decoration
file_decoration = f"-{file_decoration}"
file_decoration = namegen.apply(file_decoration) + suffix
......@@ -565,7 +566,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
def _atomically_save_image(image_to_save, filename_without_extension, extension):
# save image with .tmp extension to avoid race condition when another process detects new image in the directory
temp_file_path = filename_without_extension + ".tmp"
temp_file_path = f"{filename_without_extension}.tmp"
image_format = Image.registered_extensions()[extension]
if extension.lower() == '.png':
......@@ -625,7 +626,7 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
if opts.save_txt and info is not None:
txt_fullfn = f"{fullfn_without_extension}.txt"
with open(txt_fullfn, "w", encoding="utf8") as file:
file.write(info + "\n")
file.write(f"{info}\n")
else:
txt_fullfn = None
......
......@@ -48,7 +48,8 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args):
try:
img = Image.open(image)
except UnidentifiedImageError:
except UnidentifiedImageError as e:
print(e)
continue
# Use the EXIF orientation of photos taken by smartphones.
img = ImageOps.exif_transpose(img)
......
......@@ -28,7 +28,7 @@ def category_types():
def download_default_clip_interrogate_categories(content_dir):
print("Downloading CLIP categories...")
tmpdir = content_dir + "_tmp"
tmpdir = f"{content_dir}_tmp"
category_types = ["artists", "flavors", "mediums", "movements"]
try:
......@@ -214,7 +214,7 @@ class InterrogateModels:
if shared.opts.interrogate_return_ranks:
res += f", ({match}:{score/100:.3f})"
else:
res += ", " + match
res += f", {match}"
except Exception:
print("Error interrogating", file=sys.stderr)
......
......@@ -54,6 +54,11 @@ if has_mps:
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
if version.parse(torch.__version__) == version.parse("2.0"):
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
\ No newline at end of file
......@@ -22,9 +22,6 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
"""
output = []
if ext_filter is None:
ext_filter = []
try:
places = []
......@@ -39,22 +36,14 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
places.append(model_path)
for place in places:
if os.path.exists(place):
for file in glob.iglob(place + '**/**', recursive=True):
full_path = file
if os.path.isdir(full_path):
continue
if os.path.islink(full_path) and not os.path.exists(full_path):
print(f"Skipping broken symlink: {full_path}")
continue
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
continue
if len(ext_filter) != 0:
model_name, extension = os.path.splitext(file)
if extension not in ext_filter:
continue
if file not in output:
output.append(full_path)
for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
if os.path.islink(full_path) and not os.path.exists(full_path):
print(f"Skipping broken symlink: {full_path}")
continue
if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
continue
if full_path not in output:
output.append(full_path)
if model_url is not None and len(output) == 0:
if download_name is not None:
......@@ -133,12 +122,9 @@ forbidden_upscaler_classes = set()
def list_builtin_upscalers():
load_upscalers()
builtin_upscaler_classes.clear()
builtin_upscaler_classes.extend(Upscaler.__subclasses__())
def forbid_loaded_nonbuiltin_upscalers():
for cls in Upscaler.__subclasses__():
if cls not in builtin_upscaler_classes:
......
......@@ -223,7 +223,7 @@ class DDPM(pl.LightningModule):
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
print(f"Deleting key {k} from state_dict.")
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
sd, strict=False)
......@@ -386,7 +386,7 @@ class DDPM(pl.LightningModule):
_, loss_dict_no_ema = self.shared_step(batch)
with self.ema_scope():
_, loss_dict_ema = self.shared_step(batch)
loss_dict_ema = {key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
loss_dict_ema = {f"{key}_ema": loss_dict_ema[key] for key in loss_dict_ema}
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
......
......@@ -94,7 +94,7 @@ class NoiseScheduleVP:
"""
if schedule not in ['discrete', 'linear', 'cosine']:
raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
raise ValueError(f"Unsupported noise schedule {schedule}. The schedule needs to be 'discrete' or 'linear' or 'cosine'")
self.schedule = schedule
if schedule == 'discrete':
......@@ -469,7 +469,7 @@ class UniPC:
t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
return t
else:
raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
raise ValueError(f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'")
def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
"""
......
......@@ -7,8 +7,8 @@ def connect(token, port, region):
else:
if ':' in token:
# token = authtoken:username:password
account = token.split(':')[1] + ':' + token.split(':')[-1]
token = token.split(':')[0]
token, username, password = token.split(':', 2)
account = f"{username}:{password}"
config = conf.PyngrokConfig(
auth_token=token, region=region
......
......@@ -16,7 +16,7 @@ for possible_sd_path in possible_sd_paths:
sd_path = os.path.abspath(possible_sd_path)
break
assert sd_path is not None, "Couldn't find Stable Diffusion in any of: " + str(possible_sd_paths)
assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"
path_dirs = [
(sd_path, 'ldm', 'Stable Diffusion', []),
......
......@@ -458,6 +458,16 @@ def fix_seed(p):
p.subseed = get_fixed_seed(p.subseed)
def program_version():
import launch
res = launch.git_tag()
if res == "<none>":
res = None
return res
def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
index = position_in_batch + iteration * p.batch_size
......@@ -483,13 +493,14 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
"Version": program_version() if opts.add_version_to_infotext else None,
}
generation_params.update(p.extra_generation_params)
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
negative_prompt_text = "\nNegative prompt: " + p.all_negative_prompts[index] if p.all_negative_prompts[index] else ""
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
......@@ -769,7 +780,16 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
devices.torch_gc()
res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts)
res = Processed(
p,
images_list=output_images,
seed=p.all_seeds[0],
info=infotext(),
comments="".join(f"\n\n{comment}" for comment in comments),
subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image,
infotexts=infotexts,
)
if p.scripts is not None:
p.scripts.postprocess(p, res)
......
......@@ -96,7 +96,8 @@ def progressapi(req: ProgressRequest):
if image is not None:
buffered = io.BytesIO()
image.save(buffered, format="png")
live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii")
base64_image = base64.b64encode(buffered.getvalue()).decode('ascii')
live_preview = f"data:image/png;base64,{base64_image}"
id_live_preview = shared.state.id_live_preview
else:
live_preview = None
......
......@@ -28,9 +28,9 @@ class UpscalerRealESRGAN(Upscaler):
for scaler in scalers:
if scaler.local_data_path.startswith("http"):
filename = modelloader.friendly_name(scaler.local_data_path)
local = next(iter([local_model for local_model in local_model_paths if local_model.endswith(filename + '.pth')]), None)
if local:
scaler.local_data_path = local
local_model_candidates = [local_model for local_model in local_model_paths if local_model.endswith(f"{filename}.pth")]
if local_model_candidates:
scaler.local_data_path = local_model_candidates[0]
if scaler.name in opts.realesrgan_enabled_models:
self.scalers.append(scaler)
......@@ -47,7 +47,7 @@ class UpscalerRealESRGAN(Upscaler):
info = self.load_model(path)
if not os.path.exists(info.local_data_path):
print("Unable to load RealESRGAN model: %s" % info.name)
print(f"Unable to load RealESRGAN model: {info.name}")
return img
upsampler = RealESRGANer(
......
......@@ -163,7 +163,8 @@ class Script:
"""helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
need_tabname = self.show(True) == self.show(False)
tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
tabkind = 'img2img' if self.is_img2img else 'txt2txt'
tabname = f"{tabkind}_" if need_tabname else ""
title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
return f'script_{tabname}{title}_{item_id}'
......@@ -526,7 +527,7 @@ def add_classes_to_gradio_component(comp):
this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
"""
comp.elem_classes = ["gradio-" + comp.get_block_name(), *(comp.elem_classes or [])]
comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
if getattr(comp, 'multiselect', False):
comp.elem_classes.append('multiselect')
......
......@@ -75,7 +75,8 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text
self.hijack.comments += hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms)
self.hijack.comments.append(f"Used embeddings: {embedding_names}")
self.hijack.fixes = hijack_fixes
return self.process_tokens(remade_batch_tokens, batch_multipliers)
......@@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
if q.device.type == 'mps':
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
dtype = q.dtype
if shared.opts.upcast_attn:
q, k = q.float(), k.float()
......
......@@ -18,7 +18,7 @@ class TorchHijackForUnet:
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def cat(self, tensors, *args, **kwargs):
if len(tensors) == 2:
......
......@@ -2,6 +2,8 @@ import collections
import os.path
import sys
import gc
import threading
import torch
import re
import safetensors.torch
......@@ -45,7 +47,7 @@ class CheckpointInfo:
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
self.hash = model_hash(filename)
self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
......@@ -67,7 +69,7 @@ class CheckpointInfo:
checkpoint_alisases[id] = self
def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
if self.sha256 is None:
return
......@@ -404,13 +406,39 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
class SdModelData:
def __init__(self):
self.sd_model = None
self.lock = threading.Lock()
def get_sd_model(self):
if self.sd_model is None:
with self.lock:
try:
load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load", file=sys.stderr)
self.sd_model = None
return self.sd_model
def set_sd_model(self, v):
self.sd_model = v
model_data = SdModelData()
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
if shared.sd_model:
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
if model_data.sd_model:
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
model_data.sd_model = None
gc.collect()
devices.torch_gc()
......@@ -464,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
timer.record("hijack")
sd_model.eval()
shared.sd_model = sd_model
model_data.sd_model = sd_model
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
......@@ -484,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None):
checkpoint_info = info or select_checkpoint()
if not sd_model:
sd_model = shared.sd_model
sd_model = model_data.sd_model
if sd_model is None: # previous model load failed
current_checkpoint_info = None
......@@ -512,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None):
del sd_model
checkpoints_loaded.clear()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return shared.sd_model
return model_data.sd_model
try:
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
......@@ -535,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None):
return sd_model
def unload_model_weights(sd_model=None, info=None):
from modules import lowvram, devices, sd_hijack
timer = Timer()
if shared.sd_model:
# shared.sd_model.cond_stage_model.to(devices.cpu)
# shared.sd_model.first_stage_model.to(devices.cpu)
shared.sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(shared.sd_model)
shared.sd_model = None
if model_data.sd_model:
model_data.sd_model.to(devices.cpu)
sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
model_data.sd_model = None
sd_model = None
gc.collect()
devices.torch_gc()
......
......@@ -111,7 +111,7 @@ def find_checkpoint_config_near_filename(info):
if info is None:
return None
config = os.path.splitext(info.filename)[0] + ".yaml"
config = f"{os.path.splitext(info.filename)[0]}.yaml"
if os.path.exists(config):
return config
......
......@@ -198,7 +198,7 @@ class TorchHijack:
if hasattr(torch, item):
return getattr(torch, item)
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
if self.sampler_noises:
......
......@@ -89,7 +89,7 @@ def refresh_vae_list():
def find_vae_near_checkpoint(checkpoint_file):
checkpoint_path = os.path.splitext(checkpoint_file)[0]
for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
for vae_location in [f"{checkpoint_path}.vae.pt", f"{checkpoint_path}.vae.ckpt", f"{checkpoint_path}.vae.safetensors"]:
if os.path.isfile(vae_location):
return vae_location
......
......@@ -16,6 +16,7 @@ import modules.styles
import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
from ldm.models.diffusion.ddpm import LatentDiffusion
demo = None
......@@ -391,21 +392,20 @@ options_templates.update(options_section(('ui', "User interface"), {
"return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),
"return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),
"do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
"send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),
"send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),
"font": OptionInfo("", "Font for image grids that have text"),
"js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
"js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
"js_modal_lightbox_gamepad": OptionInfo(True, "Navigate image viewer with gamepad"),
"js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"),
"show_progress_in_title": OptionInfo(True, "Show generation progress in window title."),
"samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group"),
"dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row"),
"keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}),
"keyedit_delimiters": OptionInfo(".,\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"quicksettings": OptionInfo("sd_model_checkpoint", "Quicksettings list"),
"keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"),
"quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}),
"hidden_tabs": OptionInfo([], "Hidden UI tabs (requires restart)", ui_components.DropdownMulti, lambda: {"choices": [x for x in tab_names]}),
"ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"),
"ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"),
......@@ -413,6 +413,13 @@ options_templates.update(options_section(('ui', "User interface"), {
"gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes})
}))
options_templates.update(options_section(('infotext', "Infotext"), {
"add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
"add_model_name_to_info": OptionInfo(True, "Add model name to generation information"),
"add_version_to_infotext": OptionInfo(True, "Add program version to generation information"),
"disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."),
}))
options_templates.update(options_section(('ui', "Live previews"), {
"show_progressbar": OptionInfo(True, "Show progressbar"),
"live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
......@@ -542,6 +549,10 @@ class Options:
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
# 1.1.1 quicksettings list migration
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
bad_settings = 0
for k, v in self.data.items():
info = self.data_labels.get(k, None)
......@@ -600,13 +611,37 @@ class Options:
return value
opts = Options()
if os.path.exists(config_filename):
opts.load(config_filename)
class Shared(sys.modules[__name__].__class__):
"""
this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
at program startup.
"""
sd_model_val = None
@property
def sd_model(self):
import modules.sd_models
return modules.sd_models.model_data.get_sd_model()
@sd_model.setter
def sd_model(self, value):
import modules.sd_models
modules.sd_models.model_data.set_sd_model(value)
sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
sys.modules[__name__].__class__ = Shared
settings_components = None
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""
latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
......@@ -620,8 +655,6 @@ latent_upscale_modes = {
sd_upscalers = []
sd_model = None
clip_model = None
progress_print_out = sys.stdout
......@@ -639,8 +672,8 @@ def reload_gradio_theme(theme_name=None):
else:
try:
gradio_theme = gr.themes.ThemeClass.from_hub(theme_name)
except requests.exceptions.ConnectionError:
print("Can't access HuggingFace Hub, falling back to default Gradio theme")
except Exception as e:
errors.display(e, "changing gradio theme")
gradio_theme = gr.themes.Default()
......@@ -701,3 +734,20 @@ def html(filename):
return file.read()
return ""
def walk_files(path, allowed_extensions=None):
if not os.path.exists(path):
return
if allowed_extensions is not None:
allowed_extensions = set(allowed_extensions)
for root, dirs, files in os.walk(path):
for filename in files:
if allowed_extensions is not None:
_, ext = os.path.splitext(filename)
if ext not in allowed_extensions:
continue
yield os.path.join(root, filename)
......@@ -74,7 +74,7 @@ class StyleDatabase:
def save_styles(self, path: str) -> None:
# Always keep a backup file around
if os.path.exists(path):
shutil.copy(path, path + ".bak")
shutil.copy(path, f"{path}.bak")
fd = os.open(path, os.O_RDWR|os.O_CREAT)
with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
......
......@@ -111,7 +111,7 @@ def focal_point(im, settings):
if corner_centroid is not None:
color = BLUE
box = corner_centroid.bounding(max_size * corner_centroid.weight)
d.text((box[0], box[1]-15), "Edge: %.02f" % corner_centroid.weight, fill=color)
d.text((box[0], box[1]-15), f"Edge: {corner_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color)
if len(corner_points) > 1:
for f in corner_points:
......@@ -119,7 +119,7 @@ def focal_point(im, settings):
if entropy_centroid is not None:
color = "#ff0"
box = entropy_centroid.bounding(max_size * entropy_centroid.weight)
d.text((box[0], box[1]-15), "Entropy: %.02f" % entropy_centroid.weight, fill=color)
d.text((box[0], box[1]-15), f"Entropy: {entropy_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color)
if len(entropy_points) > 1:
for f in entropy_points:
......@@ -127,7 +127,7 @@ def focal_point(im, settings):
if face_centroid is not None:
color = RED
box = face_centroid.bounding(max_size * face_centroid.weight)
d.text((box[0], box[1]-15), "Face: %.02f" % face_centroid.weight, fill=color)
d.text((box[0], box[1]-15), f"Face: {face_centroid.weight:.02f}", fill=color)
d.ellipse(box, outline=color)
if len(face_points) > 1:
for f in face_points:
......
......@@ -72,7 +72,7 @@ class PersonalizedBase(Dataset):
except Exception:
continue
text_filename = os.path.splitext(path)[0] + ".txt"
text_filename = f"{os.path.splitext(path)[0]}.txt"
filename = os.path.basename(path)
if os.path.exists(text_filename):
......
......@@ -63,9 +63,9 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti
image.save(os.path.join(params.dstdir, f"{basename}.png"))
if params.preprocess_txt_action == 'prepend' and existing_caption:
caption = existing_caption + ' ' + caption
caption = f"{existing_caption} {caption}"
elif params.preprocess_txt_action == 'append' and existing_caption:
caption = caption + ' ' + existing_caption
caption = f"{caption} {existing_caption}"
elif params.preprocess_txt_action == 'copy' and existing_caption:
caption = existing_caption
......@@ -174,7 +174,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
params.src = filename
existing_caption = None
existing_caption_filename = os.path.splitext(filename)[0] + '.txt'
existing_caption_filename = f"{os.path.splitext(filename)[0]}.txt"
if os.path.exists(existing_caption_filename):
with open(existing_caption_filename, 'r', encoding="utf8") as file:
existing_caption = file.read()
......
......@@ -69,7 +69,7 @@ class Embedding:
'hash': self.checksum(),
'optimizer_state_dict': self.optimizer_state_dict,
}
torch.save(optimizer_saved_dict, filename + '.optim')
torch.save(optimizer_saved_dict, f"{filename}.optim")
def checksum(self):
if self.cached_checksum is not None:
......@@ -437,8 +437,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
if shared.opts.save_optimizer_state:
optimizer_state_dict = None
if os.path.exists(filename + '.optim'):
optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
if os.path.exists(f"{filename}.optim"):
optimizer_saved_dict = torch.load(f"{filename}.optim", map_location='cpu')
if embedding.checksum() == optimizer_saved_dict.get('hash', None):
optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
......@@ -599,7 +599,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
data = torch.load(last_saved_file)
info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name', '???'))
title = f"<{data.get('name', '???')}>"
try:
vectorSize = list(data['string_to_param'].values())[0].shape[0]
......@@ -608,8 +608,8 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.shorthash)
footer_right = '{}v {}s'.format(vectorSize, steps_done)
footer_mid = f'[{checkpoint.shorthash}]'
footer_right = f'{vectorSize}v {steps_done}s'
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
......
This diff is collapsed.
......@@ -61,7 +61,8 @@ def save_config_state(name):
if not name:
name = "Config"
current_config_state["name"] = name
filename = os.path.join(config_states_dir, datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + "_" + name + ".json")
timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S')
filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json")
print(f"Saving backup of webui/extension state to {filename}.")
with open(filename, "w", encoding="utf-8") as f:
json.dump(current_config_state, f)
......
......@@ -69,7 +69,9 @@ class ExtraNetworksPage:
pass
def link_preview(self, filename):
return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))
quoted_filename = urllib.parse.quote(filename.replace('\\', '/'))
mtime = os.path.getmtime(filename)
return f"./sd_extra_networks/thumb?filename={quoted_filename}&mtime={mtime}"
def search_terms_from_path(self, filename, possible_directories=None):
abspath = os.path.abspath(filename)
......@@ -89,19 +91,22 @@ class ExtraNetworksPage:
subdirs = {}
for parentdir in [os.path.abspath(x) for x in self.allowed_directories_for_previews()]:
for x in glob.glob(os.path.join(parentdir, '**/*'), recursive=True):
if not os.path.isdir(x):
continue
for root, dirs, files in os.walk(parentdir):
for dirname in dirs:
x = os.path.join(root, dirname)
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
while subdir.startswith("/"):
subdir = subdir[1:]
if not os.path.isdir(x):
continue
is_empty = len(os.listdir(x)) == 0
if not is_empty and not subdir.endswith("/"):
subdir = subdir + "/"
subdir = os.path.abspath(x)[len(parentdir):].replace("\\", "/")
while subdir.startswith("/"):
subdir = subdir[1:]
subdirs[subdir] = 1
is_empty = len(os.listdir(x)) == 0
if not is_empty and not subdir.endswith("/"):
subdir = subdir + "/"
subdirs[subdir] = 1
if subdirs:
subdirs = {"": 1, **subdirs}
......@@ -157,8 +162,20 @@ class ExtraNetworksPage:
if metadata:
metadata_button = f"<div class='metadata-button' title='Show metadata' onclick='extraNetworksRequestMetadata(event, {json.dumps(self.name)}, {json.dumps(item['name'])})'></div>"
local_path = ""
filename = item.get("filename", "")
for reldir in self.allowed_directories_for_previews():
absdir = os.path.abspath(reldir)
if filename.startswith(absdir):
local_path = filename[len(absdir):]
# if this is true, the item must not be show in the default view, and must instead only be
# shown when searching for it
serach_only = "/." in local_path or "\\." in local_path
args = {
"style": f"'{height}{width}{background_image}'",
"style": f"'display: none; {height}{width}{background_image}'",
"prompt": item.get("prompt", None),
"tabname": json.dumps(tabname),
"local_preview": json.dumps(item["local_preview"]),
......@@ -168,6 +185,7 @@ class ExtraNetworksPage:
"save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
"search_term": item.get("search_term", ""),
"metadata_button": metadata_button,
"serach_only": " search_only" if serach_only else "",
}
return self.card_page.format(**args)
......@@ -209,6 +227,11 @@ def intialize():
class ExtraNetworksUi:
def __init__(self):
self.pages = None
"""gradio HTML components related to extra networks' pages"""
self.page_contents = None
"""HTML content of the above; empty initially, filled when extra pages have to be shown"""
self.stored_extra_pages = None
self.button_save_preview = None
......@@ -236,17 +259,22 @@ def pages_in_preferred_order(pages):
def create_ui(container, button, tabname):
ui = ExtraNetworksUi()
ui.pages = []
ui.pages_contents = []
ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy())
ui.tabname = tabname
with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs:
for page in ui.stored_extra_pages:
with gr.Tab(page.title, id=page.title.lower().replace(" ", "_")):
page_id = page.title.lower().replace(" ", "_")
page_elem = gr.HTML(page.create_html(ui.tabname))
with gr.Tab(page.title, id=page_id):
elem_id = f"{tabname}_{page_id}_cards_html"
page_elem = gr.HTML('', elem_id=elem_id)
ui.pages.append(page_elem)
filter = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[])
gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False)
button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh")
ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False)
......@@ -254,19 +282,22 @@ def create_ui(container, button, tabname):
def toggle_visibility(is_visible):
is_visible = not is_visible
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary"))
if is_visible and not ui.pages_contents:
refresh()
return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")), *ui.pages_contents
state_visible = gr.State(value=False)
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button])
button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button, *ui.pages])
def refresh():
res = []
for pg in ui.stored_extra_pages:
pg.refresh()
res.append(pg.create_html(ui.tabname))
return res
ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages]
return ui.pages_contents
button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages)
......
......@@ -36,7 +36,7 @@ def save_pil_to_file(pil_image, dir=None):
if already_saved_as and os.path.isfile(already_saved_as):
register_tmp_file(shared.demo, already_saved_as)
file_obj = Savedfile(already_saved_as)
file_obj = Savedfile(f'{already_saved_as}?{os.path.getmtime(already_saved_as)}')
return file_obj
if shared.opts.temp_dir != "":
......
......@@ -3,7 +3,7 @@ transformers==4.25.1
accelerate==0.18.0
basicsr==1.4.2
gfpgan==1.3.8
gradio==3.28.1
gradio==3.29.0
numpy==1.23.5
Pillow==9.4.0
realesrgan==0.3.0
......
......@@ -77,7 +77,7 @@ return process_images(p)
module.display = display
indent = " " * indent_level
indented = code.replace('\n', '\n' + indent)
indented = code.replace('\n', f"\n{indent}")
body = f"""def __webuitemp__():
{indent}{indented}
__webuitemp__()"""
......
......@@ -84,7 +84,7 @@ class Script(scripts.Script):
p.color_corrections = initial_color_corrections
if append_interrogation != "None":
p.prompt = original_prompt + ", " if original_prompt != "" else ""
p.prompt = f"{original_prompt}, " if original_prompt else ""
if append_interrogation == "CLIP":
p.prompt += shared.interrogator.interrogate(p.init_images[0])
elif append_interrogation == "DeepBooru":
......
......@@ -100,30 +100,29 @@ def cmdargs(line):
def load_prompt_file(file):
if file is None:
lines = []
return None, gr.update(), gr.update(lines=7)
else:
lines = [x.strip() for x in file.decode('utf8', errors='ignore').split("\n")]
return None, "\n".join(lines), gr.update(lines=7)
return None, "\n".join(lines), gr.update(lines=7)
class Script(scripts.Script):
def title(self):
return "Prompts from file or textbox"
def ui(self, is_img2img):
def ui(self, is_img2img):
checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate"))
checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch"))
prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt"))
file = gr.File(label="Upload prompt inputs", type='binary', elem_id=self.elem_id("file"))
file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt])
file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt], show_progress=False)
# We start at one line. When the text changes, we jump to seven lines, or two lines if no \n.
# We don't shrink back to 1, because that causes the control to ignore [enter], and it may
# be unclear to the user that shift-enter is needed.
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt])
prompt_txt.change(lambda tb: gr.update(lines=7) if ("\n" in tb) else gr.update(lines=2), inputs=[prompt_txt], outputs=[prompt_txt], show_progress=False)
return [checkbox_iterate, checkbox_iterate_batch, prompt_txt]
def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str):
......
......@@ -222,7 +222,7 @@ axis_options = [
AxisOption("Denoising", float, apply_field("denoising_strength")),
AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]),
AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)),
AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
AxisOption("Face restore", str, apply_face_restore, format_value=format_value),
......@@ -346,7 +346,7 @@ class SharedSettingsStackHelper(object):
self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
self.vae = opts.sd_vae
self.uni_pc_order = opts.uni_pc_order
def __exit__(self, exc_type, exc_value, tb):
opts.data["sd_vae"] = self.vae
opts.data["uni_pc_order"] = self.uni_pc_order
......@@ -399,7 +399,7 @@ class Script(scripts.Script):
include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
with gr.Column():
margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
with gr.Row(variant="compact", elem_id="swap_axes"):
swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
swap_yz_axes_button = gr.Button(value="Swap Y/Z axes", elem_id="yz_grid_swap_axes_button")
......@@ -439,7 +439,7 @@ class Script(scripts.Script):
z_type.change(fn=select_axis, inputs=[z_type,z_values_dropdown], outputs=[fill_z_button,z_values,z_values_dropdown])
def get_dropdown_update_from_params(axis,params):
val_key = axis + " Values"
val_key = f"{axis} Values"
vals = params.get(val_key,"")
valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]
return gr.update(value = valslist)
......@@ -490,7 +490,7 @@ class Script(scripts.Script):
start = int(mc.group(1))
end = int(mc.group(2))
num = int(mc.group(3)) if mc.group(3) is not None else 1
valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()]
else:
valslist_ext.append(val)
......@@ -512,7 +512,7 @@ class Script(scripts.Script):
start = float(mc.group(1))
end = float(mc.group(2))
num = int(mc.group(3)) if mc.group(3) is not None else 1
valslist_ext += np.linspace(start=start, stop=end, num=num).tolist()
else:
valslist_ext.append(val)
......
......@@ -125,6 +125,10 @@ div.gradio-html.min{
text-decoration: none;
}
a{
font-weight: bold;
cursor: pointer;
}
/* general styled components */
......@@ -246,7 +250,7 @@ button.custom-button{
}
}
#txt2img_gallery img, #img2img_gallery img{
#txt2img_gallery img, #img2img_gallery img, #extras_gallery img{
object-fit: scale-down;
}
#txt2img_actions_column, #img2img_actions_column {
......@@ -397,6 +401,18 @@ div#extras_scale_to_tab div.form{
margin: 0 1.2em;
}
table.settings-value-table{
background: white;
border-collapse: collapse;
margin: 1em;
border: 4px solid white;
}
table.settings-value-table td{
padding: 0.4em;
border: 1px solid #ccc;
max-width: 36em;
}
/* live preview */
.progressDiv{
......@@ -534,6 +550,8 @@ div#extras_scale_to_tab div.form{
#lightboxModal > img.modalImageFullscreen{
object-fit: contain;
height: 100%;
width: 100%;
min-height: 0;
}
.modalPrev,
......
......@@ -6,6 +6,8 @@ import signal
import re
import warnings
import json
from threading import Thread
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
......@@ -185,24 +187,19 @@ def initialize():
modules.scripts.load_scripts()
startup_timer.record("load scripts")
modelloader.load_upscalers()
#startup_timer.record("load upscalers") #Is this necessary? I don't know.
modules.sd_vae.refresh_vae_list()
startup_timer.record("refresh VAE")
modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
startup_timer.record("refresh textual inversion templates")
try:
modules.sd_models.load_model()
except Exception as e:
errors.display(e, "loading stable diffusion model")
print("", file=sys.stderr)
print("Stable diffusion model failed to load, exiting", file=sys.stderr)
exit(1)
startup_timer.record("load SD checkpoint")
# load model in parallel to other startup stuff
Thread(target=lambda: shared.sd_model).start()
shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
......@@ -286,7 +283,6 @@ def api_only():
print(f"Startup time: {startup_timer.summary()}.")
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def webui():
launch_api = cmd_opts.api
initialize()
......@@ -313,6 +309,16 @@ def webui():
for line in file.readlines():
gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()]
# this restores the missing /docs endpoint
if launch_api and not hasattr(FastAPI, 'original_setup'):
def fastapi_setup(self):
self.docs_url = "/docs"
self.redoc_url = "/redoc"
self.original_setup()
FastAPI.original_setup = FastAPI.setup
FastAPI.setup = fastapi_setup
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
......@@ -339,6 +345,7 @@ def webui():
setup_middleware(app)
modules.progress.setup_progress_api(app)
modules.ui.setup_ui_api(app)
if launch_api:
create_api(app)
......@@ -350,6 +357,11 @@ def webui():
print(f"Startup time: {startup_timer.summary()}.")
if cmd_opts.subpath:
redirector = FastAPI()
redirector.get("/")
mounted_app = gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
wait_on_server(shared.demo)
print('Restarting UI...')
......
......@@ -153,24 +153,31 @@ else
cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
fi
printf "\n%s\n" "${delimiter}"
printf "Create and activate python venv"
printf "\n%s\n" "${delimiter}"
cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
if [[ ! -d "${venv_dir}" ]]
then
"${python_cmd}" -m venv "${venv_dir}"
first_launch=1
fi
# shellcheck source=/dev/null
if [[ -f "${venv_dir}"/bin/activate ]]
if [[ -z "${VIRTUAL_ENV}" ]];
then
source "${venv_dir}"/bin/activate
printf "\n%s\n" "${delimiter}"
printf "Create and activate python venv"
printf "\n%s\n" "${delimiter}"
cd "${install_dir}"/"${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
if [[ ! -d "${venv_dir}" ]]
then
"${python_cmd}" -m venv "${venv_dir}"
first_launch=1
fi
# shellcheck source=/dev/null
if [[ -f "${venv_dir}"/bin/activate ]]
then
source "${venv_dir}"/bin/activate
else
printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m"
printf "\n%s\n" "${delimiter}"
exit 1
fi
else
printf "\n%s\n" "${delimiter}"
printf "\e[1m\e[31mERROR: Cannot activate python venv, aborting...\e[0m"
printf "python venv already activate: ${VIRTUAL_ENV}"
printf "\n%s\n" "${delimiter}"
exit 1
fi
# Try using TCMalloc on Linux
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment