Skip to content

Commit

Permalink
[Mixed] Reset for invalid values
Browse files Browse the repository at this point in the history
It may get an invalid value for both internal tensor or gradient.
This patch checks the validation of the data, and fix for it.

Also,
sscal api is replace with scopy for setZero, because it produces
the invalid value if invalid input value is used.

Signed-off-by: Jiho Chu <[email protected]>
  • Loading branch information
jihochu committed Mar 22, 2024
1 parent 0e5db84 commit 8fd0e06
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
16 changes: 15 additions & 1 deletion nntrainer/graph/network_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,18 @@ void NetworkGraph::backwarding(
loss_scale = scale;
};

auto check_weights = [](std::vector<Weight *> &weights) {
bool valid = true;
for (auto &w : weights) {
auto grad = w->getGradient();
if (grad.checkDataValidation(false) == false) {
grad.setZero();
valid = false;
}
}
return valid;
};

// check first layer's derivative is valid
// loss scale is adjusted between 1.0f ~ 256.0f
// @todo provide max scale property
Expand All @@ -465,13 +477,15 @@ void NetworkGraph::backwarding(
ml_logd(
"Derivative validation failed. Skip applying gradient. loss_scale(%f)",
scale);
check_weights(clip_weights);
update_loss_scale(scale);
return;
} else {
for (unsigned int idx = 0; idx < clip_weights.size(); idx++) {
auto const &w = clip_weights[idx];
w->applyScaler(loss_scale);
if (w->getGradient().checkDataValidation(false) == false) {

if (!check_weights(clip_weights)) {
float scale = loss_scale > 1.5f ? loss_scale - 0.5f : 1.0f;
ml_loge("gradient validation failed. skip update. loss_scale(%f)",
scale);
Expand Down
5 changes: 5 additions & 0 deletions nntrainer/layers/bn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,11 @@ void BatchNormalizationLayer::forwarding(RunLayerContext &context,
Tensor &cvar = context.getTensor(wt_idx[BNParams::cvar]);

if (training) {
t_reduced.setZero();
deviation.setZero();
invstd.setZero();
cvar.setZero();

input_.average(axes_to_reduce, t_reduced);
input_.subtract(t_reduced, deviation);

Expand Down
9 changes: 6 additions & 3 deletions nntrainer/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3379,10 +3379,13 @@ void Tensor::setZero() {
apply_i<float>([](float val) -> float { return 0; });
} else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
if (contiguous)
sscal(size(), 0, getData<_FP16>(), 1);
else
if (contiguous) {
_FP16 zero = (_FP16)0.0f;
scopy(size(), &zero, 0, getData<_FP16>(), 1,
ml::train::TensorDim::DataType::FP16);
} else {
apply_i<_FP16>([](_FP16 val) -> _FP16 { return 0; });
}
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
Expand Down

0 comments on commit 8fd0e06

Please sign in to comment.