Commit 2362d5f0 authored by MalumaDev's avatar MalumaDev Committed by GitHub

Merge branch 'master' into test_resolve_conflicts

parents c2765c9b 1b91cbbc
from modules.api.processing import StableDiffusionProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_pnginfo
import modules.shared as shared
import uvicorn
from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json
import json
import io
import base64
sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None)
class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json
info: Json
class Api:
def __init__(self, app, queue_lock):
self.router = APIRouter()
self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model,
"sampler_index": sampler_index[0],
"do_not_save_samples": True,
"do_not_save_grid": True
}
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
with self.queue_lock:
processed = process_images(p)
b64images = []
for i in processed.images:
buffer = io.BytesIO()
i.save(buffer, format="png")
b64images.append(base64.b64encode(buffer.getvalue()))
return TextToImageResponse(images=b64images, parameters=json.dumps(vars(txt2imgreq)), info=json.dumps(processed.info))
def img2imgapi(self):
raise NotImplementedError
def extrasapi(self):
raise NotImplementedError
def pnginfoapi(self):
raise NotImplementedError
def launch(self, server_name, port):
self.app.include_router(self.router)
uvicorn.run(self.app, host=server_name, port=port)
from inflection import underscore
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, create_model
from modules.processing import StableDiffusionProcessingTxt2Img
import inspect
API_NOT_ALLOWED = [
"self",
"kwargs",
"sd_model",
"outpath_samples",
"outpath_grids",
"sampler_index",
"do_not_save_samples",
"do_not_save_grid",
"extra_generation_params",
"overlay_images",
"do_not_reload_embeddings",
"seed_enable_extras",
"prompt_for_display",
"sampler_noise_scheduler_override",
"ddim_discretize"
]
class ModelDef(BaseModel):
"""Assistance Class for Pydantic Dynamic Model Generation"""
field: str
field_alias: str
field_type: Any
field_value: Any
class PydanticModelGenerator:
"""
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
source_data is a snapshot of the default values produced by the class
params are the names of the actual keys required by __init__
"""
def __init__(
self,
model_name: str = None,
class_instance = None,
additional_fields = None,
):
def field_type_generator(k, v):
# field_type = str if not overrides.get(k) else overrides[k]["type"]
# print(k, v.annotation, v.default)
field_type = v.annotation
return Optional[field_type]
def merge_class_params(class_):
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
parameters = {}
for classes in all_classes:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters
self._model_name = model_name
self._class_data = merge_class_params(class_instance)
self._model_def = [
ModelDef(
field=underscore(k),
field_alias=k,
field_type=field_type_generator(k, v),
field_value=v.default
)
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
]
for fields in additional_fields:
self._model_def.append(ModelDef(
field=underscore(fields["key"]),
field_alias=fields["key"],
field_type=fields["type"],
field_value=fields["default"]))
def generate_model(self):
"""
Creates a pydantic BaseModel
from the json and overrides provided at initialization
"""
fields = {
d.field: (d.field_type, Field(default=d.field_value, alias=d.field_alias)) for d in self._model_def
}
DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True
DynamicModel.__config__.allow_mutation = True
return DynamicModel
StableDiffusionProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "Euler"}]
).generate_model()
\ No newline at end of file
...@@ -196,7 +196,7 @@ def stack_conds(conds): ...@@ -196,7 +196,7 @@ def stack_conds(conds):
return torch.stack(conds) return torch.stack(conds)
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert hypernetwork_name, 'hypernetwork not selected' assert hypernetwork_name, 'hypernetwork not selected'
path = shared.hypernetworks.get(hypernetwork_name, None) path = shared.hypernetworks.get(hypernetwork_name, None)
...@@ -225,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log ...@@ -225,7 +225,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"): with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size) ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
if unload: if unload:
shared.sd_model.cond_stage_model.to(devices.cpu) shared.sd_model.cond_stage_model.to(devices.cpu)
......
...@@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps ...@@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps
import random import random
import cv2 import cv2
from skimage import exposure from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram from modules import devices, prompt_parser, masking, sd_samplers, lowvram
...@@ -50,10 +51,16 @@ def get_correct_sampler(p): ...@@ -50,10 +51,16 @@ def get_correct_sampler(p):
return sd_samplers.samplers return sd_samplers.samplers
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img return sd_samplers.samplers_for_img2img
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
return sd_samplers.samplers
class StableDiffusionProcessing: class StableDiffusionProcessing():
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None, do_not_reload_embeddings=False): """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
...@@ -86,10 +93,10 @@ class StableDiffusionProcessing: ...@@ -86,10 +93,10 @@ class StableDiffusionProcessing:
self.denoising_strength: float = 0 self.denoising_strength: float = 0
self.sampler_noise_scheduler_override = None self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn self.s_churn = s_churn or opts.s_churn
self.s_tmin = opts.s_tmin self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise self.s_noise = s_noise or opts.s_noise
if not seed_enable_extras: if not seed_enable_extras:
self.subseed = -1 self.subseed = -1
...@@ -97,6 +104,7 @@ class StableDiffusionProcessing: ...@@ -97,6 +104,7 @@ class StableDiffusionProcessing:
self.seed_resize_from_h = 0 self.seed_resize_from_h = 0
self.seed_resize_from_w = 0 self.seed_resize_from_w = 0
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
pass pass
...@@ -515,7 +523,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -515,7 +523,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None sampler = None
def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs): def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.enable_hr = enable_hr self.enable_hr = enable_hr
self.denoising_strength = denoising_strength self.denoising_strength = denoising_strength
...@@ -759,4 +767,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -759,4 +767,4 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
del x del x
devices.torch_gc() devices.torch_gc()
return samples return samples
\ No newline at end of file
...@@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v): ...@@ -181,7 +181,7 @@ def einsum_op_cuda(q, k, v):
mem_free_torch = mem_reserved - mem_active mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation # Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
def einsum_op(q, k, v): def einsum_op(q, k, v):
if q.device.type == 'cuda': if q.device.type == 'cuda':
......
...@@ -122,11 +122,33 @@ def select_checkpoint(): ...@@ -122,11 +122,33 @@ def select_checkpoint():
return checkpoint_info return checkpoint_info
chckpoint_dict_replacements = {
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}
def transform_checkpoint_dict_key(k):
for text, replacement in chckpoint_dict_replacements.items():
if k.startswith(text):
k = replacement + k[len(text):]
return k
def get_state_dict_from_checkpoint(pl_sd): def get_state_dict_from_checkpoint(pl_sd):
if "state_dict" in pl_sd: if "state_dict" in pl_sd:
return pl_sd["state_dict"] pl_sd = pl_sd["state_dict"]
sd = {}
for k, v in pl_sd.items():
new_key = transform_checkpoint_dict_key(k)
if new_key is not None:
sd[new_key] = v
return pl_sd return sd
def load_model_weights(model, checkpoint_info): def load_model_weights(model, checkpoint_info):
...@@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info): ...@@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info):
print(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
sd = get_state_dict_from_checkpoint(pl_sd) sd = get_state_dict_from_checkpoint(pl_sd)
model.load_state_dict(sd, strict=False) missing, extra = model.load_state_dict(sd, strict=False)
if shared.cmd_opts.opt_channelslast: if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last) model.to(memory_format=torch.channels_last)
......
...@@ -79,6 +79,8 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help= ...@@ -79,6 +79,8 @@ parser.add_argument("--disable-console-progressbars", action='store_true', help=
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the api with the webui")
parser.add_argument("--nowebui", action='store_true', help="use api=True to launch the api instead of the webui")
cmd_opts = parser.parse_args() cmd_opts = parser.parse_args()
restricted_opts = [ restricted_opts = [
......
...@@ -12,7 +12,7 @@ import time ...@@ -12,7 +12,7 @@ import time
import traceback import traceback
import platform import platform
import subprocess as sp import subprocess as sp
from functools import reduce from functools import partial, reduce
import numpy as np import numpy as np
import torch import torch
...@@ -266,6 +266,19 @@ def wrap_gradio_call(func, extra_outputs=None): ...@@ -266,6 +266,19 @@ def wrap_gradio_call(func, extra_outputs=None):
return f return f
def calc_time_left(progress, threshold, label, force_display):
if progress == 0:
return ""
else:
time_since_start = time.time() - shared.state.time_start
eta = (time_since_start/progress)
eta_relative = eta-time_since_start
if (eta_relative > threshold and progress > 0.02) or force_display:
return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative))
else:
return ""
def check_progress_call(id_part): def check_progress_call(id_part):
if shared.state.job_count == 0: if shared.state.job_count == 0:
return "", gr_show(False), gr_show(False), gr_show(False) return "", gr_show(False), gr_show(False), gr_show(False)
...@@ -277,11 +290,15 @@ def check_progress_call(id_part): ...@@ -277,11 +290,15 @@ def check_progress_call(id_part):
if shared.state.sampling_steps > 0: if shared.state.sampling_steps > 0:
progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps
time_left = calc_time_left( progress, 60, " ETA:", shared.state.time_left_force_display )
if time_left != "":
shared.state.time_left_force_display = True
progress = min(progress, 1) progress = min(progress, 1)
progressbar = "" progressbar = ""
if opts.show_progressbar: if opts.show_progressbar:
progressbar = f"""<div class='progressDiv'><div class='progress' style="width:{progress * 100}%">{str(int(progress*100))+"%" if progress > 0.01 else ""}</div></div>""" progressbar = f"""<div class='progressDiv'><div class='progress' style="overflow:hidden;width:{progress * 100}%">{str(int(progress*100))+"%"+time_left if progress > 0.01 else ""}</div></div>"""
image = gr_show(False) image = gr_show(False)
preview_visibility = gr_show(False) preview_visibility = gr_show(False)
...@@ -313,6 +330,8 @@ def check_progress_call_initial(id_part): ...@@ -313,6 +330,8 @@ def check_progress_call_initial(id_part):
shared.state.current_latent = None shared.state.current_latent = None
shared.state.current_image = None shared.state.current_image = None
shared.state.textinfo = None shared.state.textinfo = None
shared.state.time_start = time.time()
shared.state.time_left_force_display = False
return check_progress_call(id_part) return check_progress_call(id_part)
...@@ -1417,6 +1436,8 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1417,6 +1436,8 @@ def create_ui(wrap_gradio_gpu_call):
batch_size, batch_size,
dataset_directory, dataset_directory,
log_directory, log_directory,
training_width,
training_height,
steps, steps,
create_image_every, create_image_every,
save_embedding_every, save_embedding_every,
...@@ -1613,6 +1634,7 @@ Requested path was: {f} ...@@ -1613,6 +1634,7 @@ Requested path was: {f}
def reload_scripts(): def reload_scripts():
modules.scripts.reload_script_body_only() modules.scripts.reload_script_body_only()
reload_javascript() # need to refresh the html page
reload_script_bodies.click( reload_script_bodies.click(
fn=reload_scripts, fn=reload_scripts,
...@@ -1871,26 +1893,30 @@ Requested path was: {f} ...@@ -1871,26 +1893,30 @@ Requested path was: {f}
return demo return demo
with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: def load_javascript(raw_response):
javascript = f'<script>{jsfile.read()}</script>' with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile:
javascript = f'<script>{jsfile.read()}</script>'
jsdir = os.path.join(script_path, "javascript") jsdir = os.path.join(script_path, "javascript")
for filename in sorted(os.listdir(jsdir)): for filename in sorted(os.listdir(jsdir)):
with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile: with open(os.path.join(jsdir, filename), "r", encoding="utf8") as jsfile:
javascript += f"\n<script>{jsfile.read()}</script>" javascript += f"\n<!-- {filename} --><script>{jsfile.read()}</script>"
if cmd_opts.theme is not None: if cmd_opts.theme is not None:
javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n" javascript += f"\n<script>set_theme('{cmd_opts.theme}');</script>\n"
javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>" javascript += f"\n<script>{localization.localization_js(shared.opts.localization)}</script>"
if 'gradio_routes_templates_response' not in globals():
def template_response(*args, **kwargs): def template_response(*args, **kwargs):
res = gradio_routes_templates_response(*args, **kwargs) res = raw_response(*args, **kwargs)
res.body = res.body.replace(b'</head>', f'{javascript}</head>'.encode("utf8")) res.body = res.body.replace(
b'</head>', f'{javascript}</head>'.encode("utf8"))
res.init_headers() res.init_headers()
return res return res
gradio_routes_templates_response = gradio.routes.templates.TemplateResponse
gradio.routes.templates.TemplateResponse = template_response gradio.routes.templates.TemplateResponse = template_response
reload_javascript = partial(load_javascript,
gradio.routes.templates.TemplateResponse)
reload_javascript()
...@@ -23,3 +23,4 @@ resize-right ...@@ -23,3 +23,4 @@ resize-right
torchdiffeq torchdiffeq
kornia kornia
lark lark
inflection
...@@ -22,3 +22,4 @@ resize-right==0.0.2 ...@@ -22,3 +22,4 @@ resize-right==0.0.2
torchdiffeq==0.2.3 torchdiffeq==0.2.3
kornia==0.6.7 kornia==0.6.7
lark==1.1.2 lark==1.1.2
inflection==0.5.1
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import importlib import importlib
import signal import signal
import threading import threading
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
...@@ -31,7 +31,6 @@ from modules.paths import script_path ...@@ -31,7 +31,6 @@ from modules.paths import script_path
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
queue_lock = threading.Lock() queue_lock = threading.Lock()
...@@ -87,10 +86,6 @@ def initialize(): ...@@ -87,10 +86,6 @@ def initialize():
shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
def webui():
initialize()
# make the program just exit at ctrl+c without waiting for anything # make the program just exit at ctrl+c without waiting for anything
def sigint_handler(sig, frame): def sigint_handler(sig, frame):
print(f'Interrupted with signal {sig} in {frame}') print(f'Interrupted with signal {sig} in {frame}')
...@@ -98,10 +93,37 @@ def webui(): ...@@ -98,10 +93,37 @@ def webui():
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
def create_api(app):
from modules.api.api import Api
api = Api(app, queue_lock)
return api
def wait_on_server(demo=None):
while 1: while 1:
time.sleep(0.5)
if demo and getattr(demo, 'do_restart', False):
time.sleep(0.5)
demo.close()
time.sleep(0.5)
break
def api_only():
initialize()
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000)
api = create_api(app)
api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
def webui(launch_api=False):
initialize()
while 1:
demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call) demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
app, local_url, share_url = demo.launch( app, local_url, share_url = demo.launch(
share=cmd_opts.share, share=cmd_opts.share,
server_name="0.0.0.0" if cmd_opts.listen else None, server_name="0.0.0.0" if cmd_opts.listen else None,
...@@ -111,17 +133,14 @@ def webui(): ...@@ -111,17 +133,14 @@ def webui():
inbrowser=cmd_opts.autolaunch, inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True prevent_thread_lock=True
) )
app.add_middleware(GZipMiddleware, minimum_size=1000) app.add_middleware(GZipMiddleware, minimum_size=1000)
while 1: if (launch_api):
time.sleep(0.5) create_api(app)
if getattr(demo, 'do_restart', False):
time.sleep(0.5)
demo.close()
time.sleep(0.5)
break
wait_on_server(demo)
sd_samplers.set_samplers() sd_samplers.set_samplers()
print('Reloading Custom Scripts') print('Reloading Custom Scripts')
...@@ -133,5 +152,10 @@ def webui(): ...@@ -133,5 +152,10 @@ def webui():
print('Restarting Gradio') print('Restarting Gradio')
task = []
if __name__ == "__main__": if __name__ == "__main__":
webui() if cmd_opts.nowebui:
api_only()
else:
webui(cmd_opts.api)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment