Skip to content

Commit

Permalink
Merge branch 'master' into mfem_46_dev
Browse files Browse the repository at this point in the history
  • Loading branch information
sshiraiwa authored Jan 9, 2024
2 parents cf91ff9 + ce1ba0b commit 5e7e156
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions mfem/common/array_instantiation_macro.i
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ INSTANTIATE_ARRAY0(XXX, XXX, 0)
%define INSTANTIATE_ARRAY_BOOL
%template(boolArray) mfem::Array<bool>;
%extend mfem::Array<bool> {

PyObject * __getitem__(PyObject* param) {
int len = self->Size();
if (PySlice_Check(param)) {
Expand Down Expand Up @@ -306,22 +307,24 @@ INSTANTIATE_ARRAY0(XXX, XXX, 0)
%define INSTANTIATE_ARRAY_NUMPYARRAY(XXX, YYY, NP_TYPE)
%template(##XXX##Array) mfem::Array<YYY>;
%extend mfem::Array<YYY> {

PyObject * __getitem__(PyObject* param) {
int len = self->Size();
int len = self->Size();
if (PySlice_Check(param)) {
long start = 0, stop = 0, step = 0, slicelength = 0;
int check;

//%#ifdef TARGET_PY3
check = PySlice_GetIndicesEx(param, len, &start, &stop, &step,
&slicelength);
//%#ifdef TARGET_PY3
check = PySlice_GetIndicesEx(param, len, &start, &stop, &step,
&slicelength);
//%#else
//check = PySlice_GetIndicesEx((PySliceObject*)param, len, &start, &stop, &step,
// &slicelength);
//%#endif
//check = PySlice_GetIndicesEx((PySliceObject*)param, len, &start, &stop, &step,
// &slicelength);
//%#endif

if (check == -1) {
if (check == -1) {
PyErr_SetString(PyExc_ValueError, "Slicing mfem::Array<bool> failed.");

return NULL;
}
if (step == 1) {
Expand All @@ -332,12 +335,13 @@ INSTANTIATE_ARRAY0(XXX, XXX, 0)
PyErr_SetString(PyExc_ValueError, "Slicing mfem::Array<T> with stride>1 not supported.");
return NULL;
}

} else {
PyErr_Clear();
long idx = PyInt_AsLong(param);
if (PyErr_Occurred()) {
PyErr_SetString(PyExc_ValueError, "Argument must be either int or slice");
return NULL;
return NULL;
}
PyObject *np_val = NULL;
PyArray_Descr *descr = NULL;
Expand All @@ -352,10 +356,18 @@ INSTANTIATE_ARRAY0(XXX, XXX, 0)
} else {
data_ptr = &(self->operator[](idx+len));
}

np_val = PyArray_Scalar(data_ptr, descr, NULL);
return np_val;
}
}
PyObject* GetDataArray(void) const{
const YYY * A = self->GetData();
int L = self->Size();
npy_intp dims[] = {L};
return PyArray_SimpleNewFromData(1, dims, NP_TYPE, (void *)A);
}

};
%enddef

Expand Down

0 comments on commit 5e7e156

Please sign in to comment.