Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lvardt cell CvMembList.ml.size() should be the number of contiguous regions of the fixed step thread Memb_list #3299

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion src/nrncvode/cvodeobj.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ struct model_sorted_token;
* contiguous
* - with ml.size() >= 1 and ml[i].nodecount == 1 when non-contiguous instances need to be processed
*
* generic configurations with ml.size() and ml[i].nodecount both larger than one are not supported.
* generic configurations with ml.size() and ml[i].nodecount both larger than one are only
* supported for the local variable time step method.
*/
struct CvMembList {
CvMembList(int type)
Expand Down
112 changes: 80 additions & 32 deletions src/nrncvode/netcvode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1615,6 +1615,20 @@ bool NetCvode::init_global() {
}
}

// Modified to also count the nodes and set the offsets for
// each CvMembList.ml[contig_region].
// The sum of the ml[i].nodecount must equal the mechanism
// nodecount for the cell and each ml[i] data must be contiguous.
// Ideally the node permutation would be such that each cell
// is contiguous. So only needing a ml[0]. That is sadly not
// the case with the default permutation. The cell root nodes are
// all at the beginning, and thereafter only Section nodes are
// contiguous. It would be easy to permute nodes so that each cell
// is contiguous (except root node). This would result in a
// CvMembList.ml.size() == 1 almost always with an exception of
// size() == 2 only for extracellular and for POINT_PROCESSes
// located both in the root node and other cell nodes.

for (NrnThreadMembList* tml = _nt->tml; tml; tml = tml->next) {
i = tml->index;
const Memb_func& mf = memb_func[i];
Expand All @@ -1625,42 +1639,57 @@ bool NetCvode::init_global() {
// singly linked list built below
int j;
for (j = 0; j < ml->nodecount; ++j) {
auto offset = ml->get_storage_offset() + j;
// for each Memb_list instance constructed, keep
// track of its initial storage offset (i.e. offset)

int inode = ml->nodelist[j]->v_node_index;
int icell = cellnum[inode];
Cvode& cv = d.lcv_[cellnum[inode]];
CvodeThreadData& z = cv.ctd_[0];
if (!z.cv_memb_list_) {

// Circumstances for creating a new CvMembList
// or (due to non-contiguity of a cell),
// appending a Memb_list instance to cml->ml
if (!z.cv_memb_list_) { // initialize the first
cml = new CvMembList{i};
cml->next = nullptr;
assert(cml->ml.size() == 1);
cml->ml[0].nodecount = 0;
z.cv_memb_list_ = cml;
cml->next = nullptr;
last[cellnum[inode]] = cml;
}
if (last[cellnum[inode]]->index == i) {
assert(last[cellnum[inode]]->ml.size() == 1);
++last[cellnum[inode]]->ml[0].nodecount;
} else {
assert(cml->ml.size() == 1);
assert(cml->ml[0].nodecount == 0);
} else if (last[cellnum[inode]]->index != i) { // initialize next
cml = new CvMembList{i};
last[cellnum[inode]]->next = cml;
cml->next = nullptr;
last[cellnum[inode]] = cml;
assert(cml->ml.size() == 1);
cml->ml[0].nodecount = 1;
assert(cml->ml[0].nodecount == 0);
} else { // if non-contiguous, append Memb_list
cml = last[cellnum[inode]];
auto& cvml = cml->ml.back();
auto cvml_offset = cvml.get_storage_offset() + cvml.nodecount;
if (cvml_offset != offset) {
// not contiguous, add another Memb_list
// instance to cml->ml
cml->ml.emplace_back(cml->index);
assert(cml->ml.back().nodecount == 0);
}
}

auto& cvml = cml->ml.back();
if (cvml.nodecount == 0) { // first time for this Memb_List
cvml.set_storage_offset(offset);
}
// Increment count of last Memb_list in cml->ml.
++cvml.nodecount;
}
}
}
// allocate and re-initialize count

std::vector<CvMembList*> cvml(d.nlcv_);
for (i = 0; i < d.nlcv_; ++i) {
cvml[i] = d.lcv_[i].ctd_[0].cv_memb_list_;
for (cml = cvml[i]; cml; cml = cml->next) {
// non-contiguous mode, so we're going to create a lot of 1-element Memb_list
// inside cml->ml
cml->ml.reserve(cml->ml[0].nodecount);
// remove the single entry from contiguous mode
cml->ml.clear();
}
cvml[i] = d.lcv_[i].ctd_[0].cv_memb_list_; // whole cell in thread
}
// fill pointers (and nodecount)
// now list order is from 0 to n_memb_func
Expand All @@ -1670,24 +1699,43 @@ bool NetCvode::init_global() {
Memb_list* ml = tml->ml;
if (ml->nodecount && (mf.current || mf.ode_count || mf.ode_matsol || mf.ode_spec ||
mf.state || i == CAP || ba_candidate.count(i) == 1)) {
for (int j = 0; j < ml->nodecount; ++j) {
int increment = 1; // newml.nodecount is handled in the newml loop below
for (int j = 0; j < ml->nodecount; j += increment) {
int icell = cellnum[ml->nodelist[j]->v_node_index];
if (cvml[icell]->index != i) {
cvml[icell] = cvml[icell]->next;
assert(cvml[icell] && cvml[icell]->index);
assert(cvml[icell] && cvml[icell]->index == i);
}
cml = cvml[icell];
auto& newml = cml->ml.emplace_back(cml->index /* mechanism type */);
newml.nodecount = 1;
newml.nodelist = new Node*[1];
newml.nodelist[0] = ml->nodelist[j];
newml.nodeindices = new int[1]{ml->nodeindices[j]};
newml.prop = new Prop* [1] { ml->prop[j] };
if (!mf.hoc_mech) {
newml.set_storage_offset(ml->get_storage_offset() + j);
newml.pdata = new Datum* [1] { ml->pdata[j] };
auto& cml = cvml[icell];
increment = 1;
for (auto& newml: cml->ml) {
if (!newml.nodelist) {
auto nodecount = newml.nodecount;
// do nodecount of these for ml and then
// skip forward by nodecount in the outer
// ml->nodecount j loop (i.e. a contiguity
// region)
increment = nodecount;
newml.nodelist = new Node*[nodecount];
newml.nodeindices = new int[nodecount];
newml.prop = new Prop*[nodecount];
if (!mf.hoc_mech) {
newml.pdata = new Datum*[nodecount];
}
for (int k = 0; k < nodecount; ++k) {
newml.nodelist[k] = ml->nodelist[j + k];
newml.nodeindices[k] = ml->nodeindices[j + k];
assert(cellnum[newml.nodeindices[k]] ==
cellnum[ml->nodeindices[j]]);
newml.prop[k] = ml->prop[j + k];
if (!mf.hoc_mech) {
newml.pdata[k] = ml->pdata[j + k];
}
}
newml._thread = ml->_thread;
break;
}
}
newml._thread = ml->_thread;
}
}
}
Expand Down
20 changes: 11 additions & 9 deletions src/nrncvode/occvode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,6 @@ printf("%d Cvode::init_eqn id=%d neq_v_=%d #nonvint=%d #nonvint_extra=%d nvsize=
zneq_cap_v = 0;
if (z.cmlcap_) {
for (auto& ml: z.cmlcap_->ml) {
// support `1 x n` and `n x 1` but not `n x m`
assert(z.cmlcap_->ml.size() == 1 || ml.nodecount == 1);
zneq_cap_v += ml.nodecount;
}
}
Expand Down Expand Up @@ -207,13 +205,17 @@ printf("%d Cvode::init_eqn id=%d neq_v_=%d #nonvint=%d #nonvint_extra=%d nvsize=
// sentinal values for determining no_cap
NODERHS(z.v_node_[i]) = 1.;
}
for (i = 0; i < zneq_cap_v; ++i) {
auto* const node = z.cmlcap_->ml.size() == 1 ? z.cmlcap_->ml[0].nodelist[i]
: z.cmlcap_->ml[i].nodelist[0];
z.pv_[i] = node->v_handle();
z.pvdot_[i] = node->rhs_handle();
*z.pvdot_[i] = 0.; // only ones = 1 are no_cap
}
i = 0;
if (zneq_cap_v)
for (auto& ml: z.cmlcap_->ml) {
for (int j = 0; j < ml.nodecount; ++j) {
auto* const node = ml.nodelist[j];
z.pv_[i] = node->v_handle();
z.pvdot_[i] = node->rhs_handle();
*z.pvdot_[i] = 0.; // only ones = 1 are no_cap
++i;
}
}

// the remainder are no_cap nodes
if (z.no_cap_node_) {
Expand Down
17 changes: 17 additions & 0 deletions test/cover/test_netcvode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
cv = h.CVode()
pc = h.ParallelContext()


# remove address info from cv.debug_event output
def debug_event_filter(s):
s = re.sub(r"cvode_0x[0-9abcdef]* ", "cvode_0x... ", s)
Expand Down Expand Up @@ -762,6 +763,21 @@ def nc_event_before_init():
expect_err("nc.event(0)") # nrn_assert triggered if outside of finitialize


def contiguous():
# cover a couple of lines in the lvardt part of init_global().
# Need same POINT_PROCESS type in rootnode and somewhere else on cell
net = Net(8)
syns = [h.ExpSyn(seg) for cell in net.cells for seg in cell.soma.allseg()]
cv.active(1)
cv.use_local_dt(1)
h.finitialize(-65)
pc.nthread(2)
h.finitialize(-65)
pc.nthread(1)
cv.use_local_dt(0)
cv.active(0)


def test_netcvode_cover():
nrn_use_daspk()
node()
Expand All @@ -774,6 +790,7 @@ def test_netcvode_cover():
scatter_gather()
playrecord()
interthread()
contiguous()
nc_event_before_init()


Expand Down
Loading