LCOV - code coverage report
Current view: top level - src - Manager.cpp (source / functions) Hit Total Coverage
Test: lcov.info Lines: 36 102 35.3 %
Date: 2024-12-28 17:36:05 Functions: 4 5 80.0 %

          Line data    Source code
       1             : #include "Manager.h"
       2             : #include "AppParams.h"
       3             : #include "Common.h"
       4             : #include "NeuralNetwork.h"
       5             : #include "SimpleLogger.h"
       6             : #include "VulkanController.h"
       7             : #include <algorithm>
       8             : #include <cstddef>
       9             : #include <cstdint>
      10             : #include <exception>
      11             : #include <memory>
      12             : #include <numeric>
      13             : 
      14             : using namespace sipai;
      15             : 
      16             : std::unique_ptr<Manager> Manager::instance_ = nullptr;
      17             : 
      18             : Manager &
      19           8 : Manager::createOrImportNetwork(std::function<void(int)> progressCallback) {
      20           8 :   if (network) {
      21           0 :     network.reset();
      22             :   }
      23           8 :   auto builder = std::make_unique<NeuralNetworkBuilder>();
      24          16 :   network = builder->with(progressCallback)
      25           8 :                 .createOrImport()
      26           8 :                 .addLayers()
      27           8 :                 .bindLayers()
      28           8 :                 .addNeighbors()
      29           8 :                 .initializeWeights()
      30           8 :                 .setActivationFunction()
      31           8 :                 .build();
      32           8 :   return *this;
      33           8 : }
      34             : 
      35           4 : void Manager::exportNetwork() {
      36           4 :   if (!app_params.network_to_export.empty()) {
      37           4 :     SimpleLogger::LOG_INFO(
      38           4 :         "Saving the neural network to ", app_params.network_to_export, " and ",
      39           8 :         Common::getFilenameCsv(app_params.network_to_export), "...");
      40           4 :     auto exportator = std::make_unique<NeuralNetworkImportExportFacade>();
      41           4 :     exportator->exportModel(network, network_params, app_params);
      42           4 :   }
      43           4 : }
      44             : 
      45           0 : Manager &Manager::showParameters() {
      46           0 :   SimpleLogger::LOG_INFO(
      47           0 :       "Parameters: ", "\nmode: ", Common::getRunModeStr(app_params.run_mode),
      48           0 :       "\nauto-save every ", app_params.epoch_autosave, " epochs",
      49           0 :       "\nauto-exit after ", app_params.max_epochs_without_improvement,
      50             :       " epochs without improvement",
      51           0 :       app_params.max_epochs == NO_MAX_EPOCHS
      52           0 :           ? "\nno maximum epochs"
      53           0 :           : "\nauto-exit after a maximum of " +
      54           0 :                 std::to_string(app_params.max_epochs) + " epochs",
      55           0 :       "\ntraining/validation ratio: ", app_params.training_split_ratio,
      56           0 :       "\nlearning rate: ", network_params.learning_rate,
      57             :       "\nadaptive learning rate: ",
      58           0 :       network_params.adaptive_learning_rate ? "true" : "false",
      59             :       "\nadaptive learning rate increase: ",
      60           0 :       network_params.enable_adaptive_increase ? "true" : "false",
      61             :       "\nadaptive learning rate factor: ",
      62           0 :       network_params.adaptive_learning_rate_factor,
      63           0 :       "\ntraining error min: ", network_params.error_min,
      64           0 :       "\ntraining error max: ", network_params.error_max,
      65           0 :       "\ninput layer size: ", network_params.input_size_x, "x",
      66           0 :       network_params.input_size_y,
      67           0 :       "\nhidden layer size: ", network_params.hidden_size_x, "x",
      68           0 :       network_params.hidden_size_y,
      69           0 :       "\noutput layer size: ", network_params.output_size_x, "x",
      70           0 :       network_params.output_size_y,
      71           0 :       "\nhidden layers: ", network_params.hiddens_count,
      72             :       "\nhidden activation function: ",
      73           0 :       getActivationStr(network_params.hidden_activation_function),
      74           0 :       "\nhidden activation alpha: ", network_params.hidden_activation_alpha,
      75             :       "\noutput activation function: ",
      76           0 :       getActivationStr(network_params.output_activation_function),
      77           0 :       "\noutput activation alpha: ", network_params.output_activation_alpha,
      78           0 :       "\ninput reduce factor: ", app_params.training_reduce_factor,
      79           0 :       "\noutput scale: ", app_params.output_scale,
      80           0 :       "\nimage split: ", app_params.image_split,
      81           0 :       "\nimages random loading: ", app_params.random_loading ? "true" : "false",
      82           0 :       "\nimages bulk loading: ", app_params.bulk_loading ? "true" : "false",
      83             :       "\nimages padding enabled: ",
      84           0 :       app_params.enable_padding ? "true" : "false",
      85             :       "\nCPU parallelism enabled: ",
      86           0 :       app_params.enable_parallel ? "true" : "false",
      87           0 :       "\nGPU Vulkan enabled: ", app_params.enable_vulkan ? "true" : "false",
      88           0 :       "\nverbose logs enabled: ", app_params.verbose ? "true" : "false",
      89           0 :       "\ndebug logs enabled: ", app_params.verbose_debug ? "true" : "false",
      90           0 :       "\ndebug vulkan enabled: ", app_params.vulkan_debug ? "true" : "false");
      91             : 
      92           0 :   return *this;
      93             : }
      94             : 
      95           2 : void Manager::run() {
      96             :   // Some checking
      97           2 :   if (app_params.image_split == NO_IMAGE_SPLIT) {
      98           1 :     app_params.image_split = 1; // same as no split
      99             :   }
     100             : 
     101             :   // Enabling GPU Vulkan
     102           2 :   bool wasParallel = app_params.enable_parallel;
     103           2 :   if (app_params.enable_vulkan) {
     104           0 :     SimpleLogger::LOG_INFO("Enabling Vulkan...");
     105           0 :     if (wasParallel) {
     106           0 :       app_params.enable_parallel = false;
     107             :     }
     108             :     try {
     109           0 :       VulkanController::getInstance().initialize();
     110           0 :     } catch (std::exception &ex) {
     111           0 :       SimpleLogger::LOG_ERROR("Enabling Vulkan error: ", ex.what());
     112           0 :       app_params.enable_vulkan = false;
     113           0 :       SimpleLogger::LOG_INFO("Vulkan GPU acceleration disabled.");
     114           0 :       if (wasParallel) {
     115           0 :         app_params.enable_parallel = true;
     116             :       }
     117           0 :     }
     118             :   }
     119             : 
     120             :   // Enabling CPU parallelism
     121           2 :   if (app_params.enable_parallel) {
     122           2 :     SimpleLogger::LOG_INFO("Enabling CPU parallelism...");
     123             :     try {
     124           2 :       cv::setNumThreads(std::thread::hardware_concurrency());
     125           0 :     } catch (std::exception &ex) {
     126           0 :       SimpleLogger::LOG_ERROR("Enabling CPU parallelism error: ", ex.what());
     127           0 :       cv::setNumThreads(0);
     128           0 :       app_params.enable_parallel = false;
     129           0 :       SimpleLogger::LOG_INFO("CPU threads parallelism disabled.");
     130           0 :     }
     131             :   } else {
     132           0 :     cv::setNumThreads(0);
     133             :   }
     134             : 
     135             :   // Run with visitor
     136             :   try {
     137           2 :     switch (app_params.run_mode) {
     138           2 :     case ERunMode::Training:
     139           2 :       runWithVisitor(runnerVisitorFactory_.getTrainingVisitor());
     140           2 :       break;
     141           0 :     case ERunMode::Enhancer:
     142           0 :       runWithVisitor(runnerVisitorFactory_.getEnhancerVisitor());
     143           0 :       break;
     144           0 :     default:
     145           0 :       break;
     146             :     }
     147           0 :   } catch (std::exception &ex) {
     148           0 :     SimpleLogger::LOG_ERROR("Error: ", ex.what());
     149           0 :   }
     150           2 : }
     151             : 
     152           3 : void Manager::runWithVisitor(const RunnerVisitor &visitor) { visitor.visit(); }

Generated by: LCOV version 1.16