Line data Source code
1 : #include "RunnerTrainingVisitor.h" 2 : #include "Manager.h" 3 : #include "SimpleLogger.h" 4 : 5 : using namespace sipai; 6 : 7 : volatile std::sig_atomic_t stopTraining = false; 8 : volatile std::sig_atomic_t stopTrainingNow = false; 9 : 10 0 : void signalHandler(int signal) { 11 0 : if (signal == SIGINT) { 12 0 : if (!stopTraining) { 13 0 : SimpleLogger::LOG_INFO( 14 : "Received interrupt signal (CTRL+C). Training will stop after " 15 : "the current epoch. Press another time on (CTRL+C) to force exit " 16 : "immediately without saving."); 17 0 : stopTraining = true; 18 : } else { 19 0 : SimpleLogger::LOG_INFO("Received another interrupt signal (CTRL+C). " 20 : "Forcing quitting immedialty without saving " 21 : "progress. Please wait for cleaning..."); 22 0 : stopTrainingNow = true; 23 : } 24 : } 25 0 : } 26 : 27 9 : bool RunnerTrainingVisitor::shouldContinueTraining( 28 : int epoch, size_t epochsWithoutImprovement, 29 : const AppParams &appParams) const { 30 9 : bool improvementCondition = 31 9 : epochsWithoutImprovement < appParams.max_epochs_without_improvement; 32 9 : bool epochCondition = 33 9 : (appParams.max_epochs == NO_MAX_EPOCHS) || (epoch < (int)appParams.max_epochs); 34 : 35 9 : return improvementCondition && epochCondition; 36 : } 37 : 38 0 : void RunnerTrainingVisitor::adaptLearningRate( 39 : float &learningRate, const float &validationLoss, 40 : const float &previousValidationLoss, 41 : const bool &enable_adaptive_increase) const { 42 0 : std::scoped_lock<std::mutex> lock(threadMutex_); 43 : 44 0 : const auto &manager = Manager::getConstInstance(); 45 0 : const auto &appParams = manager.app_params; 46 0 : const auto &learning_rate_min = appParams.learning_rate_min; 47 0 : const auto &learning_rate_max = appParams.learning_rate_max; 48 0 : const auto &learning_rate_adaptive_factor = 49 : manager.network_params.adaptive_learning_rate_factor; 50 : 51 0 : const float previous_learning_rate = learningRate; 52 0 : const float increase_slower_factor = 1.5f; 53 : 54 0 : if (validationLoss >= previousValidationLoss && 55 0 : learningRate > learning_rate_min) { 56 : // this will decrease learningRate (0.001 * 0.5 = 0.0005) 57 0 : learningRate *= learning_rate_adaptive_factor; 58 0 : } else if (enable_adaptive_increase && 59 0 : validationLoss < previousValidationLoss && 60 0 : learningRate < learning_rate_max) { 61 : // this will increase learningRate but slower (0.001 / (0.5 * 1.5) = 0.0013) 62 0 : learningRate /= (learning_rate_adaptive_factor * increase_slower_factor); 63 : } 64 0 : learningRate = std::clamp(learningRate, learning_rate_min, learning_rate_max); 65 : 66 0 : if (appParams.verbose && learningRate != previous_learning_rate) { 67 0 : const auto current_precision = SimpleLogger::getInstance().getPrecision(); 68 0 : SimpleLogger::getInstance() 69 0 : .setPrecision(6) 70 0 : .info("Learning rate ", previous_learning_rate, " adjusted to ", 71 : learningRate) 72 0 : .setPrecision(current_precision); 73 : } 74 0 : } 75 : 76 6 : void RunnerTrainingVisitor::logTrainingProgress( 77 : const int &epoch, const float &trainingLoss, const float &validationLoss, 78 : const float &previousTrainingLoss, 79 : const float &previousValidationLoss) const { 80 6 : std::stringstream delta; 81 6 : if (epoch > 0) { 82 3 : float dtl = trainingLoss - previousTrainingLoss; 83 3 : float dvl = validationLoss - previousValidationLoss; 84 3 : delta.precision(2); 85 3 : delta << std::fixed << " [" << (dtl > 0 ? "+" : "") << dtl * 100.0f << "%"; 86 3 : delta << std::fixed << "," << (dvl > 0 ? "+" : "") << dvl * 100.0f << "%]"; 87 : } 88 6 : SimpleLogger::LOG_INFO( 89 0 : "Epoch: ", epoch + 1, ", Train Loss: ", trainingLoss * 100.0f, 90 6 : "%, Validation Loss: ", validationLoss * 100.0f, "%", delta.str()); 91 6 : } 92 : 93 3 : void RunnerTrainingVisitor::saveNetwork(bool &hasLastEpochBeenSaved) const { 94 3 : std::scoped_lock<std::mutex> lock(threadMutex_); 95 : try { 96 3 : if (!hasLastEpochBeenSaved) { 97 3 : Manager::getInstance().exportNetwork(); 98 3 : hasLastEpochBeenSaved = true; 99 : } 100 0 : } catch (std::exception &ex) { 101 0 : SimpleLogger::LOG_INFO("Saving the neural network error: ", ex.what()); 102 0 : } 103 3 : }