diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d8757c7f9f04..6504541c3b5b 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1276,6 +1276,13 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var, NDArrayHandle* var_handles, uint32_t* reqs_array, NDArrayHandle* grad_handles); +/*! + * \brief mark nonleaf NDArrays as variables during deferredcomputation + * \param num_nleafs number of nonleaf NDArrays + * \param cnt_var count of existing marked nonleaf variables + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var); /*! * \brief unmark nonleaf NDArrays to free the memory * \param num_var number of variable NDArrays diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index 42876f7bf445..65653cc9a890 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -290,6 +290,8 @@ class Imperative { void MarkVariables(const std::vector& variables, const std::vector& grad_reqs, const std::vector& gradients); + /*! \brief mark nonleaf variables during DC for computing gradients. */ + void MarkDCVariables(const std::vector& nleafs, int cnt_vars); /*! \brief unmark nonleaf variables to free the memory. */ void DropGrads(const std::vector& variables); /*! \brief compute the gradient of outputs w.r.t variables. */ diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index bed166a9307e..51fe5a9c579a 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -351,6 +351,8 @@ class NDArray { bool fresh_out_grad() const; /*! \return updated grad state in autograd_entry_ */ void set_fresh_out_grad(bool state) const; + /*! \brief copy the autograd_entry_ from src NDArray */ + void copy_autograd_entry_(const NDArray* src); /*! \brief Returns true if a sparse ndarray's aux_data and storage are initialized * Throws an exception if the indices array shape is inconsistent * Returns false if the indices array is empty(nnz = 0) for csr/row_sparse diff --git a/python/mxnet/_ctypes/cached_op.py b/python/mxnet/_ctypes/cached_op.py index 509484b7c3e4..fd5d6a9c0c1e 100644 --- a/python/mxnet/_ctypes/cached_op.py +++ b/python/mxnet/_ctypes/cached_op.py @@ -77,6 +77,7 @@ def __call__(self, *args, **kwargs): if not default_device: default_device = kwargs.pop('default_ctx', None) out = kwargs.pop('out', None) + nleaf_vars = [container.data() for container in kwargs.pop('_nleaf_vars', [])] if kwargs: raise TypeError( "CachedOp.__call__ got unexpected keyword argument(s): " + \ @@ -93,7 +94,10 @@ def __call__(self, *args, **kwargs): *args, type_id, device_id, - *out_arg + len(out_arg), + *out_arg, + len(nleaf_vars), + *nleaf_vars ) if out is not None: return out diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py index cff346b9f4aa..21b034a79c0b 100644 --- a/python/mxnet/gluon/block.py +++ b/python/mxnet/gluon/block.py @@ -33,13 +33,14 @@ import json import numpy as np -from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB +from ..base import mx_real_t, MXNetError, NDArrayHandle, SymbolHandle, py_str, check_call, _LIB, \ + _as_list from .. import symbol, ndarray, initializer, autograd, _deferred_compute as dc, name as _name, \ profiler as _profiler, device as _device from ..symbol.numpy import _symbol as np_symbol from ..symbol import Symbol, fromjson from ..ndarray import NDArray, get_dtype_name -from .parameter import Parameter, DeferredInitializationError +from .parameter import Parameter, DeferredInitializationError, Intermediate from .utils import _indent, _brief_print_list, HookHandle, shape_is_known from .utils import _check_same_symbol_type, _check_all_np_ndarrays, _check_block_input_np_ndarrays from .. import numpy_extension as _mx_npx @@ -1091,6 +1092,7 @@ def __init__(self): self._backend_opts = {} self._partition_if_dynamic = True self._first_forward = True + self._nleaf_vars = OrderedDict() def __setattr__(self, name, value): """Registers parameters.""" @@ -1302,7 +1304,7 @@ def _call_cached_op(self, *args): args_without_none = [ele for ele in args if ele is not None] cargs = [args_without_none[i] if is_arg else i.data() for is_arg, name, i in self._cached_op_args] - out = self._cached_op(*cargs) + out = self._cached_op(*cargs, _nleaf_vars=self._nleaf_vars.values()) if isinstance(out, NDArray): out = [out] return _regroup(out, self._out_format) @@ -1678,6 +1680,92 @@ def reset_ctx(self, ctx): self.reset_device(ctx) + def intermediate(self, names, var_arrays_inp, grad_req='write'): + """Mark the intermediate variables. + + Parameters + ---------- + name : str or tuple[str], name of the registered intermediate variable + var_arrays_inp : ndarray or tuple[ndarray], the output of the expression + grad_req : str, gradient request + """ + if not self._active: + var_arrays = _as_list(var_arrays_inp) + names = _as_list(names) + self._nleaf_vars.update( + {name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)}) + else: + prev_val = dc.set_deferred_compute(False) + var_arrays = _as_list(var_arrays_inp) + names = _as_list(names) + # Prepare ctypes array types + import ctypes + var_handles_type = ctypes.c_void_p * len(var_arrays) + # Convert handles + var_handles = var_handles_type(*[arr.handle for arr in var_arrays]) + check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars))) + self._nleaf_vars.update( + {name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)}) + dc.set_deferred_compute(prev_val) + return var_arrays_inp + + def attach_grad_intermediate(self): + """Attach gradient to all the intermediate variables. + """ + for val in self._nleaf_vars.values(): + val.data().attach_grad(grad_req=val.grad_req) + + def get_intermediate(self, names): + """Get the intermediate variables by names + """ + if isinstance(names, list): + return [self._nleaf_vars[n] for n in names] + else: + return self._nleaf_vars[names] + + def intermediate(self, names, var_arrays_inp, grad_req='write'): + """Mark the intermediate variables. + + Parameters + ---------- + name : str or tuple[str], name of the registered intermediate variable + var_arrays_inp : ndarray or tuple[ndarray], the output of the expression + grad_req : str, gradient request + """ + if not self._active: + var_arrays = _as_list(var_arrays_inp) + names = _as_list(names) + self._nleaf_vars.update( + {name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)}) + else: + prev_val = dc.set_deferred_compute(False) + var_arrays = _as_list(var_arrays_inp) + names = _as_list(names) + # Prepare ctypes array types + import ctypes + var_handles_type = ctypes.c_void_p * len(var_arrays) + # Convert handles + var_handles = var_handles_type(*[arr.handle for arr in var_arrays]) + check_call(_LIB.MXNDArrayMarkDCVariables(var_handles, len(var_arrays), len(self._nleaf_vars))) + self._nleaf_vars.update( + {name : Intermediate(name, array, grad_req) for name, array in zip(names, var_arrays)}) + dc.set_deferred_compute(prev_val) + return var_arrays_inp + + def attach_grad_intermediate(self): + """Attach gradient to all the intermediate variables. + """ + for val in self._nleaf_vars.values(): + val.data().attach_grad(grad_req=val.grad_req) + + def get_intermediate(self, names): + """Get the intermediate variables by names + """ + if isinstance(names, list): + return [self._nleaf_vars[n] for n in names] + else: + return self._nleaf_vars[names] + class SymbolBlock(HybridBlock): """Construct block from symbol. This is useful for using pre-trained models as feature extractors. For example, you may want to extract the output diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 1b396490a7fb..8cb4ac56b008 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -773,3 +773,40 @@ def grad_req(self, req): warnings.warn('Constant parameter "{}" does not support ' 'grad_req other than "null", and new value "{}" ' 'is ignored.'.format(self.name, req)) + +class Intermediate: + """A Container holding marked intermediate variables of Blocks. + + Parameters + ---------- + name : str. + Name of this parameter. It be used to retrieve the marked variables. + grad_req : {'write', 'add', 'null'}, default 'write' + Specifies how to update gradient to grad arrays. + + - ``'write'`` means everytime gradient is written to grad :py:class:`NDArray`. + - ``'add'`` means everytime gradient is added to the grad :py:class:`NDArray`. You need + to manually call ``zero_grad()`` to clear the gradient buffer before each + iteration when using this option. + - 'null' means gradient is not requested for this parameter. gradient arrays + will not be allocated. + """ + def __init__(self, name, data=None, grad_req='write'): + self._name = name + self._data = data + self._grad_req = grad_req + + def __repr__(self): + s = 'Intermediate name={name}' + return s.format(name=self._name) + + def data(self): + return self._data + + @property + def name(self): + return self._name + + @property + def grad_req(self): + return self._grad_req diff --git a/src/api/cached_op_api.cc b/src/api/cached_op_api.cc index 79494ea80bcf..fad7e0a93cd8 100644 --- a/src/api/cached_op_api.cc +++ b/src/api/cached_op_api.cc @@ -44,16 +44,18 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke") ndinputs.push_back(static_cast(args[i])); } + int num_outputs = args[num_inputs + 4]; + int num_nleafs = args[num_inputs + num_outputs + 5]; std::vector ndoutputs; ndoutputs.reserve(op->num_outputs()); - if (args[num_inputs + 4].type_code() == kNull) { + if (args[num_inputs + 5].type_code() == kNull) { for (int i = 0; i < op->num_outputs(); ++i) ndoutputs.push_back(new NDArray()); } else { - int array_size = args_size - num_inputs - 4; + int array_size = args_size - num_inputs - num_nleafs - 6; CHECK_EQ(array_size, op->num_outputs()) << "CachedOp expects " << op->num_outputs() << " outputs, but " << array_size << " was given."; - for (int i = num_inputs + 4; i < array_size; ++i) { + for (int i = num_inputs + 5; i < num_inputs + num_outputs + 5; ++i) { ndoutputs.push_back(args[i].operator mxnet::NDArray*()); } } @@ -69,6 +71,13 @@ MXNET_REGISTER_GLOBAL("cached_op.invoke") default_dev_id = ctx.dev_id; } + std::vector nleafs; + nleafs.reserve(num_nleafs); + for (int i = 0; i < num_nleafs; ++i) { + nleafs.push_back(static_cast(args[i + num_inputs + num_outputs + 6])); + } + op->set_nleafs(nleafs); + // construct default context Context ctx = Context::Create(static_cast(default_dev_type), default_dev_id); diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index b91a997b7ce1..bab985c0a8c1 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -495,3 +495,15 @@ int MXNDArrayGetDeferredComputeSymbol(NDArrayHandle* output_handles, *out = s; API_END_HANDLE_ERROR(delete s;); } + +int MXNDArrayMarkDCVariables(NDArrayHandle* nleaf_handles, int num_nleafs, int cnt_var) { + API_BEGIN(); + std::vector nleafs; + nleafs.reserve(num_nleafs); + for (int i = 0; i < num_nleafs; ++i) { + NDArray* array = reinterpret_cast(nleaf_handles[i]); + nleafs.emplace_back(array); + } + Imperative::Get()->MarkDCVariables(nleafs, cnt_var); + API_END(); +} diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 894ef09a1d16..2660ece5221f 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -801,7 +801,8 @@ OpStatePtr CachedOp::DynamicForward(const Context& default_ctx, recording && inlining_, nullptr, monitor_callback_, - monitor_all_); + monitor_all_, + nleafs_); } else { mxnet::ShapeVector shapes = g.GetAttr("shape"); NaiveRunGraph(false, @@ -1063,6 +1064,7 @@ void CachedOp::StaticBackward(const bool retain_graph, if (!idx.exist(entry.node.get())) continue; auto eid = idx.entry_id(entry); + state.array_reqs[eid] = reqs[iter->second]; // An input and an output may share the same array. INIT_DETACHED(outputs[iter->second], arrays[eid]); arrays[eid] = outputs[iter->second]; @@ -1073,6 +1075,7 @@ void CachedOp::StaticBackward(const bool retain_graph, if (!idx.exist(entry.node.get())) continue; auto eid = idx.entry_id(entry); + state.array_reqs[eid] = reqs[i]; // An input and an output may share the same array. INIT_DETACHED(outputs[i], arrays[eid]); arrays[eid] = outputs[i]; diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 079a56e20a12..2d4c693b59a1 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -491,6 +491,9 @@ class CachedOp { const std::unordered_set& mutable_input_nodes() const { return fwd_graph_.indexed_graph().mutable_input_nodes(); } + void set_nleafs(const std::vector& nleafs) { + nleafs_ = nleafs; + } virtual std::vector Gradient(const nnvm::ObjectPtr& node, const std::vector& ograds) const; virtual OpStatePtr Forward(const std::shared_ptr& op_ptr, @@ -649,6 +652,7 @@ class CachedOp { std::vector bwd_in_dep_, bwd_out_dep_, bwd_ograd_dep_; std::vector save_inputs_, save_outputs_; std::vector bwd_output_reqs_; + std::vector nleafs_; std::function monitor_callback_{nullptr}; bool monitor_all_{false}; diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index fb123c18c9fc..ec8cdc59cdcd 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -171,6 +171,18 @@ void Imperative::MarkVariables(const std::vector& variables, } } +void Imperative::MarkDCVariables(const std::vector& nleafs, int cnt_vars) { + for (NDArray* nleaf : nleafs) { + if (Imperative::DCInfo::IsNone(*nleaf)) { + LOG(WARNING) << "The marked node doesn't have deferred compute history."; + } else { + nnvm::ObjectPtr node = nleaf->deferredcompute_entry_.node; + node->attrs.dict["mark_id"] = std::to_string(cnt_vars); + } + cnt_vars++; + } +} + // Unmark the variables to free the memory. void Imperative::DropGrads(const std::vector& variables) { for (auto variable : variables) { diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index e3a58804d8ac..40755ef4d85f 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -138,7 +138,8 @@ void RunGraph(const bool retain_graph, bool recording, mxnet::ShapeVector* shapes, const imperative::CachedOpMonCallback& callback, - const bool monitor_all) { + const bool monitor_all, + const std::vector& nleafs) { CHECK(shapes == nullptr); for (size_t i = node_start; i < node_end; ++i) { const nnvm::IndexedGraph::Node& node = idx[i]; @@ -166,6 +167,15 @@ void RunGraph(const bool retain_graph, if (callback) { mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback); } + // set the autograd_entry_ in marked nleafs + if (nleafs.size()) { + auto it = node.source->attrs.dict.find("mark_id"); + if (it != node.source->attrs.dict.end()) { + int mark_id = std::stoi(it->second); + CHECK_LT(mark_id, nleafs.size()) << "Mark_id exceeds the nonleaf list size."; + nleafs[mark_id]->copy_autograd_entry_(ndoutputs[0]); + } + } } } diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 7f90528f4793..5fbd5d1c4cef 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -1386,7 +1386,8 @@ void RunGraph(const bool retain_graph, bool recording, mxnet::ShapeVector* shapes = nullptr, const CachedOpMonCallback& callback = nullptr, - const bool monitor_all_ = false); + const bool monitor_all_ = false, + const std::vector& nleafs = std::vector()); void NaiveRunGraph(const bool retain_graph, const Context& default_ctx, diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 902880fb1d52..f4cf02c15633 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -513,6 +513,10 @@ void NDArray::set_fresh_out_grad(bool state) const { info.fresh_out_grad = state; } +void NDArray::copy_autograd_entry_(const NDArray* src) { + autograd_entry_ = nnvm::NodeEntry{src->autograd_entry_.node, 0, 0}; +} + #if MXNET_USE_ONEDNN == 1 bool NDArray::Chunk::IsDNNL() const { diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index c48d20479f15..a103d8e917a0 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -533,7 +533,7 @@ def test_retain_grad_drop_grad(): z.attach_grad() out_grad = nd.array([10, 10, 10, 10]) z.backward(out_grad, retain_graph=True) - + assert (u.grad == out_grad * x).asnumpy().all() assert (z.grad == out_grad).asnumpy().all() assert (x.grad == out_grad * 2 * x * y).asnumpy().all() @@ -548,39 +548,48 @@ def test_retain_grad_drop_grad(): assert u.grad is None and z.grad is None and y.grad is None assert (x.grad == out_grad * 2 * x * y).asnumpy().all() -def test_retain_grad_drop_grad_gluon(): - class CompBlock(mx.gluon.HybridBlock): +@pytest.fixture(scope="function", params=[True, False]) +def test_retain_grad_drop_grad_gluon(request): + class CompBlock(mx.HybridBlock): def __init__(self): super().__init__() - self.marked_var = None - def forward(self, a, b): - out1 = a*b - out2 = out1 * a - self.marked_var = out1 + + def forward(self, a, b, c): + out1 = self.intermediate(('out1_0', 'out1_1'), ((a+b)*c, a*b), grad_req='write') + out2 = self.intermediate('out2', out1[1] * a) return out2 + x = mx.np.array([1,2,3,4]) y = mx.np.array([5,6,7,8]) + w = mx.np.array([0.1, 0.1, 0.1, 0.1]) x.attach_grad() y.attach_grad() + w.attach_grad() block2 = CompBlock() block2.initialize() - # block2.hybridize() + param = request.param + if param: + block2.hybridize() with mx.autograd.record(): - z = block2(x, y) - u = block2.marked_var - u.attach_grad() - z.attach_grad() + z = block2(x, y, w) + + block2.attach_grad_intermediate() + u0 = block2.get_intermediate('out1_0').data() + u = block2.get_intermediate('out1_1').data() + z = block2.get_intermediate('out2').data() z.backward(retain_graph=True) assert (u.grad == x).all() + assert (u0.grad == mx.np.array([0, 0, 0, 0])).all() assert (z.grad == mx.np.array([1,1,1,1])).all() assert (x.grad == 2 * x * y).all() assert (y.grad == x*x).all() u.drop_grad() + u0.drop_grad() z.drop_grad() y.drop_grad() z.backward() - assert u.grad is None and z.grad is None and y.grad is None + assert u.grad is None and u0.grad is None and y.grad is None and z.grad is None assert (x.grad == 2 * x * y).all()