Line data Source code
1 : /** 2 : * @file TrainingDataFactory.h 3 : * @author Damien Balima (www.dams-labs.net) 4 : * @brief TrainingData Factory 5 : * @date 2024-04-12 6 : * 7 : * @copyright Damien Balima (c) CC-BY-NC-SA-4.0 2024 8 : * 9 : */ 10 : #pragma once 11 : #include "Common.h" 12 : #include "DataList.h" 13 : #include "ImageHelper.h" 14 : #include "TrainingDataReader.h" 15 : #include "exception/TrainingDataFactoryException.h" 16 : #include <atomic> 17 : #include <cstddef> 18 : #include <memory> 19 : #include <mutex> 20 : #include <random> 21 : 22 : namespace sipai { 23 : class TrainingDataFactory { 24 : public: 25 35 : static TrainingDataFactory &getInstance() { 26 : static std::once_flag initInstanceFlag; 27 35 : std::call_once(initInstanceFlag, 28 1 : [] { instance_.reset(new TrainingDataFactory); }); 29 35 : return *instance_; 30 : } 31 : static const TrainingDataFactory &getConstInstance() { 32 : return const_cast<const TrainingDataFactory &>(getInstance()); 33 : } 34 : TrainingDataFactory(TrainingDataFactory const &) = delete; 35 : void operator=(TrainingDataFactory const &) = delete; 36 1 : ~TrainingDataFactory() = default; 37 : 38 : enum class DataListType { 39 : INPUT_TARGET, 40 : TARGET_FOLDER, 41 : }; 42 : 43 : /** 44 : * @brief Get the next input and target images for training. 45 : * 46 : * @return Pointer to the next input and target images 47 : * for training, or nullptr if no more images are available. 48 : */ 49 : std::shared_ptr<Data> next(const TrainingPhase &phase); 50 : 51 : /** 52 : * @brief Get training pairs collection size 53 : * 54 : * @return size_t 55 : */ 56 35 : size_t getSize(TrainingPhase phase) const { 57 35 : switch (phase) { 58 23 : case TrainingPhase::Training: 59 23 : return dataList_.data_training.size(); 60 12 : case TrainingPhase::Validation: 61 12 : return dataList_.data_validation.size(); 62 0 : default: 63 0 : throw TrainingDataFactoryException("Non-implemeted TrainingPhase"); 64 : } 65 : } 66 : 67 : /** 68 : * @brief Load the training and validation collections paths 69 : * 70 : */ 71 : void loadData(); 72 : 73 : /** 74 : * @brief Reset training and validation counters 75 : * 76 : */ 77 : void resetCounters(); 78 : 79 : /** 80 : * @brief Indicate if the collections are loaded 81 : * 82 : * @return true 83 : * @return false 84 : */ 85 7 : bool isLoaded() const { return isLoaded_; } 86 : 87 : /** 88 : * @brief Indicate if the training is using a data folder 89 : * 90 : * @return true 91 : * @return false 92 : */ 93 : bool isDataFolder() const; 94 : 95 : /** 96 : * @brief Clear all data and reset counters 97 : * 98 : */ 99 : void clear(); 100 : 101 : /** 102 : * @brief Shuffle a vector 103 : * 104 : * @param data 105 : */ 106 6 : void shuffle(TrainingPhase phase) { 107 6 : switch (phase) { 108 6 : case TrainingPhase::Training: 109 6 : std::shuffle(dataList_.data_training.begin(), 110 6 : dataList_.data_training.end(), gen_); 111 6 : break; 112 0 : case TrainingPhase::Validation: 113 0 : std::shuffle(dataList_.data_validation.begin(), 114 0 : dataList_.data_validation.end(), gen_); 115 0 : default: 116 0 : break; 117 : } 118 6 : } 119 : 120 : private: 121 1 : TrainingDataFactory() : gen_(rd_()) {} 122 : static std::unique_ptr<TrainingDataFactory> instance_; 123 : 124 : TrainingDataReader trainingDataReader_; 125 : ImageHelper imageHelper_; 126 : std::atomic<bool> isLoaded_ = false; 127 : size_t currentTrainingIndex_ = 0; 128 : size_t currentValidationIndex_ = 0; 129 : 130 : // form random 131 : std::random_device rd_; 132 : std::mt19937 gen_; 133 : 134 : DataList dataList_; 135 : DataListType dataListType_; 136 : }; 137 : } // namespace sipai