Line data Source code
1 : #include "TrainingDataFactory.h" 2 : #include "ImageHelper.h" 3 : #include "Manager.h" 4 : #include "SimpleLogger.h" 5 : #include "exception/TrainingDataFactoryException.h" 6 : #include <filesystem> 7 : #include <memory> 8 : #include <numeric> 9 : #include <optional> 10 : 11 : using namespace sipai; 12 : 13 : std::unique_ptr<TrainingDataFactory> TrainingDataFactory::instance_ = nullptr; 14 : 15 2 : bool TrainingDataFactory::isDataFolder() const { 16 2 : const auto &app_params = Manager::getConstInstance().app_params; 17 3 : return !app_params.training_data_folder.empty() && 18 3 : app_params.training_data_file.empty(); 19 : } 20 : 21 4 : void TrainingDataFactory::loadData() { 22 4 : if (isLoaded_) { 23 0 : return; 24 : } 25 : 26 4 : const auto &app_params = Manager::getConstInstance().app_params; 27 4 : if (app_params.verbose) { 28 1 : SimpleLogger::LOG_INFO("Loading images paths..."); 29 : } 30 : 31 4 : std::vector<Data> datas; 32 : // load images paths 33 4 : if (!app_params.training_data_file.empty()) { 34 1 : datas = trainingDataReader_.loadTrainingDataPaths(); 35 1 : dataListType_ = DataListType::INPUT_TARGET; 36 3 : } else if (!app_params.training_data_folder.empty()) { 37 2 : datas = trainingDataReader_.loadTrainingDataFolder(); 38 2 : dataListType_ = DataListType::TARGET_FOLDER; 39 : } else { 40 1 : throw TrainingDataFactoryException( 41 2 : "Invalid training data file or data folder"); 42 : } 43 3 : if (app_params.random_loading) { 44 3 : std::shuffle(datas.begin(), datas.end(), gen_); 45 : } 46 : // split datas 47 : size_t split_index = 48 3 : static_cast<size_t>(datas.size() * app_params.training_split_ratio); 49 33 : for (size_t i = 0; i < datas.size(); ++i) { 50 30 : if (i < split_index) { 51 21 : dataList_.data_training.push_back(datas[i]); 52 : } else { 53 9 : dataList_.data_validation.push_back(datas[i]); 54 : } 55 : } 56 : 57 3 : isLoaded_ = true; 58 3 : if (app_params.verbose) { 59 1 : SimpleLogger::LOG_INFO( 60 1 : "Images paths loaded: ", dataList_.data_training.size(), 61 2 : " images for training, ", dataList_.data_validation.size(), 62 : " images for validation."); 63 : } 64 4 : } 65 : 66 72 : std::shared_ptr<Data> TrainingDataFactory::next(const TrainingPhase &phase) { 67 72 : const auto &manager = Manager::getConstInstance(); 68 72 : const auto &app_params = manager.app_params; 69 72 : const auto &network_params = manager.network_params; 70 72 : size_t *index = nullptr; 71 72 : std::vector<Data> *datas = nullptr; 72 72 : switch (phase) { 73 48 : case TrainingPhase::Training: 74 48 : index = ¤tTrainingIndex_; 75 48 : datas = &dataList_.data_training; 76 48 : break; 77 24 : case TrainingPhase::Validation: 78 24 : index = ¤tValidationIndex_; 79 24 : datas = &dataList_.data_validation; 80 24 : break; 81 0 : default: 82 0 : throw TrainingDataFactoryException("Unimplemented TrainingPhase"); 83 : } 84 : 85 72 : if (*index >= datas->size()) { 86 : // No more training data 87 12 : return nullptr; 88 : } 89 60 : auto &data = datas->at(*index); 90 : // check if bulk_loading and already loaded 91 60 : if (app_params.bulk_loading && data.img_input.size() > 0 && 92 0 : data.img_output.size() > 0) { 93 0 : return std::make_shared<Data>(data); 94 : } 95 : 96 : // load the target image 97 : ImageParts targetImageParts = imageHelper_.loadImage( 98 60 : data.file_target, app_params.image_split, app_params.enable_padding, 99 120 : network_params.output_size_x, network_params.output_size_y); 100 : 101 : // generate or load the input image 102 60 : ImageParts inputImageParts; 103 60 : switch (dataListType_) { 104 40 : case DataListType::TARGET_FOLDER: 105 40 : inputImageParts = imageHelper_.generateInputImage( 106 40 : targetImageParts, app_params.training_reduce_factor, 107 40 : network_params.input_size_x, network_params.input_size_y); 108 40 : break; 109 20 : case DataListType::INPUT_TARGET: 110 20 : inputImageParts = imageHelper_.loadImage( 111 20 : data.file_input, app_params.image_split, app_params.enable_padding, 112 20 : network_params.input_size_x, network_params.input_size_y); 113 20 : break; 114 0 : default: 115 0 : throw TrainingDataFactoryException("Unimplemented DataListType"); 116 : } 117 : 118 60 : (*index)++; 119 : 120 60 : if (app_params.bulk_loading) { 121 0 : data.img_input = inputImageParts; 122 0 : data.img_target = targetImageParts; 123 0 : return std::make_shared<Data>(data); 124 : } else { 125 60 : return std::make_shared<Data>(Data{ 126 60 : .file_input = data.file_input, 127 60 : .file_target = data.file_target, 128 60 : .file_output = data.file_output, 129 : .img_input = inputImageParts, 130 : .img_target = targetImageParts, 131 60 : .img_output = data.img_output, 132 60 : }); 133 : } 134 60 : } 135 : 136 19 : void TrainingDataFactory::resetCounters() { 137 19 : currentTrainingIndex_ = 0; 138 19 : currentValidationIndex_ = 0; 139 19 : } 140 : 141 7 : void TrainingDataFactory::clear() { 142 7 : dataList_.data_training.clear(); 143 7 : dataList_.data_validation.clear(); 144 7 : resetCounters(); 145 7 : isLoaded_ = false; 146 7 : }