LCOV - code coverage report
Current view: top level - src - TrainingDataFactory.cpp (source / functions) Hit Total Coverage
Test: lcov.info Lines: 83 93 89.2 %
Date: 2024-12-28 17:36:05 Functions: 5 5 100.0 %

          Line data    Source code
       1             : #include "TrainingDataFactory.h"
       2             : #include "ImageHelper.h"
       3             : #include "Manager.h"
       4             : #include "SimpleLogger.h"
       5             : #include "exception/TrainingDataFactoryException.h"
       6             : #include <filesystem>
       7             : #include <memory>
       8             : #include <numeric>
       9             : #include <optional>
      10             : 
      11             : using namespace sipai;
      12             : 
      13             : std::unique_ptr<TrainingDataFactory> TrainingDataFactory::instance_ = nullptr;
      14             : 
      15           2 : bool TrainingDataFactory::isDataFolder() const {
      16           2 :   const auto &app_params = Manager::getConstInstance().app_params;
      17           3 :   return !app_params.training_data_folder.empty() &&
      18           3 :          app_params.training_data_file.empty();
      19             : }
      20             : 
      21           4 : void TrainingDataFactory::loadData() {
      22           4 :   if (isLoaded_) {
      23           0 :     return;
      24             :   }
      25             : 
      26           4 :   const auto &app_params = Manager::getConstInstance().app_params;
      27           4 :   if (app_params.verbose) {
      28           1 :     SimpleLogger::LOG_INFO("Loading images paths...");
      29             :   }
      30             : 
      31           4 :   std::vector<Data> datas;
      32             :   // load images paths
      33           4 :   if (!app_params.training_data_file.empty()) {
      34           1 :     datas = trainingDataReader_.loadTrainingDataPaths();
      35           1 :     dataListType_ = DataListType::INPUT_TARGET;
      36           3 :   } else if (!app_params.training_data_folder.empty()) {
      37           2 :     datas = trainingDataReader_.loadTrainingDataFolder();
      38           2 :     dataListType_ = DataListType::TARGET_FOLDER;
      39             :   } else {
      40           1 :     throw TrainingDataFactoryException(
      41           2 :         "Invalid training data file or data folder");
      42             :   }
      43           3 :   if (app_params.random_loading) {
      44           3 :     std::shuffle(datas.begin(), datas.end(), gen_);
      45             :   }
      46             :   // split datas
      47             :   size_t split_index =
      48           3 :       static_cast<size_t>(datas.size() * app_params.training_split_ratio);
      49          33 :   for (size_t i = 0; i < datas.size(); ++i) {
      50          30 :     if (i < split_index) {
      51          21 :       dataList_.data_training.push_back(datas[i]);
      52             :     } else {
      53           9 :       dataList_.data_validation.push_back(datas[i]);
      54             :     }
      55             :   }
      56             : 
      57           3 :   isLoaded_ = true;
      58           3 :   if (app_params.verbose) {
      59           1 :     SimpleLogger::LOG_INFO(
      60           1 :         "Images paths loaded: ", dataList_.data_training.size(),
      61           2 :         " images for training, ", dataList_.data_validation.size(),
      62             :         " images for validation.");
      63             :   }
      64           4 : }
      65             : 
      66          72 : std::shared_ptr<Data> TrainingDataFactory::next(const TrainingPhase &phase) {
      67          72 :   const auto &manager = Manager::getConstInstance();
      68          72 :   const auto &app_params = manager.app_params;
      69          72 :   const auto &network_params = manager.network_params;
      70          72 :   size_t *index = nullptr;
      71          72 :   std::vector<Data> *datas = nullptr;
      72          72 :   switch (phase) {
      73          48 :   case TrainingPhase::Training:
      74          48 :     index = &currentTrainingIndex_;
      75          48 :     datas = &dataList_.data_training;
      76          48 :     break;
      77          24 :   case TrainingPhase::Validation:
      78          24 :     index = &currentValidationIndex_;
      79          24 :     datas = &dataList_.data_validation;
      80          24 :     break;
      81           0 :   default:
      82           0 :     throw TrainingDataFactoryException("Unimplemented TrainingPhase");
      83             :   }
      84             : 
      85          72 :   if (*index >= datas->size()) {
      86             :     // No more training data
      87          12 :     return nullptr;
      88             :   }
      89          60 :   auto &data = datas->at(*index);
      90             :   // check if bulk_loading and already loaded
      91          60 :   if (app_params.bulk_loading && data.img_input.size() > 0 &&
      92           0 :       data.img_output.size() > 0) {
      93           0 :     return std::make_shared<Data>(data);
      94             :   }
      95             : 
      96             :   // load the target image
      97             :   ImageParts targetImageParts = imageHelper_.loadImage(
      98          60 :       data.file_target, app_params.image_split, app_params.enable_padding,
      99         120 :       network_params.output_size_x, network_params.output_size_y);
     100             : 
     101             :   // generate or load the input image
     102          60 :   ImageParts inputImageParts;
     103          60 :   switch (dataListType_) {
     104          40 :   case DataListType::TARGET_FOLDER:
     105          40 :     inputImageParts = imageHelper_.generateInputImage(
     106          40 :         targetImageParts, app_params.training_reduce_factor,
     107          40 :         network_params.input_size_x, network_params.input_size_y);
     108          40 :     break;
     109          20 :   case DataListType::INPUT_TARGET:
     110          20 :     inputImageParts = imageHelper_.loadImage(
     111          20 :         data.file_input, app_params.image_split, app_params.enable_padding,
     112          20 :         network_params.input_size_x, network_params.input_size_y);
     113          20 :     break;
     114           0 :   default:
     115           0 :     throw TrainingDataFactoryException("Unimplemented DataListType");
     116             :   }
     117             : 
     118          60 :   (*index)++;
     119             : 
     120          60 :   if (app_params.bulk_loading) {
     121           0 :     data.img_input = inputImageParts;
     122           0 :     data.img_target = targetImageParts;
     123           0 :     return std::make_shared<Data>(data);
     124             :   } else {
     125          60 :     return std::make_shared<Data>(Data{
     126          60 :         .file_input = data.file_input,
     127          60 :         .file_target = data.file_target,
     128          60 :         .file_output = data.file_output,
     129             :         .img_input = inputImageParts,
     130             :         .img_target = targetImageParts,
     131          60 :         .img_output = data.img_output,
     132          60 :     });
     133             :   }
     134          60 : }
     135             : 
     136          19 : void TrainingDataFactory::resetCounters() {
     137          19 :   currentTrainingIndex_ = 0;
     138          19 :   currentValidationIndex_ = 0;
     139          19 : }
     140             : 
     141           7 : void TrainingDataFactory::clear() {
     142           7 :   dataList_.data_training.clear();
     143           7 :   dataList_.data_validation.clear();
     144           7 :   resetCounters();
     145           7 :   isLoaded_ = false;
     146           7 : }

Generated by: LCOV version 1.16