diff --git a/nntrainer/graph/graph_core.cpp b/nntrainer/graph/graph_core.cpp index b624e066e4..3eafbb9261 100644 --- a/nntrainer/graph/graph_core.cpp +++ b/nntrainer/graph/graph_core.cpp @@ -35,6 +35,10 @@ GraphCore::getSortedNode(unsigned int ith) const { return Sorted.at(ith); } +const unsigned int GraphCore::getSortedNodeIdx(const std::string &name) const { + return sorted_node_map.at(name); +} + void GraphCore::makeAdjacencyList( std::vector>> &adj) { /** initialize the adj list */ @@ -93,6 +97,11 @@ void GraphCore::topologicalSort() { if (Sorted.size() != node_list.size()) throw std::runtime_error("Internal error in topologicalSort"); + unsigned int idx = 0; + for (auto n : Sorted) { + sorted_node_map[n->getName()] = idx; + idx++; + } } const std::shared_ptr & diff --git a/nntrainer/graph/graph_core.h b/nntrainer/graph/graph_core.h index 83d3ce7c39..77aa63666a 100644 --- a/nntrainer/graph/graph_core.h +++ b/nntrainer/graph/graph_core.h @@ -91,6 +91,13 @@ class GraphCore { */ const std::shared_ptr &getSortedNode(unsigned int ith) const; + /** + * @brief getter of Sorted GraphNode index with name + * @param[in] layer name + * @ret index + */ + const unsigned int getSortedNodeIdx(const std::string &name) const; + /** * @brief getter of GraphNode with node name * @param[in] node name @@ -252,6 +259,7 @@ class GraphCore { std::vector> node_list; /**< Unordered Node List */ std::unordered_map node_map; /**< Unordered Node map */ + std::unordered_map sorted_node_map; /**< Unordered Node map */ std::vector> Sorted; /**< Ordered Node List */ bool sorted; /** if the node_list is sorted */ diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 297cd3e881..ac703e490b 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -337,7 +337,7 @@ void NetworkGraph::applyGradients( continue; } - if (rc.isGradientClipByGlobalNorm(i)) { + if (rc.isGradientClipByGlobalNorm(i) || rc.isMixedPrecision(i)) { /** * @note the weights whose gradient are to be clipped by global norm will * be clipped at once at the end of iteration and applied then. @@ -393,56 +393,100 @@ sharedConstTensors NetworkGraph::incremental_forwarding( return out; } -void NetworkGraph::backwarding( +bool NetworkGraph::backwarding( int iteration, - std::function, int)> &backwarding_op, - std::function &apply_grad_clip_op, - std::function stop_cb, void *userdata) const { + std::function, bool)> &forwarding_op, + std::function, int)> &backwarding_op, + std::function &lazy_apply_grad_op, + std::function stop_cb, void *userdata) { /** * last layer backwarding is run out of this loop */ auto iter_begin = getBackwardingBeginIter(); auto iter_end = getBackwardingEndIter(); + bool has_nan = false; /// there is no layer to train, so backwarding is essentially noop if (iter_begin == iter_end) { - return; + return true; } auto const &lptr_begin = (*iter_begin); + // graph_const_reverse_iterator + auto iter_ = iter_begin; if (lptr_begin->requireLabel() == false) throw std::runtime_error( "Error: last layer does not accept label, we can't train"); - for (auto iter = iter_begin; iter != iter_end && !stop_cb(userdata); iter++) { - auto &ln = *iter; + for (iter_ = iter_begin; iter_ != iter_end && !stop_cb(userdata); iter_++) { + auto &ln = *iter_; PROFILE_TIME_START(profile_keys.at(ln->getType())); - backwarding_op(ln, iteration); + has_nan = backwarding_op(ln, iteration); PROFILE_TIME_END(profile_keys.at(ln->getType())); + + if (has_nan) { + std::cout << "Gradient has NaN" << std::endl; + break; + } } - /** perform clipping of the gradients by global norm if any */ - if (clip_weights.empty()) - return; + if (has_nan) { + /** if has NaN + * 1. reset the loss scale. + * 2. run forwarding from cur_iter to cend() && !stop_cb(userdata); + * 3. return false --> run backwarding again; + */ + float scale = (*iter_)->getRunContext().getLossScale(); + float s = scale > 1.5f ? scale - 0.5f : 1.0f; + + resetLossScale(s); - /** calculate the global norm */ - Tensor global_norm_t( - TensorDim({1u, 1u, 1u, (unsigned int)clip_weights.size()})); - float *global_norm_data = global_norm_t.getData(); - for (unsigned int idx = 0; idx < clip_weights.size(); idx++) { - auto const &w = clip_weights[idx]; - global_norm_data[idx] = w->getGradientNorm(); + auto f_iter = cbegin() + graph.getSortedNodeIdx((*iter_)->getName()); + + for (auto iter = f_iter; iter != cend() && !stop_cb(userdata); iter++) { + auto &ln = *iter; + PROFILE_TIME_START(profile_keys.at(ln->getType())); + forwarding_op(*iter, true); + PROFILE_TIME_END(profile_keys.at(ln->getType())); + } + + return false; } - float global_norm = global_norm_t.l2norm(); - /** apply the gradient with the above global norm */ - for (auto w : clip_weights) { - w->clipGradientByGlobalNorm(global_norm); + + /** perform clipping of the gradients by global norm if any */ + if (lazy_weights.empty()) + return true; + + if (is_clip_grad) { + /** calculate the global norm */ + Tensor global_norm_t( + TensorDim({1u, 1u, 1u, (unsigned int)lazy_weights.size()})); + float *global_norm_data = global_norm_t.getData(); + for (unsigned int idx = 0; idx < lazy_weights.size(); idx++) { + auto const &w = lazy_weights[idx]; + global_norm_data[idx] = w->getGradientNorm(); + } + float global_norm = global_norm_t.l2norm(); + /** apply the gradient with the above global norm */ + for (auto w : lazy_weights) { + w->clipGradientByGlobalNorm(global_norm); + } } /** apply the gradient with the above global norm */ - for (auto w : clip_weights) { - apply_grad_clip_op(*w, iteration); + for (auto w : lazy_weights) { + lazy_apply_grad_op(*w, iteration); + } + nan_count++; + + if (nan_count > 10) { + float scale = (*iter_)->getRunContext().getLossScale(); + float s = scale + 2.0f; + resetLossScale(s); + nan_count = 0; } + + return true; } LayerNode *NetworkGraph::computeBackwardEnd() { @@ -1290,11 +1334,19 @@ int NetworkGraph::initialize(ExecutionMode mode, /** select weights which would require clipping of the gradients by global * norm if any */ - clip_weights = tensor_manager->getWeights([](const Weight *w) { + lazy_weights = tensor_manager->getWeights([](const Weight *w) { return w->hasGradient() && w->isGradientLastAccess() && - w->isGradientClipByGlobalNorm(); + (w->isGradientClipByGlobalNorm() || w->isMixedPrecision()); }); + is_clip_grad = false; + for (auto w : lazy_weights) { + if (w->isGradientClipByGlobalNorm()) { + is_clip_grad = true; + break; + } + } + return ML_ERROR_NONE; } @@ -1566,4 +1618,11 @@ void NetworkGraph::requestOptimizerVariable( } } +void NetworkGraph::resetLossScale(float scale) { + for (auto iter = cbegin(); iter != cend(); iter++) { + auto &ln = *iter; + ln->getRunContext().setLossScale(scale); + } +} + } /* namespace nntrainer */ diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index 5c9adf0363..22f14e1b73 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -51,7 +51,9 @@ class NetworkGraph { optimize_memory(true), exec_mode(ExecutionMode::TRAIN), tensor_format("NCHW"), - tensor_dtype(split("FP32-FP32", getRegex("\\-"))) {} + tensor_dtype(split("FP32-FP32", getRegex("\\-"))) { + nan_count = 0; + } /** * @brief Constructor of NeuralNetwork Graph Class @@ -73,7 +75,9 @@ class NetworkGraph { optimize_memory(true), exec_mode(ExecutionMode::TRAIN), tensor_format(tensor_format_), - tensor_dtype(split(tensor_dtype_, getRegex("\\-"))) {} + tensor_dtype(split(tensor_dtype_, getRegex("\\-"))) { + nan_count = 0; + } /** * @brief Destructor of the NeuralNetwork Graph class @@ -206,13 +210,14 @@ class NetworkGraph { * @param[in] backwarding_op operation for the backwarding * @param[in] apply_grad_clip_op operation for applying the clip gradients */ - void backwarding( + bool backwarding( int iteration, - std::function, int)> &backwarding_op, - std::function &apply_grad_clip_op, + std::function, bool)> &forwarding_op, + std::function, int)> &backwarding_op, + std::function &lazy_apply_grad_op, std::function stop_cb = [](void *user_data) { return false; }, - void *user_data = nullptr) const; + void *user_data = nullptr); /** * @brief get begin iterator for the graph @@ -444,6 +449,12 @@ class NetworkGraph { getLayerExecutionOrders(const std::shared_ptr &lnode); #endif // ENABLE_TEST + /** + * @brief reset the loss scale + * @param[in] scale + */ + void resetLossScale(float scale); + private: std::map sub_in_out; /** This is map to identify input and output layer name of subgraph */ @@ -480,7 +491,10 @@ class NetworkGraph { std::unordered_map profile_keys; /**< profile keys based on the layer type */ std::vector - clip_weights; /**< weights with global norm based clipping enabled */ + lazy_weights; /**< weights with global norm based clipping enabled */ + bool is_clip_grad; + + unsigned int nan_count; /** * @brief topological sort diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index fbbc9ecaff..5862e6af14 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -416,6 +416,17 @@ bool RunLayerContext::isGradientClipByGlobalNorm(unsigned int idx) const { return weights[idx]->isGradientClipByGlobalNorm(); } +bool RunLayerContext::isMixedPrecision(unsigned int idx) const { + return weights[idx]->isMixedPrecision(); +} + +bool RunLayerContext::isMixedPrecision() const { + for (auto w : weights) + if (w->isMixedPrecision()) + return true; + return false; +} + /** * @brief Get the tensor name * diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 09bccc2c73..c68c42f11d 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -689,6 +689,20 @@ class RunLayerContext { */ bool isGradientClipByGlobalNorm(unsigned int idx) const; + /** + * @brief check if the weight is mixed precsion + * + * @param idx index + * @return bool true if it is mixed precision + */ + bool isMixedPrecision(unsigned int idx) const; + + /** + * @brief check if the weight is mixed precsion + * @return bool true if it is mixed precision + */ + bool isMixedPrecision() const; + /** * @brief Get the tensor name * @@ -910,6 +924,18 @@ class RunLayerContext { */ float getLossScale() { return loss_scale; } + /** + * @brief set Loss_Scale. + * + * @return loss_scale + */ + void setLossScale(float scale) { + loss_scale = scale; + for (auto w : weights) { + w->setLossScale(scale); + } + } + private: std::tuple props; /**< props of the layer */ float loss; /**< loss of the layer */ diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index d0e542825f..f7c0914d32 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -412,9 +412,21 @@ void NeuralNetwork::backwarding(int iteration, NNTR_THROW_IF(!opt, std::invalid_argument) << "optimizer is null!"; #endif - std::function, int)> backwarding_op = + std::function, bool)> forwarding_op = + [this, stop_cb, userdata](std::shared_ptr node, + bool training) -> void { + (void)this; + PROFILE_MEM_ANNOTATE("Forwarding for layer: " + node->getName()); + + auto f = std::get<0>(node->getExecutionOrder()); + model_graph.flushCacheExcept(f); + + node->forwarding(training); + }; + + std::function, int)> backwarding_op = [this, stop_cb, userdata](std::shared_ptr node, - int iteration) -> void { + int iteration) -> bool { /** * Do not change this order: * 1. calcGradient @@ -448,19 +460,29 @@ void NeuralNetwork::backwarding(int iteration, /** If gradient must be applied and its not gradient mode, calculate * gradient */ - if (!dynamic_training_opt.isGradientMode() && apply_gradient) + if (!dynamic_training_opt.isGradientMode() && apply_gradient) { node->calcGradient(); + + RunLayerContext &rc = node->getRunContext(); + if (rc.isMixedPrecision()) { + for (auto w : rc.getWeights()) { + if (w->getGradientRef().hasNaN()) + return true; + } + } + } } model_graph.flushCacheExcept(std::get<2>(node->getExecutionOrder())); PROFILE_MEM_ANNOTATE("CalcDerivative: " + node->getName()); if (stop_cb(userdata)) { - return; + return false; } - if (node->needsCalcDerivative()) + if (node->needsCalcDerivative()) { node->calcDerivative(); + } model_graph.flushCacheExcept(std::get<3>(node->getExecutionOrder())); PROFILE_MEM_ANNOTATE("ApplyGradient: " + node->getName()); @@ -476,9 +498,10 @@ void NeuralNetwork::backwarding(int iteration, opt_->applyGradient(opt_context); }); } + return false; }; - std::function apply_grad_clip_op = + std::function lazy_apply_grad_op = [opt_ = opt.get()](Weight &w, int iteration) -> void { w.calcRegularizationGradient(); w.calcWeightDecayGradient(); @@ -487,8 +510,12 @@ void NeuralNetwork::backwarding(int iteration, opt_->applyGradient(opt_context); }; - model_graph.backwarding(iteration, backwarding_op, apply_grad_clip_op, - stop_cb, userdata); + bool ret = false; + + while (!ret) { + ret = model_graph.backwarding(iteration, forwarding_op, backwarding_op, + lazy_apply_grad_op, stop_cb, userdata); + } } void NeuralNetwork::save(const std::string &file_path, diff --git a/nntrainer/tensor/blas_avx.cpp b/nntrainer/tensor/blas_avx.cpp index 2fd4908463..ce4d8de47f 100644 --- a/nntrainer/tensor/blas_avx.cpp +++ b/nntrainer/tensor/blas_avx.cpp @@ -127,7 +127,7 @@ bool hasNaN(const size_t N, const _Float16 *input) { const __m256 vec0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)input)); const __m256 vec1 = - _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)input + 8)); + _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(input + 8))); input += 16; @@ -161,6 +161,7 @@ bool hasNaN(const size_t N, const _Float16 *input) { return true; } ++input; + ++idx; } return false; @@ -205,6 +206,7 @@ bool hasNaN(const size_t N, const float *input) { return true; } ++input; + ++idx; } return false; diff --git a/nntrainer/tensor/weight.cpp b/nntrainer/tensor/weight.cpp index df262f50d9..0e9879540a 100644 --- a/nntrainer/tensor/weight.cpp +++ b/nntrainer/tensor/weight.cpp @@ -153,7 +153,7 @@ void Weight::quantizeWeight() { // NYI break; case ml::train::TensorDim::DataType::FP16: - getVariableRef().copy(getVariableFP32Ref()); + getVariableRef().copyData(getVariableFP32Ref()); break; case ml::train::TensorDim::DataType::FP32: break; diff --git a/nntrainer/tensor/weight.h b/nntrainer/tensor/weight.h index 5382c686e1..8ac3aa0190 100644 --- a/nntrainer/tensor/weight.h +++ b/nntrainer/tensor/weight.h @@ -349,6 +349,13 @@ class Weight : public Var_Grad { */ void quantizeWeight(); + /** + * @brief set loss scale + * param[in] scale + * + */ + void setLossScale(float scale) { loss_scale = scale; }; + private: static constexpr float epsilon = 1e-6; /**< epsilon for zero comparison */ static constexpr float epsilon_decay =