Commit bde4731f authored by Chris OBryan's avatar Chris OBryan

extras: Rework image cache

Bit of a refactor to the image cache to make it easier to extend.
Also takes into account the entire image instead of just a cropped portion.
parent 26d08193
...@@ -7,7 +7,7 @@ from PIL import Image ...@@ -7,7 +7,7 @@ from PIL import Image
import torch import torch
import tqdm import tqdm
from typing import Callable, List, Tuple from typing import Callable, Dict, List, Tuple
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
...@@ -21,7 +21,18 @@ import piexif.helper ...@@ -21,7 +21,18 @@ import piexif.helper
import gradio as gr import gradio as gr
cached_images = {} @dataclass(frozen=True)
class CacheKey:
image_hash: int
info_hash: int
args_hash: int
@dataclass
class CacheEntry:
image: Image.Image
info: str
cached_images: Dict[CacheKey, CacheEntry] = {}
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ):
...@@ -84,22 +95,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -84,22 +95,13 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop): def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10)) upscaler = shared.sd_upscalers[scaler_index]
pixels = tuple(np.array(small).flatten().tolist()) res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight, if mode == 1 and crop:
resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
c = cached_images.get(key) res = cropped
if c is None: return res
upscaler = shared.sd_upscalers[scaler_index]
c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
if mode == 1 and crop:
cropped = Image.new("RGB", (resize_w, resize_h))
cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
c = cropped
cached_images[key] = c
return c
def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]: def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
# Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
...@@ -118,8 +120,18 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -118,8 +120,18 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]: def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
blended_result: Image.Image = None blended_result: Image.Image = None
for upscaler in params: for upscaler in params:
res = upscale(image, upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" cache_key = CacheKey( image_hash = hash(np.array(image.getdata()).tobytes()),
info_hash = hash(info),
args_hash = hash(upscale_args) )
cached_entry = cached_images.get(cache_key)
if cached_entry is None:
res = upscale(image, *upscale_args)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
cached_images[cache_key] = CacheEntry(image=res, info=info)
else:
res, info = cached_entry.image, cached_entry.info
if blended_result is None: if blended_result is None:
blended_result = res blended_result = res
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment