Commit 9ddaf826 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub

Merge pull request #135 from rewbs/img2img2-color-correction

Add color correction to img2img loopback to avoid a progressive skew to magenta. Based on codedealer's PR to hlky's repo here: https://github.com/sd-webui/stable-diffusion-webui/pull/698/files.
parents 0959fa2d 21a375e6
import math import math
import cv2
import numpy as np
from PIL import Image, ImageOps, ImageChops from PIL import Image, ImageOps, ImageChops
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
...@@ -59,8 +61,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index ...@@ -59,8 +61,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
state.job_count = n_iter state.job_count = n_iter
do_color_correction = False
try:
from skimage import exposure
do_color_correction = True
except:
print("Install scikit-image to perform color correction on loopback")
for i in range(n_iter): for i in range(n_iter):
if do_color_correction and i == 0:
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
p.n_iter = 1 p.n_iter = 1
p.batch_size = 1 p.batch_size = 1
p.do_not_save_grid = True p.do_not_save_grid = True
...@@ -72,7 +85,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index ...@@ -72,7 +85,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
initial_seed = processed.seed initial_seed = processed.seed
initial_info = processed.info initial_info = processed.info
p.init_images = [processed.images[0]] init_img = processed.images[0]
if do_color_correction and correction_target is not None:
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(init_img),
cv2.COLOR_RGB2LAB
),
correction_target,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
p.init_images = [init_img]
p.seed = processed.seed + 1 p.seed = processed.seed + 1
p.denoising_strength = max(p.denoising_strength * 0.95, 0.1) p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
history.append(processed.images[0]) history.append(processed.images[0])
......
...@@ -10,5 +10,6 @@ omegaconf ...@@ -10,5 +10,6 @@ omegaconf
pytorch_lightning pytorch_lightning
diffusers diffusers
invisible-watermark invisible-watermark
scikit-image
git+https://github.com/crowsonkb/k-diffusion.git git+https://github.com/crowsonkb/k-diffusion.git
git+https://github.com/TencentARC/GFPGAN.git git+https://github.com/TencentARC/GFPGAN.git
...@@ -8,3 +8,4 @@ torch ...@@ -8,3 +8,4 @@ torch
transformers==4.19.2 transformers==4.19.2
omegaconf==2.1.1 omegaconf==2.1.1
pytorch_lightning==1.7.2 pytorch_lightning==1.7.2
scikit-image==0.19.2
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment