LCOV - code coverage report
Current view: top level - src - RunnerTrainingVisitor.cpp (source / functions) Hit Total Coverage
Test: lcov.info Lines: 23 59 39.0 %
Date: 2024-12-28 17:36:05 Functions: 3 5 60.0 %

          Line data    Source code
       1             : #include "RunnerTrainingVisitor.h"
       2             : #include "Manager.h"
       3             : #include "SimpleLogger.h"
       4             : 
       5             : using namespace sipai;
       6             : 
       7             : volatile std::sig_atomic_t stopTraining = false;
       8             : volatile std::sig_atomic_t stopTrainingNow = false;
       9             : 
      10           0 : void signalHandler(int signal) {
      11           0 :   if (signal == SIGINT) {
      12           0 :     if (!stopTraining) {
      13           0 :       SimpleLogger::LOG_INFO(
      14             :           "Received interrupt signal (CTRL+C). Training will stop after "
      15             :           "the current epoch. Press another time on (CTRL+C) to force exit "
      16             :           "immediately without saving.");
      17           0 :       stopTraining = true;
      18             :     } else {
      19           0 :       SimpleLogger::LOG_INFO("Received another interrupt signal (CTRL+C). "
      20             :                              "Forcing quitting immedialty without saving "
      21             :                              "progress. Please wait for cleaning...");
      22           0 :       stopTrainingNow = true;
      23             :     }
      24             :   }
      25           0 : }
      26             : 
      27           9 : bool RunnerTrainingVisitor::shouldContinueTraining(
      28             :     int epoch, size_t epochsWithoutImprovement,
      29             :     const AppParams &appParams) const {
      30           9 :   bool improvementCondition =
      31           9 :       epochsWithoutImprovement < appParams.max_epochs_without_improvement;
      32           9 :   bool epochCondition =
      33           9 :       (appParams.max_epochs == NO_MAX_EPOCHS) || (epoch < (int)appParams.max_epochs);
      34             : 
      35           9 :   return improvementCondition && epochCondition;
      36             : }
      37             : 
      38           0 : void RunnerTrainingVisitor::adaptLearningRate(
      39             :     float &learningRate, const float &validationLoss,
      40             :     const float &previousValidationLoss,
      41             :     const bool &enable_adaptive_increase) const {
      42           0 :   std::scoped_lock<std::mutex> lock(threadMutex_);
      43             : 
      44           0 :   const auto &manager = Manager::getConstInstance();
      45           0 :   const auto &appParams = manager.app_params;
      46           0 :   const auto &learning_rate_min = appParams.learning_rate_min;
      47           0 :   const auto &learning_rate_max = appParams.learning_rate_max;
      48           0 :   const auto &learning_rate_adaptive_factor =
      49             :       manager.network_params.adaptive_learning_rate_factor;
      50             : 
      51           0 :   const float previous_learning_rate = learningRate;
      52           0 :   const float increase_slower_factor = 1.5f;
      53             : 
      54           0 :   if (validationLoss >= previousValidationLoss &&
      55           0 :       learningRate > learning_rate_min) {
      56             :     // this will decrease learningRate (0.001 * 0.5 = 0.0005)
      57           0 :     learningRate *= learning_rate_adaptive_factor;
      58           0 :   } else if (enable_adaptive_increase &&
      59           0 :              validationLoss < previousValidationLoss &&
      60           0 :              learningRate < learning_rate_max) {
      61             :     // this will increase learningRate but slower (0.001 / (0.5 * 1.5) = 0.0013)
      62           0 :     learningRate /= (learning_rate_adaptive_factor * increase_slower_factor);
      63             :   }
      64           0 :   learningRate = std::clamp(learningRate, learning_rate_min, learning_rate_max);
      65             : 
      66           0 :   if (appParams.verbose && learningRate != previous_learning_rate) {
      67           0 :     const auto current_precision = SimpleLogger::getInstance().getPrecision();
      68           0 :     SimpleLogger::getInstance()
      69           0 :         .setPrecision(6)
      70           0 :         .info("Learning rate ", previous_learning_rate, " adjusted to ",
      71             :               learningRate)
      72           0 :         .setPrecision(current_precision);
      73             :   }
      74           0 : }
      75             : 
      76           6 : void RunnerTrainingVisitor::logTrainingProgress(
      77             :     const int &epoch, const float &trainingLoss, const float &validationLoss,
      78             :     const float &previousTrainingLoss,
      79             :     const float &previousValidationLoss) const {
      80           6 :   std::stringstream delta;
      81           6 :   if (epoch > 0) {
      82           3 :     float dtl = trainingLoss - previousTrainingLoss;
      83           3 :     float dvl = validationLoss - previousValidationLoss;
      84           3 :     delta.precision(2);
      85           3 :     delta << std::fixed << " [" << (dtl > 0 ? "+" : "") << dtl * 100.0f << "%";
      86           3 :     delta << std::fixed << "," << (dvl > 0 ? "+" : "") << dvl * 100.0f << "%]";
      87             :   }
      88           6 :   SimpleLogger::LOG_INFO(
      89           0 :       "Epoch: ", epoch + 1, ", Train Loss: ", trainingLoss * 100.0f,
      90           6 :       "%, Validation Loss: ", validationLoss * 100.0f, "%", delta.str());
      91           6 : }
      92             : 
      93           3 : void RunnerTrainingVisitor::saveNetwork(bool &hasLastEpochBeenSaved) const {
      94           3 :   std::scoped_lock<std::mutex> lock(threadMutex_);
      95             :   try {
      96           3 :     if (!hasLastEpochBeenSaved) {
      97           3 :       Manager::getInstance().exportNetwork();
      98           3 :       hasLastEpochBeenSaved = true;
      99             :     }
     100           0 :   } catch (std::exception &ex) {
     101           0 :     SimpleLogger::LOG_INFO("Saving the neural network error: ", ex.what());
     102           0 :   }
     103           3 : }

Generated by: LCOV version 1.16