LCOV - code coverage report
Current view: top level - src - NetworkImportExportJSON.cpp (source / functions) Hit Total Coverage
Test: lcov.info Lines: 112 128 87.5 %
Date: 2024-12-28 17:36:05 Functions: 2 2 100.0 %

          Line data    Source code
       1             : #include "AppParams.h"
       2             : #include "LayerHidden.h"
       3             : #include "LayerInput.h"
       4             : #include "LayerOutput.h"
       5             : #include "NeuralNetwork.h"
       6             : #include "NeuralNetworkImportExportJSON.h"
       7             : #include "NeuralNetworkParams.h"
       8             : #include "SimpleLogger.h"
       9             : #include "exception/ImportExportException.h"
      10             : #include "json.hpp"
      11             : #include <fstream>
      12             : #include <memory>
      13             : // for nlohmann json doc, see https://github.com/nlohmann/json
      14             : 
      15             : using namespace sipai;
      16             : 
      17           4 : void NeuralNetworkImportExportJSON::exportModel(
      18             :     const std::unique_ptr<NeuralNetwork> &network,
      19             :     const NeuralNetworkParams &networkParams,
      20             :     const AppParams &appParams) const {
      21             :   using json = nlohmann::json;
      22           4 :   json json_network;
      23             : 
      24             :   // Serialize the version
      25           4 :   json_network["version"] = appParams.version;
      26             : 
      27             :   // Serialize the layers to JSON.
      28          16 :   for (auto layer : network->layers) {
      29          12 :     json json_layer = {{"type", layer->getLayerTypeStr()},
      30          12 :                        {"size_x", layer->size_x},
      31          12 :                        {"size_y", layer->size_y},
      32         156 :                        {"neurons", layer->neurons.size()}};
      33          12 :     json_network["layers"].push_back(json_layer);
      34          12 :   }
      35             : 
      36             :   // max weights info
      37           4 :   json_network["max_weights"] = network->max_weights;
      38             : 
      39             :   // Serialize the parameters to JSON.
      40           4 :   json_network["parameters"]["input_size_x"] = json(networkParams.input_size_x);
      41           4 :   json_network["parameters"]["input_size_y"] = json(networkParams.input_size_y);
      42           4 :   json_network["parameters"]["hidden_size_x"] =
      43           8 :       json(networkParams.hidden_size_x);
      44           4 :   json_network["parameters"]["hidden_size_y"] =
      45           8 :       json(networkParams.hidden_size_y);
      46           4 :   json_network["parameters"]["output_size_x"] =
      47           8 :       json(networkParams.output_size_x);
      48           4 :   json_network["parameters"]["output_size_y"] =
      49           8 :       json(networkParams.output_size_y);
      50           4 :   json_network["parameters"]["hiddens_count"] =
      51           8 :       json(networkParams.hiddens_count);
      52           4 :   json_network["parameters"]["learning_rate"] =
      53           8 :       json(networkParams.learning_rate);
      54           4 :   json_network["parameters"]["adaptive_learning_rate"] =
      55           8 :       json(networkParams.adaptive_learning_rate);
      56           4 :   json_network["parameters"]["adaptive_learning_rate_factor"] =
      57           8 :       json(networkParams.adaptive_learning_rate_factor);
      58           4 :   json_network["parameters"]["enable_adaptive_increase"] =
      59           8 :       json(networkParams.enable_adaptive_increase);
      60           4 :   json_network["parameters"]["error_min"] = json(networkParams.error_min);
      61           4 :   json_network["parameters"]["error_max"] = json(networkParams.error_max);
      62           4 :   json_network["parameters"]["hidden_activation_alpha"] =
      63           8 :       json(networkParams.hidden_activation_alpha);
      64           4 :   json_network["parameters"]["output_activation_alpha"] =
      65           8 :       json(networkParams.output_activation_alpha);
      66           4 :   json_network["parameters"]["hidden_activation_function"] =
      67           8 :       json(networkParams.hidden_activation_function);
      68           4 :   json_network["parameters"]["output_activation_function"] =
      69           8 :       json(networkParams.output_activation_function);
      70             : 
      71             :   // Write the JSON object to the file.
      72             :   // The 4 argument specifies the indentation level of the resulting string.
      73           4 :   std::ofstream file(appParams.network_to_export);
      74           4 :   file << json_network.dump(2);
      75           4 :   file.close();
      76         136 : }
      77             : 
      78             : std::unique_ptr<NeuralNetwork>
      79           1 : NeuralNetworkImportExportJSON::importModel(const AppParams &appParams,
      80             :                                            NeuralNetworkParams &networkParams) {
      81             :   using json = nlohmann::json;
      82           1 :   const auto &logger = SimpleLogger::getInstance();
      83             : 
      84           1 :   if (appParams.network_to_import.empty()) {
      85           0 :     throw ImportExportException("Empty parameter network_to_import");
      86             :   }
      87             : 
      88           1 :   std::string path_in_ext = appParams.network_to_import;
      89           1 :   if (std::filesystem::path p(path_in_ext); p.parent_path().empty()) {
      90           1 :     path_in_ext = "./" + path_in_ext;
      91           1 :   }
      92             : 
      93           1 :   std::ifstream file(path_in_ext);
      94           1 :   json json_model;
      95             : 
      96           1 :   if (!file.is_open()) {
      97           0 :     throw ImportExportException("Failed to open file: " + path_in_ext);
      98             :   }
      99             : 
     100           1 :   if (!json::accept(file)) {
     101           0 :     file.close();
     102           0 :     throw ImportExportException("Json parsing error: " + path_in_ext);
     103             :   }
     104           1 :   file.seekg(0, std::ifstream::beg);
     105             : 
     106             :   try {
     107           1 :     json_model = json::parse(file);
     108           1 :     auto network = std::make_unique<NeuralNetwork>();
     109             : 
     110           1 :     if (std::string jversion = json_model["version"];
     111           1 :         jversion != appParams.version) {
     112           0 :       logger.warn("The model version of the file is different from the current "
     113           0 :                   "version: " +
     114           0 :                   jversion + " vs " + appParams.version);
     115           1 :     }
     116             : 
     117             :     // Create a new Network object and deserialize the JSON data into it.
     118           1 :     networkParams.input_size_x = json_model["parameters"]["input_size_x"];
     119           1 :     networkParams.input_size_y = json_model["parameters"]["input_size_y"];
     120           1 :     networkParams.hidden_size_x = json_model["parameters"]["hidden_size_x"];
     121           1 :     networkParams.hidden_size_y = json_model["parameters"]["hidden_size_y"];
     122           1 :     networkParams.output_size_x = json_model["parameters"]["output_size_x"];
     123           1 :     networkParams.output_size_y = json_model["parameters"]["output_size_y"];
     124           1 :     networkParams.hiddens_count = json_model["parameters"]["hiddens_count"];
     125           1 :     networkParams.learning_rate = json_model["parameters"]["learning_rate"];
     126           1 :     networkParams.adaptive_learning_rate =
     127           1 :         json_model["parameters"]["adaptive_learning_rate"];
     128           1 :     networkParams.adaptive_learning_rate_factor =
     129           1 :         json_model["parameters"]["adaptive_learning_rate_factor"];
     130           1 :     networkParams.enable_adaptive_increase =
     131           1 :         json_model["parameters"]["enable_adaptive_increase"];
     132           1 :     networkParams.error_min = json_model["parameters"]["error_min"];
     133           1 :     networkParams.error_max = json_model["parameters"]["error_max"];
     134           1 :     networkParams.hidden_activation_alpha =
     135           1 :         json_model["parameters"]["hidden_activation_alpha"];
     136           1 :     networkParams.output_activation_alpha =
     137           1 :         json_model["parameters"]["output_activation_alpha"];
     138           1 :     networkParams.hidden_activation_function =
     139           1 :         json_model["parameters"]["hidden_activation_function"];
     140           1 :     networkParams.output_activation_function =
     141           1 :         json_model["parameters"]["output_activation_function"];
     142             : 
     143           1 :     network->max_weights = json_model["max_weights"];
     144             : 
     145           4 :     for (auto json_layer : json_model["layers"]) {
     146             :       // Get the type of the layer.
     147           3 :       std::string layer_type_str = json_layer["type"];
     148           3 :       LayerType layer_type = layer_map.at(layer_type_str);
     149             : 
     150             :       // // Create a new layer object of the appropriate type.
     151           3 :       Layer *layer = nullptr;
     152           3 :       switch (layer_type) {
     153           1 :       case LayerType::LayerInput:
     154           0 :         layer = new LayerInput((size_t)json_layer["size_x"],
     155           1 :                                (size_t)json_layer["size_y"]);
     156           1 :         break;
     157           1 :       case LayerType::LayerHidden:
     158           0 :         layer = new LayerHidden((size_t)json_layer["size_x"],
     159           1 :                                 (size_t)json_layer["size_y"]);
     160           1 :         layer->eactivationFunction = networkParams.hidden_activation_function;
     161           1 :         layer->activationFunctionAlpha = networkParams.hidden_activation_alpha;
     162           1 :         break;
     163           1 :       case LayerType::LayerOutput:
     164           0 :         layer = new LayerOutput((size_t)json_layer["size_x"],
     165           1 :                                 (size_t)json_layer["size_y"]);
     166           1 :         layer->eactivationFunction = networkParams.output_activation_function;
     167           1 :         layer->activationFunctionAlpha = networkParams.output_activation_alpha;
     168           1 :         break;
     169           0 :       default:
     170           0 :         throw ImportExportException("Layer type not recognized");
     171             :       }
     172             : 
     173             :       // Add the layer to the network.
     174           3 :       network->layers.push_back(layer);
     175           3 :     }
     176             : 
     177           1 :     if (network->layers.front()->layerType != LayerType::LayerInput) {
     178           0 :       throw ImportExportException("Invalid input layer");
     179             :     }
     180             : 
     181           1 :     if (network->layers.back()->layerType != LayerType::LayerOutput) {
     182           0 :       throw ImportExportException("Invalid output layer");
     183             :     }
     184           2 :     return network;
     185           1 :   } catch (const nlohmann::json::parse_error &e) {
     186           0 :     throw ImportExportException("Json parsing error: " + std::string(e.what()));
     187           0 :   }
     188           1 : }

Generated by: LCOV version 1.16