diff --git a/nntrainer/graph/network_graph.cpp b/nntrainer/graph/network_graph.cpp index 0f40eaf330..0bdc9ce3a5 100644 --- a/nntrainer/graph/network_graph.cpp +++ b/nntrainer/graph/network_graph.cpp @@ -455,6 +455,18 @@ void NetworkGraph::backwarding( loss_scale = scale; }; + auto check_weights = [](std::vector &weights) { + bool valid = true; + for (auto &w : weights) { + auto grad = w->getGradient(); + if (grad.checkDataValidation(false) == false) { + grad.setZero(); + valid = false; + } + } + return valid; + }; + // check first layer's derivative is valid // loss scale is adjusted between 1.0f ~ 256.0f // @todo provide max scale property @@ -465,13 +477,15 @@ void NetworkGraph::backwarding( ml_logd( "Derivative validation failed. Skip applying gradient. loss_scale(%f)", scale); + check_weights(clip_weights); update_loss_scale(scale); return; } else { for (unsigned int idx = 0; idx < clip_weights.size(); idx++) { auto const &w = clip_weights[idx]; w->applyScaler(loss_scale); - if (w->getGradient().checkDataValidation(false) == false) { + + if (!check_weights(clip_weights)) { float scale = loss_scale > 1.5f ? loss_scale - 0.5f : 1.0f; ml_loge("gradient validation failed. skip update. loss_scale(%f)", scale); diff --git a/nntrainer/layers/bn_layer.cpp b/nntrainer/layers/bn_layer.cpp index e978b1ef59..3ca7628a3a 100644 --- a/nntrainer/layers/bn_layer.cpp +++ b/nntrainer/layers/bn_layer.cpp @@ -182,6 +182,11 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context, Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]); if (training) { + t_reduced.setZero(); + deviation.setZero(); + invstd.setZero(); + cvar.setZero(); + input_.average(axes_to_reduce, t_reduced); input_.subtract(t_reduced, deviation); diff --git a/nntrainer/tensor/tensor.cpp b/nntrainer/tensor/tensor.cpp index 9065192242..bc93212880 100644 --- a/nntrainer/tensor/tensor.cpp +++ b/nntrainer/tensor/tensor.cpp @@ -3379,10 +3379,13 @@ void Tensor::setZero() { apply_i([](float val) -> float { return 0; }); } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) { #ifdef ENABLE_FP16 - if (contiguous) - sscal(size(), 0, getData<_FP16>(), 1); - else + if (contiguous) { + _FP16 zero = (_FP16)0.0f; + scopy(size(), &zero, 0, getData<_FP16>(), 1, + ml::train::TensorDim::DataType::FP16); + } else { apply_i<_FP16>([](_FP16 val) -> _FP16 { return 0; }); + } #else throw std::invalid_argument("Error: enable-fp16 is not enabled"); #endif