Commit 8d5d863a authored by arcticfaded's avatar arcticfaded

gradio and FastAPI

parent 1df3ff25
...@@ -16,9 +16,11 @@ class TextToImageResponse(BaseModel): ...@@ -16,9 +16,11 @@ class TextToImageResponse(BaseModel):
class Api: class Api:
def __init__(self, app): def __init__(self, app, queue_lock):
self.router = APIRouter() self.router = APIRouter()
app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app = app
self.queue_lock = queue_lock
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
...@@ -30,7 +32,8 @@ class Api: ...@@ -30,7 +32,8 @@ class Api:
) )
p = StableDiffusionProcessingTxt2Img(**vars(populate)) p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param # Override object param
processed = process_images(p) with self.queue_lock:
processed = process_images(p)
b64images = [] b64images = []
for i in processed.images: for i in processed.images:
...@@ -52,5 +55,5 @@ class Api: ...@@ -52,5 +55,5 @@ class Api:
raise NotImplementedError raise NotImplementedError
def launch(self, server_name, port): def launch(self, server_name, port):
app.include_router(self.router) self.app.include_router(self.router)
uvicorn.run(app, host=server_name, port=port) uvicorn.run(self.app, host=server_name, port=port)
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import importlib import importlib
import signal import signal
import threading import threading
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware from fastapi.middleware.gzip import GZipMiddleware
from modules.paths import script_path from modules.paths import script_path
...@@ -31,7 +31,6 @@ from modules.paths import script_path ...@@ -31,7 +31,6 @@ from modules.paths import script_path
from modules.shared import cmd_opts from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork import modules.hypernetworks.hypernetwork
queue_lock = threading.Lock() queue_lock = threading.Lock()
...@@ -97,7 +96,7 @@ def initialize(): ...@@ -97,7 +96,7 @@ def initialize():
def create_api(app): def create_api(app):
from modules.api.api import Api from modules.api.api import Api
api = Api(app) api = Api(app, queue_lock)
return api return api
def wait_on_server(demo=None): def wait_on_server(demo=None):
...@@ -141,7 +140,7 @@ def webui(launch_api=False): ...@@ -141,7 +140,7 @@ def webui(launch_api=False):
create_api(app) create_api(app)
wait_on_server(demo) wait_on_server(demo)
sd_samplers.set_samplers() sd_samplers.set_samplers()
print('Reloading Custom Scripts') print('Reloading Custom Scripts')
...@@ -153,11 +152,10 @@ def webui(launch_api=False): ...@@ -153,11 +152,10 @@ def webui(launch_api=False):
print('Restarting Gradio') print('Restarting Gradio')
task = []
if __name__ == "__main__": if __name__ == "__main__":
if not cmd_opts.nowebui: if cmd_opts.nowebui:
api_only() api_only()
if cmd_opts.api:
webui(True)
else: else:
webui(False) webui(cmd_opts.api)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment