LCOV - code coverage report
Current view: top level - include - TrainingDataFactory.h (source / functions) Hit Total Coverage
Test: lcov.info Lines: 20 27 74.1 %
Date: 2024-12-28 17:36:05 Functions: 7 7 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file TrainingDataFactory.h
       3             :  * @author Damien Balima (www.dams-labs.net)
       4             :  * @brief TrainingData Factory
       5             :  * @date 2024-04-12
       6             :  *
       7             :  * @copyright Damien Balima (c) CC-BY-NC-SA-4.0 2024
       8             :  *
       9             :  */
      10             : #pragma once
      11             : #include "Common.h"
      12             : #include "DataList.h"
      13             : #include "ImageHelper.h"
      14             : #include "TrainingDataReader.h"
      15             : #include "exception/TrainingDataFactoryException.h"
      16             : #include <atomic>
      17             : #include <cstddef>
      18             : #include <memory>
      19             : #include <mutex>
      20             : #include <random>
      21             : 
      22             : namespace sipai {
      23             : class TrainingDataFactory {
      24             : public:
      25          35 :   static TrainingDataFactory &getInstance() {
      26             :     static std::once_flag initInstanceFlag;
      27          35 :     std::call_once(initInstanceFlag,
      28           1 :                    [] { instance_.reset(new TrainingDataFactory); });
      29          35 :     return *instance_;
      30             :   }
      31             :   static const TrainingDataFactory &getConstInstance() {
      32             :     return const_cast<const TrainingDataFactory &>(getInstance());
      33             :   }
      34             :   TrainingDataFactory(TrainingDataFactory const &) = delete;
      35             :   void operator=(TrainingDataFactory const &) = delete;
      36           1 :   ~TrainingDataFactory() = default;
      37             : 
      38             :   enum class DataListType {
      39             :     INPUT_TARGET,
      40             :     TARGET_FOLDER,
      41             :   };
      42             : 
      43             :   /**
      44             :    * @brief Get the next input and target images for training.
      45             :    *
      46             :    * @return Pointer to the next input and target images
      47             :    * for training, or nullptr if no more images are available.
      48             :    */
      49             :   std::shared_ptr<Data> next(const TrainingPhase &phase);
      50             : 
      51             :   /**
      52             :    * @brief Get training pairs collection size
      53             :    *
      54             :    * @return size_t
      55             :    */
      56          35 :   size_t getSize(TrainingPhase phase) const {
      57          35 :     switch (phase) {
      58          23 :     case TrainingPhase::Training:
      59          23 :       return dataList_.data_training.size();
      60          12 :     case TrainingPhase::Validation:
      61          12 :       return dataList_.data_validation.size();
      62           0 :     default:
      63           0 :       throw TrainingDataFactoryException("Non-implemeted TrainingPhase");
      64             :     }
      65             :   }
      66             : 
      67             :   /**
      68             :    * @brief Load the training and validation collections paths
      69             :    *
      70             :    */
      71             :   void loadData();
      72             : 
      73             :   /**
      74             :    * @brief Reset training and validation counters
      75             :    *
      76             :    */
      77             :   void resetCounters();
      78             : 
      79             :   /**
      80             :    * @brief Indicate if the collections are loaded
      81             :    *
      82             :    * @return true
      83             :    * @return false
      84             :    */
      85           7 :   bool isLoaded() const { return isLoaded_; }
      86             : 
      87             :   /**
      88             :    * @brief Indicate if the training is using a data folder
      89             :    *
      90             :    * @return true
      91             :    * @return false
      92             :    */
      93             :   bool isDataFolder() const;
      94             : 
      95             :   /**
      96             :    * @brief Clear all data and reset counters
      97             :    *
      98             :    */
      99             :   void clear();
     100             : 
     101             :   /**
     102             :    * @brief Shuffle a vector
     103             :    *
     104             :    * @param data
     105             :    */
     106           6 :   void shuffle(TrainingPhase phase) {
     107           6 :     switch (phase) {
     108           6 :     case TrainingPhase::Training:
     109           6 :       std::shuffle(dataList_.data_training.begin(),
     110           6 :                    dataList_.data_training.end(), gen_);
     111           6 :       break;
     112           0 :     case TrainingPhase::Validation:
     113           0 :       std::shuffle(dataList_.data_validation.begin(),
     114           0 :                    dataList_.data_validation.end(), gen_);
     115           0 :     default:
     116           0 :       break;
     117             :     }
     118           6 :   }
     119             : 
     120             : private:
     121           1 :   TrainingDataFactory() : gen_(rd_()) {}
     122             :   static std::unique_ptr<TrainingDataFactory> instance_;
     123             : 
     124             :   TrainingDataReader trainingDataReader_;
     125             :   ImageHelper imageHelper_;
     126             :   std::atomic<bool> isLoaded_ = false;
     127             :   size_t currentTrainingIndex_ = 0;
     128             :   size_t currentValidationIndex_ = 0;
     129             : 
     130             :   // form random
     131             :   std::random_device rd_;
     132             :   std::mt19937 gen_;
     133             : 
     134             :   DataList dataList_;
     135             :   DataListType dataListType_;
     136             : };
     137             : } // namespace sipai

Generated by: LCOV version 1.16