Line data Source code
1 : /** 2 : * @file NeuralNetworkImportExportFacade.h 3 : * @author Damien Balima (www.dams-labs.net) 4 : * @brief Import/export of the neural network 5 : * @date 2024-02-20 6 : * 7 : * @copyright Damien Balima (c) CC-BY-NC-SA-4.0 2024 8 : * 9 : */ 10 : #pragma once 11 : 12 : #include "NeuralNetwork.h" 13 : #include "NeuralNetworkImportExportCSV.h" 14 : #include "NeuralNetworkImportExportJSON.h" 15 : #include <memory> 16 : 17 : // TODO: find a better serialization framework for big data 18 : namespace sipai { 19 : class NeuralNetworkImportExportFacade { 20 : public: 21 10 : virtual ~NeuralNetworkImportExportFacade() = default; 22 : 23 : /** 24 : * @brief Import a network model from JSON model file (without weights) 25 : * 26 : * @param appParams 27 : * @param networkParams 28 : * @return std::unique_ptr<NeuralNetwork> 29 : */ 30 : virtual std::unique_ptr<NeuralNetwork> 31 : importModel(const AppParams &appParams, NeuralNetworkParams &networkParams); 32 : 33 : /** 34 : * @brief Import the network weights from a CSV weights file (network model 35 : * should be imported first) 36 : * 37 : * @param network 38 : * @param appParams 39 : * @param progressCallback 40 : * @param progressInitialValue 41 : */ 42 : void importWeights(std::unique_ptr<NeuralNetwork> &network, 43 : const AppParams &appParams, 44 : std::function<void(int)> progressCallback = {}, 45 : int progressInitialValue = 0); 46 : 47 : /** 48 : * @brief Export a network model files (JSON meta data and CSV neurons data) 49 : * 50 : * @param network 51 : * @param networkParams 52 : * @param appParams 53 : */ 54 : virtual void exportModel(const std::unique_ptr<NeuralNetwork> &network, 55 : const NeuralNetworkParams &networkParams, 56 : const AppParams &appParams) const; 57 : 58 : protected: 59 : NeuralNetworkImportExportCSV csvIE; 60 : NeuralNetworkImportExportJSON jsonIE; 61 : }; 62 : } // namespace sipai