Line data Source code
1 : #include "ImageHelper.h"
2 : #include "Common.h"
3 : #include "Data.h"
4 : #include "SimpleLogger.h"
5 : #include "exception/ImageHelperException.h"
6 : #include <filesystem>
7 : #include <memory>
8 : #include <opencv2/core/matx.hpp>
9 : #include <opencv2/imgcodecs.hpp>
10 : #include <opencv2/imgproc.hpp>
11 : #include <stdexcept>
12 : #include <string>
13 :
14 : using namespace sipai;
15 :
16 85 : ImageParts ImageHelper::loadImage(const std::string &imagePath, size_t split,
17 : bool withPadding, size_t resize_x,
18 : size_t resize_y) const {
19 85 : if (split == 0) {
20 0 : throw ImageHelperException("internal exception: split 0.");
21 : }
22 : // Check the path
23 85 : if (!std::filesystem::exists(imagePath)) {
24 0 : throw ImageHelperException("Could not find the image: " + imagePath);
25 : }
26 :
27 : // Load the image
28 : try {
29 : cv::Mat mat =
30 85 : cv::imread(imagePath, cv::IMREAD_ANYCOLOR | cv::IMREAD_ANYDEPTH);
31 :
32 85 : if (mat.empty()) {
33 0 : throw ImageHelperException("Could not open the image: " + imagePath);
34 : }
35 170 : Image orig{.orig_height = (size_t)mat.size().height,
36 170 : .orig_width = (size_t)mat.size().width,
37 85 : .orig_type = mat.type(),
38 85 : .orig_channels = mat.channels()};
39 :
40 : // Ensure the image is in BGR format
41 85 : switch (mat.channels()) {
42 0 : case 1:
43 0 : cv::cvtColor(mat, mat, cv::COLOR_GRAY2BGRA);
44 0 : break;
45 85 : case 3:
46 85 : cv::cvtColor(mat, mat, cv::COLOR_RGB2BGRA);
47 85 : break;
48 0 : case 4:
49 0 : cv::cvtColor(mat, mat, cv::COLOR_RGBA2BGRA);
50 0 : break;
51 0 : default:
52 0 : SimpleLogger::LOG_WARN(
53 0 : "Non implemented image colors channels processing: ", mat.channels());
54 0 : break;
55 : }
56 :
57 : // If the image has only 3 channels (BGR), create and merge an alpha channel
58 85 : if (mat.channels() == 3) {
59 0 : cv::Mat alphaMat(mat.size(), CV_8UC1, cv::Scalar(255));
60 0 : std::vector<cv::Mat> channels{mat, alphaMat};
61 0 : cv::Mat bgraMat;
62 0 : cv::merge(channels, bgraMat);
63 0 : mat = bgraMat;
64 0 : }
65 :
66 : // Convert to floating-point range [0, 1] with 4 channels
67 85 : mat.convertTo(mat, CV_32FC4, 1.0 / 255.0);
68 85 : if (mat.channels() != 4) {
69 0 : throw ImageHelperException("incorrect image channels");
70 : }
71 :
72 : // cv::imshow("Original Image step 3", mat);
73 : // cv::waitKey(1000 * 60 * 2);
74 :
75 85 : ImageParts imagesParts;
76 170 : auto matParts = splitImage(mat, split, withPadding);
77 192 : for (auto &matPart : matParts) {
78 : Image image{.data = matPart,
79 107 : .orig_height = orig.orig_height,
80 107 : .orig_width = orig.orig_width,
81 107 : .orig_type = orig.orig_type,
82 107 : .orig_channels = orig.orig_channels};
83 107 : image.resize(resize_x, resize_y);
84 107 : auto image_ptr = std::make_shared<Image>(image);
85 107 : imagesParts.push_back(image_ptr);
86 107 : }
87 :
88 : // Rq. C++ use Return Value Optimization (RVO) to avoid the extra copy or
89 : // move operation associated with the return.
90 170 : return imagesParts;
91 85 : } catch (const cv::Exception &e) {
92 0 : throw ImageHelperException("Error loading image: " + imagePath + ": " +
93 0 : e.what());
94 0 : }
95 0 : }
96 :
97 41 : ImageParts ImageHelper::generateInputImage(const ImageParts &targetImage,
98 : size_t reduce_factor,
99 : size_t resize_x,
100 : size_t resize_y) const {
101 41 : ImageParts imagesParts;
102 90 : for (auto &targetPart : targetImage) {
103 : // clone of the Target image to the Input image
104 49 : Image inputImage = {.data = targetPart->data.clone(),
105 98 : .orig_height = targetPart->orig_height,
106 98 : .orig_width = targetPart->orig_width,
107 98 : .orig_type = targetPart->orig_type,
108 49 : .orig_channels = targetPart->orig_channels};
109 :
110 : // reduce the resolution of the input image
111 49 : if (reduce_factor != 0) {
112 : int new_width =
113 49 : (int)(inputImage.data.size().width / (float)reduce_factor);
114 : int new_height =
115 49 : (int)(inputImage.data.size().height / (float)reduce_factor);
116 49 : inputImage.resize(new_width, new_height);
117 : }
118 :
119 : // then resize to the layer resolution
120 49 : inputImage.resize(resize_x, resize_y);
121 :
122 : // finally convert back to Image
123 49 : auto image = std::make_shared<Image>(inputImage);
124 49 : imagesParts.push_back(image);
125 49 : }
126 :
127 41 : return imagesParts;
128 0 : }
129 :
130 85 : std::vector<cv::Mat> ImageHelper::splitImage(const cv::Mat &inputImage,
131 : size_t split,
132 : bool withPadding) const {
133 85 : if (split == 0) {
134 0 : throw ImageHelperException("internal exception: split 0.");
135 : }
136 :
137 85 : std::vector<cv::Mat> outputImages;
138 :
139 85 : if (split == 1) {
140 81 : outputImages.push_back(inputImage);
141 81 : return outputImages;
142 : }
143 :
144 : // Calculate the size of each part in pixels
145 4 : int partSizeX = (int)((inputImage.cols + split - 1) / split);
146 4 : int partSizeY = (int)((inputImage.rows + split - 1) / split);
147 :
148 : // Calculate the number of splits in x and y directions
149 4 : int splitsX = (inputImage.cols + partSizeX - 1) / partSizeX;
150 4 : int splitsY = (inputImage.rows + partSizeY - 1) / partSizeY;
151 :
152 4 : cv::Mat paddedImage;
153 4 : if (withPadding) {
154 : // Calculate the size of padding to make the image size a multiple of
155 : // partSize
156 2 : int paddingX = splitsX * partSizeX - inputImage.cols;
157 2 : int paddingY = splitsY * partSizeY - inputImage.rows;
158 :
159 : // Create a copy of the image with padding black (cv::Scalar(0,0,0)) on the
160 : // right and bottom.
161 2 : cv::copyMakeBorder(inputImage, paddedImage, 0, paddingY, 0, paddingX,
162 4 : cv::BORDER_CONSTANT, cv::Scalar(0, 0, 0));
163 : }
164 : // Loop over the image and create the smaller Region of Interest (roi) parts
165 14 : for (int i = 0; i < splitsY; ++i) {
166 36 : for (int j = 0; j < splitsX; ++j) {
167 26 : int roiWidth =
168 26 : (j == (int)split - 1) ? inputImage.cols - j * partSizeX : partSizeX;
169 26 : int roiHeight =
170 26 : (i == (int)split - 1) ? inputImage.rows - i * partSizeY : partSizeY;
171 26 : cv::Rect roi(j * partSizeX, i * partSizeY, roiWidth, roiHeight);
172 70 : outputImages.push_back(withPadding ? paddedImage(roi).clone()
173 44 : : inputImage(roi).clone());
174 : }
175 : }
176 :
177 4 : return outputImages;
178 4 : }
179 :
180 2 : void ImageHelper::saveImage(const std::string &imagePath,
181 : const ImageParts &imageParts, size_t split,
182 : size_t resize_x, size_t resize_y) const {
183 3 : if (imageParts.empty() || split == 0 ||
184 1 : (split == 1 && imageParts.size() != 1)) {
185 0 : throw ImageHelperException(
186 0 : "internal exception: invalid image parts or split number.");
187 : }
188 : try {
189 :
190 2 : auto image = split == 1 ? *imageParts.front()
191 2 : : joinImages(imageParts, (int)split, (int)split);
192 :
193 2 : if (image.data.empty()) {
194 0 : throw ImageHelperException("Image data is empty.");
195 : }
196 :
197 2 : image.resize(resize_x, resize_y);
198 :
199 : // convert back the [0,1] float range image to 255 pixel values
200 2 : image.data.convertTo(image.data, image.orig_type, 255.0);
201 :
202 : // Convert back to the original color format
203 2 : cv::Mat tmp;
204 2 : switch (image.orig_channels) {
205 0 : case 1:
206 0 : cv::cvtColor(image.data, tmp, cv::COLOR_BGRA2GRAY);
207 0 : break;
208 2 : case 3:
209 2 : cv::cvtColor(image.data, tmp, cv::COLOR_BGRA2RGB);
210 2 : break;
211 0 : case 4:
212 0 : cv::cvtColor(image.data, tmp, cv::COLOR_BGRA2RGBA);
213 0 : break;
214 0 : default:
215 0 : SimpleLogger::LOG_WARN(
216 : "Non implemented image colors channels processing: ",
217 : image.orig_channels);
218 0 : tmp = image.data;
219 0 : break;
220 : }
221 :
222 : // write the image
223 : // std::vector<int> params;
224 : // params.push_back(cv::IMWRITE_PNG_COMPRESSION);
225 : // params.push_back(9); // Compression level
226 : // if (!cv::imwrite(imagePath, mat, params)) {
227 2 : if (!cv::imwrite(imagePath, tmp)) {
228 0 : throw ImageHelperException("Error saving image: " + imagePath);
229 : }
230 2 : } catch (ImageHelperException &ihe) {
231 0 : throw ihe;
232 0 : } catch (const cv::Exception &e) {
233 0 : throw ImageHelperException("Error saving image: " + imagePath + ": " +
234 0 : e.what());
235 0 : } catch (std::exception &ex) {
236 0 : throw ImageHelperException(ex.what());
237 0 : }
238 2 : }
239 :
240 1 : Image ImageHelper::joinImages(const ImageParts &images, int splitsX,
241 : int splitsY) const {
242 1 : if (images.empty()) {
243 0 : throw ImageHelperException("internal exception: empty parts.");
244 : }
245 1 : if (splitsX == 0 || splitsY == 0) {
246 0 : throw ImageHelperException("internal exception: split 0.");
247 : }
248 1 : std::vector<cv::Mat> rows;
249 3 : for (int i = 0; i < splitsY; ++i) {
250 2 : std::vector<cv::Mat> row;
251 6 : for (int j = 0; j < splitsX; ++j) {
252 4 : row.push_back(images[i * splitsX + j]->data);
253 : }
254 2 : cv::Mat rowImage;
255 2 : cv::hconcat(row, rowImage);
256 2 : rows.push_back(rowImage);
257 2 : }
258 1 : cv::Mat result;
259 1 : cv::vconcat(rows, result);
260 :
261 : Image image{.data = result,
262 1 : .orig_height = images.front()->orig_height,
263 1 : .orig_width = images.front()->orig_height,
264 1 : .orig_type = images.front()->orig_type,
265 4 : .orig_channels = images.front()->orig_channels};
266 2 : return image;
267 1 : }
268 :
269 38 : float ImageHelper::computeLoss(const cv::Mat &outputData,
270 : const cv::Mat &targetData) const {
271 76 : if (outputData.total() != targetData.total() || outputData.total() == 0 ||
272 38 : targetData.total() == 0) {
273 0 : throw std::invalid_argument("Output and target images have different "
274 0 : "sizes, or some are empty.");
275 : }
276 :
277 : // Calculate element-wise squared differences
278 38 : cv::Mat diff;
279 38 : cv::absdiff(outputData, targetData, diff);
280 38 : diff = diff.mul(diff);
281 :
282 : // Compute the sum of squared differences
283 38 : cv::Scalar sumSquaredDiff = cv::sum(diff);
284 :
285 : // Compute the number of pixels
286 38 : size_t numPixels = outputData.total();
287 :
288 : // Calculate the MSE loss
289 38 : float mseLoss = 0.0f;
290 : if (sumSquaredDiff.rows > 0) { // sumSquaredDiff is 1 col, 4 rows (rgba).
291 190 : for (int i = 0; i < sumSquaredDiff.rows; i++) {
292 152 : mseLoss += static_cast<float>(sumSquaredDiff.val[i]) /
293 152 : static_cast<float>(numPixels);
294 : }
295 :
296 38 : mseLoss /= static_cast<float>(sumSquaredDiff.rows);
297 : }
298 :
299 38 : return mseLoss;
300 38 : }
|