diff --git a/Applications/KNN/jni/meson.build b/Applications/KNN/jni/meson.build index bc50dc0214..58ca099d75 100644 --- a/Applications/KNN/jni/meson.build +++ b/Applications/KNN/jni/meson.build @@ -15,4 +15,4 @@ e = executable('knn_sample', install_dir: application_install_dir ) -test('app_knn', e, args: [nntr_app_resdir / 'KNN']) +test('app_knn', e, args: [nntr_app_resdir / 'KNN/']) diff --git a/meson.build b/meson.build index d4aea330a4..7ae692e6d9 100644 --- a/meson.build +++ b/meson.build @@ -64,9 +64,19 @@ warning_c_flags = [ '-Wno-error=varargs' ] +arch = host_machine.cpu_family() + +if get_option('enable-avx') + extra_defines += '-DUSE_AVX=1' + if get_option('platform') == 'tizen' + add_project_arguments(['-mavx2'], language: ['c','cpp']) + else + add_project_arguments(['-march=native'], language: ['c','cpp']) + endif + message('-march=native added for AVX hardware acceleration.') +endif if get_option('enable-fp16') - arch = host_machine.cpu_family() if get_option('platform') == 'android' add_project_arguments('-mfp16-format=ieee', language: ['c', 'cpp']) extra_defines += '-DENABLE_FP16=1' @@ -105,11 +115,6 @@ if get_option('enable-fp16') if cc.version().version_compare('>=12.1.0') message ('Float16 for x86_64 enabled. Modern gcc-x64 generally supports float16 with _Float16.') extra_defines += '-DENABLE_FP16=1' - if get_option('enable-avx') - extra_defines += '-DUSE_AVX=1' - add_project_arguments(['-march=native'], language: ['c','cpp']) - message('-march=native added for AVX hardware acceleration.') - endif else warning ('Float16 for x86_64 enabled. However, software emulation is applied for fp16, making it slower and inconsistent. Use GCC 12+ for FP16 support. This build will probably fail unless you bring a compiler that supports fp16 for x64.') endif diff --git a/nntrainer/graph/graph_core.cpp b/nntrainer/graph/graph_core.cpp index b624e066e4..3eafbb9261 100644 --- a/nntrainer/graph/graph_core.cpp +++ b/nntrainer/graph/graph_core.cpp @@ -35,6 +35,10 @@ GraphCore::getSortedNode(unsigned int ith) const { return Sorted.at(ith); } +const unsigned int GraphCore::getSortedNodeIdx(const std::string &name) const { + return sorted_node_map.at(name); +} + void GraphCore::makeAdjacencyList( std::vector>> &adj) { /** initialize the adj list */ @@ -93,6 +97,11 @@ void GraphCore::topologicalSort() { if (Sorted.size() != node_list.size()) throw std::runtime_error("Internal error in topologicalSort"); + unsigned int idx = 0; + for (auto n : Sorted) { + sorted_node_map[n->getName()] = idx; + idx++; + } } const std::shared_ptr & diff --git a/nntrainer/graph/graph_core.h b/nntrainer/graph/graph_core.h index 83d3ce7c39..77aa63666a 100644 --- a/nntrainer/graph/graph_core.h +++ b/nntrainer/graph/graph_core.h @@ -91,6 +91,13 @@ class GraphCore { */ const std::shared_ptr &getSortedNode(unsigned int ith) const; + /** + * @brief getter of Sorted GraphNode index with name + * @param[in] layer name + * @ret index + */ + const unsigned int getSortedNodeIdx(const std::string &name) const; + /** * @brief getter of GraphNode with node name * @param[in] node name @@ -252,6 +259,7 @@ class GraphCore { std::vector> node_list; /**< Unordered Node List */ std::unordered_map node_map; /**< Unordered Node map */ + std::unordered_map sorted_node_map; /**< Unordered Node map */ std::vector> Sorted; /**< Ordered Node List */ bool sorted; /** if the node_list is sorted */ diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 2d4cfdc769..821731e949 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -337,7 +337,7 @@ void NetworkGraph::applyGradients( continue; } - if (rc.isGradientClipByGlobalNorm(i)) { + if (rc.isGradientClipByGlobalNorm(i) || rc.isMixedPrecision(i)) { /** * @note the weights whose gradient are to be clipped by global norm will * be clipped at once at the end of iteration and applied then. @@ -393,56 +393,100 @@ sharedConstTensors NetworkGraph::incremental_forwarding( return out; } -void NetworkGraph::backwarding( +bool NetworkGraph::backwarding( int iteration, - std::function, int)> &backwarding_op, - std::function &apply_grad_clip_op, - std::function stop_cb, void *userdata) const { + std::function, bool)> &forwarding_op, + std::function, int)> &backwarding_op, + std::function &lazy_apply_grad_op, + std::function stop_cb, void *userdata) { /** * last layer backwarding is run out of this loop */ auto iter_begin = getBackwardingBeginIter(); auto iter_end = getBackwardingEndIter(); + bool is_valid = true; /// there is no layer to train, so backwarding is essentially noop if (iter_begin == iter_end) { - return; + return true; } auto const &lptr_begin = (*iter_begin); + // graph_const_reverse_iterator + auto iter_ = iter_begin; if (lptr_begin->requireLabel() == false) throw std::runtime_error( "Error: last layer does not accept label, we can't train"); - for (auto iter = iter_begin; iter != iter_end && !stop_cb(userdata); iter++) { - auto &ln = *iter; + for (iter_ = iter_begin; iter_ != iter_end && !stop_cb(userdata); iter_++) { + auto &ln = *iter_; PROFILE_TIME_START(profile_keys.at(ln->getType())); - backwarding_op(ln, iteration); + is_valid = backwarding_op(ln, iteration); PROFILE_TIME_END(profile_keys.at(ln->getType())); + + if (!is_valid) { + std::cout << "Gradient has NaN" << std::endl; + break; + } } - /** perform clipping of the gradients by global norm if any */ - if (clip_weights.empty()) - return; + if (!is_valid) { + /** if has NaN + * 1. reset the loss scale. + * 2. run forwarding from cur_iter to cend() && !stop_cb(userdata); + * 3. return false --> run backwarding again; + */ + float scale = (*iter_)->getRunContext().getLossScale(); + float s = scale > 1.5f ? scale - 0.5f : 1.0f; + + resetLossScale(s); + + auto f_iter = cbegin() + graph.getSortedNodeIdx((*iter_)->getName()); + + for (auto iter = f_iter; iter != cend() && !stop_cb(userdata); iter++) { + auto &ln = *iter; + PROFILE_TIME_START(profile_keys.at(ln->getType())); + forwarding_op(*iter, true); + PROFILE_TIME_END(profile_keys.at(ln->getType())); + } - /** calculate the global norm */ - Tensor global_norm_t( - TensorDim({1u, 1u, 1u, (unsigned int)clip_weights.size()})); - float *global_norm_data = global_norm_t.getData(); - for (unsigned int idx = 0; idx < clip_weights.size(); idx++) { - auto const &w = clip_weights[idx]; - global_norm_data[idx] = w->getGradientNorm(); + return false; } - float global_norm = global_norm_t.l2norm(); - /** apply the gradient with the above global norm */ - for (auto w : clip_weights) { - w->clipGradientByGlobalNorm(global_norm); + + /** perform clipping of the gradients by global norm if any */ + if (lazy_weights.empty()) + return true; + + if (is_clip_grad) { + /** calculate the global norm */ + Tensor global_norm_t( + TensorDim({1u, 1u, 1u, (unsigned int)lazy_weights.size()})); + float *global_norm_data = global_norm_t.getData(); + for (unsigned int idx = 0; idx < lazy_weights.size(); idx++) { + auto const &w = lazy_weights[idx]; + global_norm_data[idx] = w->getGradientNorm(); + } + float global_norm = global_norm_t.l2norm(); + /** apply the gradient with the above global norm */ + for (auto w : lazy_weights) { + w->clipGradientByGlobalNorm(global_norm); + } } /** apply the gradient with the above global norm */ - for (auto w : clip_weights) { - apply_grad_clip_op(*w, iteration); + for (auto w : lazy_weights) { + lazy_apply_grad_op(*w, iteration); + } + nan_count++; + + if (nan_count > 10) { + float scale = (*iter_)->getRunContext().getLossScale(); + float s = scale + 2.0f; + resetLossScale(s); + nan_count = 0; } + + return true; } LayerNode *NetworkGraph::computeBackwardEnd() { @@ -768,9 +812,10 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, * node is going to be used with in-place optimizations. */ auto out_specs = init_context.getOutSpecs(); + /// @note try move inplace control to finalize bool shared_var = false, shared_grad = false; - if (lnode->executeInPlace() != InPlace::NONE) { + if (lnode->executeInPlace() != InPlace::NONE && lnode->supportInPlace()) { setInplaceSharedMemoryConfigByLayer(lnode, shared_var, shared_grad); for (unsigned int i = 0; i < out_specs.size(); ++i) { auto &s = out_specs.at(i); @@ -879,7 +924,8 @@ NetworkGraph::finalizeContext(const std::shared_ptr &lnode, lnode->getTrainable(), shared_weight_names), inputs, outputs, tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(), - lnode->getTrainable(), shared_tensor_names)); + lnode->getTrainable(), shared_tensor_names), + init_context.getLossScale()); return outputs; } @@ -1027,7 +1073,8 @@ NetworkGraph::refinalizeContext(const std::shared_ptr &lnode, // TODO: update weights spec for trainable based on layer trainable prop weights, inputs, outputs, tensor_manager->requestTensors(gnode, init_context.getTensorsSpec(), - lnode->getTrainable(), shared_tensor_names)); + lnode->getTrainable(), shared_tensor_names), + init_context.getLossScale()); return outputs; } @@ -1287,11 +1334,19 @@ int NetworkGraph::initialize(ExecutionMode mode, /** select weights which would require clipping of the gradients by global * norm if any */ - clip_weights = tensor_manager->getWeights([](const Weight *w) { + lazy_weights = tensor_manager->getWeights([](const Weight *w) { return w->hasGradient() && w->isGradientLastAccess() && - w->isGradientClipByGlobalNorm(); + (w->isGradientClipByGlobalNorm() || w->isMixedPrecision()); }); + is_clip_grad = false; + for (auto w : lazy_weights) { + if (w->isGradientClipByGlobalNorm()) { + is_clip_grad = true; + break; + } + } + return ML_ERROR_NONE; } @@ -1556,10 +1611,18 @@ void NetworkGraph::requestOptimizerVariable( const TensorDim &dim = w->getDim(); std::vector dims = cb(dim); w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables( - dims, w->getName(), TensorLifespan::MAX_LIFESPAN, - w->isGradientClipByGlobalNorm(), Tensor::Initializer::ZEROS)); + dims, w->getName(), ":opt", TensorLifespan::MAX_LIFESPAN, + w->isGradientClipByGlobalNorm(), w->isMixedPrecision(), + Tensor::Initializer::ZEROS)); } } } +void NetworkGraph::resetLossScale(float scale) { + for (auto iter = cbegin(); iter != cend(); iter++) { + auto &ln = *iter; + ln->getRunContext().setLossScale(scale); + } +} + } /* namespace nntrainer */ diff --git a/nntrainer/graph/network_graph.h b/nntrainer/graph/network_graph.h index 5c9adf0363..22f14e1b73 100644 --- a/nntrainer/graph/network_graph.h +++ b/nntrainer/graph/network_graph.h @@ -51,7 +51,9 @@ class NetworkGraph { optimize_memory(true), exec_mode(ExecutionMode::TRAIN), tensor_format("NCHW"), - tensor_dtype(split("FP32-FP32", getRegex("\\-"))) {} + tensor_dtype(split("FP32-FP32", getRegex("\\-"))) { + nan_count = 0; + } /** * @brief Constructor of NeuralNetwork Graph Class @@ -73,7 +75,9 @@ class NetworkGraph { optimize_memory(true), exec_mode(ExecutionMode::TRAIN), tensor_format(tensor_format_), - tensor_dtype(split(tensor_dtype_, getRegex("\\-"))) {} + tensor_dtype(split(tensor_dtype_, getRegex("\\-"))) { + nan_count = 0; + } /** * @brief Destructor of the NeuralNetwork Graph class @@ -206,13 +210,14 @@ class NetworkGraph { * @param[in] backwarding_op operation for the backwarding * @param[in] apply_grad_clip_op operation for applying the clip gradients */ - void backwarding( + bool backwarding( int iteration, - std::function, int)> &backwarding_op, - std::function &apply_grad_clip_op, + std::function, bool)> &forwarding_op, + std::function, int)> &backwarding_op, + std::function &lazy_apply_grad_op, std::function stop_cb = [](void *user_data) { return false; }, - void *user_data = nullptr) const; + void *user_data = nullptr); /** * @brief get begin iterator for the graph @@ -444,6 +449,12 @@ class NetworkGraph { getLayerExecutionOrders(const std::shared_ptr &lnode); #endif // ENABLE_TEST + /** + * @brief reset the loss scale + * @param[in] scale + */ + void resetLossScale(float scale); + private: std::map sub_in_out; /** This is map to identify input and output layer name of subgraph */ @@ -480,7 +491,10 @@ class NetworkGraph { std::unordered_map profile_keys; /**< profile keys based on the layer type */ std::vector - clip_weights; /**< weights with global norm based clipping enabled */ + lazy_weights; /**< weights with global norm based clipping enabled */ + bool is_clip_grad; + + unsigned int nan_count; /** * @brief topological sort diff --git a/nntrainer/layers/input_layer.cpp b/nntrainer/layers/input_layer.cpp index eabd40b297..d9f058d8ce 100644 --- a/nntrainer/layers/input_layer.cpp +++ b/nntrainer/layers/input_layer.cpp @@ -33,8 +33,7 @@ namespace nntrainer { static constexpr size_t SINGLE_INOUT_IDX = 0; InputLayer::InputLayer() : - Layer(), - input_props(props::Normalization(), props::Standardization()) {} + Layer(), input_props(props::Normalization(), props::Standardization()) {} void InputLayer::setProperty(const std::vector &values) { auto remain_props = loadProperties(values, input_props); @@ -47,7 +46,7 @@ void InputLayer::forwarding(RunLayerContext &context, bool training) { Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); if (!context.executeInPlace()) { Tensor &input_ = context.getInput(SINGLE_INOUT_IDX); - hidden_.copy(input_); + hidden_.copyData(input_); } if (std::get(input_props)) @@ -70,7 +69,21 @@ void InputLayer::finalize(InitLayerContext &context) { std::vector output_dims = context.getInputDimensions(); + for (auto &d : output_dims) { + d.setDataType(context.getActivationDataType()); + } + context.setOutputDimensions(output_dims); + + is_inplace = true; + + /** + * @note Input Layer assuems that the FP32 IN Tensor always. Therefore, if the + * activation data type is not fp32, then it does not support in-place + * operation. + */ + if (context.getActivationDataType() != ml::train::TensorDim::DataType::FP32) + is_inplace = false; } } /* namespace nntrainer */ diff --git a/nntrainer/layers/input_layer.h b/nntrainer/layers/input_layer.h index f6728d676b..e9183e23d1 100644 --- a/nntrainer/layers/input_layer.h +++ b/nntrainer/layers/input_layer.h @@ -82,7 +82,7 @@ class InputLayer : public Layer { /** * @copydoc Layer::supportInPlace() */ - bool supportInPlace() const override { return true; } + bool supportInPlace() const override { return is_inplace; } /** * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods @@ -105,6 +105,7 @@ class InputLayer : public Layer { private: std::tuple input_props; + bool is_inplace; }; } // namespace nntrainer diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index fff2eb15ec..5862e6af14 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -126,13 +126,14 @@ const std::vector &InitLayerContext::getOutSpecs() const { } RunLayerContext::RunLayerContext(const std::string &name, bool trainable, - float l, bool in_place_, + float l, bool in_place_, float loss_scale_, const std::vector &w, const std::vector &in, const std::vector &out, const std::vector &t) : loss(l), in_place(in_place_), + loss_scale(loss_scale_), weights(w), inputs(in), outputs(out), @@ -169,6 +170,19 @@ Tensor &RunLayerContext::getWeightGrad(unsigned int idx) const { return weights[idx]->getGradientRef(); } +/** + * @brief Get the Weight Gradient tensor object + * + * @param idx Identifier of the weight + * @return Tensor& Reference to the weight grad tensor + */ +Tensor &RunLayerContext::getWeightFP32(unsigned int idx) const { + if (!weights[idx]->hasGradient()) + throw std::invalid_argument( + "Requesting gradient for a non-trainable weight."); + return weights[idx]->getVariableFP32Ref(); +} + /** * @brief Get the Weight Optimizer Variable tensor object * @@ -402,6 +416,17 @@ bool RunLayerContext::isGradientClipByGlobalNorm(unsigned int idx) const { return weights[idx]->isGradientClipByGlobalNorm(); } +bool RunLayerContext::isMixedPrecision(unsigned int idx) const { + return weights[idx]->isMixedPrecision(); +} + +bool RunLayerContext::isMixedPrecision() const { + for (auto w : weights) + if (w->isMixedPrecision()) + return true; + return false; +} + /** * @brief Get the tensor name * diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index e5c6759638..c68c42f11d 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -63,7 +63,7 @@ class InitLayerContext { const float max_norm = 0.0, std::array tensor_type_ = {"NCHW", "FP32", "FP32"}, - const float loss_scale = 0.0); + const float loss_scale = 1.0); /** * @brief get Tensor Format of Layer * @@ -348,6 +348,14 @@ class InitLayerContext { */ bool executeInPlace() const { return in_place; } + /** + * @brief get Initial value of Loss_Scale. This is set to RunLayerContext + * and updated + * + * @return loss_scale + */ + float getLossScale() const { return loss_scale; } + private: std::vector input_dim; /**< Input dimensions for the layer */ bool in_place; /**< if the layer is expected to run in-place */ @@ -385,7 +393,7 @@ class RunLayerContext { * @brief Construct a new Run Layer Context object * */ - RunLayerContext() : loss(0.0), in_place(false) {} + RunLayerContext() : loss(0.0), in_place(false), loss_scale(1.0) {} /** * @brief Construct a new Run Layer Context object @@ -396,6 +404,17 @@ class RunLayerContext { std::get(props).set(name); } + /** + * @brief Construct a new Run Layer Context object + * + */ + RunLayerContext(const std::string &name, bool in_place_, float loss_scale_) : + RunLayerContext() { + in_place = in_place_; + std::get(props).set(name); + loss_scale = loss_scale_; + } + /** * @brief Construct a new Run Layer Context object * @@ -403,13 +422,15 @@ class RunLayerContext { * @param trainable if the layer is trainable * @param l loss of the layer * @param in_place_ execution in-place of the layer + * @param loss_scale loss_scale of the layer * @param w weights of the layer * @param in inputs of the layer * @param out outputs of the layer * @param t extra tensors of the layer */ RunLayerContext(const std::string &name, bool trainable, float l, - bool in_place_, const std::vector &w, + bool in_place_, float loss_scale_, + const std::vector &w, const std::vector &in, const std::vector &out, const std::vector &t); @@ -463,6 +484,15 @@ class RunLayerContext { Tensor &getWeightGrad(unsigned int idx) const; /** + * @brief Get the Weight Gradient tensor object + * + * @param idx Identifier of the weight + * @return Tensor& Reference to the weight grad tensor + */ + Tensor &getWeightFP32(unsigned int idx) const; + + /** + * @brief Get the Weight Optimizer Variable tensor object * * @param idx Identifier of the weight @@ -659,6 +689,20 @@ class RunLayerContext { */ bool isGradientClipByGlobalNorm(unsigned int idx) const; + /** + * @brief check if the weight is mixed precsion + * + * @param idx index + * @return bool true if it is mixed precision + */ + bool isMixedPrecision(unsigned int idx) const; + + /** + * @brief check if the weight is mixed precsion + * @return bool true if it is mixed precision + */ + bool isMixedPrecision() const; + /** * @brief Get the tensor name * @@ -874,10 +918,29 @@ class RunLayerContext { */ ml::train::LayerComputeEngine getComputeEngine() { return compute_engine; } + /** + * @brief get loss scale + * @return loss scale + */ + float getLossScale() { return loss_scale; } + + /** + * @brief set Loss_Scale. + * + * @return loss_scale + */ + void setLossScale(float scale) { + loss_scale = scale; + for (auto w : weights) { + w->setLossScale(scale); + } + } + private: std::tuple props; /**< props of the layer */ float loss; /**< loss of the layer */ - bool in_place; /**< if the layer is expected to run in-place */ + bool in_place; /**< if the layer is expected to run in-place */ + float loss_scale; /**< loss_scale of the layer */ std::vector weights; /**< weights of the layer */ std::vector inputs; /**< inputs of the layer */ diff --git a/nntrainer/layers/layer_node.cpp b/nntrainer/layers/layer_node.cpp index 8b18d80762..f41752a4d8 100644 --- a/nntrainer/layers/layer_node.cpp +++ b/nntrainer/layers/layer_node.cpp @@ -599,7 +599,7 @@ InitLayerContext LayerNode::finalize(const std::vector &input_dims, const auto &scope = getSharedFrom().empty() ? getName() : getSharedFrom(); float max_norm = 0.0; - float loss_scale = 0.0; + float loss_scale = 1.0; if (!std::get(*layer_node_props).empty()) max_norm = std::get(*layer_node_props).get(); @@ -864,10 +864,11 @@ float LayerNode::getLoss() const { return *loss; } void LayerNode::configureRunContext(const std::vector &weights, const std::vector &inputs, const std::vector &outputs, - const std::vector &tensors) { + const std::vector &tensors, + float loss_scale) { run_context = std::make_unique( - getName(), getTrainable(), 0.0f, executeInPlace() != InPlace::NONE, weights, - inputs, outputs, tensors); + getName(), getTrainable(), 0.0f, executeInPlace() != InPlace::NONE, + loss_scale, weights, inputs, outputs, tensors); } /** diff --git a/nntrainer/layers/layer_node.h b/nntrainer/layers/layer_node.h index 93e7ac7069..3fd2d55b97 100644 --- a/nntrainer/layers/layer_node.h +++ b/nntrainer/layers/layer_node.h @@ -487,6 +487,7 @@ class LayerNode final : public ml::train::Layer, public GraphNode { const std::vector getOutputDimensions() const; /** * @brief Get the Weight object + * currently, only unittest uses this func. * * @param idx Identifier of the weight * @return Weight& Reference to the weight @@ -495,11 +496,11 @@ class LayerNode final : public ml::train::Layer, public GraphNode { NNTR_THROW_IF(!run_context, std::runtime_error) << __func__ << " layer needs to be finalized first!"; if (run_context->weightHasGradient(idx)) { - return Weight(run_context->getWeight(idx), - run_context->getWeightGrad(idx), - run_context->getWeightName(idx)); + return Weight( + run_context->getWeight(idx), run_context->getWeightGrad(idx), + run_context->getWeightFP32(idx), run_context->getWeightName(idx)); } else { - return Weight(run_context->getWeight(idx), Tensor(), + return Weight(run_context->getWeight(idx), Tensor(), Tensor(), run_context->getWeightName(idx)); } } @@ -819,7 +820,8 @@ class LayerNode final : public ml::train::Layer, public GraphNode { void configureRunContext(const std::vector &weights, const std::vector &inputs, const std::vector &outputs, - const std::vector &tensors); + const std::vector &tensors, + float loss_scale); /** * @brief Preset modes for printing summary for the layer diff --git a/nntrainer/layers/loss/loss_layer.cpp b/nntrainer/layers/loss/loss_layer.cpp index 40f74717f8..ab2ccf8be2 100644 --- a/nntrainer/layers/loss/loss_layer.cpp +++ b/nntrainer/layers/loss/loss_layer.cpp @@ -22,7 +22,7 @@ void LossLayer::finalize(InitLayerContext &context) { d.setDataType( str_converter::from_string("FP32")); - + context.setOutputDimensions(output_dim); } @@ -36,6 +36,13 @@ void LossLayer::updateLoss(RunLayerContext &context, const Tensor &l) { context.setLoss(loss_sum / (float)l.batch()); } +void LossLayer::applyLossScale(RunLayerContext &context, Tensor &ret_deriv) { + + float loss_scale = context.getLossScale(); + if (loss_scale != 1.0) + ret_deriv.multiply_i(loss_scale); +} + /** * @copydoc Layer::setProperty(const std::vector &values) */ diff --git a/nntrainer/layers/loss/loss_layer.h b/nntrainer/layers/loss/loss_layer.h index 00b520f6e6..581e9477a8 100644 --- a/nntrainer/layers/loss/loss_layer.h +++ b/nntrainer/layers/loss/loss_layer.h @@ -60,6 +60,13 @@ class LossLayer : public Layer { */ void updateLoss(RunLayerContext &context, const Tensor &l); + /** + * @brief update return derivative with loss scale + * @param context Run context to update + * @param return_dev Tensor data to calculate + */ + void applyLossScale(RunLayerContext &context, Tensor &l); + Tensor l; /**< loss tensor to store intermediate value to calculate loss value */ }; diff --git a/nntrainer/layers/loss/mse_loss_layer.cpp b/nntrainer/layers/loss/mse_loss_layer.cpp index 7f7bd1626f..ed4390655d 100644 --- a/nntrainer/layers/loss/mse_loss_layer.cpp +++ b/nntrainer/layers/loss/mse_loss_layer.cpp @@ -20,7 +20,16 @@ static constexpr size_t SINGLE_INOUT_IDX = 0; void MSELossLayer::forwarding(RunLayerContext &context, bool training) { Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX); - Tensor &y = context.getInput(SINGLE_INOUT_IDX); + + Tensor empty_tensor; + Tensor &y = context.getInput(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getInput(SINGLE_INOUT_IDX) + : empty_tensor; + + if (y.empty()) + y = context.getInput(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); // hidden_ <- y2 - y; if (context.isLabelAvailable(SINGLE_INOUT_IDX)) { @@ -41,9 +50,28 @@ void MSELossLayer::forwarding(RunLayerContext &context, bool training) { } void MSELossLayer::calcDerivative(RunLayerContext &context) { - Tensor &ret_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX); + Tensor empty_tensor; + + Tensor &ret_derivative = + context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getOutgoingDerivative(SINGLE_INOUT_IDX) + : empty_tensor; + + if (ret_derivative.empty()) + ret_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); + + Tensor &y = context.getInput(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getInput(SINGLE_INOUT_IDX) + : empty_tensor; + + if (y.empty()) + y = context.getInput(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); + const Tensor &y2 = context.getIncomingDerivative(SINGLE_INOUT_IDX); - Tensor &y = context.getInput(SINGLE_INOUT_IDX); y.subtract(y2, ret_derivative); float divider = ((float)y.size()) / 2; @@ -51,6 +79,16 @@ void MSELossLayer::calcDerivative(RunLayerContext &context) { throw std::runtime_error( "[MSELossLayer::calcDerivative] Error when calculating loss"); } + + // Loss Scale needs Full precsiion of ret_derivative. Therefore, + // ret_derivateive should be FP32 when applying scale, and after applying it + // need to convert original type for backpropagating. + + LossLayer::applyLossScale(context, ret_derivative); + + if (context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType() != + ml::train::TensorDim::DataType::FP32) + context.getOutgoingDerivative(SINGLE_INOUT_IDX).copyData(ret_derivative); } } // namespace nntrainer diff --git a/nntrainer/layers/time_dist.cpp b/nntrainer/layers/time_dist.cpp index 80451416df..779010065a 100644 --- a/nntrainer/layers/time_dist.cpp +++ b/nntrainer/layers/time_dist.cpp @@ -256,8 +256,8 @@ void TimeDistLayer::forwarding(RunLayerContext &context, bool training) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->forwarding(dist_context, training); } @@ -303,8 +303,8 @@ void TimeDistLayer::calcDerivative(RunLayerContext &context) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->calcDerivative(dist_context); } @@ -354,8 +354,8 @@ void TimeDistLayer::calcGradient(RunLayerContext &context) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->calcGradient(dist_context); } @@ -396,8 +396,8 @@ void TimeDistLayer::setBatch(RunLayerContext &context, unsigned int batch) { RunLayerContext dist_context(context.getName(), context.getTrainable(), context.getLoss(), context.executeInPlace(), - getWeightsForContext(), {&in_var}, {&out_var}, - getTensorsForContext()); + context.getLossScale(), getWeightsForContext(), + {&in_var}, {&out_var}, getTensorsForContext()); dist_layer->setBatch(dist_context, batch); diff --git a/nntrainer/models/model_common_properties.h b/nntrainer/models/model_common_properties.h index 3776afefca..3435d18e96 100644 --- a/nntrainer/models/model_common_properties.h +++ b/nntrainer/models/model_common_properties.h @@ -217,7 +217,7 @@ class ModelTensorDataType final : public EnumProperty { */ class LossScale : public Property { public: - LossScale(float value = 0.0f); + LossScale(float value = 1.0f); static constexpr const char *key = "loss_scale"; /**< unique key to access */ using prop_tag = float_prop_tag; /**< property type */ }; diff --git a/nntrainer/models/neuralnet.cpp b/nntrainer/models/neuralnet.cpp index d0e542825f..afc560603e 100644 --- a/nntrainer/models/neuralnet.cpp +++ b/nntrainer/models/neuralnet.cpp @@ -412,9 +412,21 @@ void NeuralNetwork::backwarding(int iteration, NNTR_THROW_IF(!opt, std::invalid_argument) << "optimizer is null!"; #endif - std::function, int)> backwarding_op = + std::function, bool)> forwarding_op = + [this, stop_cb, userdata](std::shared_ptr node, + bool training) -> void { + (void)this; + PROFILE_MEM_ANNOTATE("Forwarding for layer: " + node->getName()); + + auto f = std::get<0>(node->getExecutionOrder()); + model_graph.flushCacheExcept(f); + + node->forwarding(training); + }; + + std::function, int)> backwarding_op = [this, stop_cb, userdata](std::shared_ptr node, - int iteration) -> void { + int iteration) -> bool { /** * Do not change this order: * 1. calcGradient @@ -448,19 +460,29 @@ void NeuralNetwork::backwarding(int iteration, /** If gradient must be applied and its not gradient mode, calculate * gradient */ - if (!dynamic_training_opt.isGradientMode() && apply_gradient) + if (!dynamic_training_opt.isGradientMode() && apply_gradient) { node->calcGradient(); + + RunLayerContext &rc = node->getRunContext(); + if (rc.isMixedPrecision()) { + for (auto w : rc.getWeights()) { + if (!w->getGradientRef().isValid()) + return false; + } + } + } } model_graph.flushCacheExcept(std::get<2>(node->getExecutionOrder())); PROFILE_MEM_ANNOTATE("CalcDerivative: " + node->getName()); if (stop_cb(userdata)) { - return; + return true; } - if (node->needsCalcDerivative()) + if (node->needsCalcDerivative()) { node->calcDerivative(); + } model_graph.flushCacheExcept(std::get<3>(node->getExecutionOrder())); PROFILE_MEM_ANNOTATE("ApplyGradient: " + node->getName()); @@ -476,9 +498,10 @@ void NeuralNetwork::backwarding(int iteration, opt_->applyGradient(opt_context); }); } + return true; }; - std::function apply_grad_clip_op = + std::function lazy_apply_grad_op = [opt_ = opt.get()](Weight &w, int iteration) -> void { w.calcRegularizationGradient(); w.calcWeightDecayGradient(); @@ -487,8 +510,13 @@ void NeuralNetwork::backwarding(int iteration, opt_->applyGradient(opt_context); }; - model_graph.backwarding(iteration, backwarding_op, apply_grad_clip_op, - stop_cb, userdata); + // return false if the gradient is not valid + bool ret = false; + + while (!ret) { + ret = model_graph.backwarding(iteration, forwarding_op, backwarding_op, + lazy_apply_grad_op, stop_cb, userdata); + } } void NeuralNetwork::save(const std::string &file_path, diff --git a/nntrainer/optimizers/adam.cpp b/nntrainer/optimizers/adam.cpp index 18c0a0fcc1..530e7fdf31 100644 --- a/nntrainer/optimizers/adam.cpp +++ b/nntrainer/optimizers/adam.cpp @@ -36,7 +36,15 @@ Adam::~Adam() {} enum AdamParams { wm, wv }; std::vector Adam::getOptimizerVariableDim(const TensorDim &dim) { - return {dim, dim}; + /** + * @note We assume the optimizer parameters should be full precsion to + * maintain the accuracy even in mixed precision training. + */ + TensorDim wm_dim(dim); + TensorDim wv_dim(dim); + wm_dim.setDataType(ml::train::TensorDim::DataType::FP32); + wv_dim.setDataType(ml::train::TensorDim::DataType::FP32); + return {wm_dim, wv_dim}; } void Adam::exportTo(Exporter &exporter, @@ -64,7 +72,15 @@ double Adam::getUpdatedLearningRate(unsigned int iteration, double ll) const { } void Adam::applyGradient(RunOptimizerContext &context) { - Tensor &x_grad = context.getGradient(); + Tensor empty_tensor; + + Tensor &x_grad = + context.getGradient().getDataType() == ml::train::TensorDim::DataType::FP32 + ? context.getGradient() + : empty_tensor; + + if (x_grad.empty()) + x_grad = context.getGradient().clone(ml::train::TensorDim::DataType::FP32); auto &beta1 = std::get(adam_props).get(); auto &beta2 = std::get(adam_props).get(); @@ -91,7 +107,7 @@ void Adam::applyGradient(RunOptimizerContext &context) { denom.add_i(epsilon); wm.divide(denom, x_grad); - context.applyGradient(context.getLearningRate() / biasCorrection1); + context.applyGradient(context.getLearningRate() / biasCorrection1, x_grad); } else { std::function sqrtEps = [epsilon](double f) { @@ -100,8 +116,9 @@ void Adam::applyGradient(RunOptimizerContext &context) { x_grad = wv.apply(sqrtEps, x_grad); x_grad.multiply_i(wm); - context.applyGradient(getUpdatedLearningRate(context.getIteration(), - context.getLearningRate())); + context.applyGradient( + getUpdatedLearningRate(context.getIteration(), context.getLearningRate()), + x_grad); } } diff --git a/nntrainer/optimizers/optimizer_context.cpp b/nntrainer/optimizers/optimizer_context.cpp index da4cd1f7e9..f70ab773a9 100644 --- a/nntrainer/optimizers/optimizer_context.cpp +++ b/nntrainer/optimizers/optimizer_context.cpp @@ -42,4 +42,11 @@ Tensor &RunOptimizerContext::getOptimizerVariable(unsigned int idx) const { void RunOptimizerContext::applyGradient(double lr) const { weight->applyGradient(lr); } + +/** + * @brief Apply the gradient with the given learning rate and gradient + */ +void RunOptimizerContext::applyGradient(double lr, Tensor &updated_grad) const { + weight->applyGradient(lr, updated_grad); +} } // namespace nntrainer diff --git a/nntrainer/optimizers/optimizer_context.h b/nntrainer/optimizers/optimizer_context.h index 62f9e0945d..6b4b983e35 100644 --- a/nntrainer/optimizers/optimizer_context.h +++ b/nntrainer/optimizers/optimizer_context.h @@ -35,9 +35,7 @@ class RunOptimizerContext { * */ RunOptimizerContext(Weight *w = nullptr, size_t iter = 0, double lr = 0.0) : - weight(w), - iteration(iter), - learning_rate(lr) {} + weight(w), iteration(iter), learning_rate(lr) {} /** * @brief Get the Weight tensor object @@ -75,6 +73,16 @@ class RunOptimizerContext { */ void applyGradient(double lr) const; + /** + * @brief Apply the gradient with the given learning rate and updated + * gradient + * + * @param lr learning rate + * @param updated_grad gradient tensor which is updated. (usually it could be + * fp32) + */ + void applyGradient(double lr, Tensor &updated_grad) const; + /** * @brief Get the current iteration value * diff --git a/nntrainer/tensor/blas_avx.cpp b/nntrainer/tensor/blas_avx.cpp index ce59583d6f..411dbcbb5d 100644 --- a/nntrainer/tensor/blas_avx.cpp +++ b/nntrainer/tensor/blas_avx.cpp @@ -20,6 +20,7 @@ namespace nntrainer::avx { +#ifdef ENABLE_FP16 void vcvt_f16_f32(size_t N, const void *input, float *output) { assert(N != 0); assert(input != NULL); @@ -114,4 +115,163 @@ void vcvt_f32_f16(size_t N, const float *input, void *output) { } } +bool isValid(const size_t N, const _Float16 *input) { + assert(N != 0); + assert(input != NULL); + + int temp = 0; + size_t idx = 0; + + const __m256 SIGN_MASK = _mm256_set1_ps(-0.0); + const __m256 INF = _mm256_set1_ps(std::numeric_limits::infinity()); + + // 16 single-precision check : ( X != X ) + for (; N - idx >= 16; idx += 16) { + __m256 vec0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)input)); + __m256 vec1 = + _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(input + 8))); + + input += 16; + + // check NaN in vec0 + __m256 res = _mm256_cmp_ps(vec0, vec0, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res); + if (temp) + return false; + + // check infinity in vec0 + vec0 = _mm256_andnot_ps(SIGN_MASK, vec0); + vec0 = _mm256_cmp_ps(vec0, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec0); + if (temp) + return false; + + // check NaN in vec1 + __m256 res1 = _mm256_cmp_ps(vec1, vec1, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res1); + + if (temp) + return false; + + // check infinity in vec1 + vec1 = _mm256_andnot_ps(SIGN_MASK, vec1); + vec1 = _mm256_cmp_ps(vec1, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec1); + + if (temp) + return false; + } + + // 8 single-precision check : ( X != X ) + for (; N - idx >= 8; idx += 8) { + __m256 vec = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)input)); + input += 8; + __m256 res = _mm256_cmp_ps(vec, vec, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res); + + if (temp) + return false; + + // check infinity in vec1 + vec = _mm256_andnot_ps(SIGN_MASK, vec); + vec = _mm256_cmp_ps(vec, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec); + + if (temp) + return false; + } + + // remain check : ( X != X || X == Inf ) + while (idx < N) { + if (*input != *input || *input == std::numeric_limits::infinity()) { + return false; + } + ++input; + ++idx; + } + + return true; +} +#endif + +bool isValid(const size_t N, const float *input) { + assert(N != 0); + assert(input != NULL); + + int temp = 0; + size_t idx = 0; + + const __m256 SIGN_MASK = _mm256_set1_ps(-0.0); + const __m256 INF = _mm256_set1_ps(std::numeric_limits::infinity()); + + // 16 single-precision check : ( X != X ) + for (; N - idx >= 16; idx += 16) { + __m256 vec0 = _mm256_loadu_ps(input); + __m256 vec1 = _mm256_loadu_ps(input + 8); + input += 16; + __m256 res = _mm256_cmp_ps(vec0, vec0, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res); + + if (temp) + return false; + + // check infinity in vec0 + vec0 = _mm256_andnot_ps(SIGN_MASK, vec0); + vec0 = _mm256_cmp_ps(vec0, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec0); + if (temp) + return false; + + __m256 res1 = _mm256_cmp_ps(vec1, vec1, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res1); + + if (temp) + return false; + + // check infinity in vec1 + vec1 = _mm256_andnot_ps(SIGN_MASK, vec1); + vec1 = _mm256_cmp_ps(vec1, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec1); + + if (temp) + return false; + } + + // 8 single-precision check : ( X != X ) + for (; N - idx >= 8; idx += 8) { + __m256 vec = _mm256_loadu_ps(input); + input += 8; + __m256 res = _mm256_cmp_ps(vec, vec, _CMP_NEQ_UQ); + temp = temp | _mm256_movemask_ps(res); + + if (temp) + return false; + + // check infinity in vec + vec = _mm256_andnot_ps(SIGN_MASK, vec); + vec = _mm256_cmp_ps(vec, INF, _CMP_EQ_OQ); + + temp = temp | _mm256_movemask_ps(vec); + + if (temp) + return false; + } + + // remain check : ( X != X ) + while (idx < N) { + if (*input != *input || *input == std::numeric_limits::infinity()) { + return false; + } + ++input; + ++idx; + } + + return true; +} + } // namespace nntrainer::avx diff --git a/nntrainer/tensor/blas_avx.h b/nntrainer/tensor/blas_avx.h index ab1270a208..5eabcbdb2c 100644 --- a/nntrainer/tensor/blas_avx.h +++ b/nntrainer/tensor/blas_avx.h @@ -20,6 +20,7 @@ namespace nntrainer::avx { +#ifdef ENABLE_FP16 /** * @brief Converts half-precision floating point values to single-precision * floating point values. @@ -40,6 +41,25 @@ void vcvt_f16_f32(size_t N, const void *input, float *output); */ void vcvt_f32_f16(size_t N, const float *input, void *output); +/** + * @brief check if the X has NaN value + * @note it compare (x!=x || x == inf) + * @param[in] N length of the vector + * @param[in] X half-precision * for Vector X + * @param[out] false if it has NaN or inf + */ +bool isValid(const size_t N, const _Float16 *X); +#endif + +/** + * @brief check if the X has NaN value + * @note it compare (x!=x || x == inf) + * @param[in] N length of the vector + * @param[in] X float * for Vector X + * @param[out] false if it has NaN or inf + */ +bool isValid(const size_t N, const float *X); + } // namespace nntrainer::avx #endif /* __cplusplus */ diff --git a/nntrainer/tensor/blas_interface.cpp b/nntrainer/tensor/blas_interface.cpp index 9be6fb9911..6219919fd8 100644 --- a/nntrainer/tensor/blas_interface.cpp +++ b/nntrainer/tensor/blas_interface.cpp @@ -1038,6 +1038,16 @@ static void ele_div_fallback(const unsigned int N, const float *X, } } +static bool is_valid_fallback(const size_t N, const float *X) { + for (size_t i = 0; i < N; ++i) { + if (*X != *X || *X == std::numeric_limits::infinity()) + return false; + ++X; + } + + return true; +} + void ele_mul(const unsigned int N, const float *X, const float *Y, float *Z, float alpha, float beta, unsigned int i_stride, unsigned int o_stride) { @@ -1090,4 +1100,30 @@ void ele_div(const unsigned int N, const float *X, const float *Y, float *Z, ele_div_fallback(N, X, Y, Z, alpha, beta, i_stride, o_stride); } +bool is_valid(const size_t N, ml::train::TensorDim::DataType d_type, + const void *X) { + if (d_type == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + const _FP16 *vec = (const _FP16 *)X; +#ifdef USE_NEON + return nntrainer::neon::isValid(N, vec); +#elif defined(USE_AVX) + return nntrainer::avx::isValid(N, vec); +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif +#endif + } else if (d_type == ml::train::TensorDim::DataType::FP32) { + const float *vec = (const float *)X; +#ifdef USE_NEON + return nntrainer::neon::isValid(N, vec); +#elif defined(USE_AVX) + return nntrainer::avx::isValid(N, vec); +#endif + + return is_valid_fallback(N, vec); + } + return false; +} + } // namespace nntrainer diff --git a/nntrainer/tensor/blas_interface.h b/nntrainer/tensor/blas_interface.h index 04a8a23018..2b5ef72922 100644 --- a/nntrainer/tensor/blas_interface.h +++ b/nntrainer/tensor/blas_interface.h @@ -478,6 +478,16 @@ void ele_sub(const unsigned N, const float *X, const float *Y, float *Z, void ele_div(const unsigned N, const float *X, const float *Y, float *Z, float alpha = 1.f, float beta = 0.f, unsigned int i_stride = 1, unsigned int o_stride = 1); + +/** + * @brief check if X array has NaN or inf + * @param[in] N length of the vector + * @param[in] X float/fp16 * for Vector X + * @param[out] bool false if not valide else true + */ +bool is_valid(const size_t N, ml::train::TensorDim::DataType d_type, + const void *X); + } /* namespace nntrainer */ #endif /* __cplusplus */ #endif /* __BLAS_INTERFACE_H__ */ diff --git a/nntrainer/tensor/blas_neon.cpp b/nntrainer/tensor/blas_neon.cpp index 3609b6b8b5..9f6a3652c4 100644 --- a/nntrainer/tensor/blas_neon.cpp +++ b/nntrainer/tensor/blas_neon.cpp @@ -546,6 +546,36 @@ void ele_div(const unsigned N, const float *X, const float *Y, float *Z, } } +bool isValid(const size_t N, const float *X) { + size_t i = 0; + float inf_s = std::numeric_limits::infinity(); + float32x4_t inf = vdupq_n_f32(inf_s); + uint16x8_t zero = vdupq_n_f32(0); + + for (; N - i >= 4; i += 4) { + float32x4_t vec = vld1q_f32(&X[i]); + uint32x4_t vcmp = vceqq_f32(vec, vec); + + vcmp = vceqq_f32(vcmp, zero); + + if (vaddvq_u32(vcmp)) + return false; + + vcmp = vceqq_f32(vec, inf); + + if (vaddvq_u16(vcmp)) + return false; + } + + while (i < N) { + if (X[i] != X[i] || X[i] == std::numeric_limits::infinity()) + return false; + ++i; + } + + return true; +} + #ifdef ENABLE_FP16 void hgemv(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t M, uint32_t N, @@ -1994,5 +2024,40 @@ void inv_sqrt_inplace(const unsigned int N, __fp16 *X) { } } +bool isValid(const size_t N, const __fp16 *input) { + bool temp = 0; + size_t i = 0; + __fp16 inf_s = std::numeric_limits::infinity(); + float16x8_t inf = vdupq_n_f16(inf_s); + uint16x8_t zero = vdupq_n_f16(0); + + for (; N - i >= 8; i += 8) { + float16x8_t vec = vld1q_f16(&input[i]); + + uint16x8_t vcmp = vceqq_f16(vec, vec); + + vcmp = vceqq_f16(vcmp, zero); + + if (vaddvq_u16(vcmp)) { + return false; + } + + vcmp = vceqq_f16(vec, inf); + + if (vaddvq_u16(vcmp)) { + return false; + } + } + + while (i < N) { + if (input[i] != input[i] || + input[i] == std::numeric_limits::infinity()) { + return false; + } + ++i; + } + return true; +} + #endif } // namespace nntrainer::neon diff --git a/nntrainer/tensor/blas_neon.h b/nntrainer/tensor/blas_neon.h index db1b6a5ccc..978d3428f7 100644 --- a/nntrainer/tensor/blas_neon.h +++ b/nntrainer/tensor/blas_neon.h @@ -148,6 +148,15 @@ void ele_sub(const unsigned N, const float *X, const float *Y, float *Z, void ele_div(const unsigned N, const float *X, const float *Y, float *Z, float alpha = 1.f, float beta = 0.f); +/** + * @brief check if the X has NaN value or Inf + * @note it compare (x!=x || x == inf) + * @param[in] N length of the vector + * @param[in] input float * for Vector X + * @param[out] false if it has NaN or Inf + */ +bool isValid(const size_t N, const float *input); + #ifdef ENABLE_FP16 /** * @brief hgemv computation with neon : Y = alpha*A*X + beta*Y @@ -380,6 +389,15 @@ void hgemm_transAB(const __fp16 *A, const __fp16 *B, float *C, uint32_t M, * @param X __fp16 * for Vector X */ void inv_sqrt_inplace(const unsigned int N, __fp16 *X); + +/** + * @brief check if the X is valid: Check NaN or Inf + * @note it compare (x!=x || x == inf) + * @param[in] N length of the vector + * @param[in] X float * for Vector X + * @param[out] false if it has NaN or Inf + */ +bool isValid(const size_t N, const __fp16 *X); #endif } // namespace nntrainer::neon diff --git a/nntrainer/tensor/manager.cpp b/nntrainer/tensor/manager.cpp index 9a0d235ba9..14d710b3c0 100644 --- a/nntrainer/tensor/manager.cpp +++ b/nntrainer/tensor/manager.cpp @@ -414,7 +414,7 @@ std::vector Manager::requestWeights( // var_exec_order.push_back(TensorPool::PERSIST_END_ORDER); } - Tensor *var = nullptr, *grad = nullptr; + Tensor *var = nullptr, *grad = nullptr, *var32 = nullptr; bool is_dependent = !shared_names.empty(); if (is_dependent) { /// shared_name is used and the orignal name is discarded @@ -431,6 +431,17 @@ std::vector Manager::requestWeights( grad = tensor_pool.requestOrExtend(shared_name + Var_Grad::grad_suffix, dim_g, grad_exec_order, grad_ls, Tensor::Initializer::ZEROS); + + if (var->getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim_v); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + std::vector var32_exec_order; + var32_exec_order.push_back(TensorPool::PERSIST_END_ORDER); + + var32 = weight_pool.requestOrExtend(shared_name + ":var32", var32_dim, + var32_exec_order, var_ls, + Tensor::Initializer::ZEROS); + } } } else { /** case requesting fresh weights */ @@ -448,11 +459,21 @@ std::vector Manager::requestWeights( grad = tensor_pool.request(name + Var_Grad::grad_suffix, dim_g, grad_exec_order, grad_ls, Tensor::Initializer::ZEROS, is_wgrad); + if (var->getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim_v); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + std::vector var32_exec_order; + var32_exec_order.push_back(TensorPool::PERSIST_END_ORDER); + var32 = + weight_pool.request(name + ":var32", var32_dim, var32_exec_order, + var_ls, Tensor::Initializer::ZEROS); + } } } weights_v2.emplace_back(std::make_unique( - var, grad, w_reg, w_reg_const, decay, is_dependent, clip_by_global_norm)); + var, grad, var32, w_reg, w_reg_const, decay, is_dependent, + clip_by_global_norm, axis, loss_scale)); } std::transform(weights_v2.begin() + current_size, weights_v2.end(), @@ -668,15 +689,15 @@ bool Manager::isSecondLastAccess(const std::string &name, */ std::vector Manager::requestWeightOptimizerVariables( const std::vector &dims, const std::string &name, - const TensorLifespan &lifespan, bool is_grad_clip, - Tensor::Initializer initializer) { + const std::string &suffix, const TensorLifespan &lifespan, bool is_grad_clip, + bool is_mixed_precision, Tensor::Initializer initializer) { std::vector ret; ret.reserve(dims.size()); std::vector exec; exec.reserve(1); - if (is_grad_clip) { + if (is_grad_clip || is_mixed_precision) { exec.emplace_back(TensorPool::PERSIST_END_ORDER); } else { exec.emplace_back(getMinMaxTensorExecutionOrder(name, true).second); @@ -685,7 +706,7 @@ std::vector Manager::requestWeightOptimizerVariables( /// @note this is assuming weight optimizer variables is treated as weight, if /// not, there is room to optimize below behavior for (unsigned int idx = 0; idx < dims.size(); idx++) - ret.push_back(weight_pool.request(name + ":opt" + std::to_string(idx), + ret.push_back(weight_pool.request(name + suffix + std::to_string(idx), dims[idx], exec, lifespan, initializer)); return ret; diff --git a/nntrainer/tensor/manager.h b/nntrainer/tensor/manager.h index ab1c018153..80ffb9d21d 100644 --- a/nntrainer/tensor/manager.h +++ b/nntrainer/tensor/manager.h @@ -224,7 +224,8 @@ class Manager { */ std::vector requestWeightOptimizerVariables( const std::vector &dims, const std::string &name, - const TensorLifespan &lifespan, bool is_grad_clip, + const std::string &suffix, const TensorLifespan &lifespan, + bool is_grad_clip, bool is_mixed_type, Tensor::Initializer initializer = Tensor::Initializer::NONE); /** diff --git a/nntrainer/tensor/meson.build b/nntrainer/tensor/meson.build index 0884dbd3b4..b14fa0ee85 100644 --- a/nntrainer/tensor/meson.build +++ b/nntrainer/tensor/meson.build @@ -44,6 +44,12 @@ cl_headers = [ arch = host_machine.cpu_family() + +if get_option('enable-avx') + tensor_sources += 'blas_avx.cpp' + tensor_headers += 'blas_avx.h' +endif + if get_option('enable-fp16') if arch == 'arm' error ('FP16/ARM code (blas_neon.cpp) uses armv8.2 instructions. armv7 is not supported.') @@ -55,9 +61,6 @@ if get_option('enable-fp16') nntrainer_inc += include_directories('hgemm') nntrainer_inc_abs += meson.current_source_dir() / 'hgemm' endif - elif get_option('enable-avx') - tensor_sources += 'blas_avx.cpp' - tensor_headers += 'blas_avx.h' endif endif diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 4f1e8e0721..827ba7e979 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -3065,6 +3065,18 @@ Tensor Tensor::clone() const { return t; } +Tensor Tensor::clone(ml::train::TensorDim::DataType type) const { + if (getDataType() == type) + return clone(); + + TensorDim dim = getDim(); + dim.setDataType(type); + Tensor t(dim, true); + t.copyData(*this); + t.name = name; + return t; +} + void Tensor::reshape(const TensorDim &d) { NNTR_THROW_IF(!contiguous, std::invalid_argument) @@ -3808,6 +3820,18 @@ void Tensor::dequantize(Tensor &output, unsigned int axis) const { return; } +bool Tensor::isValid() const { + if (getDataType() == Tdatatype::FP16) { +#ifdef ENABLE_FP16 + return is_valid(dim.getDataLen(), Tdatatype::FP16, getData<_FP16>()); +#else + throw std::invalid_argument("enble-fp16 is not set"); +#endif + } else { + return is_valid(dim.getDataLen(), Tdatatype::FP32, getData()); + } +} + // namespace nntrainer } /* namespace nntrainer */ diff --git a/nntrainer/tensor/tensor.h b/nntrainer/tensor/tensor.h index 211334da40..ad3781526f 100644 --- a/nntrainer/tensor/tensor.h +++ b/nntrainer/tensor/tensor.h @@ -1680,6 +1680,13 @@ class Tensor { */ Tensor clone() const; + /** + * @brief Convient wrapper for inplace copy of @a this. + * @param[in] type output tensor data type + * @retval Copied version of this + */ + Tensor clone(ml::train::TensorDim::DataType type) const; + /** * @brief Save the Tensor into file * @param[in] file output file stream @@ -2031,6 +2038,12 @@ class Tensor { static constexpr float epsilon = 1e-5; + /** + * @brief check if there is NaN or Inf element + * @param[out] bool false if there is NaN or Inf else false + */ + bool isValid() const; + private: /**< handle the data as a std::shared_ptr type */ TensorDim dim; diff --git a/nntrainer/tensor/weight.cpp b/nntrainer/tensor/weight.cpp index f98c8c8356..0e9879540a 100644 --- a/nntrainer/tensor/weight.cpp +++ b/nntrainer/tensor/weight.cpp @@ -34,6 +34,28 @@ Weight::Weight(const TensorDim &dim, const Tensor::Initializer init, throw std::invalid_argument("Weight initializer cannot be none"); if (regularizer == WeightRegularizer::UNKNOWN) throw std::invalid_argument("Weight regularizer unknown"); + + std::string var32_suffix = ":fp32"; + std::string var32_name = name + var32_suffix; + + /** + * @note We assume if the Weight Data Type is not FP32, then FP32 Weight is + * necessary to maintain the accuracy. + * We could think it can be other data type and if there is the case to + * support other data type, then the code below needs to be udpated. + * + * Also, the loss_scale is not used in Weight but leave as it is for later + * usage. + */ + + if (train && dim.getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + + var32 = std::make_shared(var32_dim, alloc_now_, init, var32_name); + } else { + var32 = std::make_shared(var32_name); + } } Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g, @@ -52,6 +74,94 @@ Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g, throw std::invalid_argument("Weight initializer cannot be none"); if (regularizer == WeightRegularizer::UNKNOWN) throw std::invalid_argument("Weight regularizer unknown"); + + std::string var32_suffix = ":fp32"; + std::string var32_name = name + var32_suffix; + + if (train && dim_v.getDataType() != ml::train::TensorDim::DataType::FP32) { + TensorDim var32_dim(dim_v); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + std::string var32_suffix = ":fp32"; + std::string var32_name = name + var32_suffix; + + var32 = std::make_shared(var32_dim, alloc_now_, init, var32_name); + } else { + var32 = std::make_shared(var32_name); + } +} + +Weight::Weight(const Tensor &v, const Tensor &g, const Tensor &v32, + const std::string &n, bool is_dependent, + unsigned int output_axis_) : + Var_Grad(v, g, n, is_dependent), + regularizer(WeightRegularizer::NONE), + regularizer_constant(1.0f), + decay(0.0f), + clip_by_global_norm(0.0f), + output_axis(output_axis_), + loss_scale(1.0), + var32(std::make_shared(n + ":fp32")) { + + if (!g.empty() && isMixedPrecision()) { + TensorDim var32_dim(v.getDim()); + var32_dim.setDataType(ml::train::TensorDim::DataType::FP32); + if (!v32.empty()) + var32 = std::make_shared( + v32.getSharedDataTensor(var32_dim, 0, false, n + ":fp32")); + } +} + +Weight::Weight(Tensor *v, Tensor *g, Tensor *v32, const WeightRegularizer reg, + const float reg_const, const float decay, bool is_dependent, + const float max_norm, unsigned int output_axis_, + float loss_scale_) : + Var_Grad(v, g, is_dependent), + regularizer(reg), + regularizer_constant(reg_const), + decay(decay), + clip_by_global_norm(max_norm), + output_axis(output_axis_), + loss_scale(loss_scale_), + var32(std::shared_ptr(v32, [](void *) {})) { + if (!v32) + var32 = std::make_shared(); +} + +void Weight::applyGradient(double lr, Tensor &updated_grad) { + if (isMixedPrecision() && + updated_grad.getDataType() == ml::train::TensorDim::DataType::FP32) { + updated_grad.divide(loss_scale); + var32->add_i(updated_grad, -lr); + quantizeWeight(); + return; + } + + return applyGradient(lr); +} + +void Weight::quantizeWeight() { + if (!isMixedPrecision()) + return; + + Tensor &var = getVariableRef(); + ml::train::TensorDim::DataType type = var.getDataType(); + switch (type) { + case ml::train::TensorDim::DataType::QINT4: + // NYI + break; + case ml::train::TensorDim::DataType::QINT8: + // NYI + break; + case ml::train::TensorDim::DataType::FP16: + getVariableRef().copyData(getVariableFP32Ref()); + break; + case ml::train::TensorDim::DataType::FP32: + break; + default: + break; + } + + return; } } // namespace nntrainer diff --git a/nntrainer/tensor/weight.h b/nntrainer/tensor/weight.h index 552f6d5739..8ac3aa0190 100644 --- a/nntrainer/tensor/weight.h +++ b/nntrainer/tensor/weight.h @@ -46,7 +46,7 @@ class Weight : public Var_Grad { decay(0.0f), clip_by_global_norm(0.0f), output_axis(3), - loss_scale(0.0) {} + loss_scale(1.0) {} /** * @brief Construct a new Weight object @@ -66,7 +66,7 @@ class Weight : public Var_Grad { const float reg_const = 1.0f, const float decay = 0.0f, const float clip_by_global_norm = 0.0f, bool ng = true, bool alloc_now = false, std::string name = "", unsigned int axis = 3, - float loss_scale_ = 0.0); + float loss_scale_ = 1.0); /** * @brief Construct a new Weight object @@ -87,7 +87,7 @@ class Weight : public Var_Grad { const float reg_const = 1.0f, const float decay = 0.0f, const float clip_by_global_norm = 0.0f, bool ng = true, bool alloc_now = false, std::string name = "", unsigned int axis = 3, - float loss_scale_ = 0.0); + float loss_scale_ = 1.0); /** * @brief Construct a new Weight object @@ -114,6 +114,7 @@ class Weight : public Var_Grad { * * @param v Already created variable object * @param g Already created gradient object + * @param v32 Already created gradient object * @param n Name for this Weight * * @note This is primarily used to created wrapper of variable extracted from @@ -123,35 +124,24 @@ class Weight : public Var_Grad { * uses only, as Weight does not own the tensors v and g, and can go invalid * if the owner of these tensors free the tensors. */ - explicit Weight(const Tensor &v, const Tensor &g, const std::string &n = "", - bool is_dependent = false, unsigned int output_axis_ = 3) : - Var_Grad(v, g, n, is_dependent), - regularizer(WeightRegularizer::NONE), - regularizer_constant(1.0f), - decay(0.0f), - clip_by_global_norm(0.0f), - output_axis(output_axis_), - loss_scale(0.0) {} + explicit Weight(const Tensor &v, const Tensor &g, const Tensor &v32, + const std::string &n = "", bool is_dependent = false, + unsigned int output_axis_ = 3); /** * @brief Construct a new Weight object * * @param v ptr to already created variable tensor * @param g ptr to already created gradient tensor + * @param v32 ptr to already created variable32 tensor * @param reg Regularizer for the weight * @param reg_const Constant multiplier for regularizer */ - explicit Weight(Tensor *v, Tensor *g, const WeightRegularizer reg, - const float reg_const, const float decay, - bool is_dependent = false, const float max_norm = 0.0f, - unsigned int output_axis_ = 3, float loss_scale_ = 0.0f) : - Var_Grad(v, g, is_dependent), - regularizer(reg), - regularizer_constant(reg_const), - decay(decay), - clip_by_global_norm(max_norm), - output_axis(output_axis_), - loss_scale(loss_scale_) {} + explicit Weight(Tensor *v, Tensor *g, Tensor *v32, + const WeightRegularizer reg, const float reg_const, + const float decay, bool is_dependent = false, + const float max_norm = 0.0f, unsigned int output_axis_ = 3, + float loss_scale_ = 1.0f); /** * @brief Swap for weight @@ -170,6 +160,7 @@ class Weight : public Var_Grad { swap(lhs.output_axis, rhs.output_axis); swap(lhs.opt_vars, rhs.opt_vars); swap(lhs.loss_scale, rhs.loss_scale); + swap(lhs.var32, rhs.var32); } /** @@ -213,6 +204,8 @@ class Weight : public Var_Grad { w.var = std::make_shared(this->var->clone()); if (!this->grad->empty()) w.grad = std::make_shared(this->grad->clone()); + if (!this->var32->empty()) + w.var32 = std::make_shared(this->var32->clone()); return w; } @@ -294,6 +287,13 @@ class Weight : public Var_Grad { */ void applyGradient(double lr) { var->add_i(*grad.get(), -lr); } + /** + * @brief Apply the gradient to the weight with updated gradient + * @param[in] updated_grad gradient tensor which is updated in optimizer + * it might be different data type with gradient in weight. .eg : FP32 + */ + void applyGradient(double lr, Tensor &updated_grad); + /** * @brief Check if the gradient is supposed to be clipped by global norm with * the given max_norm value @@ -316,6 +316,16 @@ class Weight : public Var_Grad { return clip_by_global_norm > epsilon; } + /** + * @brief Check if the variable type is not full precision + * + * @return true if it is not full precsion + * @return false otherwise + */ + bool isMixedPrecision() const { + return var->getDataType() != ml::train::TensorDim::DataType::FP32; + } + /** * @brief clip the gradient value based on the given global norm * @@ -326,6 +336,26 @@ class Weight : public Var_Grad { grad->multiply_i(clip_by_global_norm / (global_norm + epsilon)); } + /** + * @brief Get the variable FP32 tensor (by reference) + * + * @return Tensor Variable FP32 tensor + */ + Tensor &getVariableFP32Ref() { return *var32.get(); } + + /** + * @brief Quantize var32 to var + * + */ + void quantizeWeight(); + + /** + * @brief set loss scale + * param[in] scale + * + */ + void setLossScale(float scale) { loss_scale = scale; }; + private: static constexpr float epsilon = 1e-6; /**< epsilon for zero comparison */ static constexpr float epsilon_decay = @@ -337,7 +367,8 @@ class Weight : public Var_Grad { float clip_by_global_norm; /**< constant factor to clip gradient by L2 norm */ unsigned int output_axis; float loss_scale; - std::vector opt_vars; /**< optimizer variables */ + std::vector + opt_vars; /**< optimizer variables : We assume it is always full-precsion*/ std::shared_ptr var32; /** diff --git a/packaging/nntrainer.spec b/packaging/nntrainer.spec index 36ba371d22..2f1dc57f68 100644 --- a/packaging/nntrainer.spec +++ b/packaging/nntrainer.spec @@ -65,6 +65,13 @@ %define neon_support -Denable-neon=false %endif # arch aarch64 +%ifarch x86_64 +%define enable_avx 1 +%define avx_support -Denable-avx=true +%else +%define avx_support -Denable-avx=false +%endif # arch aarch64 + Name: nntrainer Summary: Software framework for training neural networks @@ -410,7 +417,7 @@ meson --buildtype=plain --prefix=%{_prefix} --sysconfdir=%{_sysconfdir} \ %{enable_reduce_tolerance} %{configure_subplugin_install_path} %{enable_debug} \ -Dml-api-support=enabled -Denable-nnstreamer-tensor-filter=enabled \ -Denable-nnstreamer-tensor-trainer=enabled -Denable-capi=enabled \ - %{fp16_support} %{neon_support} build + %{fp16_support} %{neon_support} %{avx_support} build ninja -C build %{?_smp_mflags} @@ -563,6 +570,10 @@ cp -r result %{buildroot}%{_datadir}/nntrainer/unittest/ %{_includedir}/nntrainer/util_simd_neon.h %endif +%if 0%{?enable_avx} +%{_includedir}/nntrainer/blas_avx.h +%endif + %files devel-static %{_libdir}/libnntrainer*.a %exclude %{_libdir}/libcapi*.a diff --git a/test/unittest/layers/layers_golden_tests.cpp b/test/unittest/layers/layers_golden_tests.cpp index 64400e6ecd..c71d653c05 100644 --- a/test/unittest/layers/layers_golden_tests.cpp +++ b/test/unittest/layers/layers_golden_tests.cpp @@ -156,7 +156,7 @@ static RunLayerContext prepareRunContext(const TensorPacks &packs) { }; auto rc = - RunLayerContext("golden", true, 0.0f, false, create_view(weights), + RunLayerContext("golden", true, 0.0f, false, 1.0, create_view(weights), create_view(ins), create_view(outs), create_view(tensors)); auto num_outputs = rc.getNumOutputs(); diff --git a/test/unittest/layers/unittest_layer_node.cpp b/test/unittest/layers/unittest_layer_node.cpp index 3b41f02f30..37287f7ce5 100644 --- a/test/unittest/layers/unittest_layer_node.cpp +++ b/test/unittest/layers/unittest_layer_node.cpp @@ -131,7 +131,7 @@ TEST(nntrainer_LayerNode, finalize_05_n) { nntrainer::createLayerNode(nntrainer::IdentityLayer::type)); EXPECT_NO_THROW(lnode->setProperty({"input_shape=1:1:1", "name=abc"})); EXPECT_NO_THROW(lnode->finalize()); - EXPECT_NO_THROW(lnode->configureRunContext({}, {&input}, {}, {})); + EXPECT_NO_THROW(lnode->configureRunContext({}, {&input}, {}, {}, 1.0)); EXPECT_THROW(lnode->finalize(), std::runtime_error); } @@ -298,7 +298,7 @@ TEST(nntrainer_LayerNode, setWeights_02_n) { EXPECT_NO_THROW(lnode = nntrainer::createLayerNode(nntrainer::IdentityLayer::type)); EXPECT_NO_THROW(lnode->setProperty({"input_shape=1:1:1", "name=abc"})); - EXPECT_NO_THROW(lnode->configureRunContext({&weight}, {&input}, {}, {})); + EXPECT_NO_THROW(lnode->configureRunContext({&weight}, {&input}, {}, {}, 1.0)); EXPECT_THROW(lnode->setWeights(new_weights), std::runtime_error); } diff --git a/test/unittest/models/meson.build b/test/unittest/models/meson.build index 7166fc41ff..3f17369f94 100644 --- a/test/unittest/models/meson.build +++ b/test/unittest/models/meson.build @@ -1,4 +1,5 @@ test_name = 'unittest_models' +mixed_test_name = 'unittest_mixed_models' test_target = [] @@ -11,6 +12,30 @@ models_targets = [ # disable temperally ] +mixed_test_targets = [ + 'models_test_utils.cpp', + 'models_golden_test.cpp', + 'unittest_models_mixed_precision.cpp', +] + +if get_option('enable-fp16') + mixed_exe = executable( + mixed_test_name, + mixed_test_targets, + include_directories: include_directories('.'), + dependencies: [ + nntrainer_test_main_deps, nntrainer_ccapi_dep + ], + install: get_option('enable-test'), + install_dir: application_install_dir + ) + + test(mixed_test_name, mixed_exe, + args: '--gtest_output=xml:@0@/@1@.xml'.format(meson.build_root(), mixed_test_name), + timeout: test_timeout + ) +endif + test_target += models_targets exe = executable( test_name, diff --git a/test/unittest/models/unittest_models_mixed_precision.cpp b/test/unittest/models/unittest_models_mixed_precision.cpp new file mode 100644 index 0000000000..becf11ff44 --- /dev/null +++ b/test/unittest/models/unittest_models_mixed_precision.cpp @@ -0,0 +1,54 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Jijoong Moon + * + * @file unittest_models_mixed_precision.cpp + * @date 3 May 2024 + * @brief unittest models to cover mixed precision + * @see https://github.com/nnstreamer/nntrainer + * @author Jijoong Moon + * @bug No known bugs except for NYI items + */ + +#include + +#include + +#include +#include +#include + +#include + +using namespace nntrainer; + +static std::unique_ptr fc_mixed_training() { + std::unique_ptr nn(new NeuralNetwork()); + nn->setProperty( + {"batch_size=2", "model_tensor_type=FP16-FP16", "loss_scale=128"}); + + auto graph = makeGraph({ + {"input", {"name=in", "input_shape=1:1:3"}}, + {"Fully_connected", {"name=fc", "input_layers=in", "unit=10"}}, + {"mse", {"name=loss", "input_layers=fc"}}, + }); + for (auto &node : graph) { + nn->addLayer(node); + } + + nn->setOptimizer(ml::train::createOptimizer("adam", {"learning_rate = 0.1"})); + + return nn; +} + +GTEST_PARAMETER_TEST( + MixedPrecision, nntrainerModelTest, + ::testing::ValuesIn({ + mkModelTc_V2(fc_mixed_training, "fc_mixed_training", + ModelTestOption::NO_THROW_RUN_V2), + /** ModelTestOption::ALL_V2), + * Disabled for now to check + */ + }), + [](const testing::TestParamInfo &info) + -> const auto & { return std::get<1>(info.param); }); diff --git a/test/unittest/unittest_nntrainer_tensor.cpp b/test/unittest/unittest_nntrainer_tensor.cpp index 94aa01836d..d5b6a028f9 100644 --- a/test/unittest/unittest_nntrainer_tensor.cpp +++ b/test/unittest/unittest_nntrainer_tensor.cpp @@ -4704,6 +4704,30 @@ TEST(nntrainer_Tensor, inv_sqrt_i_uncontiguous_p) { } } +/** + * @brief fp16 tensor has NaN + */ +TEST(nntrainer_Tensor, is_valid_01) { + size_t batch = 1; + size_t channel = 3; + size_t height = 4; + size_t width = 5; + + nntrainer::Tensor input( + {batch, + channel, + height, + width, + {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32}}, + true, nntrainer::Tensor::Initializer::ZEROS); + + EXPECT_EQ(input.isValid(), true); + + input.setValue(0, 0, 0, 0, std::nan("1")); + + EXPECT_EQ(input.isValid(), false); +} + int main(int argc, char **argv) { int result = -1; diff --git a/test/unittest/unittest_nntrainer_tensor_fp16.cpp b/test/unittest/unittest_nntrainer_tensor_fp16.cpp index 2b0d9c040d..58455757c5 100644 --- a/test/unittest/unittest_nntrainer_tensor_fp16.cpp +++ b/test/unittest/unittest_nntrainer_tensor_fp16.cpp @@ -6196,6 +6196,34 @@ TEST(nntrainer_Tensor, dequantize_06_p) { EXPECT_EQ(output, answer3); } +/** + * @brief fp16 tensor has NaN + */ +TEST(nntrainer_Tensor, is_valid_01) { + size_t batch = 1; + size_t channel = 3; + size_t height = 4; + size_t width = 5; + + nntrainer::Tensor input( + {batch, + channel, + height, + width, + {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}}, + true, nntrainer::Tensor::Initializer::ZEROS); + + EXPECT_EQ(input.isValid(), true); + + input.setValue(0, 0, 0, 0, std::nan("1")); + + EXPECT_EQ(input.isValid(), false); + + input.setValue(0, 0, 0, 0, std::numeric_limits::infinity()); + + EXPECT_EQ(input.isValid(), false); +} + GTEST_API_ int main(int argc, char **argv) { int result = -1; diff --git a/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp b/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp index e02eac1786..2b3952bb10 100644 --- a/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp +++ b/test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp @@ -994,6 +994,38 @@ TEST(nntrainer_Tensor, inv_sqrt_i_p) { EXPECT_EQ(flag, true); } +/** + * @brief fp16 tensor has NaN + */ +TEST(nntrainer_Tensor, is_valid_01) { + size_t batch = 1; + size_t channel = 3; + size_t height = 4; + size_t width = 5; + + nntrainer::Tensor input( + {batch, + channel, + height, + width, + {nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16}}, + true, nntrainer::Tensor::Initializer::ZEROS); + + EXPECT_EQ(input.isValid(), true); + + input.setValue(0, 0, 0, 0, std::nan("1")); + + EXPECT_EQ(input.isValid(), false); + + input.setValue(0, 0, 0, 0, std::numeric_limits::infinity()); + + EXPECT_EQ(input.isValid(), false); + + input.setValue(0, 0, 0, 0, 1); + + EXPECT_EQ(input.isValid(), true); +} + GTEST_API_ int main(int argc, char **argv) { int result = -1;