LCOV - code coverage report
Current view: top level - include - NeuralNetworkBuilder.h (source / functions) Hit Total Coverage
Test: lcov.info Lines: 13 16 81.2 %
Date: 2024-12-28 17:36:05 Functions: 4 4 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file NeuralNetworkBuilder.h
       3             :  * @author Damien Balima (www.dams-labs.net)
       4             :  * @brief NeuralNetwork Builder
       5             :  * @date 2024-03-30
       6             :  *
       7             :  * @copyright Damien Balima (c) CC-BY-NC-SA-4.0 2024
       8             :  *
       9             :  */
      10             : #pragma once
      11             : #include "AppParams.h"
      12             : #include "NeuralNetwork.h"
      13             : #include "NeuralNetworkImportExportFacade.h"
      14             : #include "NeuralNetworkParams.h"
      15             : #include <memory.h>
      16             : #include <memory>
      17             : 
      18             : namespace sipai {
      19             : class NeuralNetworkBuilder {
      20             : public:
      21             :   NeuralNetworkBuilder();
      22             :   NeuralNetworkBuilder(AppParams &appParams,
      23             :                        NeuralNetworkParams &networkParams);
      24             :   /**
      25             :    * @brief Using an external network to build up.
      26             :    *
      27             :    * @param network
      28             :    * @return NeuralNetworkBuilder&
      29             :    */
      30           1 :   NeuralNetworkBuilder &with(std::unique_ptr<NeuralNetwork> &network) {
      31           1 :     network_ = std::move(network);
      32           1 :     return *this;
      33             :   }
      34             : 
      35             :   /**
      36             :    * @brief Using external network parameters.
      37             :    *
      38             :    * @param network_params
      39             :    * @return NeuralNetworkBuilder&
      40             :    */
      41           8 :   NeuralNetworkBuilder &with(NeuralNetworkParams &network_params) {
      42           8 :     network_params_ = network_params;
      43           8 :     return *this;
      44             :   }
      45             : 
      46             :   /**
      47             :    * @brief Using external app parameters.
      48             :    *
      49             :    * @param app_params
      50             :    * @return NeuralNetworkBuilder&
      51             :    */
      52             :   NeuralNetworkBuilder &with(const AppParams &app_params) {
      53             :     app_params_ = app_params;
      54             :     return *this;
      55             :   }
      56             : 
      57             :   /**
      58             :    * @brief Using a progress callback
      59             :    *
      60             :    * @param progressCallback
      61             :    * @return NeuralNetworkBuilder&
      62             :    */
      63           8 :   NeuralNetworkBuilder &with(std::function<void(int)> progressCallback) {
      64           8 :     progressCallback_ = progressCallback;
      65           8 :     progressCallbackValue_ = 0;
      66           8 :     return *this;
      67             :   }
      68             : 
      69             :   /**
      70             :    * @brief Create a Or import a neural network
      71             :    *
      72             :    * @return const NeuralNetworkBuilder&
      73             :    */
      74             :   NeuralNetworkBuilder &createOrImport();
      75             : 
      76             :   /**
      77             :    * @brief Add the neurons layers of the network.
      78             :    *
      79             :    */
      80             :   NeuralNetworkBuilder &addLayers();
      81             : 
      82             :   /**
      83             :    * @brief Add the neighbors connections in a same layer, for all the layers.
      84             :    */
      85             :   NeuralNetworkBuilder &addNeighbors();
      86             : 
      87             :   /**
      88             :    * @brief Binds the layers of the network together.
      89             :    */
      90             :   NeuralNetworkBuilder &bindLayers();
      91             : 
      92             :   /**
      93             :    * @brief Initializes the weights of the neurons.
      94             :    */
      95             :   NeuralNetworkBuilder &initializeWeights();
      96             : 
      97             :   /**
      98             :    * @brief Sets the activation function for the layers.
      99             :    *
     100             :    */
     101             :   NeuralNetworkBuilder &setActivationFunction();
     102             : 
     103             :   /**
     104             :    * @brief Build the neural network following the methods chain.
     105             :    *
     106             :    * @return std::unique_ptr<NeuralNetwork>
     107             :    */
     108             :   std::unique_ptr<NeuralNetwork> build();
     109             : 
     110             : private:
     111             :   std::unique_ptr<NeuralNetwork> network_ = nullptr;
     112             :   AppParams &app_params_;
     113             :   NeuralNetworkParams &network_params_;
     114             :   bool isImported = false;
     115             :   std::function<void(int)> progressCallback_ = {};
     116             :   int progressCallbackValue_ = 0;
     117             : 
     118          55 :   void _incrementProgress(int increment) {
     119          55 :     if (progressCallback_) {
     120           0 :       progressCallbackValue_ = progressCallbackValue_ + increment > 100
     121           0 :                                    ? 100
     122             :                                    : progressCallbackValue_ + increment;
     123           0 :       progressCallback_(progressCallbackValue_);
     124             :     }
     125          55 :   }
     126             : };
     127             : } // namespace sipai

Generated by: LCOV version 1.16