pythonproject/test_LearningBasedWB.py
2024-06-25 14:15:07 +08:00

92 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from cv2 import xphoto
import cv2
from PIL import Image, ImageOps
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
def learnWb(img):
wb = xphoto.createLearningBasedWB('color_balance_model.yml')
# 读取图像
# image = cv2.imread('IMG_3575.png')
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 应用白平衡
white_balanced_image = wb.balanceWhite(img)
# 显示或保存结果
cv2.imshow('White Balanced Image', white_balanced_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
def ground_truth(image, img_patch, mode='mean'):
# 使用均值或最大值来进行颜色校正
if mode == 'mean':
image_gt = ((image * (img_patch.mean() / image.mean(axis=(0, 1)))).clip(0, 255).astype(int))
elif mode == 'max':
image_gt = ((image * 1.0 / img_patch.max(axis=(0, 1))).clip(0, 1))
else:
raise ValueError("Invalid mode. Use 'mean' or 'max'.")
# 绘制原始图像和Ground Truth校正后的对比图
fig, ax = plt.subplots(1, 2, figsize=(14, 10))
ax[0].imshow(image)
ax[0].set_title('原始图像')
ax[0].axis('off')
ax[1].imshow(image_gt)
ax[1].set_title('Ground Truth校正后的图像')
ax[1].axis('off')
plt.show()
def gray_world(image_path):
# 读取图像
image = cv2.imread(image_path)
# 转换图像颜色通道顺序OpenCV使用BGR而matplotlib使用RGB
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 计算图像的平均亮度
avg_brightness = image.mean()
# 计算每个颜色通道的平均亮度
avg_channel_brightness = image.mean(axis=(0, 1))
# 计算亮度比例,并相应地调整每个颜色通道
image_grayworld = ((image * (avg_brightness / avg_channel_brightness)).clip(0, 255).astype(int))
# 绘制原始图像和灰色世界校正后的对比图
fig, ax = plt.subplots(1, 2, figsize=(14, 10))
ax[0].imshow(image)
ax[0].set_title('原始图像')
ax[0].axis('off')
ax[1].imshow(image_grayworld)
ax[1].set_title('灰色世界校正后的图像')
ax[1].axis('off')
plt.show()
# img = Image.open('IMG_3579.png')
# img = np.array(img)
# learnWb(img)
# # 读取图像
# image = cv2.imread('IMG_3578.png')
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#
# # 选择参考色块的区域
# img_patch = image[800:850, 1800:1850, :]
# # 显示图像和参考色块的矩形框
# fig, ax = plt.subplots(figsize=(10, 10))
# ax.set_title('参考块在红色矩形框内')
# ax.imshow(image)
# ax.add_patch(Rectangle((1800, 800), 50, 50, edgecolor='r', facecolor='none'))
#
# # 调用 Ground Truth Algorithm
# ground_truth(image, img_patch, mode='mean')
gray_world("IMG_3577.png")