Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Wait for #2580] [ Mixed Precision ] Enable Mixed Precision #2581

Closed
wants to merge 8 commits into from
Prev Previous commit
Next Next commit
[ Mixed ] Create weight with var32 tensor
This pr create the variable fp32 tensor when we create the Weight and
Optimizer Weight.

. update the manager to create Weight with  var32 tensor which
requested to weight pool.
. update the weight requests with Weight Spec and var, grad and var32
tensors which created already.
. add clone Tensor with specific type in tensor.h

Resolves:

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
jijoongmoon committed May 7, 2024
commit b6ad1d06d683cc1a00ccca7e7bbd02764fa240df
7 changes: 4 additions & 3 deletions nntrainer/graph/network_graph.cpp
Original file line number Diff line number Diff line change
@@ -768,6 +768,7 @@ NetworkGraph::finalizeContext(const std::shared_ptr<LayerNode> &lnode,
* node is going to be used with in-place optimizations.
*/
auto out_specs = init_context.getOutSpecs();

/// @note try move inplace control to finalize
bool shared_var = false, shared_grad = false;
if (lnode->executeInPlace() != InPlace::NONE) {
@@ -1556,16 +1557,16 @@ void NetworkGraph::requestOptimizerVariable(
const TensorDim &dim = w->getDim();
std::vector<TensorDim> dims = cb(dim);
w->setOptimizerVariables(tensor_manager->requestWeightOptimizerVariables(
dims, w->getName(), TensorLifespan::MAX_LIFESPAN,
dims, w->getName(), ":opt", TensorLifespan::MAX_LIFESPAN,
w->isGradientClipByGlobalNorm(), w->isMixedPrecision(),
Tensor::Initializer::ZEROS));

if (dim.getDataType() != ml::train::TensorDim::DataType::FP32) {
if (w->isMixedPrecision()) {
for (auto &dim : dims)
dim.setDataType(ml::train::TensorDim::DataType::FP32);
w->setOptimizerVariables32(
tensor_manager->requestWeightOptimizerVariables(
dims, w->getName(), TensorLifespan::MAX_LIFESPAN,
dims, w->getName(), ":opt32:", TensorLifespan::MAX_LIFESPAN,
w->isGradientClipByGlobalNorm(), w->isMixedPrecision(),
Tensor::Initializer::ZEROS));
}
13 changes: 13 additions & 0 deletions nntrainer/layers/layer_context.cpp
Original file line number Diff line number Diff line change
@@ -169,6 +169,19 @@ Tensor &RunLayerContext::getWeightGrad(unsigned int idx) const {
return weights[idx]->getGradientRef();
}

/**
* @brief Get the Weight Gradient tensor object
*
* @param idx Identifier of the weight
* @return Tensor& Reference to the weight grad tensor
*/
Tensor &RunLayerContext::getWeightFP32(unsigned int idx) const {
if (!weights[idx]->hasGradient())
throw std::invalid_argument(
"Requesting gradient for a non-trainable weight.");
return weights[idx]->getVariableFP32Ref();
}

/**
* @brief Get the Weight Optimizer Variable tensor object
*
9 changes: 9 additions & 0 deletions nntrainer/layers/layer_context.h
Original file line number Diff line number Diff line change
@@ -463,6 +463,15 @@ class RunLayerContext {
Tensor &getWeightGrad(unsigned int idx) const;

/**
* @brief Get the Weight Gradient tensor object
*
* @param idx Identifier of the weight
* @return Tensor& Reference to the weight grad tensor
*/
Tensor &getWeightFP32(unsigned int idx) const;

/**

* @brief Get the Weight Optimizer Variable tensor object
*
* @param idx Identifier of the weight
8 changes: 4 additions & 4 deletions nntrainer/layers/layer_node.h
Original file line number Diff line number Diff line change
@@ -496,11 +496,11 @@ class LayerNode final : public ml::train::Layer, public GraphNode {
NNTR_THROW_IF(!run_context, std::runtime_error)
<< __func__ << " layer needs to be finalized first!";
if (run_context->weightHasGradient(idx)) {
return Weight(run_context->getWeight(idx),
run_context->getWeightGrad(idx),
run_context->getWeightName(idx));
return Weight(
run_context->getWeight(idx), run_context->getWeightGrad(idx),
run_context->getWeightFP32(idx), run_context->getWeightName(idx));
} else {
return Weight(run_context->getWeight(idx), Tensor(),
return Weight(run_context->getWeight(idx), Tensor(), Tensor(),
run_context->getWeightName(idx));
}
}
6 changes: 3 additions & 3 deletions nntrainer/tensor/manager.cpp
Original file line number Diff line number Diff line change
@@ -689,8 +689,8 @@ bool Manager::isSecondLastAccess(const std::string &name,
*/
std::vector<Tensor *> Manager::requestWeightOptimizerVariables(
const std::vector<TensorDim> &dims, const std::string &name,
const TensorLifespan &lifespan, bool is_grad_clip, bool is_mixed_precision,
Tensor::Initializer initializer) {
const std::string &suffix, const TensorLifespan &lifespan, bool is_grad_clip,
bool is_mixed_precision, Tensor::Initializer initializer) {

std::vector<Tensor *> ret;
ret.reserve(dims.size());
@@ -706,7 +706,7 @@ std::vector<Tensor *> Manager::requestWeightOptimizerVariables(
/// @note this is assuming weight optimizer variables is treated as weight, if
/// not, there is room to optimize below behavior
for (unsigned int idx = 0; idx < dims.size(); idx++)
ret.push_back(weight_pool.request(name + ":opt" + std::to_string(idx),
ret.push_back(weight_pool.request(name + suffix + std::to_string(idx),
dims[idx], exec, lifespan, initializer));

return ret;
3 changes: 2 additions & 1 deletion nntrainer/tensor/manager.h
Original file line number Diff line number Diff line change
@@ -224,7 +224,8 @@ class Manager {
*/
std::vector<Tensor *> requestWeightOptimizerVariables(
const std::vector<TensorDim> &dims, const std::string &name,
const TensorLifespan &lifespan, bool is_grad_clip, bool is_mixed_type,
const std::string &suffix, const TensorLifespan &lifespan,
bool is_grad_clip, bool is_mixed_type,
Tensor::Initializer initializer = Tensor::Initializer::NONE);

/**
12 changes: 12 additions & 0 deletions nntrainer/tensor/tensor.cpp
Original file line number Diff line number Diff line change
@@ -3065,6 +3065,18 @@ Tensor Tensor::clone() const {
return t;
}

Tensor Tensor::clone(ml::train::TensorDim::DataType type) const {
if (getDataType() == type)
return clone();

TensorDim dim = getDim();
dim.setDataType(type);
Tensor t(dim, true);
t.copyData(*this);
t.name = name;
return t;
}

void Tensor::reshape(const TensorDim &d) {

NNTR_THROW_IF(!contiguous, std::invalid_argument)
7 changes: 7 additions & 0 deletions nntrainer/tensor/tensor.h
Original file line number Diff line number Diff line change
@@ -1680,6 +1680,13 @@ class Tensor {
*/
Tensor clone() const;

/**
* @brief Convient wrapper for inplace copy of @a this.
* @param[in] type output tensor data type
* @retval Copied version of this
*/
Tensor clone(ml::train::TensorDim::DataType type) const;

/**
* @brief Save the Tensor into file
* @param[in] file output file stream
28 changes: 9 additions & 19 deletions nntrainer/tensor/weight.cpp
Original file line number Diff line number Diff line change
@@ -90,34 +90,24 @@ Weight::Weight(const TensorDim &dim_v, const TensorDim &dim_g,
}
}

Weight::Weight(const Tensor &v, const Tensor &g, const std::string &n,
bool is_dependent, unsigned int output_axis_) :
Weight::Weight(const Tensor &v, const Tensor &g, const Tensor &v32,
const std::string &n, bool is_dependent,
unsigned int output_axis_) :
Var_Grad(v, g, n, is_dependent),
regularizer(WeightRegularizer::NONE),
regularizer_constant(1.0f),
decay(0.0f),
clip_by_global_norm(0.0f),
output_axis(output_axis_),
loss_scale(0.0) {
loss_scale(0.0),
var32(std::make_shared<Tensor>(n + ":fp32")) {

std::string var32_suffix = ":fp32";
std::string var32_name = n + var32_suffix;

/**
* @note We assume here that Weight is created with variable and gradient
* tensor. It is not copy or clone and, therefore, we do need create var32 if
* it is trainable. For now, We haven't seen the case create wieght with var,
* grad and var32. But we will add weight constructor if there is the cases.
*/

if (!g.empty() && v.getDataType() != ml::train::TensorDim::DataType::FP32) {
if (!g.empty() && isMixedPrecision()) {
TensorDim var32_dim(v.getDim());
var32_dim.setDataType(ml::train::TensorDim::DataType::FP32);

var32 = std::make_shared<Tensor>(var32_dim, true, Tensor::Initializer::NONE,
var32_name);
} else {
var32 = std::make_shared<Tensor>(var32_name);
if (!v32.empty())
var32 = std::make_shared<Tensor>(
v32.getSharedDataTensor(var32_dim, 0, false, n + ":fp32"));
}
}

15 changes: 12 additions & 3 deletions nntrainer/tensor/weight.h
Original file line number Diff line number Diff line change
@@ -114,6 +114,7 @@ class Weight : public Var_Grad {
*
* @param v Already created variable object
* @param g Already created gradient object
* @param v32 Already created gradient object
* @param n Name for this Weight
*
* @note This is primarily used to created wrapper of variable extracted from
@@ -123,8 +124,9 @@ class Weight : public Var_Grad {
* uses only, as Weight does not own the tensors v and g, and can go invalid
* if the owner of these tensors free the tensors.
*/
explicit Weight(const Tensor &v, const Tensor &g, const std::string &n = "",
bool is_dependent = false, unsigned int output_axis_ = 3);
explicit Weight(const Tensor &v, const Tensor &g, const Tensor &v32,
const std::string &n = "", bool is_dependent = false,
unsigned int output_axis_ = 3);

/**
* @brief Construct a new Weight object
@@ -324,7 +326,7 @@ class Weight : public Var_Grad {
* @return false otherwise
*/
bool isMixedPrecision() const {
return var->getDataType() == ml::train::TensorDim::DataType::FP32;
return var->getDataType() != ml::train::TensorDim::DataType::FP32;
}

/**
@@ -337,6 +339,13 @@ class Weight : public Var_Grad {
grad->multiply_i(clip_by_global_norm / (global_norm + epsilon));
}

/**
* @brief Get the variable FP32 tensor (by reference)
*
* @return Tensor Variable FP32 tensor
*/
Tensor &getVariableFP32Ref() { return *var32.get(); }

private:
static constexpr float epsilon = 1e-6; /**< epsilon for zero comparison */
static constexpr float epsilon_decay =