LCOV - code coverage report
Current view: top level - src - RunnerTrainingOpenCVVisitor.cpp (source / functions) Hit Total Coverage
Test: lcov.info Lines: 120 139 86.3 %
Date: 2024-12-28 17:36:05 Functions: 3 3 100.0 %

          Line data    Source code
       1             : #include "RunnerTrainingOpenCVVisitor.h"
       2             : #include "AppParams.h"
       3             : #include "Common.h"
       4             : #include "ImageHelper.h"
       5             : #include "Manager.h"
       6             : #include "SimpleLogger.h"
       7             : #include "TrainingDataFactory.h"
       8             : #include "exception/RunnerVisitorException.h"
       9             : #include <cstddef>
      10             : #include <exception>
      11             : #include <memory>
      12             : #include <sstream>
      13             : #include <string>
      14             : #include <utility>
      15             : 
      16             : using namespace sipai;
      17             : 
      18           5 : void RunnerTrainingOpenCVVisitor::visit() const {
      19           5 :   SimpleLogger::LOG_INFO(
      20             :       "Starting training monitored, press (CTRL+C) to stop at anytime...");
      21             : 
      22           5 :   auto &manager = Manager::getInstance();
      23           5 :   if (!manager.network) {
      24           1 :     throw RunnerVisitorException("No neural network. Aborting.");
      25             :   }
      26             : 
      27           4 :   const auto &appParams = manager.app_params;
      28           4 :   auto &learning_rate = manager.network_params.learning_rate;
      29           4 :   const auto &adaptive_learning_rate =
      30             :       manager.network_params.adaptive_learning_rate;
      31           4 :   const auto &enable_adaptive_increase =
      32             :       manager.network_params.enable_adaptive_increase;
      33           4 :   auto &trainingDataFactory = TrainingDataFactory::getInstance();
      34             : 
      35           4 :   const auto start{std::chrono::steady_clock::now()}; // starting timer
      36           4 :   SimpleLogger::getInstance().setPrecision(2);
      37             : 
      38             :   try {
      39             :     // Load training data
      40           4 :     if (appParams.verbose_debug) {
      41           1 :       SimpleLogger::LOG_DEBUG("Loading images data...");
      42             :     }
      43           4 :     trainingDataFactory.loadData();
      44           6 :     if (!trainingDataFactory.isLoaded() ||
      45           3 :         trainingDataFactory.getSize(TrainingPhase::Training) == 0) {
      46           0 :       throw RunnerVisitorException("No training data found. Aborting.");
      47             :     }
      48             : 
      49             :     // Reset the stopTraining flag
      50           3 :     stopTraining = false;
      51           3 :     stopTrainingNow = false;
      52             : 
      53             :     // Set up signal handler
      54           3 :     std::signal(SIGINT, signalHandler);
      55             : 
      56           3 :     float trainingLoss = 0.0f;
      57           3 :     float validationLoss = 0.0f;
      58           3 :     float previousTrainingLoss = 0.0f;
      59           3 :     float previousValidationLoss = 0.0f;
      60           3 :     int epoch = 0;
      61           3 :     int epochsWithoutImprovement = 0;
      62           3 :     bool hasLastEpochBeenSaved = false;
      63             : 
      64          18 :     while (!stopTraining && !stopTrainingNow &&
      65           9 :            shouldContinueTraining(epoch, epochsWithoutImprovement, appParams)) {
      66             : 
      67             :       // if Adaptive Learning Rate enabled, adapt the learning rate.
      68           6 :       if (adaptive_learning_rate && epoch > 1) {
      69           0 :         adaptLearningRate(learning_rate, validationLoss, previousValidationLoss,
      70             :                           enable_adaptive_increase);
      71             :       }
      72             : 
      73           6 :       TrainingDataFactory::getInstance().shuffle(TrainingPhase::Training);
      74             : 
      75           6 :       previousTrainingLoss = trainingLoss;
      76           6 :       previousValidationLoss = validationLoss;
      77             : 
      78           6 :       trainingLoss = training(epoch, TrainingPhase::Training);
      79           6 :       if (stopTrainingNow) {
      80           0 :         break;
      81             :       }
      82             : 
      83           6 :       validationLoss = training(epoch, TrainingPhase::Validation);
      84           6 :       if (stopTrainingNow) {
      85           0 :         break;
      86             :       }
      87             : 
      88           6 :       logTrainingProgress(epoch, trainingLoss, validationLoss,
      89             :                           previousTrainingLoss, previousValidationLoss);
      90             : 
      91             :       // check the epochs without improvement counter
      92           6 :       if (epoch > 0) {
      93           3 :         if (validationLoss < previousValidationLoss ||
      94           1 :             trainingLoss < previousTrainingLoss) {
      95           3 :           epochsWithoutImprovement = 0;
      96             :         } else {
      97           0 :           epochsWithoutImprovement++;
      98             :         }
      99             :       }
     100             : 
     101           6 :       hasLastEpochBeenSaved = false;
     102           6 :       epoch++;
     103             : 
     104           6 :       if (!stopTrainingNow && (epoch % appParams.epoch_autosave == 0)) {
     105             :         // TODO: an option to save the best validation rate network (if not
     106             :         // saved)
     107           0 :         saveNetwork(hasLastEpochBeenSaved);
     108             :       }
     109             :     }
     110             : 
     111           3 :     SimpleLogger::LOG_INFO("Exiting training...");
     112           3 :     if (!stopTrainingNow) {
     113           3 :       saveNetwork(hasLastEpochBeenSaved);
     114             :     }
     115             :     // Show elapsed time
     116           3 :     const auto end{std::chrono::steady_clock::now()};
     117             :     const std::chrono::duration elapsed_seconds =
     118           3 :         std::chrono::duration_cast<std::chrono::seconds>(end - start);
     119           3 :     const auto &hms = Common::getHMSfromS(elapsed_seconds.count());
     120           3 :     SimpleLogger::LOG_INFO("Elapsed time: ", hms[0], "h ", hms[1], "m ", hms[2],
     121             :                            "s");
     122             : 
     123           1 :   } catch (std::exception &ex) {
     124           1 :     throw RunnerVisitorException(ex.what());
     125           1 :   }
     126           3 : }
     127             : 
     128          12 : float RunnerTrainingOpenCVVisitor::training(size_t epoch,
     129             :                                             TrainingPhase phase) const {
     130             : 
     131             :   // Initialize the total loss to 0
     132          12 :   float loss = 0.0f;
     133          12 :   size_t lossComputed = 0;
     134          12 :   size_t counter = 0;
     135          12 :   bool isLossFrequency = false;
     136          12 :   auto &trainingDataFactory = TrainingDataFactory::getInstance();
     137          12 :   trainingDataFactory.resetCounters();
     138          12 :   const auto &app_params = Manager::getConstInstance().app_params;
     139             : 
     140             :   // Compute the frequency at which the loss should be computed
     141          12 :   size_t lossFrequency = std::max(
     142          12 :       static_cast<size_t>(std::sqrt(trainingDataFactory.getSize(phase))),
     143          24 :       (size_t)1);
     144             : 
     145             :   // Loop over all images
     146          72 :   while (auto data = trainingDataFactory.next(phase)) {
     147          60 :     if (stopTrainingNow) {
     148           0 :       break;
     149             :     }
     150          60 :     counter++;
     151          60 :     if (app_params.verbose) {
     152          20 :       SimpleLogger::LOG_INFO(
     153          20 :           "Epoch: ", epoch + 1, ", ", Common::getTrainingPhaseStr(phase), ": ",
     154          40 :           "image ", counter, "/", trainingDataFactory.getSize(phase), "...");
     155             :     }
     156             : 
     157             :     // Check if the loss should be computed for the current image
     158          60 :     isLossFrequency = counter % lossFrequency == 0 ? true : false;
     159             : 
     160             :     // Compute the image parts loss
     161          60 :     float imageLoss = _training(epoch, data, phase, isLossFrequency);
     162          60 :     if (stopTrainingNow) {
     163           0 :       break;
     164             :     }
     165             : 
     166             :     // If the loss was computed for the current image, add the average loss for
     167             :     // the current image to the total loss
     168          60 :     if (isLossFrequency) {
     169          36 :       loss += imageLoss;
     170          36 :       lossComputed++;
     171             :     }
     172         132 :   }
     173             : 
     174             :   // Return the average loss over all images for which the loss was computed
     175          12 :   if (lossComputed == 0) {
     176           0 :     return 0;
     177             :   }
     178          12 :   return (loss / static_cast<float>(lossComputed));
     179             : }
     180             : 
     181          60 : float RunnerTrainingOpenCVVisitor::_training(size_t epoch,
     182             :                                              std::shared_ptr<Data> data,
     183             :                                              TrainingPhase phase,
     184             :                                              bool isLossFrequency) const {
     185          60 :   if (data->img_input.size() != data->img_target.size()) {
     186           0 :     throw ImageHelperException(
     187           0 :         "internal exception: input and target parts have different sizes.");
     188             :   }
     189             : 
     190          60 :   auto &manager = Manager::getInstance();
     191          60 :   const auto &error_min = manager.network_params.error_min;
     192          60 :   const auto &error_max = manager.network_params.error_max;
     193             : 
     194             :   // Initialize the loss for the current image to 0
     195          60 :   float partsLoss = 0.0f;
     196             : 
     197             :   // Initialize a counter to keep track of the number of parts for which the
     198             :   // loss is computed
     199          60 :   size_t partsLossComputed = 0;
     200             : 
     201             :   // Loop over all parts of the current image
     202         120 :   for (size_t i = 0; i < data->img_input.size(); i++) {
     203          60 :     if (stopTrainingNow) {
     204           0 :       break;
     205             :     }
     206             : 
     207             :     // Get the input and target parts
     208          60 :     const auto &inputPart = data->img_input.at(i);
     209          60 :     const auto &targetPart = data->img_target.at(i);
     210             : 
     211             :     // Perform forward propagation
     212          60 :     if (manager.app_params.verbose_debug) {
     213           0 :       SimpleLogger::LOG_DEBUG("forward propagation part ", i + 1, "/",
     214          20 :                               data->img_input.size(), "...");
     215             :     }
     216             :     const auto &outputData =
     217          60 :         manager.network->forwardPropagation(inputPart->data);
     218             : 
     219          60 :     if (stopTrainingNow) {
     220           0 :       break;
     221             :     }
     222             : 
     223             :     // If the loss should be computed for the current image, compute the loss
     224             :     // for the current part
     225          60 :     if (isLossFrequency) {
     226          36 :       if (manager.app_params.verbose_debug) {
     227          12 :         SimpleLogger::LOG_DEBUG("loss computation...");
     228             :       }
     229          36 :       float partLoss = imageHelper_.computeLoss(outputData, targetPart->data);
     230          36 :       if (manager.app_params.verbose_debug) {
     231          12 :         SimpleLogger::LOG_DEBUG("part loss: ", partLoss * 100.0f, "%");
     232             :       }
     233          36 :       partsLoss += partLoss;
     234          36 :       partsLossComputed++;
     235             :     }
     236          60 :     if (stopTrainingNow) {
     237           0 :       break;
     238             :     }
     239             : 
     240             :     // If backward propagation and weight update should be performed, perform
     241             :     // them
     242          60 :     if (phase == TrainingPhase::Training) {
     243          42 :       if (manager.app_params.verbose_debug) {
     244           0 :         SimpleLogger::LOG_DEBUG("backward propagation part ", i + 1, "/",
     245          14 :                                 data->img_input.size(), "...");
     246             :       }
     247          42 :       manager.network->backwardPropagation(targetPart->data, error_min,
     248             :                                            error_max);
     249          42 :       if (stopTrainingNow) {
     250           0 :         break;
     251             :       }
     252             : 
     253          42 :       if (manager.app_params.verbose_debug) {
     254           0 :         SimpleLogger::LOG_DEBUG("weights update part ", i + 1, "/",
     255          14 :                                 data->img_input.size(), "...");
     256             :       }
     257          42 :       manager.network->updateWeights(manager.network_params.learning_rate);
     258          42 :       if (stopTrainingNow) {
     259           0 :         break;
     260             :       }
     261             :     }
     262          60 :   }
     263             : 
     264          60 :   if (partsLossComputed == 0) {
     265          24 :     return 0;
     266             :   }
     267          36 :   return (partsLoss / static_cast<float>(partsLossComputed));
     268             : }

Generated by: LCOV version 1.16