Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implement backward_all as a counterpart of forward_all #340

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions include/nbla/computation_graph/computation_graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,12 @@ NBLA_API void steal_variable_from_to(CgVariablePtr from, CgVariablePtr to);
*/
NBLA_API void forward_all(const vector<CgVariablePtr> variables,
bool clear_no_need_grad = false);

/** Backward given variables in a single call.
* Backward all given variables in a single call.
*/
NBLA_API void backward_all(
const vector<CgVariablePtr> variables, bool clear_buffer = false,
const vector<CommunicatorBackwardCallbackPtr> communicator_callbacks = {});
}
#endif
31 changes: 27 additions & 4 deletions include/nbla/computation_graph/variable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ struct CommunicatorBackwardCallback {
typedef shared_ptr<CommunicatorBackwardCallback>
CommunicatorBackwardCallbackPtr;

/** visit functions backward to calculate gradients
*/
void visit_function_backward(
vector<CgFunctionPtr> roots,
std::function<void(CgFunctionPtr)> backward_callback,
vector<CommunicatorBackwardCallbackPtr> communicator_callbacks);

/** Computation graph variable.

A Variable object is held in this object as a data container. In addition,
Expand Down Expand Up @@ -75,10 +82,6 @@ class CgVariable {
unordered_set<CgFunctionPtr> &fclosed,
std::function<void(CgFunctionPtr)> forward_callback);

void visit_function_backward(
CgFunctionPtr func, std::function<void(CgFunctionPtr)> backward_callback,
vector<CommunicatorBackwardCallbackPtr> communicator_callbacks);

public:
typedef shared_ptr<CgVariable> Ptr;

Expand Down Expand Up @@ -298,5 +301,25 @@ class CgVariable {
/** shared_ptr typedef of CGVariable
*/
typedef CgVariable::Ptr CgVariablePtr;

/** Callback invoked at backward
*/
class BackwardCallback {
bool clear_buffer_;
unordered_map<CgVariablePtr, bool> vseen_;
vector<bool> get_accum(const vector<CgVariablePtr> &inputs,
const vector<bool> &first_visit_flags);
void force_zero_grad_if_unseen(vector<CgVariablePtr> outputs,
const vector<bool> &first_visit);
void clear_output_buffers(CgFunctionPtr func,
const vector<bool> &prohibit_clear);
pair<vector<bool>, vector<bool>> query_outputs_flags(
const vector<CgVariablePtr> &outputs);
vector<bool> query_input_flags(const vector<CgVariablePtr> &inputs,
CgFunctionPtr func);
public:
BackwardCallback(vector<CgFunctionPtr> roots, bool clear_buffer);
void operator()(CgFunctionPtr f);
};
}
#endif
2 changes: 1 addition & 1 deletion python/src/nnabla/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .context import (
context_scope, set_default_context, get_current_context)
from .auto_forward import auto_forward, set_auto_forward, get_auto_forward
from._computation_graph import forward_all
from._computation_graph import forward_all, backward_all

# Prefer cached array by default for performance.
prefer_cached_array(True)
4 changes: 4 additions & 0 deletions python/src/nnabla/_computation_graph.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,7 @@ cdef extern from "nbla/computation_graph/computation_graph.hpp" namespace "nbla"
cpp_bool) except+
void steal_variable_from_to(CgVariablePtr f, CgVariablePtr t) except+
void forward_all(const vector[CgVariablePtr] &, cpp_bool) nogil except+
void backward_all(
const vector[CgVariablePtr] &,
cpp_bool,
vector[CommunicatorBackwardCallbackPtr]) nogil except+
24 changes: 24 additions & 0 deletions python/src/nnabla/_computation_graph.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from libcpp cimport bool as cpp_bool
from libcpp.vector cimport vector
from _variable cimport Variable as _Variable
from _variable cimport CommunicatorBackwardCallback
from _variable cimport CommunicatorBackwardCallbackPtr
from _variable cimport CgVariable as _CgVariable
from _computation_graph cimport forward_all as cforward_all
from _computation_graph cimport backward_all as cbackward_all


def forward_all(variables, cpp_bool clear_no_need_grad=False):
Expand All @@ -28,3 +32,23 @@ def forward_all(variables, cpp_bool clear_no_need_grad=False):
cg_variables[i] = (<_Variable?> variables[i]).var
with nogil:
cforward_all(cg_variables, clear_no_need_grad)


def backward_all(variables, cpp_bool clear_buffer=False, communicator_callbacks=None):
cdef vector[CommunicatorBackwardCallbackPtr] callback_list
if type(communicator_callbacks) == list:
for x in communicator_callbacks:
callback_list.push_back((< CommunicatorBackwardCallback?> x).var)
elif type(communicator_callbacks) != type(None):
callback_list.push_back((< CommunicatorBackwardCallback?> communicator_callbacks).var)

cdef vector[CgVariablePtr] cg_variables
cdef int i
cdef int size
size = len(variables)
cg_variables.resize(size)
for i in range(size):
cg_variables[i] = (<_Variable?> variables[i]).var

with nogil:
cbackward_all(cg_variables, clear_buffer, callback_list)
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,20 @@ def test_graph_logreg(seed):
L2 = F.mean(l2)
nn.forward_all([L1, L2])

def zero_grad():
x.g = 0
w1.g = 0
w2.g = 0
b1.g = 0
b2.g = 0

def backup_grads():
grads = [x.g, w1.g, w2.g, b1.g, b2.g]
return map(lambda v: v.copy(), grads)

# Backprop for z1
# Diff should be initialized since they are always accumulated
x.g = 0
w1.g = 0
b1.g = 0
zero_grad()
L1.backward(clear_buffer=True)

inputs = [x, w1, b1]
Expand All @@ -68,9 +77,7 @@ def test_graph_logreg(seed):

# Backprop for z2
# Diff should be initialized since they are always accumulated
x.g = 0
w2.g = 0
b2.g = 0
zero_grad()
L2.backward(clear_buffer=True)

inputs = [x, w2, b2]
Expand All @@ -80,6 +87,16 @@ def test_graph_logreg(seed):
agrad, ngrad = grads(L2, inputs, 1e-3, False)
assert np.allclose(ngrad, agrad, atol=1e-2)

zero_grad()
L1.backward(clear_buffer=True)
L2.backward(clear_buffer=True)
grad1 = backup_grads()
zero_grad()
nn.backward_all([L1, L2], clear_buffer=True)
grad2 = backup_grads()
for g1, g2 in zip(grad1, grad2):
np.allclose(g1, g2)


@pytest.mark.parametrize("seed", [311])
@pytest.mark.parametrize("model", ["mlp", "recurrent", "convolution"])
Expand Down Expand Up @@ -157,6 +174,15 @@ def test_graph_model(model, seed):
agrad, ngrad = grads(L2, inputs, 1e-3, False)
assert np.allclose(ngrad, agrad, atol=1.05e-2)

# test backward_all
initialize_grad(parameters)
L1.backward(clear_buffer=False)
L2.backward(clear_buffer=True)
backup_grads = {k: v.g.copy() for k, v in parameters.items()}
nn.backward_all([L1, L2], clear_buffer=True)
for k, g in backup_grads.items():
np.allclose(parameters[k].g, g)


@pytest.mark.parametrize("seed", [311])
def test_graph_unlink_backward(seed):
Expand Down Expand Up @@ -186,6 +212,11 @@ def test_graph_unlink_backward(seed):
assert np.all(x0.g == 0)
assert not np.all(x1.g == 0)

# test backward_all
nn.backward_all([y1, y2], clear_buffer=True)
assert np.all(x0.g == 0)
assert not np.all(x1.g == 0)


@pytest.mark.parametrize("seed", [311])
def test_graph_clear_buffer(seed):
Expand Down Expand Up @@ -224,11 +255,7 @@ def test_graph_clear_buffer(seed):
for v in nn.get_parameters().values():
v.grad.zero()
nn.forward_all([L1, L2], clear_no_need_grad=cnng)

# for now, the first backward cannot be
# called with clear_buffer=True
L1.backward(clear_buffer=False)
L2.backward(clear_buffer=cb)
nn.backward_all([L1, L2], clear_buffer=cb)
if not first:
first = True
g = list(nn.get_parameters().values())[0].g.copy()
Expand Down Expand Up @@ -308,3 +335,14 @@ def backup_params():
assert np.allclose(xa.d, xc.d)
for b, c in zip(gb, gc):
assert np.allclose(b, c)

# test backward_all
zero_grad()
nn.backward_all([yb1, yb2], clear_buffer=True)
gb = backup_params()
zero_grad()
nn.backward_all([yc1, yc2], clear_buffer=True)
gc = backup_params()
assert np.allclose(xa.d, xc.d)
for b, c in zip(gb, gc):
assert np.allclose(b, c)
26 changes: 26 additions & 0 deletions src/nbla/computation_graph/computation_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,30 @@ void forward_all(const vector<CgVariablePtr> variables,
variables[i]->forward(false, clear_no_need_grad, &fclosed);
}
}

void backward_all(const vector<CgVariablePtr> variables, bool clear_buffer,
const vector<CommunicatorBackwardCallbackPtr> communicator_callbacks) {
// setup backward at each variable
vector<NdArrayPtr> bak_grads;
DestructorCallback at_scope_exit([&]() {
for (int i = 0; i < variables.size(); ++i) {
variables[i]->variable()->set_grad(bak_grads[i]);
}
});
vector<CgFunctionPtr> roots;
for (auto v : variables) {
// backup gradients
bak_grads.push_back(v->variable()->grad());
// set function to avoid clearing
roots.push_back(v->parent());
}

// Create callback
BackwardCallback backward_callback(roots, clear_buffer);

// Visit backward
visit_function_backward(
roots, [&backward_callback](CgFunctionPtr f) { backward_callback(f); },
communicator_callbacks);
}
}
Loading