Line data Source code
1 : #include "RunnerEnhancerOpenCVVisitor.h" 2 : #include "ImageHelper.h" 3 : #include "Manager.h" 4 : #include "SimpleLogger.h" 5 : #include "exception/RunnerVisitorException.h" 6 : #include <exception> 7 : #include <memory> 8 : 9 : using namespace sipai; 10 : 11 4 : void RunnerEnhancerOpenCVVisitor::visit() const { 12 4 : SimpleLogger::LOG_INFO("Image enhancement..."); 13 4 : auto &manager = Manager::getInstance(); 14 : 15 4 : if (!manager.network) { 16 1 : throw RunnerVisitorException("No neural network. Aborting."); 17 : } 18 : 19 3 : if (manager.app_params.input_file.empty()) { 20 1 : throw RunnerVisitorException("No input file. Aborting."); 21 : } 22 : 23 2 : if (manager.app_params.output_file.empty()) { 24 1 : throw RunnerVisitorException("No output file. Aborting."); 25 : } 26 : 27 : try { 28 1 : const auto &app_params = manager.app_params; 29 1 : const auto &network_params = manager.network_params; 30 : 31 : // Load input image parts 32 : const auto &inputImage = imageHelper_.loadImage( 33 1 : app_params.input_file, app_params.image_split, 34 1 : app_params.enable_padding, network_params.input_size_x, 35 1 : network_params.input_size_y); 36 : 37 : // Get output image parts by forward propagation 38 1 : ImageParts outputParts; 39 2 : for (const auto &inputPart : inputImage) { 40 : const auto &outputData = 41 1 : manager.network->forwardPropagation(inputPart->data); 42 : Image output{.data = outputData, 43 2 : .orig_height = inputPart->orig_height, 44 2 : .orig_width = inputPart->orig_width, 45 2 : .orig_type = inputPart->orig_type, 46 1 : .orig_channels = inputPart->orig_channels}; 47 1 : outputParts.push_back(std::make_unique<Image>(output)); 48 1 : } 49 : 50 : // Save the output image parts as a single image 51 1 : size_t outputSizeX = outputParts.front()->orig_width; 52 1 : size_t outputSizeY = outputParts.front()->orig_height; 53 1 : imageHelper_.saveImage(app_params.output_file, outputParts, 54 1 : app_params.image_split, 55 1 : (size_t)(outputSizeX * app_params.output_scale), 56 1 : (size_t)(outputSizeY * app_params.output_scale)); 57 : 58 1 : SimpleLogger::LOG_INFO("Image enhancement done. Image output saved in ", 59 1 : manager.app_params.output_file); 60 : 61 1 : } catch (std::exception &ex) { 62 0 : throw RunnerVisitorException(ex.what()); 63 0 : } 64 1 : }