Skip to content

Commit

Permalink
Transform and agg (#13)
Browse files Browse the repository at this point in the history
* use transform / aggregate in two steps instead of hooks

* remove constructors

* update build workflow

* delete old hooks

* update template file

* activate venv before running dont-fret

* fix name

* try claude's suggestion for pypi test workflow
  • Loading branch information
Jhsmit authored Dec 4, 2024
1 parent c47ec0d commit abab8df
Show file tree
Hide file tree
Showing 21 changed files with 711 additions and 278 deletions.
23 changes: 10 additions & 13 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,19 @@ jobs:
- name: Install uv
run: pip install uv

- name: Create uv venv
- name: Create and activate venv, install package
run: |
uv venv
. .venv/bin/activate
source .venv/bin/activate
uv pip install dist/*.tar.gz
which dont-fret
- name: Install package using uv
- name: Download test file
run: |
uv pip install dist/*.tar.gz
wget "https://filedn.eu/loRXwzWCNnU4XoFPGbllt1y/datafile_1.ptu" -O tests/test_data/input/ds1/datafile_1.ptu
- name: Run dont fret serve
- name: Run dont fret process
run: |
timeout 10s dont-fret serve || exit_code=$?
if [ $exit_code -eq 124 ]; then
echo "ran for 10 seconds without error"
exit 0
else
echo "failed or exited too quickly"
exit 1
fi
source .venv/bin/activate
dont-fret process tests/test_data/input/ds1/datafile_1.ptu
23 changes: 22 additions & 1 deletion .github/workflows/pypi_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,37 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: "3.10"

- name: Configure Git
run: |
git config --global user.email "[email protected]"
git config --global user.name "GitHub Actions"
- name: Create test version tag
run: |
# Get number of commits in current branch
COMMIT_COUNT=$(git rev-list --count HEAD)
# Get short SHA
SHA=$(git rev-parse --short HEAD)
# Create a PEP 440 compliant version number
VERSION="0.2.1.dev${COMMIT_COUNT}"
# Create and push tag
git tag -a "v${VERSION}" -m "Test release ${VERSION}"
echo "Created tag v${VERSION}"
- name: Install Hatch
run: pip install hatch

- name: Build
run: hatch build

- name: Publish distribution 📦 to Test PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
repository-url: https://test.pypi.org/legacy/
repository-url: https://test.pypi.org/legacy/
38 changes: 31 additions & 7 deletions default_testing.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,37 @@ burst_search:
M: 100
T: 500.e-6

hooks:
alex_2cde:
tau: 50.e-6 # make sure to format as float
aggregations:
parse_expr:
exprs:
n_photons: len()
# same as parse expr above
# length:
# name: n_photons
stream_length: {}
stream_mean:
column: nanotimes
column_stats: # get timestamps min/max/mean
column: timestamps
stream_asymmetry:
lhs_streams: [DD, DA]
rhs_streams: [AA]
column: timestamps

# dict of transforms to apply in order
transforms:
"alex_2cde:75":
tau: 75.e-6
"alex_2cde:150":
tau: 150.e-6
fret_2cde:
tau: 50.e-6
tau: 45.e-6
with_columns:
exprs:
E_app: "n_DD / (n_DD + n_DA)"
S_app: "(n_DD + n_DA) / (n_DD + n_DA + n_AA)"
timestamps_length: "timestamps_max - timestamps_min"


# settings related to dont-fret's web interface
web:
Expand All @@ -53,6 +79,4 @@ web:
protect_filebrowser: false # true to prevent navigation above default_dir
burst_filters: # default filters to apply to burst search filters
- name: n_photons
min: 150
- name: alex_2cde
max: 100
min: 150
98 changes: 58 additions & 40 deletions dont_fret/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import yaml

Check failure on line 6 in dont_fret/__main__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

dont_fret/__main__.py:6:8: F401 `yaml` imported but unused
from solara.__main__ import run

from dont_fret.config import CONFIG_HOME, cfg
from dont_fret.config import CONFIG_DEFAULT_DIR, CONFIG_HOME_DIR, cfg, update_config_from_yaml

Check failure on line 9 in dont_fret/__main__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

dont_fret/__main__.py:9:67: F401 `dont_fret.config.cfg` imported but unused
from dont_fret.process import batch_search_and_save, search_and_save

ROOT = Path(__file__).parent
Expand All @@ -19,52 +19,43 @@ def cli():
pass


def find_config_file(config_path: Path) -> Optional[Path]:
if config_path.exists():
return config_path
elif (pth := CONFIG_HOME_DIR / config_path).exists():
return pth
elif (pth := CONFIG_DEFAULT_DIR / config_path).exists():
return pth


def load_config(config_path: Path) -> None:
resolved_cfg_path = find_config_file(Path(config_path))
if not resolved_cfg_path:
raise click.BadParameter(f"Configuration file '{config_path}' not found")

update_config_from_yaml(resolved_cfg_path)
click.echo("Loading config file at: " + str(resolved_cfg_path))


@cli.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True))
@click.option("--config", default=None, help="Configuration file to use")
@click.argument("solara_args", nargs=-1, type=click.UNPROCESSED)
def serve(config: Optional[str] = None, solara_args=None):
"""Run the don't fret web application."""
if config is not None:
data = yaml.safe_load(Path(config).read_text())
cfg.update(data)
load_config(Path(config))
else:
update_config_from_yaml(CONFIG_DEFAULT_DIR / "default_web.yaml")

solara_args = solara_args or tuple()
args = [str(APP_PATH), *solara_args]

run(args)


@cli.command()
@click.option(
"--global", "is_global", is_flag=True, help="Create config file in user's home directory"
)
def config(is_global: bool):
"""Create a local or global default configuration file."""
src = ROOT / "config" / "default.yaml"
if is_global:
(CONFIG_HOME / "dont_fret").mkdir(exist_ok=True, parents=True)
output = CONFIG_HOME / "dont_fret" / "dont_fret.yaml"
else:
output = Path.cwd() / "dont_fret.yaml"

if output.exists():
click.echo(f"Configuration file already exists at '{str(output)}'")
return

else:
output.write_text(src.read_text())

click.echo(f"Configuration file created at '{str(output)}'")


SUPPORTED_SUFFIXES = {
".ptu",
}


@cli.command()
@click.argument("input_path", type=click.Path(exists=True))
@click.option("--burst-colors", default=None, multiple=True, help="Burst colors to process")
@click.option("--config", default=None, help="Configuration file to use")
@click.option(
"--write-photons/--no-write-photons", default=False, help="Whether to write photon data"
)
Expand All @@ -74,15 +65,14 @@ def config(is_global: bool):
@click.option("--max-workers", type=int, default=None, help="Maximum number of worker threads")
def process(
input_path: str,
burst_colors: Optional[list[str]],
write_photons: bool,
output_type: Literal[".pq", ".csv"],
max_workers: Optional[int],
config: Optional[str] = None,
write_photons: bool = False,
output_type: Literal[".pq", ".csv"] = ".pq",
max_workers: Optional[int] = None,
):
"""Process photon file(s) and perform burst search."""

pth = Path(input_path)

if pth.is_file():
files = [pth]
elif pth.is_dir():
Expand All @@ -96,22 +86,24 @@ def process(

click.echo(f"Found {len(files)} file(s) to process.")

if config is not None:
load_config(Path(config))
else:
update_config_from_yaml(CONFIG_DEFAULT_DIR / "default.yaml")

# Convert burst_colors to the expected format
burst_colors_param = list(burst_colors) if burst_colors else None

if len(files) == 1:
click.echo(f"Processing file: {files[0]}")
search_and_save(
files[0],
burst_colors=burst_colors_param,
write_photons=write_photons,
output_type=output_type,
)
else:
click.echo("Processing files in batch mode.")
batch_search_and_save(
files,
burst_colors=burst_colors_param,
write_photons=write_photons,
output_type=output_type,
max_workers=max_workers,
Expand All @@ -122,3 +114,29 @@ def process(

if __name__ == "__main__":
cli()


@cli.command()
@click.option("--user", "user", is_flag=True, help="Create config file in user's home directory")
def config(user: bool):
"""Create a local or global default configuration file."""
src = ROOT / "config" / "default.yaml"
if user:
(CONFIG_HOME_DIR).mkdir(exist_ok=True, parents=True)
output = CONFIG_HOME_DIR / "dont_fret.yaml"
else:
output = Path.cwd() / "dont_fret.yaml"

if output.exists():
click.echo(f"Configuration file already exists at '{str(output)}'")
return

else:
output.write_text(src.read_text())

click.echo(f"Configuration file created at '{str(output)}'")


SUPPORTED_SUFFIXES = {
".ptu",
}
85 changes: 85 additions & 0 deletions dont_fret/aggregation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from functools import wraps
from typing import Callable, Iterable, Optional

import polars as pl

from dont_fret.expr import parse_expression
from dont_fret.models import Bursts, PhotonData

Check failure on line 7 in dont_fret/aggregation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

dont_fret/aggregation.py:7:30: F401 `dont_fret.models.Bursts` imported but unused

Check failure on line 7 in dont_fret/aggregation.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

dont_fret/aggregation.py:7:38: F401 `dont_fret.models.PhotonData` imported but unused
from dont_fret.utils import suffice

# photon aggregation hooks


# Global registry of transforms
aggregation_registry: dict[str, Callable] = {}


def aggregate(_func=None, *, name: Optional[str] = None) -> Callable:
"""
Decorator to register a aggregate function.
Can be used as @aggregate or @aggregate(name="custom_name")
"""

def decorator(func: Callable) -> Callable:
agg_name = name or func.__name__
aggregation_registry[agg_name] = func

@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

return wrapper

if _func is None:
return decorator
return decorator(_func)


@aggregate
def length(name="n_photons", suffix: str = "") -> list[pl.Expr]:
return [pl.len().alias(suffice(name, suffix))]


@aggregate
def stream_length(streams: Iterable[str], suffix: str = "") -> list[pl.Expr]:
"""length of each stream (number of elements)"""
return [
(pl.col("stream") == stream).sum().alias(suffice(f"n_{stream}", suffix))
for stream in streams
]


@aggregate
def stream_mean(streams: Iterable[str], column: str, suffix: str = "") -> list[pl.Expr]:
return [
pl.col(column)
.filter(pl.col("stream") == stream)
.mean()
.alias(suffice(f"{column}_{stream}", suffix))
for stream in streams
]


@aggregate
def column_stats(
column: str, stat_funcs: list[str] = ["mean", "min", "max"], suffix: str = ""
) -> list[pl.Expr]:
return [
getattr(pl.col(column), d)().alias(suffice(f"{column}_{d}", suffix)) for d in stat_funcs
]


@aggregate
def parse_expr(exprs: dict[str, str], suffix: str = "") -> list[pl.Expr]:
return [parse_expression(v).alias(suffice(k, suffix)) for k, v in exprs.items()]


@aggregate
def stream_asymmetry(
lhs_streams: list[str], rhs_streams: list[str], column: str, suffix=""
) -> list[pl.Expr]:
value_lhs = pl.col(column).filter(pl.col("stream").is_in(lhs_streams)).mean()
value_rhs = pl.col(column).filter(pl.col("stream").is_in(rhs_streams)).mean()

return [(value_lhs - value_rhs).alias(suffice("asymmetry", suffix))]
20 changes: 18 additions & 2 deletions dont_fret/config/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
from .config import CONFIG_HOME, cfg
from .config import (
CONFIG_DEFAULT,
CONFIG_DEFAULT_DIR,
CONFIG_HOME_DIR,
BurstColor,
DontFRETConfig,
cfg,
update_config_from_yaml,
)

__all__ = ["cfg", "CONFIG_HOME"]
__all__ = [
"CONFIG_DEFAULT",
"CONFIG_DEFAULT_DIR",
"CONFIG_HOME_DIR",
"BurstColor",
"DontFRETConfig",
"cfg",
"update_config_from_yaml",
]
Loading

0 comments on commit abab8df

Please sign in to comment.