Skip to content

Commit

Permalink
[ Mixed Precision ] Enable Mixed Precision
Browse files Browse the repository at this point in the history
This PR enables the Mixed Precision Training. For now only FP16-FP32
is considered. Additional Test cases will be added.

. add getSortedLayerIdx to set the graph order for fowarding.
. change clip_weights to lazy_apply_weights to use both cases.
. add fowarding_op to run forwarding from that layer which has a
gradient with nan.
. add while loop for re-run backwarding after reset the loss scale.
. add setLossScale in RunLayerContext
. add check the gradient if mixed precsion enable.

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <[email protected]>
  • Loading branch information
jijoongmoon committed May 13, 2024
1 parent adc2f2a commit 929eab9
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 44 deletions.
9 changes: 9 additions & 0 deletions nntrainer/graph/graph_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ GraphCore::getSortedNode(unsigned int ith) const {
return Sorted.at(ith);
}

const unsigned int GraphCore::getSortedNodeIdx(const std::string &name) const {
return sorted_node_map.at(name);
}

void GraphCore::makeAdjacencyList(
std::vector<std::list<std::shared_ptr<GraphNode>>> &adj) {
/** initialize the adj list */
Expand Down Expand Up @@ -93,6 +97,11 @@ void GraphCore::topologicalSort() {

if (Sorted.size() != node_list.size())
throw std::runtime_error("Internal error in topologicalSort");
unsigned int idx = 0;
for (auto n : Sorted) {
sorted_node_map[n->getName()] = idx;
idx++;
}
}

const std::shared_ptr<GraphNode> &
Expand Down
8 changes: 8 additions & 0 deletions nntrainer/graph/graph_core.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ class GraphCore {
*/
const std::shared_ptr<GraphNode> &getSortedNode(unsigned int ith) const;

/**
* @brief getter of Sorted GraphNode index with name
* @param[in] layer name
* @ret index
*/
const unsigned int getSortedNodeIdx(const std::string &name) const;

/**
* @brief getter of GraphNode with node name
* @param[in] node name
Expand Down Expand Up @@ -252,6 +259,7 @@ class GraphCore {
std::vector<std::shared_ptr<GraphNode>>
node_list; /**< Unordered Node List */
std::unordered_map<std::string, int> node_map; /**< Unordered Node map */
std::unordered_map<std::string, int> sorted_node_map; /**< Unordered Node map */
std::vector<std::shared_ptr<GraphNode>> Sorted; /**< Ordered Node List */
bool sorted; /** if the node_list is sorted */

Expand Down
113 changes: 86 additions & 27 deletions nntrainer/graph/network_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ void NetworkGraph::applyGradients(
continue;
}

if (rc.isGradientClipByGlobalNorm(i)) {
if (rc.isGradientClipByGlobalNorm(i) || rc.isMixedPrecision(i)) {
/**
* @note the weights whose gradient are to be clipped by global norm will
* be clipped at once at the end of iteration and applied then.
Expand Down Expand Up @@ -393,56 +393,100 @@ sharedConstTensors NetworkGraph::incremental_forwarding(
return out;
}

void NetworkGraph::backwarding(
bool NetworkGraph::backwarding(
int iteration,
std::function<void(std::shared_ptr<LayerNode>, int)> &backwarding_op,
std::function<void(Weight &, int)> &apply_grad_clip_op,
std::function<bool(void *userdata)> stop_cb, void *userdata) const {
std::function<void(std::shared_ptr<LayerNode>, bool)> &forwarding_op,
std::function<bool(std::shared_ptr<LayerNode>, int)> &backwarding_op,
std::function<void(Weight &, int)> &lazy_apply_grad_op,
std::function<bool(void *userdata)> stop_cb, void *userdata) {
/**
* last layer backwarding is run out of this loop
*/
auto iter_begin = getBackwardingBeginIter();
auto iter_end = getBackwardingEndIter();
bool has_nan = false;

/// there is no layer to train, so backwarding is essentially noop
if (iter_begin == iter_end) {
return;
return true;
}

auto const &lptr_begin = (*iter_begin);
// graph_const_reverse_iterator
auto iter_ = iter_begin;

if (lptr_begin->requireLabel() == false)
throw std::runtime_error(
"Error: last layer does not accept label, we can't train");

for (auto iter = iter_begin; iter != iter_end && !stop_cb(userdata); iter++) {
auto &ln = *iter;
for (iter_ = iter_begin; iter_ != iter_end && !stop_cb(userdata); iter_++) {
auto &ln = *iter_;
PROFILE_TIME_START(profile_keys.at(ln->getType()));
backwarding_op(ln, iteration);
has_nan = backwarding_op(ln, iteration);
PROFILE_TIME_END(profile_keys.at(ln->getType()));

if (has_nan) {
std::cout << "Gradient has NaN" << std::endl;
break;
}
}

/** perform clipping of the gradients by global norm if any */
if (clip_weights.empty())
return;
if (has_nan) {
/** if has NaN
* 1. reset the loss scale.
* 2. run forwarding from cur_iter to cend() && !stop_cb(userdata);
* 3. return false --> run backwarding again;
*/
float scale = (*iter_)->getRunContext().getLossScale();
float s = scale > 1.5f ? scale - 0.5f : 1.0f;

resetLossScale(s);

/** calculate the global norm */
Tensor global_norm_t(
TensorDim({1u, 1u, 1u, (unsigned int)clip_weights.size()}));
float *global_norm_data = global_norm_t.getData();
for (unsigned int idx = 0; idx < clip_weights.size(); idx++) {
auto const &w = clip_weights[idx];
global_norm_data[idx] = w->getGradientNorm();
auto f_iter = cbegin() + graph.getSortedNodeIdx((*iter_)->getName());

for (auto iter = f_iter; iter != cend() && !stop_cb(userdata); iter++) {
auto &ln = *iter;
PROFILE_TIME_START(profile_keys.at(ln->getType()));
forwarding_op(*iter, true);
PROFILE_TIME_END(profile_keys.at(ln->getType()));
}

return false;
}
float global_norm = global_norm_t.l2norm();
/** apply the gradient with the above global norm */
for (auto w : clip_weights) {
w->clipGradientByGlobalNorm(global_norm);

/** perform clipping of the gradients by global norm if any */
if (lazy_weights.empty())
return true;

if (is_clip_grad) {
/** calculate the global norm */
Tensor global_norm_t(
TensorDim({1u, 1u, 1u, (unsigned int)lazy_weights.size()}));
float *global_norm_data = global_norm_t.getData();
for (unsigned int idx = 0; idx < lazy_weights.size(); idx++) {
auto const &w = lazy_weights[idx];
global_norm_data[idx] = w->getGradientNorm();
}
float global_norm = global_norm_t.l2norm();
/** apply the gradient with the above global norm */
for (auto w : lazy_weights) {
w->clipGradientByGlobalNorm(global_norm);
}
}
/** apply the gradient with the above global norm */
for (auto w : clip_weights) {
apply_grad_clip_op(*w, iteration);
for (auto w : lazy_weights) {
lazy_apply_grad_op(*w, iteration);
}
nan_count++;

if (nan_count > 10) {
float scale = (*iter_)->getRunContext().getLossScale();
float s = scale + 2.0f;
resetLossScale(s);
nan_count = 0;
}

return true;
}

LayerNode *NetworkGraph::computeBackwardEnd() {
Expand Down Expand Up @@ -1290,11 +1334,19 @@ int NetworkGraph::initialize(ExecutionMode mode,

/** select weights which would require clipping of the gradients by global
* norm if any */
clip_weights = tensor_manager->getWeights([](const Weight *w) {
lazy_weights = tensor_manager->getWeights([](const Weight *w) {
return w->hasGradient() && w->isGradientLastAccess() &&
w->isGradientClipByGlobalNorm();
(w->isGradientClipByGlobalNorm() || w->isMixedPrecision());
});

is_clip_grad = false;
for (auto w : lazy_weights) {
if (w->isGradientClipByGlobalNorm()) {
is_clip_grad = true;
break;
}
}

return ML_ERROR_NONE;
}

Expand Down Expand Up @@ -1566,4 +1618,11 @@ void NetworkGraph::requestOptimizerVariable(
}
}

void NetworkGraph::resetLossScale(float scale) {
for (auto iter = cbegin(); iter != cend(); iter++) {
auto &ln = *iter;
ln->getRunContext().setLossScale(scale);
}
}

} /* namespace nntrainer */
28 changes: 21 additions & 7 deletions nntrainer/graph/network_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ class NetworkGraph {
optimize_memory(true),
exec_mode(ExecutionMode::TRAIN),
tensor_format("NCHW"),
tensor_dtype(split("FP32-FP32", getRegex("\\-"))) {}
tensor_dtype(split("FP32-FP32", getRegex("\\-"))) {
nan_count = 0;
}

/**
* @brief Constructor of NeuralNetwork Graph Class
Expand All @@ -73,7 +75,9 @@ class NetworkGraph {
optimize_memory(true),
exec_mode(ExecutionMode::TRAIN),
tensor_format(tensor_format_),
tensor_dtype(split(tensor_dtype_, getRegex("\\-"))) {}
tensor_dtype(split(tensor_dtype_, getRegex("\\-"))) {
nan_count = 0;
}

/**
* @brief Destructor of the NeuralNetwork Graph class
Expand Down Expand Up @@ -206,13 +210,14 @@ class NetworkGraph {
* @param[in] backwarding_op operation for the backwarding
* @param[in] apply_grad_clip_op operation for applying the clip gradients
*/
void backwarding(
bool backwarding(
int iteration,
std::function<void(std::shared_ptr<LayerNode>, int)> &backwarding_op,
std::function<void(Weight &, int)> &apply_grad_clip_op,
std::function<void(std::shared_ptr<LayerNode>, bool)> &forwarding_op,
std::function<bool(std::shared_ptr<LayerNode>, int)> &backwarding_op,
std::function<void(Weight &, int)> &lazy_apply_grad_op,
std::function<bool(void *userdata)> stop_cb =
[](void *user_data) { return false; },
void *user_data = nullptr) const;
void *user_data = nullptr);

/**
* @brief get begin iterator for the graph
Expand Down Expand Up @@ -444,6 +449,12 @@ class NetworkGraph {
getLayerExecutionOrders(const std::shared_ptr<LayerNode> &lnode);
#endif // ENABLE_TEST

/**
* @brief reset the loss scale
* @param[in] scale
*/
void resetLossScale(float scale);

private:
std::map<std::string, std::string> sub_in_out; /** This is map to identify
input and output layer name of subgraph */
Expand Down Expand Up @@ -480,7 +491,10 @@ class NetworkGraph {
std::unordered_map<std::string, int>
profile_keys; /**< profile keys based on the layer type */
std::vector<Weight *>
clip_weights; /**< weights with global norm based clipping enabled */
lazy_weights; /**< weights with global norm based clipping enabled */
bool is_clip_grad;

unsigned int nan_count;

/**
* @brief topological sort
Expand Down
11 changes: 11 additions & 0 deletions nntrainer/layers/layer_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,17 @@ bool RunLayerContext::isGradientClipByGlobalNorm(unsigned int idx) const {
return weights[idx]->isGradientClipByGlobalNorm();
}

bool RunLayerContext::isMixedPrecision(unsigned int idx) const {
return weights[idx]->isMixedPrecision();
}

bool RunLayerContext::isMixedPrecision() const {
for (auto w : weights)
if (w->isMixedPrecision())
return true;
return false;
}

/**
* @brief Get the tensor name
*
Expand Down
26 changes: 26 additions & 0 deletions nntrainer/layers/layer_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,20 @@ class RunLayerContext {
*/
bool isGradientClipByGlobalNorm(unsigned int idx) const;

/**
* @brief check if the weight is mixed precsion
*
* @param idx index
* @return bool true if it is mixed precision
*/
bool isMixedPrecision(unsigned int idx) const;

/**
* @brief check if the weight is mixed precsion
* @return bool true if it is mixed precision
*/
bool isMixedPrecision() const;

/**
* @brief Get the tensor name
*
Expand Down Expand Up @@ -910,6 +924,18 @@ class RunLayerContext {
*/
float getLossScale() { return loss_scale; }

/**
* @brief set Loss_Scale.
*
* @return loss_scale
*/
void setLossScale(float scale) {
loss_scale = scale;
for (auto w : weights) {
w->setLossScale(scale);
}
}

private:
std::tuple<props::Name, props::Trainable> props; /**< props of the layer */
float loss; /**< loss of the layer */
Expand Down
Loading

0 comments on commit 929eab9

Please sign in to comment.