Line data Source code
1 : #include "RunnerTrainingOpenCVVisitor.h"
2 : #include "AppParams.h"
3 : #include "Common.h"
4 : #include "ImageHelper.h"
5 : #include "Manager.h"
6 : #include "SimpleLogger.h"
7 : #include "TrainingDataFactory.h"
8 : #include "exception/RunnerVisitorException.h"
9 : #include <cstddef>
10 : #include <exception>
11 : #include <memory>
12 : #include <sstream>
13 : #include <string>
14 : #include <utility>
15 :
16 : using namespace sipai;
17 :
18 5 : void RunnerTrainingOpenCVVisitor::visit() const {
19 5 : SimpleLogger::LOG_INFO(
20 : "Starting training monitored, press (CTRL+C) to stop at anytime...");
21 :
22 5 : auto &manager = Manager::getInstance();
23 5 : if (!manager.network) {
24 1 : throw RunnerVisitorException("No neural network. Aborting.");
25 : }
26 :
27 4 : const auto &appParams = manager.app_params;
28 4 : auto &learning_rate = manager.network_params.learning_rate;
29 4 : const auto &adaptive_learning_rate =
30 : manager.network_params.adaptive_learning_rate;
31 4 : const auto &enable_adaptive_increase =
32 : manager.network_params.enable_adaptive_increase;
33 4 : auto &trainingDataFactory = TrainingDataFactory::getInstance();
34 :
35 4 : const auto start{std::chrono::steady_clock::now()}; // starting timer
36 4 : SimpleLogger::getInstance().setPrecision(2);
37 :
38 : try {
39 : // Load training data
40 4 : if (appParams.verbose_debug) {
41 1 : SimpleLogger::LOG_DEBUG("Loading images data...");
42 : }
43 4 : trainingDataFactory.loadData();
44 6 : if (!trainingDataFactory.isLoaded() ||
45 3 : trainingDataFactory.getSize(TrainingPhase::Training) == 0) {
46 0 : throw RunnerVisitorException("No training data found. Aborting.");
47 : }
48 :
49 : // Reset the stopTraining flag
50 3 : stopTraining = false;
51 3 : stopTrainingNow = false;
52 :
53 : // Set up signal handler
54 3 : std::signal(SIGINT, signalHandler);
55 :
56 3 : float trainingLoss = 0.0f;
57 3 : float validationLoss = 0.0f;
58 3 : float previousTrainingLoss = 0.0f;
59 3 : float previousValidationLoss = 0.0f;
60 3 : int epoch = 0;
61 3 : int epochsWithoutImprovement = 0;
62 3 : bool hasLastEpochBeenSaved = false;
63 :
64 18 : while (!stopTraining && !stopTrainingNow &&
65 9 : shouldContinueTraining(epoch, epochsWithoutImprovement, appParams)) {
66 :
67 : // if Adaptive Learning Rate enabled, adapt the learning rate.
68 6 : if (adaptive_learning_rate && epoch > 1) {
69 0 : adaptLearningRate(learning_rate, validationLoss, previousValidationLoss,
70 : enable_adaptive_increase);
71 : }
72 :
73 6 : TrainingDataFactory::getInstance().shuffle(TrainingPhase::Training);
74 :
75 6 : previousTrainingLoss = trainingLoss;
76 6 : previousValidationLoss = validationLoss;
77 :
78 6 : trainingLoss = training(epoch, TrainingPhase::Training);
79 6 : if (stopTrainingNow) {
80 0 : break;
81 : }
82 :
83 6 : validationLoss = training(epoch, TrainingPhase::Validation);
84 6 : if (stopTrainingNow) {
85 0 : break;
86 : }
87 :
88 6 : logTrainingProgress(epoch, trainingLoss, validationLoss,
89 : previousTrainingLoss, previousValidationLoss);
90 :
91 : // check the epochs without improvement counter
92 6 : if (epoch > 0) {
93 3 : if (validationLoss < previousValidationLoss ||
94 1 : trainingLoss < previousTrainingLoss) {
95 3 : epochsWithoutImprovement = 0;
96 : } else {
97 0 : epochsWithoutImprovement++;
98 : }
99 : }
100 :
101 6 : hasLastEpochBeenSaved = false;
102 6 : epoch++;
103 :
104 6 : if (!stopTrainingNow && (epoch % appParams.epoch_autosave == 0)) {
105 : // TODO: an option to save the best validation rate network (if not
106 : // saved)
107 0 : saveNetwork(hasLastEpochBeenSaved);
108 : }
109 : }
110 :
111 3 : SimpleLogger::LOG_INFO("Exiting training...");
112 3 : if (!stopTrainingNow) {
113 3 : saveNetwork(hasLastEpochBeenSaved);
114 : }
115 : // Show elapsed time
116 3 : const auto end{std::chrono::steady_clock::now()};
117 : const std::chrono::duration elapsed_seconds =
118 3 : std::chrono::duration_cast<std::chrono::seconds>(end - start);
119 3 : const auto &hms = Common::getHMSfromS(elapsed_seconds.count());
120 3 : SimpleLogger::LOG_INFO("Elapsed time: ", hms[0], "h ", hms[1], "m ", hms[2],
121 : "s");
122 :
123 1 : } catch (std::exception &ex) {
124 1 : throw RunnerVisitorException(ex.what());
125 1 : }
126 3 : }
127 :
128 12 : float RunnerTrainingOpenCVVisitor::training(size_t epoch,
129 : TrainingPhase phase) const {
130 :
131 : // Initialize the total loss to 0
132 12 : float loss = 0.0f;
133 12 : size_t lossComputed = 0;
134 12 : size_t counter = 0;
135 12 : bool isLossFrequency = false;
136 12 : auto &trainingDataFactory = TrainingDataFactory::getInstance();
137 12 : trainingDataFactory.resetCounters();
138 12 : const auto &app_params = Manager::getConstInstance().app_params;
139 :
140 : // Compute the frequency at which the loss should be computed
141 12 : size_t lossFrequency = std::max(
142 12 : static_cast<size_t>(std::sqrt(trainingDataFactory.getSize(phase))),
143 24 : (size_t)1);
144 :
145 : // Loop over all images
146 72 : while (auto data = trainingDataFactory.next(phase)) {
147 60 : if (stopTrainingNow) {
148 0 : break;
149 : }
150 60 : counter++;
151 60 : if (app_params.verbose) {
152 20 : SimpleLogger::LOG_INFO(
153 20 : "Epoch: ", epoch + 1, ", ", Common::getTrainingPhaseStr(phase), ": ",
154 40 : "image ", counter, "/", trainingDataFactory.getSize(phase), "...");
155 : }
156 :
157 : // Check if the loss should be computed for the current image
158 60 : isLossFrequency = counter % lossFrequency == 0 ? true : false;
159 :
160 : // Compute the image parts loss
161 60 : float imageLoss = _training(epoch, data, phase, isLossFrequency);
162 60 : if (stopTrainingNow) {
163 0 : break;
164 : }
165 :
166 : // If the loss was computed for the current image, add the average loss for
167 : // the current image to the total loss
168 60 : if (isLossFrequency) {
169 36 : loss += imageLoss;
170 36 : lossComputed++;
171 : }
172 132 : }
173 :
174 : // Return the average loss over all images for which the loss was computed
175 12 : if (lossComputed == 0) {
176 0 : return 0;
177 : }
178 12 : return (loss / static_cast<float>(lossComputed));
179 : }
180 :
181 60 : float RunnerTrainingOpenCVVisitor::_training(size_t epoch,
182 : std::shared_ptr<Data> data,
183 : TrainingPhase phase,
184 : bool isLossFrequency) const {
185 60 : if (data->img_input.size() != data->img_target.size()) {
186 0 : throw ImageHelperException(
187 0 : "internal exception: input and target parts have different sizes.");
188 : }
189 :
190 60 : auto &manager = Manager::getInstance();
191 60 : const auto &error_min = manager.network_params.error_min;
192 60 : const auto &error_max = manager.network_params.error_max;
193 :
194 : // Initialize the loss for the current image to 0
195 60 : float partsLoss = 0.0f;
196 :
197 : // Initialize a counter to keep track of the number of parts for which the
198 : // loss is computed
199 60 : size_t partsLossComputed = 0;
200 :
201 : // Loop over all parts of the current image
202 120 : for (size_t i = 0; i < data->img_input.size(); i++) {
203 60 : if (stopTrainingNow) {
204 0 : break;
205 : }
206 :
207 : // Get the input and target parts
208 60 : const auto &inputPart = data->img_input.at(i);
209 60 : const auto &targetPart = data->img_target.at(i);
210 :
211 : // Perform forward propagation
212 60 : if (manager.app_params.verbose_debug) {
213 0 : SimpleLogger::LOG_DEBUG("forward propagation part ", i + 1, "/",
214 20 : data->img_input.size(), "...");
215 : }
216 : const auto &outputData =
217 60 : manager.network->forwardPropagation(inputPart->data);
218 :
219 60 : if (stopTrainingNow) {
220 0 : break;
221 : }
222 :
223 : // If the loss should be computed for the current image, compute the loss
224 : // for the current part
225 60 : if (isLossFrequency) {
226 36 : if (manager.app_params.verbose_debug) {
227 12 : SimpleLogger::LOG_DEBUG("loss computation...");
228 : }
229 36 : float partLoss = imageHelper_.computeLoss(outputData, targetPart->data);
230 36 : if (manager.app_params.verbose_debug) {
231 12 : SimpleLogger::LOG_DEBUG("part loss: ", partLoss * 100.0f, "%");
232 : }
233 36 : partsLoss += partLoss;
234 36 : partsLossComputed++;
235 : }
236 60 : if (stopTrainingNow) {
237 0 : break;
238 : }
239 :
240 : // If backward propagation and weight update should be performed, perform
241 : // them
242 60 : if (phase == TrainingPhase::Training) {
243 42 : if (manager.app_params.verbose_debug) {
244 0 : SimpleLogger::LOG_DEBUG("backward propagation part ", i + 1, "/",
245 14 : data->img_input.size(), "...");
246 : }
247 42 : manager.network->backwardPropagation(targetPart->data, error_min,
248 : error_max);
249 42 : if (stopTrainingNow) {
250 0 : break;
251 : }
252 :
253 42 : if (manager.app_params.verbose_debug) {
254 0 : SimpleLogger::LOG_DEBUG("weights update part ", i + 1, "/",
255 14 : data->img_input.size(), "...");
256 : }
257 42 : manager.network->updateWeights(manager.network_params.learning_rate);
258 42 : if (stopTrainingNow) {
259 0 : break;
260 : }
261 : }
262 60 : }
263 :
264 60 : if (partsLossComputed == 0) {
265 24 : return 0;
266 : }
267 36 : return (partsLoss / static_cast<float>(partsLossComputed));
268 : }
|