Commit c84e3336 authored by AUTOMATIC's avatar AUTOMATIC

color correction option for all img2img modes #363

parent 823cf946
import math import math
import cv2
import numpy as np import numpy as np
from PIL import Image, ImageOps, ImageChops from PIL import Image, ImageOps, ImageChops
...@@ -76,18 +75,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init ...@@ -76,18 +75,7 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
state.job_count = n_iter state.job_count = n_iter
do_color_correction = False
try:
from skimage import exposure
do_color_correction = True
except:
print("Install scikit-image to perform color correction on loopback")
for i in range(n_iter): for i in range(n_iter):
if do_color_correction and i == 0:
correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
p.n_iter = 1 p.n_iter = 1
p.batch_size = 1 p.batch_size = 1
p.do_not_save_grid = True p.do_not_save_grid = True
...@@ -101,16 +89,6 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init ...@@ -101,16 +89,6 @@ def img2img(prompt: str, negative_prompt: str, prompt_style: str, init_img, init
init_img = processed.images[0] init_img = processed.images[0]
if do_color_correction and correction_target is not None:
init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(init_img),
cv2.COLOR_RGB2LAB
),
correction_target,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
p.init_images = [init_img] p.init_images = [init_img]
p.seed = processed.seed + 1 p.seed = processed.seed + 1
p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1) p.denoising_strength = min(max(p.denoising_strength * denoising_strength_change_factor, 0.1), 1)
......
...@@ -8,6 +8,8 @@ import torch ...@@ -8,6 +8,8 @@ import torch
import numpy as np import numpy as np
from PIL import Image, ImageFilter, ImageOps from PIL import Image, ImageFilter, ImageOps
import random import random
import cv2
from skimage import exposure
import modules.sd_hijack import modules.sd_hijack
from modules import devices from modules import devices
...@@ -19,11 +21,30 @@ import modules.face_restoration ...@@ -19,11 +21,30 @@ import modules.face_restoration
import modules.images as images import modules.images as images
import modules.styles import modules.styles
# some of those options should not be changed at all because they would break the model, so I removed them from options. # some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4 opt_C = 4
opt_f = 8 opt_f = 8
def setup_color_correction(image):
correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB)
return correction_target
def apply_color_correction(correction, image):
image = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
cv2.cvtColor(
np.asarray(image),
cv2.COLOR_RGB2LAB
),
correction,
channel_axis=2
), cv2.COLOR_LAB2RGB).astype("uint8"))
return image
class StableDiffusionProcessing: class StableDiffusionProcessing:
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None): def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", prompt_style="None", seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
self.sd_model = sd_model self.sd_model = sd_model
...@@ -52,6 +73,7 @@ class StableDiffusionProcessing: ...@@ -52,6 +73,7 @@ class StableDiffusionProcessing:
self.extra_generation_params: dict = extra_generation_params self.extra_generation_params: dict = extra_generation_params
self.overlay_images = overlay_images self.overlay_images = overlay_images
self.paste_to = None self.paste_to = None
self.color_corrections = None
def init(self, seed): def init(self, seed):
pass pass
...@@ -265,6 +287,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed: ...@@ -265,6 +287,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
image = Image.fromarray(x_sample) image = Image.fromarray(x_sample)
if p.color_corrections is not None and i < len(p.color_corrections):
image = apply_color_correction(p.color_corrections[i], image)
if p.overlay_images is not None and i < len(p.overlay_images): if p.overlay_images is not None and i < len(p.overlay_images):
overlay = p.overlay_images[i] overlay = p.overlay_images[i]
...@@ -420,6 +444,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -420,6 +444,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask latent_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
self.color_corrections = []
imgs = [] imgs = []
for img in self.init_images: for img in self.init_images:
image = img.convert("RGB") image = img.convert("RGB")
...@@ -441,6 +466,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): ...@@ -441,6 +466,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_fill != 1: if self.inpainting_fill != 1:
image = fill(image, latent_mask) image = fill(image, latent_mask)
if opts.img2img_color_correction:
self.color_corrections.append(setup_color_correction(image))
image = np.array(image).astype(np.float32) / 255.0 image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0) image = np.moveaxis(image, 2, 0)
......
...@@ -122,6 +122,7 @@ class Options: ...@@ -122,6 +122,7 @@ class Options:
"export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"), "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),
"enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
"add_model_hash_to_info": OptionInfo(False, "Add model hash to generation information"), "add_model_hash_to_info": OptionInfo(False, "Add model hash to generation information"),
"img2img_color_correction": OptionInfo(True, "Apply color correction to img2img results to match original colors."),
"font": OptionInfo("", "Font for image grids that have text"), "font": OptionInfo("", "Font for image grids that have text"),
"enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text text and [text] to make it pay less attention"), "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text text and [text] to make it pay less attention"),
"save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment