Commit f80e914a authored by arcticfaded's avatar arcticfaded

example API working with gradio

parent d42125ba
...@@ -23,8 +23,13 @@ class Api: ...@@ -23,8 +23,13 @@ class Api:
app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"]) app.add_api_route("/v1/txt2img", self.text2imgapi, methods=["POST"])
def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ): def text2imgapi(self, txt2imgreq: StableDiffusionProcessingAPI ):
p = StableDiffusionProcessingTxt2Img(**vars(txt2imgreq)) populate = txt2imgreq.copy(update={ # Override __init__ params
p.sd_model = shared.sd_model "sd_model": shared.sd_model,
"sampler_index": 0,
}
)
p = StableDiffusionProcessingTxt2Img(**vars(populate))
# Override object param
processed = process_images(p) processed = process_images(p)
b64images = [] b64images = []
......
...@@ -5,6 +5,24 @@ from modules.processing import StableDiffusionProcessing, Processed, StableDiffu ...@@ -5,6 +5,24 @@ from modules.processing import StableDiffusionProcessing, Processed, StableDiffu
import inspect import inspect
API_NOT_ALLOWED = [
"self",
"kwargs",
"sd_model",
"outpath_samples",
"outpath_grids",
"sampler_index",
"do_not_save_samples",
"do_not_save_grid",
"extra_generation_params",
"overlay_images",
"do_not_reload_embeddings",
"seed_enable_extras",
"prompt_for_display",
"sampler_noise_scheduler_override",
"ddim_discretize"
]
class ModelDef(BaseModel): class ModelDef(BaseModel):
"""Assistance Class for Pydantic Dynamic Model Generation""" """Assistance Class for Pydantic Dynamic Model Generation"""
...@@ -14,7 +32,7 @@ class ModelDef(BaseModel): ...@@ -14,7 +32,7 @@ class ModelDef(BaseModel):
field_value: Any field_value: Any
class pydanticModelGenerator: class PydanticModelGenerator:
""" """
Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about: Takes in created classes and stubs them out in a way FastAPI/Pydantic is happy about:
source_data is a snapshot of the default values produced by the class source_data is a snapshot of the default values produced by the class
...@@ -24,30 +42,33 @@ class pydanticModelGenerator: ...@@ -24,30 +42,33 @@ class pydanticModelGenerator:
def __init__( def __init__(
self, self,
model_name: str = None, model_name: str = None,
source_data: {} = {}, class_instance = None
params: Dict = {},
overrides: Dict = {},
optionals: Dict = {},
): ):
def field_type_generator(k, v, overrides, optionals): def field_type_generator(k, v):
field_type = str if not overrides.get(k) else overrides[k]["type"] # field_type = str if not overrides.get(k) else overrides[k]["type"]
if v is None: # print(k, v.annotation, v.default)
field_type = Any field_type = v.annotation
else:
field_type = type(v)
return Optional[field_type] return Optional[field_type]
def merge_class_params(class_):
all_classes = list(filter(lambda x: x is not object, inspect.getmro(class_)))
parameters = {}
for classes in all_classes:
parameters = {**parameters, **inspect.signature(classes.__init__).parameters}
return parameters
self._model_name = model_name self._model_name = model_name
self._json_data = source_data self._class_data = merge_class_params(class_instance)
self._model_def = [ self._model_def = [
ModelDef( ModelDef(
field=underscore(k), field=underscore(k),
field_alias=k, field_alias=k,
field_type=field_type_generator(k, v, overrides, optionals), field_type=field_type_generator(k, v),
field_value=v field_value=v.default
) )
for (k,v) in source_data.items() if k in params for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED
] ]
def generate_model(self): def generate_model(self):
...@@ -60,8 +81,7 @@ class pydanticModelGenerator: ...@@ -60,8 +81,7 @@ class pydanticModelGenerator:
} }
DynamicModel = create_model(self._model_name, **fields) DynamicModel = create_model(self._model_name, **fields)
DynamicModel.__config__.allow_population_by_field_name = True DynamicModel.__config__.allow_population_by_field_name = True
DynamicModel.__config__.allow_mutation = True
return DynamicModel return DynamicModel
StableDiffusionProcessingAPI = pydanticModelGenerator("StableDiffusionProcessing", StableDiffusionProcessingAPI = PydanticModelGenerator("StableDiffusionProcessingTxt2Img", StableDiffusionProcessingTxt2Img).generate_model()
StableDiffusionProcessing().__dict__,
inspect.signature(StableDiffusionProcessing.__init__).parameters).generate_model()
...@@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps ...@@ -9,6 +9,7 @@ from PIL import Image, ImageFilter, ImageOps
import random import random
import cv2 import cv2
from skimage import exposure from skimage import exposure
from typing import Any, Dict, List, Optional
import modules.sd_hijack import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram from modules import devices, prompt_parser, masking, sd_samplers, lowvram
...@@ -51,9 +52,15 @@ def get_correct_sampler(p): ...@@ -51,9 +52,15 @@ def get_correct_sampler(p):
return sd_samplers.samplers return sd_samplers.samplers
elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img): elif isinstance(p, modules.processing.StableDiffusionProcessingImg2Img):
return sd_samplers.samplers_for_img2img return sd_samplers.samplers_for_img2img
elif isinstance(p, modules.api.processing.StableDiffusionProcessingAPI):
return sd_samplers.samplers
class StableDiffusionProcessing: class StableDiffusionProcessing():
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None, do_not_reload_embeddings=False): """
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str="", styles: List[str]=None, seed: int=-1, subseed: int=-1, subseed_strength: float=0, seed_resize_from_h: int=-1, seed_resize_from_w: int=-1, seed_enable_extras: bool=True, sampler_index: int=0, batch_size: int=1, n_iter: int=1, steps:int =50, cfg_scale:float=7.0, width:int=512, height:int=512, restore_faces:bool=False, tiling:bool=False, do_not_save_samples:bool=False, do_not_save_grid:bool=False, extra_generation_params: Dict[Any,Any]=None, overlay_images: Any=None, negative_prompt: str=None, eta: float =None, do_not_reload_embeddings: bool=False, denoising_strength: float = 0, ddim_discretize: str = "uniform", s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0):
self.sd_model = sd_model self.sd_model = sd_model
self.outpath_samples: str = outpath_samples self.outpath_samples: str = outpath_samples
self.outpath_grids: str = outpath_grids self.outpath_grids: str = outpath_grids
...@@ -86,10 +93,10 @@ class StableDiffusionProcessing: ...@@ -86,10 +93,10 @@ class StableDiffusionProcessing:
self.denoising_strength: float = 0 self.denoising_strength: float = 0
self.sampler_noise_scheduler_override = None self.sampler_noise_scheduler_override = None
self.ddim_discretize = opts.ddim_discretize self.ddim_discretize = opts.ddim_discretize
self.s_churn = opts.s_churn self.s_churn = s_churn or opts.s_churn
self.s_tmin = opts.s_tmin self.s_tmin = s_tmin or opts.s_tmin
self.s_tmax = float('inf') # not representable as a standard ui option self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
self.s_noise = opts.s_noise self.s_noise = s_noise or opts.s_noise
if not seed_enable_extras: if not seed_enable_extras:
self.subseed = -1 self.subseed = -1
...@@ -97,6 +104,7 @@ class StableDiffusionProcessing: ...@@ -97,6 +104,7 @@ class StableDiffusionProcessing:
self.seed_resize_from_h = 0 self.seed_resize_from_h = 0
self.seed_resize_from_w = 0 self.seed_resize_from_w = 0
def init(self, all_prompts, all_seeds, all_subseeds): def init(self, all_prompts, all_seeds, all_subseeds):
pass pass
...@@ -497,7 +505,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -497,7 +505,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None sampler = None
def __init__(self, enable_hr=False, denoising_strength=0.75, firstphase_width=0, firstphase_height=0, **kwargs): def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.enable_hr = enable_hr self.enable_hr = enable_hr
self.denoising_strength = denoising_strength self.denoising_strength = denoising_strength
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment