Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix arithmetic bug #368

Merged
merged 13 commits into from
Dec 1, 2023
4 changes: 3 additions & 1 deletion .github/workflows/ci-cmake_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ jobs:
- name: Build
shell: bash -l {0}
working-directory: ${{github.workspace}}/build
run: cmake --build . -j `nproc`
run: |
cmake --version
cmake --build . -j `nproc`

- name: Run CTest
shell: bash -l {0}
Expand Down
2 changes: 1 addition & 1 deletion Install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ echo ${FLAG}
mkdir build
cd build
cmake ../ ${FLAG} #-DDEV_MODE=on
make -j`nproc`
make -j${nproc}
make install
#if DRUN_TESTS=ON, run tests
shopt -s nocasematch
Expand Down
21 changes: 16 additions & 5 deletions src/linalg/Add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,12 +612,17 @@ namespace cytnx {
//============================================

cytnx::UniTensor Add(const cytnx::UniTensor &Lt, const cytnx::UniTensor &Rt) {
UniTensor out = Lt.clone();
UniTensor out;
if (Lt.dtype() > Rt.dtype()) {
out = Rt.clone();
out.Add_(Lt);
} else {
out = Lt.clone();
out.Add_(Rt);
}
out.set_labels(vec_range<std::string>(Lt.rank()));
out.set_name("");

jeffry1829 marked this conversation as resolved.
Show resolved Hide resolved
out.Add_(Rt);

return out;
}

Expand All @@ -628,11 +633,17 @@ namespace cytnx {
// cytnx_error_msg(Rt.is_tag(),"[ERROR] cannot perform arithmetic on tagged
// unitensor.%s","\n");

UniTensor out = Rt.clone();
UniTensor out;
if (Scalar(lc).dtype() < Rt.dtype()) {
out = Rt.astype(Scalar(lc).dtype());
out.Add_(lc);
} else {
out = Rt.clone();
out.Add_(lc);
}
// out.set_labels(vec_range<cytnx_int64>(Rt.rank()));
out.set_name("");

out.Add_(lc);
return out;
}

Expand Down
23 changes: 19 additions & 4 deletions src/linalg/Div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,6 +901,9 @@
//===============
cytnx::UniTensor Div(const cytnx::UniTensor &Lt, const cytnx::UniTensor &Rt) {
UniTensor out = Lt.clone();
if (Lt.dtype() > Rt.dtype()) {
out = out.astype(Rt.dtype());
}
out.set_labels(vec_range<std::string>(Lt.rank()));
out.set_name("");

Expand All @@ -914,11 +917,17 @@
// cytnx_error_msg(Rt.is_tag(),"[ERROR] cannot perform arithmetic on tagged
// unitensor.%s","\n");

UniTensor out = Rt.clone();
UniTensor out;

Check warning on line 920 in src/linalg/Div.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Div.cpp#L920

Added line #L920 was not covered by tests
if (Scalar(lc).dtype() < Rt.dtype()) {
out = Rt.astype(Scalar(lc).dtype());
out._impl->lDiv_(lc);
} else {
out = Rt.clone();
out._impl->lDiv_(lc);
}
// out.set_labels(vec_range<cytnx_int64>(Rt.rank()));
out.set_name("");

out._impl->lDiv_(lc);
return out;
}

Expand All @@ -942,11 +951,17 @@
// cytnx_error_msg(Lt.is_tag(),"[ERROR] cannot perform arithmetic on tagged
// unitensor.%s","\n");

UniTensor out = Lt.clone();
UniTensor out;
if (Lt.dtype() > Scalar(rc).dtype()) {
out = Lt.astype(Scalar(rc).dtype());
out.Div_(rc);
} else {
out = Lt.clone();
out.Div_(rc);
}
// out.set_labels(vec_range<cytnx_int64>(Lt.rank()));
out.set_name("");

out.Div_(rc);
return out;
}

Expand Down
21 changes: 16 additions & 5 deletions src/linalg/Mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,12 +661,17 @@ namespace cytnx {
//============================================

UniTensor Mul(const UniTensor &Lt, const UniTensor &Rt) {
UniTensor out = Lt.clone();
UniTensor out;
if (Lt.dtype() > Rt.dtype()) {
out = Rt.clone();
out.Mul_(Lt);
} else {
out = Lt.clone();
out.Mul_(Rt);
}
out.set_labels(vec_range<std::string>(Lt.rank()));
out.set_name("");

out.Mul_(Rt);

return out;
}

Expand All @@ -677,11 +682,17 @@ namespace cytnx {
// cytnx_error_msg(Rt.is_tag(),"[ERROR] cannot perform arithmetic on tagged
// unitensor.%s","\n");

UniTensor out = Rt.clone();
UniTensor out;
if (Scalar(lc).dtype() < Rt.dtype()) {
out = Rt.astype(Scalar(lc).dtype());
out.Mul_(lc);
} else {
out = Rt.clone();
out.Mul_(lc);
}
// out.set_labels(vec_range<cytnx_int64>(Rt.rank()));
out.set_name("");

out.Mul_(lc);
return out;
}

Expand Down
23 changes: 19 additions & 4 deletions src/linalg/Sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,9 @@
//===============
cytnx::UniTensor Sub(const cytnx::UniTensor &Lt, const cytnx::UniTensor &Rt) {
UniTensor out = Lt.clone();
if (Lt.dtype() > Rt.dtype()) {
out = out.astype(Rt.dtype());
}
out.set_labels(vec_range<std::string>(Lt.rank()));
out.set_name("");

Expand All @@ -926,11 +929,17 @@
// cytnx_error_msg(Rt.is_tag(),"[ERROR] cannot perform arithmetic on tagged
// unitensor.%s","\n");

UniTensor out = Rt.clone();
UniTensor out;

Check warning on line 932 in src/linalg/Sub.cpp

View check run for this annotation

Codecov / codecov/patch

src/linalg/Sub.cpp#L932

Added line #L932 was not covered by tests
if (Scalar(lc).dtype() < Rt.dtype()) {
out = Rt.astype(Scalar(lc).dtype());
out._impl->lSub_(lc);
} else {
out = Rt.clone();
out._impl->lSub_(lc);
}
// out.set_labels(vec_range<cytnx_int64>(Rt.rank()));
out.set_name("");

out._impl->lSub_(lc);
return out;
}

Expand All @@ -954,11 +963,17 @@
// cytnx_error_msg(Lt.is_tag(),"[ERROR] cannot perform arithmetic on tagged
// unitensor.%s","\n");

UniTensor out = Lt.clone();
UniTensor out;
if (Lt.dtype() > Scalar(rc).dtype()) {
out = Lt.astype(Scalar(rc).dtype());
out.Sub_(rc);
} else {
out = Lt.clone();
out.Sub_(rc);
}
// out.set_labels(vec_range<cytnx_int64>(Lt.rank()));
out.set_name("");

out.Sub_(rc);
return out;
}

Expand Down
Loading