Commit e7f48085 authored by arcticfaded's avatar arcticfaded

provide sampler by name

parent 8d5d863a
from modules.api.processing import StableDiffusionProcessingAPI from modules.api.processing import StableDiffusionProcessingAPI
from modules.processing import StableDiffusionProcessingTxt2Img, process_images from modules.processing import StableDiffusionProcessingTxt2Img, process_images
from modules.sd_samplers import samplers_k_diffusion
import modules.shared as shared import modules.shared as shared
import uvicorn import uvicorn
from fastapi import Body, APIRouter from fastapi import Body, APIRouter, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field, Json from pydantic import BaseModel, Field, Json
import json import json
import io import io
import base64 import base64
sampler_to_index = lambda name: next(filter(lambda row: name in row[1][2], enumerate(samplers_k_diffusion)), None)
class TextToImageResponse(BaseModel): class TextToImageResponse(BaseModel):
images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.") images: list[str] = Field(default=None, title="Image", description="The generated image in base64 format.")
parameters: Json parameters: Json
...@@ -23,9 +26,14 @@ class Api: ...@@ -23,9 +26,14 @@ class Api:
self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"]) self.app.add_api_route("/sdapi/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
sampler_index = sampler_to_index(txt2imgreq.sampler_index)
if sampler_index is None:
raise HTTPException(status_code=404, detail="Sampler not found")
populate = txt2imgreq.copy(update={ # Override __init__ params populate = txt2imgreq.copy(update={ # Override __init__ params
"sd_model": shared.sd_model, "sd_model": shared.sd_model,
"sampler_index": 0, "sampler_index": sampler_index[0],
"do_not_save_samples": True, "do_not_save_samples": True,
"do_not_save_grid": True "do_not_save_grid": True
} }
......
...@@ -42,7 +42,8 @@ class PydanticModelGenerator: ...@@ -42,7 +42,8 @@ class PydanticModelGenerator:
def __init__( def __init__(
self, self,
model_name: str = None, model_name: str = None,
class_instance = None class_instance = None,
additional_fields = None,
): ):
def field_type_generator(k, v): def field_type_generator(k, v):
# field_type = str if not overrides.get(k) else overrides[k]["type"] # field_type = str if not overrides.get(k) else overrides[k]["type"]
...@@ -70,6 +71,13 @@ class PydanticModelGenerator: ...@@ -70,6 +71,13 @@ class PydanticModelGenerator:
) )
for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
] ]
for fields in additional_fields:
self._model_def.append(ModelDef(
field=underscore(fields["key"]),
field_alias=fields["key"],
field_type=fields["type"],
field_value=fields["default"]))
def generate_model(self): def generate_model(self):
""" """
...@@ -84,4 +92,8 @@ class PydanticModelGenerator: ...@@ -84,4 +92,8 @@ class PydanticModelGenerator:
DynamicModel.__config__.allow_mutation = True DynamicModel.__config__.allow_mutation = True
return DynamicModel return DynamicModel
StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model() StableDiffusionProcessingAPI = PydanticModelGenerator(
"StableDiffusionProcessingTxt2Img",
StableDiffusionProcessingTxt2Img,
[{"key": "sampler_index", "type": str, "default": "k_euler_a"}]
).generate_model()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment