Skip to content

Commit

Permalink
more tidy
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Adye authored and Tim Adye committed Nov 4, 2024
1 parent e262ea2 commit fae8e92
Showing 1 changed file with 58 additions and 22 deletions.
80 changes: 58 additions & 22 deletions Examples/Scripts/Python/full_chain_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
#!/usr/bin/env python3

import sys, os

# needed if this script is a symlink to another directory
sys.path.insert(0, os.path.dirname(__file__))
import sys, os, argparse, pathlib
import acts, acts.examples


def parse_args():
import argparse, pathlib
import seeding
from acts.examples.reconstruction import SeedingAlgorithm

parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -37,42 +33,49 @@ def parse_args():
"--events",
type=int,
default=100,
help="The number of events to process (default=100).",
help="The number of events to process (default=%(default)d).",
)
parser.add_argument(
"-s",
"--skip",
type=int,
default=0,
help="Number of events to skip (default=0)",
help="Number of events to skip (default=%(default)d)",
)
parser.add_argument(
"-N",
"--gen-nparticles",
type=int,
default=4,
help="Number of generated particles per vertex from the particle gun (default=4).",
help="Number of generated particles per vertex from the particle gun (default=%(default)d).",
)
parser.add_argument(
"-M",
"--gen-nvertices",
type=int,
default=200,
help="Number of vertices per event (multiplicity) from the particle gun; or number of pileup events (default=200)",
help="Number of vertices per event (multiplicity) from the particle gun; or number of pileup events (default=%(default)d)",
)
parser.add_argument(
"-j",
"--jobs",
"--threads",
type=int,
default=-1,
help="Number of parallel jobs, negative for automatic (default).",
help="Number of parallel threads, negative for automatic (default).",
)
parser.add_argument(
"-t",
"--ttbar-pu200",
action="store_true",
help="Generate ttbar + mu=200 pile-up using Pythia8",
)
parser.add_argument(
"-r",
"--random-seed",
type=int,
default=42,
help="Random number seed (default=%(default)d)",
)
parser.add_argument(
"-l",
"--loglevel",
Expand Down Expand Up @@ -105,9 +108,9 @@ def parse_args():
help="Directory to write outputs to",
)
parser.add_argument(
"-a",
"--algorithm",
action=seeding.EnumAction,
"-S",
"--seeding-algorithm",
action=EnumAction,
enum=SeedingAlgorithm,
default=SeedingAlgorithm.Default,
help="Select the seeding algorithm to use",
Expand Down Expand Up @@ -188,7 +191,7 @@ def parse_args():
type=str,
choices=["greedy", "scoring", "ML", "none"],
default="greedy",
help="Set which ambiguity solver to use (default=greedy)",
help="Set which ambiguity solver to use (default=%(default)s)",
)
parser.add_argument(
"--ambi-config",
Expand All @@ -209,8 +212,6 @@ def parse_args():
def full_chain(args):
# keep these in memory after we return the sequence
global detector, trackingGeometry, decorators, field, rnd
import pathlib
import acts, acts.examples

if args.dump_args_calls:
acts.examples.dump_args_calls(locals())
Expand Down Expand Up @@ -492,7 +493,7 @@ def full_chain(args):
VertexFinder,
)

if args.itk and args.algorithm == SeedingAlgorithm.Default:
if args.itk and args.seeding_algorithm == SeedingAlgorithm.Default:
seedingAlgConfig = itk.itkSeedingAlgConfig(
itk.InputSpacePointsType.PixelSpacePoints
)
Expand All @@ -507,7 +508,7 @@ def full_chain(args):
ParticleSmearingSigmas(
ptRel=0.01
), # only needed for SeedingAlgorithm.TruthSmeared
seedingAlgorithm=args.algorithm,
seedingAlgorithm=args.seeding_algorithm,
rnd=rnd, # only needed for SeedingAlgorithm.TruthSmeared
initialSigmas=[
1 * u.mm,
Expand Down Expand Up @@ -564,7 +565,7 @@ def full_chain(args):
**(dict(
seedDeduplication=True,
stayOnSeed=True,
) if not args.simple_ckf and args.algorithm != SeedingAlgorithm.TruthSmeared else {}),
) if not args.simple_ckf and args.seeding_algorithm != SeedingAlgorithm.TruthSmeared else {}),
**(dict(
pixelVolumes=[16, 17, 18],
stripVolumes=[23, 24, 25],
Expand Down Expand Up @@ -602,7 +603,7 @@ def full_chain(args):
**(dict(
seedDeduplication=True,
stayOnSeed=True,
) if args.algorithm != SeedingAlgorithm.TruthSmeared else {}),
) if args.seeding_algorithm != SeedingAlgorithm.TruthSmeared else {}),
# ITk volumes from Noemi's plot
pixelVolumes=[8, 9, 10, 13, 14, 15, 16, 18, 19, 20],
stripVolumes=[22, 23, 24],
Expand Down Expand Up @@ -697,4 +698,39 @@ def full_chain(args):
return s


# Graciously taken from https://stackoverflow.com/a/60750535/4280680 (via seeding.py)
class EnumAction(argparse.Action):
"""
Argparse action for handling Enums
"""

def __init__(self, **kwargs):
import enum

# Pop off the type value
enum_type = kwargs.pop("enum", None)

# Ensure an Enum subclass is provided
if enum_type is None:
raise ValueError("type must be assigned an Enum when using EnumAction")
if not issubclass(enum_type, enum.Enum):
raise TypeError("type must be an Enum when using EnumAction")

# Generate choices from the Enum
kwargs.setdefault("choices", tuple(e.name for e in enum_type))

super(EnumAction, self).__init__(**kwargs)

self._enum = enum_type

def __call__(self, parser, namespace, values, option_string=None):
for e in self._enum:
if e.name == values:
setattr(namespace, self.dest, e)
break
else:
raise ValueError("%s is not a validly enumerated algorithm." % values)


# main program: parse arguments, setup sequence, and run the full chain
full_chain(parse_args()).run()

0 comments on commit fae8e92

Please sign in to comment.