LCOV - code coverage report
Current view: top level - include - NeuralNetworkImportExportFacade.h (source / functions) Hit Total Coverage
Test: lcov.info Lines: 1 1 100.0 %
Date: 2024-12-28 17:36:05 Functions: 2 2 100.0 %

          Line data    Source code
       1             : /**
       2             :  * @file NeuralNetworkImportExportFacade.h
       3             :  * @author Damien Balima (www.dams-labs.net)
       4             :  * @brief Import/export of the neural network
       5             :  * @date 2024-02-20
       6             :  *
       7             :  * @copyright Damien Balima (c) CC-BY-NC-SA-4.0 2024
       8             :  *
       9             :  */
      10             : #pragma once
      11             : 
      12             : #include "NeuralNetwork.h"
      13             : #include "NeuralNetworkImportExportCSV.h"
      14             : #include "NeuralNetworkImportExportJSON.h"
      15             : #include <memory>
      16             : 
      17             : // TODO: find a better serialization framework for big data
      18             : namespace sipai {
      19             : class NeuralNetworkImportExportFacade {
      20             : public:
      21          10 :   virtual ~NeuralNetworkImportExportFacade() = default;
      22             : 
      23             :   /**
      24             :    * @brief Import a network model from JSON model file (without weights)
      25             :    *
      26             :    * @param appParams
      27             :    * @param networkParams
      28             :    * @return std::unique_ptr<NeuralNetwork>
      29             :    */
      30             :   virtual std::unique_ptr<NeuralNetwork>
      31             :   importModel(const AppParams &appParams, NeuralNetworkParams &networkParams);
      32             : 
      33             :   /**
      34             :    * @brief Import the network weights from a CSV weights file (network model
      35             :    * should be imported first)
      36             :    *
      37             :    * @param network
      38             :    * @param appParams
      39             :    * @param progressCallback
      40             :    * @param progressInitialValue
      41             :    */
      42             :   void importWeights(std::unique_ptr<NeuralNetwork> &network,
      43             :                      const AppParams &appParams,
      44             :                      std::function<void(int)> progressCallback = {},
      45             :                      int progressInitialValue = 0);
      46             : 
      47             :   /**
      48             :    * @brief Export a network model files (JSON meta data and CSV neurons data)
      49             :    *
      50             :    * @param network
      51             :    * @param networkParams
      52             :    * @param appParams
      53             :    */
      54             :   virtual void exportModel(const std::unique_ptr<NeuralNetwork> &network,
      55             :                            const NeuralNetworkParams &networkParams,
      56             :                            const AppParams &appParams) const;
      57             : 
      58             : protected:
      59             :   NeuralNetworkImportExportCSV csvIE;
      60             :   NeuralNetworkImportExportJSON jsonIE;
      61             : };
      62             : } // namespace sipai

Generated by: LCOV version 1.16