Line data Source code
1 : /** 2 : * @file RunnerTrainingVisitor.h 3 : * @author Damien Balima (www.dams-labs.net) 4 : * @brief RunnerTrainingVisitor 5 : * @date 2024-05-15 6 : * 7 : * @copyright Damien Balima (c) CC-BY-NC-SA-4.0 2024 8 : * 9 : */ 10 : #pragma once 11 : 12 : #include "RunnerVisitor.h" 13 : #include "SimpleLogger.h" 14 : #include <csignal> 15 : #include <mutex> 16 : 17 : extern volatile std::sig_atomic_t stopTraining; 18 : extern volatile std::sig_atomic_t stopTrainingNow; 19 : 20 : extern void signalHandler(int signal); 21 : 22 : namespace sipai { 23 : class RunnerTrainingVisitor : public RunnerVisitor { 24 : public: 25 3 : virtual ~RunnerTrainingVisitor() = default; 26 : 27 : /** 28 : * @brief Training and validation, including loss compute of all images. 29 : * 30 : * @param epoch the current epoch 31 : * @param phase indicate if it is training or validation phase 32 : * @return float 33 : */ 34 : virtual float training(size_t epoch, TrainingPhase phase) const = 0; 35 : 36 : /** 37 : * @brief Determines whether the training should continue based on the 38 : * provided conditions. 39 : * 40 : * @param epoch The current epoch number. 41 : * @param epochsWithoutImprovement The number of epochs without improvement in 42 : * validation loss. 43 : * @param appParams The application parameters containing the maximum number 44 : * of epochs and maximum epochs without improvement. 45 : * @return True if the training should continue, false otherwise. 46 : */ 47 : virtual bool shouldContinueTraining(int epoch, 48 : size_t epochsWithoutImprovement, 49 : const AppParams &appParams) const; 50 : 51 : /** 52 : * @brief Adaptive Learning Rate 53 : * 54 : * @param learningRate 55 : * @param validationLoss 56 : * @param previousValidationLoss 57 : * @param enable_adaptive_increase 58 : */ 59 : virtual void adaptLearningRate(float &learningRate, 60 : const float &validationLoss, 61 : const float &previousValidationLoss, 62 : const bool &enable_adaptive_increase) const; 63 : 64 : /** 65 : * @brief Logs the training progress for the current epoch. 66 : * 67 : * @param epoch The current epoch number. 68 : * @param trainingLoss The average training loss for the current epoch. 69 : * @param validationLoss The average validation loss for the current epoch. 70 : * @param previousTrainingLoss 71 : * @param previousValidationLoss 72 : */ 73 : virtual void logTrainingProgress(const int &epoch, const float &trainingLoss, 74 : const float &validationLoss, 75 : const float &previousTrainingLoss, 76 : const float &previousValidationLoss) const; 77 : 78 : /** 79 : * @brief Save and export the neural network 80 : * 81 : * @param hasLastEpochBeenSaved 82 : */ 83 : virtual void saveNetwork(bool &hasLastEpochBeenSaved) const; 84 : 85 : protected: 86 : mutable std::mutex threadMutex_; 87 : }; 88 : } // namespace sipai