Commit 03ee297a authored by w-e-w's avatar w-e-w

fix Auto focal point crop for opencv >= 4.8.x

autocrop.download_and_cache_models
in opencv >= 4.8 the face detection model was updated
download the base on opencv version
returns the model path or raise exception
parent f0f100e6
...@@ -3,6 +3,8 @@ import requests ...@@ -3,6 +3,8 @@ import requests
import os import os
import numpy as np import numpy as np
from PIL import ImageDraw from PIL import ImageDraw
from modules import paths_internal
from pkg_resources import parse_version
GREEN = "#0F0" GREEN = "#0F0"
BLUE = "#00F" BLUE = "#00F"
...@@ -294,22 +296,23 @@ def is_square(w, h): ...@@ -294,22 +296,23 @@ def is_square(w, h):
return w == h return w == h
def download_and_cache_models(dirname): model_dir_opencv = os.path.join(paths_internal.models_path, 'opencv')
download_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true' if parse_version(cv2.__version__) >= parse_version('4.8'):
model_file_name = 'face_detection_yunet.onnx' model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet_2023mar.onnx')
model_url = 'https://github.com/opencv/opencv_zoo/blob/b6e370b10f641879a87890d44e42173077154a05/models/face_detection_yunet/face_detection_yunet_2023mar.onnx?raw=true'
else:
model_file_path = os.path.join(model_dir_opencv, 'face_detection_yunet.onnx')
model_url = 'https://github.com/opencv/opencv_zoo/blob/91fb0290f50896f38a0ab1e558b74b16bc009428/models/face_detection_yunet/face_detection_yunet_2022mar.onnx?raw=true'
os.makedirs(dirname, exist_ok=True)
cache_file = os.path.join(dirname, model_file_name) def download_and_cache_models():
if not os.path.exists(cache_file): if not os.path.exists(model_file_path):
print(f"downloading face detection model from '{download_url}' to '{cache_file}'") os.makedirs(model_dir_opencv, exist_ok=True)
response = requests.get(download_url) print(f"downloading face detection model from '{model_url}' to '{model_file_path}'")
with open(cache_file, "wb") as f: response = requests.get(model_url)
with open(model_file_path, "wb") as f:
f.write(response.content) f.write(response.content)
return model_file_path
if os.path.exists(cache_file):
return cache_file
return None
class PointOfInterest: class PointOfInterest:
......
...@@ -3,7 +3,7 @@ from PIL import Image, ImageOps ...@@ -3,7 +3,7 @@ from PIL import Image, ImageOps
import math import math
import tqdm import tqdm
from modules import paths, shared, images, deepbooru from modules import shared, images, deepbooru
from modules.textual_inversion import autocrop from modules.textual_inversion import autocrop
...@@ -196,7 +196,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre ...@@ -196,7 +196,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
dnn_model_path = None dnn_model_path = None
try: try:
dnn_model_path = autocrop.download_and_cache_models(os.path.join(paths.models_path, "opencv")) dnn_model_path = autocrop.download_and_cache_models()
except Exception as e: except Exception as e:
print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e) print("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", e)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment