Line data Source code
1 : /** 2 : * @file NeuralNetworkBuilder.h 3 : * @author Damien Balima (www.dams-labs.net) 4 : * @brief NeuralNetwork Builder 5 : * @date 2024-03-30 6 : * 7 : * @copyright Damien Balima (c) CC-BY-NC-SA-4.0 2024 8 : * 9 : */ 10 : #pragma once 11 : #include "AppParams.h" 12 : #include "NeuralNetwork.h" 13 : #include "NeuralNetworkImportExportFacade.h" 14 : #include "NeuralNetworkParams.h" 15 : #include <memory.h> 16 : #include <memory> 17 : 18 : namespace sipai { 19 : class NeuralNetworkBuilder { 20 : public: 21 : NeuralNetworkBuilder(); 22 : NeuralNetworkBuilder(AppParams &appParams, 23 : NeuralNetworkParams &networkParams); 24 : /** 25 : * @brief Using an external network to build up. 26 : * 27 : * @param network 28 : * @return NeuralNetworkBuilder& 29 : */ 30 1 : NeuralNetworkBuilder &with(std::unique_ptr<NeuralNetwork> &network) { 31 1 : network_ = std::move(network); 32 1 : return *this; 33 : } 34 : 35 : /** 36 : * @brief Using external network parameters. 37 : * 38 : * @param network_params 39 : * @return NeuralNetworkBuilder& 40 : */ 41 8 : NeuralNetworkBuilder &with(NeuralNetworkParams &network_params) { 42 8 : network_params_ = network_params; 43 8 : return *this; 44 : } 45 : 46 : /** 47 : * @brief Using external app parameters. 48 : * 49 : * @param app_params 50 : * @return NeuralNetworkBuilder& 51 : */ 52 : NeuralNetworkBuilder &with(const AppParams &app_params) { 53 : app_params_ = app_params; 54 : return *this; 55 : } 56 : 57 : /** 58 : * @brief Using a progress callback 59 : * 60 : * @param progressCallback 61 : * @return NeuralNetworkBuilder& 62 : */ 63 8 : NeuralNetworkBuilder &with(std::function<void(int)> progressCallback) { 64 8 : progressCallback_ = progressCallback; 65 8 : progressCallbackValue_ = 0; 66 8 : return *this; 67 : } 68 : 69 : /** 70 : * @brief Create a Or import a neural network 71 : * 72 : * @return const NeuralNetworkBuilder& 73 : */ 74 : NeuralNetworkBuilder &createOrImport(); 75 : 76 : /** 77 : * @brief Add the neurons layers of the network. 78 : * 79 : */ 80 : NeuralNetworkBuilder &addLayers(); 81 : 82 : /** 83 : * @brief Add the neighbors connections in a same layer, for all the layers. 84 : */ 85 : NeuralNetworkBuilder &addNeighbors(); 86 : 87 : /** 88 : * @brief Binds the layers of the network together. 89 : */ 90 : NeuralNetworkBuilder &bindLayers(); 91 : 92 : /** 93 : * @brief Initializes the weights of the neurons. 94 : */ 95 : NeuralNetworkBuilder &initializeWeights(); 96 : 97 : /** 98 : * @brief Sets the activation function for the layers. 99 : * 100 : */ 101 : NeuralNetworkBuilder &setActivationFunction(); 102 : 103 : /** 104 : * @brief Build the neural network following the methods chain. 105 : * 106 : * @return std::unique_ptr<NeuralNetwork> 107 : */ 108 : std::unique_ptr<NeuralNetwork> build(); 109 : 110 : private: 111 : std::unique_ptr<NeuralNetwork> network_ = nullptr; 112 : AppParams &app_params_; 113 : NeuralNetworkParams &network_params_; 114 : bool isImported = false; 115 : std::function<void(int)> progressCallback_ = {}; 116 : int progressCallbackValue_ = 0; 117 : 118 55 : void _incrementProgress(int increment) { 119 55 : if (progressCallback_) { 120 0 : progressCallbackValue_ = progressCallbackValue_ + increment > 100 121 0 : ? 100 122 : : progressCallbackValue_ + increment; 123 0 : progressCallback_(progressCallbackValue_); 124 : } 125 55 : } 126 : }; 127 : } // namespace sipai