Commit 1f1b3279 authored by Chris OBryan's avatar Chris OBryan

extras: Make image cache LRU

This changes the extras image cache into a Least-Recently-Used
cache. This allows more experimentation with different upscalers
without missing the cache.

Max cache size is increased to 5 and is cleared on source image
update.
parent bde4731f
from __future__ import annotations
import math import math
import os import os
...@@ -7,7 +8,7 @@ from PIL import Image ...@@ -7,7 +8,7 @@ from PIL import Image
import torch import torch
import tqdm import tqdm
from typing import Callable, Dict, List, Tuple from typing import Callable, List, OrderedDict, Tuple
from functools import partial from functools import partial
from dataclasses import dataclass from dataclasses import dataclass
...@@ -21,18 +22,34 @@ import piexif.helper ...@@ -21,18 +22,34 @@ import piexif.helper
import gradio as gr import gradio as gr
@dataclass(frozen=True) class LruCache(OrderedDict):
class CacheKey: @dataclass(frozen=True)
image_hash: int class Key:
info_hash: int image_hash: int
args_hash: int info_hash: int
args_hash: int
@dataclass @dataclass
class CacheEntry: class Value:
image: Image.Image image: Image.Image
info: str info: str
def __init__(self, max_size:int = 5, *args, **kwargs):
super().__init__(*args, **kwargs)
self._max_size = max_size
def get(self, key: LruCache.Key) -> LruCache.Value:
ret = super().get(key)
if ret is not None:
self.move_to_end(key) # Move to end of eviction list
return ret
def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
self[key] = value
while len(self) > self._max_size:
self.popitem(last=False)
cached_images: Dict[CacheKey, CacheEntry] = {} cached_images: LruCache = LruCache(max_size = 5)
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ): def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ):
...@@ -121,14 +138,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -121,14 +138,14 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
blended_result: Image.Image = None blended_result: Image.Image = None
for upscaler in params: for upscaler in params:
upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop) upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
cache_key = CacheKey( image_hash = hash(np.array(image.getdata()).tobytes()), cache_key = LruCache.Key( image_hash = hash(np.array(image.getdata()).tobytes()),
info_hash = hash(info), info_hash = hash(info),
args_hash = hash(upscale_args) ) args_hash = hash(upscale_args + (upscaler.blend_alpha,)) )
cached_entry = cached_images.get(cache_key) cached_entry = cached_images.get(cache_key)
if cached_entry is None: if cached_entry is None:
res = upscale(image, *upscale_args) res = upscale(image, *upscale_args)
info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n" info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
cached_images[cache_key] = CacheEntry(image=res, info=info) cached_images.put(cache_key, LruCache.Value(image=res, info=info))
else: else:
res, info = cached_entry.image, cached_entry.info res, info = cached_entry.image, cached_entry.info
...@@ -140,14 +157,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -140,14 +157,11 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
# Build a list of operations to run # Build a list of operations to run
facefix_ops: List[Callable] = [] facefix_ops: List[Callable] = []
if gfpgan_visibility > 0: facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
facefix_ops.append(run_gfpgan) facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []
if codeformer_visibility > 0:
facefix_ops.append(run_codeformer)
upscale_ops: List[Callable] = [] upscale_ops: List[Callable] = []
if resize_mode == 1: upscale_ops += [run_prepare_crop] if resize_mode == 1 else []
upscale_ops.append(run_prepare_crop)
if upscaling_resize != 0: if upscaling_resize != 0:
step_params: List[UpscaleParams] = [] step_params: List[UpscaleParams] = []
...@@ -157,12 +171,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -157,12 +171,7 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
upscale_ops.append( partial(run_upscalers_blend, step_params) ) upscale_ops.append( partial(run_upscalers_blend, step_params) )
extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)
extras_ops: List[Callable] = []
if upscale_first:
extras_ops = upscale_ops + facefix_ops
else:
extras_ops = facefix_ops + upscale_ops
for image, image_name in zip(imageArr, imageNameArr): for image, image_name in zip(imageArr, imageNameArr):
...@@ -176,9 +185,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -176,9 +185,6 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
for op in extras_ops: for op in extras_ops:
image, info = op(image, info) image, info = op(image, info)
while len(cached_images) > 2:
del cached_images[next(iter(cached_images.keys()))]
if opts.use_original_name_batch and image_name != None: if opts.use_original_name_batch and image_name != None:
basename = os.path.splitext(os.path.basename(image_name))[0] basename = os.path.splitext(os.path.basename(image_name))[0]
else: else:
...@@ -198,6 +204,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_ ...@@ -198,6 +204,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
return outputs, plaintext_to_html(info), '' return outputs, plaintext_to_html(info), ''
def clear_cache():
cached_images.clear()
def run_pnginfo(image): def run_pnginfo(image):
if image is None: if image is None:
......
...@@ -1178,6 +1178,11 @@ def create_ui(wrap_gradio_gpu_call): ...@@ -1178,6 +1178,11 @@ def create_ui(wrap_gradio_gpu_call):
outputs=[init_img_with_mask], outputs=[init_img_with_mask],
) )
extras_image.change(
fn=modules.extras.clear_cache,
inputs=[], outputs=[]
)
with gr.Blocks(analytics_enabled=False) as pnginfo_interface: with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
with gr.Row().style(equal_height=False): with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'): with gr.Column(variant='panel'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment