LCOV - code coverage report
Current view: top level - include - RunnerTrainingVisitor.h (source / functions) Hit Total Coverage
Test: lcov.info Lines: 1 1 100.0 %
Date: 2024-12-28 17:36:05 Functions: 1 2 50.0 %

          Line data    Source code
       1             : /**
       2             :  * @file RunnerTrainingVisitor.h
       3             :  * @author Damien Balima (www.dams-labs.net)
       4             :  * @brief RunnerTrainingVisitor
       5             :  * @date 2024-05-15
       6             :  *
       7             :  * @copyright Damien Balima (c) CC-BY-NC-SA-4.0 2024
       8             :  *
       9             :  */
      10             : #pragma once
      11             : 
      12             : #include "RunnerVisitor.h"
      13             : #include "SimpleLogger.h"
      14             : #include <csignal>
      15             : #include <mutex>
      16             : 
      17             : extern volatile std::sig_atomic_t stopTraining;
      18             : extern volatile std::sig_atomic_t stopTrainingNow;
      19             : 
      20             : extern void signalHandler(int signal);
      21             : 
      22             : namespace sipai {
      23             : class RunnerTrainingVisitor : public RunnerVisitor {
      24             : public:
      25           3 :   virtual ~RunnerTrainingVisitor() = default;
      26             : 
      27             :   /**
      28             :    * @brief Training and validation, including loss compute of all images.
      29             :    *
      30             :    * @param epoch the current epoch
      31             :    * @param phase indicate if it is training or validation phase
      32             :    * @return float
      33             :    */
      34             :   virtual float training(size_t epoch, TrainingPhase phase) const = 0;
      35             : 
      36             :   /**
      37             :    * @brief Determines whether the training should continue based on the
      38             :    * provided conditions.
      39             :    *
      40             :    * @param epoch The current epoch number.
      41             :    * @param epochsWithoutImprovement The number of epochs without improvement in
      42             :    * validation loss.
      43             :    * @param appParams The application parameters containing the maximum number
      44             :    * of epochs and maximum epochs without improvement.
      45             :    * @return True if the training should continue, false otherwise.
      46             :    */
      47             :   virtual bool shouldContinueTraining(int epoch,
      48             :                                       size_t epochsWithoutImprovement,
      49             :                                       const AppParams &appParams) const;
      50             : 
      51             :   /**
      52             :    * @brief Adaptive Learning Rate
      53             :    *
      54             :    * @param learningRate
      55             :    * @param validationLoss
      56             :    * @param previousValidationLoss
      57             :    * @param enable_adaptive_increase
      58             :    */
      59             :   virtual void adaptLearningRate(float &learningRate,
      60             :                                  const float &validationLoss,
      61             :                                  const float &previousValidationLoss,
      62             :                                  const bool &enable_adaptive_increase) const;
      63             : 
      64             :   /**
      65             :    * @brief Logs the training progress for the current epoch.
      66             :    *
      67             :    * @param epoch The current epoch number.
      68             :    * @param trainingLoss The average training loss for the current epoch.
      69             :    * @param validationLoss The average validation loss for the current epoch.
      70             :    * @param previousTrainingLoss
      71             :    * @param previousValidationLoss
      72             :    */
      73             :   virtual void logTrainingProgress(const int &epoch, const float &trainingLoss,
      74             :                                    const float &validationLoss,
      75             :                                    const float &previousTrainingLoss,
      76             :                                    const float &previousValidationLoss) const;
      77             : 
      78             :   /**
      79             :    * @brief Save and export the neural network
      80             :    *
      81             :    * @param hasLastEpochBeenSaved
      82             :    */
      83             :   virtual void saveNetwork(bool &hasLastEpochBeenSaved) const;
      84             : 
      85             : protected:
      86             :   mutable std::mutex threadMutex_;
      87             : };
      88             : } // namespace sipai

Generated by: LCOV version 1.16