Skip to content

Commit

Permalink
Merge branch 'main' into gx2f-update-params
Browse files Browse the repository at this point in the history
  • Loading branch information
AJPfleger authored Nov 11, 2024
2 parents 89ff33a + bf3faa3 commit b081d14
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 31 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ jobs:
- name: Check
run: >
CI/check_spelling
math_macros:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Check
run: >
CI/check_math_macros.py . --exclude "thirdparty/*"
missing_includes:
runs-on: ubuntu-latest
steps:
Expand Down
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,11 @@ repos:
name: Leftover conflict markers
language: system
entry: git diff --staged --check

- repo: local
hooks:
- id: math_macros
name: math_macros
language: system
entry: CI/check_math_macros.py
files: \.(cpp|hpp|ipp|cu|cuh)$
120 changes: 120 additions & 0 deletions CI/check_math_macros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#!/usr/bin/env python3

from pathlib import Path
import os
import argparse
from fnmatch import fnmatch
import re
import sys


math_constants = [
("M_PI", "std::numbers::pi"),
("M_PI_2", "std::numbers::pi / 2."),
("M_PI_4", "std::numbers::pi / 4."),
("M_1_PI", "std::numbers::inv_pi"),
("M_2_PI", "2. * std::numbers::inv_pi"),
("M_2_SQRTPI", "2. * std::numbers::inv_sqrtpi"),
("M_E", "std::numbers::e"),
("M_LOG2E", "std::numbers::log2e"),
("M_LOG10E", "std::numbers::log10e"),
("M_LN2", "std::numbers::ln2"),
("M_LN10", "std::numbers::ln10"),
("M_SQRT2", "std::numbers::sqrt2"),
("M_SQRT1_2", "1. / std::numbers::sqrt2"),
("M_SQRT3", "std::numbers::sqrt3"),
("M_INV_SQRT3", "std::numbers::inv_sqrt3"),
("M_EGAMMA", "std::numbers::egamma"),
("M_PHI", "std::numbers::phi"),
]


github = "GITHUB_ACTIONS" in os.environ


def handle_file(
file: Path, fix: bool, math_const: tuple[str, str]
) -> list[tuple[int, str]]:
ex = re.compile(rf"(?<!\w){math_const[0]}(?!\w)")

content = file.read_text()
lines = content.splitlines()

changed_lines = []

for i, oline in enumerate(lines):
line, n_subs = ex.subn(rf"{math_const[1]}", oline)
lines[i] = line
if n_subs > 0:
changed_lines.append((i, oline))

if fix and len(changed_lines) > 0:
file.write_text("\n".join(lines) + "\n")

return changed_lines


def main():
p = argparse.ArgumentParser()
p.add_argument("input", nargs="+")
p.add_argument("--fix", action="store_true", help="Attempt to fix M_* macros.")
p.add_argument("--exclude", "-e", action="append", default=[])

args = p.parse_args()

exit_code = 0

inputs = []

if len(args.input) == 1 and os.path.isdir(args.input[0]):
# walk over all files
for root, _, files in os.walk(args.input[0]):
root = Path(root)
for filename in files:
# get the full path of the file
filepath = root / filename
if filepath.suffix not in (
".hpp",
".cpp",
".ipp",
".h",
".C",
".c",
".cu",
".cuh",
):
continue

if any([fnmatch(str(filepath), e) for e in args.exclude]):
continue

inputs.append(filepath)
else:
for file in args.input:
inputs.append(Path(file))

for filepath in inputs:
for math_const in math_constants:
changed_lines = handle_file(
file=filepath, fix=args.fix, math_const=math_const
)
if len(changed_lines) > 0:
exit_code = 1
print()
print(filepath)
for i, oline in changed_lines:
print(f"{i}: {oline}")

if github:
print(
f"::error file={filepath},line={i+1},title=Do not use macro {math_const[0]}::Replace {math_const[0]} with std::{math_const[1]}"
)

if exit_code == 1 and github:
print(f"::info You will need in each flagged file #include <numbers>")

return exit_code


if "__main__" == __name__:
sys.exit(main())
84 changes: 53 additions & 31 deletions Core/include/Acts/TrackFitting/GlobalChiSquareFitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,8 @@ void addMeasurementToGx2fSums(Gx2fSystem& extendedSystem,
}

// Create an extended Jacobian. This one contains only eBoundSize rows,
// because the rest is irrelevant. We fill it in the next steps
// because the rest is irrelevant. We fill it in the next steps.
// TODO make dimsExtendedParams template with unrolling

// We create an empty Jacobian and fill it in the next steps
Eigen::MatrixXd extendedJacobian =
Eigen::MatrixXd::Zero(eBoundSize, extendedSystem.nDims());

Expand Down Expand Up @@ -566,10 +564,11 @@ void addMaterialToGx2fSums(
///
/// @tparam track_proxy_t The type of the track proxy
///
/// @param track A mutable track proxy to operate on
/// @param track A constant track proxy to inspect
/// @param extendedSystem All parameters of the current equation system
/// @param multipleScattering Flag to consider multiple scattering in the calculation
/// @param scatteringMap Map of geometry identifiers to scattering properties, containing all scattering angles and covariances
/// @param scatteringMap Map of geometry identifiers to scattering properties,
/// containing scattering angles and validation status
/// @param geoIdVector A vector to store geometry identifiers for tracking processed elements
/// @param logger A logger instance
template <TrackProxyConcept track_proxy_t>
Expand Down Expand Up @@ -650,6 +649,51 @@ void fillGx2fSystem(
}
}

/// @brief Count the valid material states in a track for scattering calculations.
///
/// This function counts the valid material surfaces encountered in a track
/// by examining each track state. The count is based on the presence of
/// material flags and the availability of scattering information for each
/// surface.
///
/// @tparam track_proxy_t The type of the track proxy
///
/// @param track A constant track proxy to inspect
/// @param scatteringMap Map of geometry identifiers to scattering properties,
/// containing scattering angles and validation status
/// @param logger A logger instance
template <TrackProxyConcept track_proxy_t>
std::size_t countMaterialStates(
const track_proxy_t track,
const std::unordered_map<GeometryIdentifier, ScatteringProperties>&
scatteringMap,
const Logger& logger) {
std::size_t nMaterialSurfaces = 0;
ACTS_DEBUG("Count the valid material surfaces.");
for (const auto& trackState : track.trackStates()) {
const auto typeFlags = trackState.typeFlags();
const bool stateHasMaterial = typeFlags.test(TrackStateFlag::MaterialFlag);

if (!stateHasMaterial) {
continue;
}

// Get and store geoId for the current material surface
const GeometryIdentifier geoId = trackState.referenceSurface().geometryId();

const auto scatteringMapId = scatteringMap.find(geoId);
assert(scatteringMapId != scatteringMap.end() &&
"No scattering angles found for material surface.");
if (!scatteringMapId->second.materialIsValid()) {
continue;
}

nMaterialSurfaces++;
}

return nMaterialSurfaces;
}

/// @brief Update parameters (and scattering angles if applicable)
///
/// @param params Parameters to be updated
Expand Down Expand Up @@ -1310,32 +1354,10 @@ class Gx2Fitter {

// Count the material surfaces, to set up the system. In the multiple
// scattering case, we need to extend our system.
std::size_t nMaterialSurfaces = 0;
if (multipleScattering) {
ACTS_DEBUG("Count the valid material surfaces.");
for (const auto& trackState : track.trackStates()) {
const auto typeFlags = trackState.typeFlags();
const bool stateHasMaterial =
typeFlags.test(TrackStateFlag::MaterialFlag);

if (!stateHasMaterial) {
continue;
}

// Get and store geoId for the current material surface
const GeometryIdentifier geoId =
trackState.referenceSurface().geometryId();

const auto scatteringMapId = scatteringMap.find(geoId);
assert(scatteringMapId != scatteringMap.end() &&
"No scattering angles found for material surface.");
if (!scatteringMapId->second.materialIsValid()) {
continue;
}

nMaterialSurfaces++;
}
}
const std::size_t nMaterialSurfaces =
multipleScattering
? countMaterialStates(track, scatteringMap, *m_addToSumLogger)
: 0u;

// We need 6 dimensions for the bound parameters and 2 * nMaterialSurfaces
// dimensions for the scattering angles.
Expand Down

0 comments on commit b081d14

Please sign in to comment.