From abab8dfc7828bb28391bdd4757d39000dfca7ba8 Mon Sep 17 00:00:00 2001 From: Jochem Smit Date: Wed, 4 Dec 2024 15:05:38 +0100 Subject: [PATCH] Transform and agg (#13) * use transform / aggregate in two steps instead of hooks * remove constructors * update build workflow * delete old hooks * update template file * activate venv before running dont-fret * fix name * try claude's suggestion for pypi test workflow --- .github/workflows/build.yml | 23 ++--- .github/workflows/pypi_test.yml | 23 ++++- default_testing.yaml | 38 ++++++-- dont_fret/__main__.py | 98 ++++++++++++-------- dont_fret/aggregation.py | 85 +++++++++++++++++ dont_fret/config/__init__.py | 20 +++- dont_fret/config/config.py | 22 ++++- dont_fret/config/default.yaml | 44 +++++++-- dont_fret/config/default_web.yaml | 79 ++++++++++++++++ dont_fret/config/hooks.py | 17 ---- dont_fret/expr.py | 7 ++ dont_fret/models.py | 112 +++++------------------ dont_fret/process.py | 137 ++++++++++++++++++++-------- dont_fret/transform.py | 141 +++++++++++++++++++++++++++++ dont_fret/utils.py | 10 +- dont_fret/web/bursts/components.py | 4 +- dont_fret/web/datamanager.py | 12 +-- dont_fret/web/methods.py | 17 ++-- templates/01_load_datafile_1.py | 26 +++--- tests/test_models.py | 72 ++++++++++----- tests/test_web.py | 2 +- 21 files changed, 711 insertions(+), 278 deletions(-) create mode 100644 dont_fret/aggregation.py create mode 100644 dont_fret/config/default_web.yaml delete mode 100644 dont_fret/config/hooks.py create mode 100644 dont_fret/transform.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index c41d45e..e48d70b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -22,22 +22,19 @@ jobs: - name: Install uv run: pip install uv - - name: Create uv venv + - name: Create and activate venv, install package run: | uv venv - . .venv/bin/activate + source .venv/bin/activate + uv pip install dist/*.tar.gz + which dont-fret - - name: Install package using uv + - name: Download test file run: | - uv pip install dist/*.tar.gz + wget "https://filedn.eu/loRXwzWCNnU4XoFPGbllt1y/datafile_1.ptu" -O tests/test_data/input/ds1/datafile_1.ptu + - - name: Run dont fret serve + - name: Run dont fret process run: | - timeout 10s dont-fret serve || exit_code=$? - if [ $exit_code -eq 124 ]; then - echo "ran for 10 seconds without error" - exit 0 - else - echo "failed or exited too quickly" - exit 1 - fi \ No newline at end of file + source .venv/bin/activate + dont-fret process tests/test_data/input/ds1/datafile_1.ptu \ No newline at end of file diff --git a/.github/workflows/pypi_test.yml b/.github/workflows/pypi_test.yml index 6344faf..b7d46b1 100644 --- a/.github/workflows/pypi_test.yml +++ b/.github/workflows/pypi_test.yml @@ -11,16 +11,37 @@ jobs: uses: actions/checkout@v3 with: fetch-depth: 0 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: python-version: "3.10" + + - name: Configure Git + run: | + git config --global user.email "github-actions@github.com" + git config --global user.name "GitHub Actions" + + - name: Create test version tag + run: | + # Get number of commits in current branch + COMMIT_COUNT=$(git rev-list --count HEAD) + # Get short SHA + SHA=$(git rev-parse --short HEAD) + # Create a PEP 440 compliant version number + VERSION="0.2.1.dev${COMMIT_COUNT}" + # Create and push tag + git tag -a "v${VERSION}" -m "Test release ${VERSION}" + echo "Created tag v${VERSION}" + - name: Install Hatch run: pip install hatch + - name: Build run: hatch build + - name: Publish distribution 📦 to Test PyPI uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.TEST_PYPI_API_TOKEN }} - repository-url: https://test.pypi.org/legacy/ + repository-url: https://test.pypi.org/legacy/ \ No newline at end of file diff --git a/default_testing.yaml b/default_testing.yaml index d5ac65c..38d0de7 100644 --- a/default_testing.yaml +++ b/default_testing.yaml @@ -40,11 +40,37 @@ burst_search: M: 100 T: 500.e-6 -hooks: - alex_2cde: - tau: 50.e-6 # make sure to format as float +aggregations: + parse_expr: + exprs: + n_photons: len() + # same as parse expr above + # length: + # name: n_photons + stream_length: {} + stream_mean: + column: nanotimes + column_stats: # get timestamps min/max/mean + column: timestamps + stream_asymmetry: + lhs_streams: [DD, DA] + rhs_streams: [AA] + column: timestamps + +# dict of transforms to apply in order +transforms: + "alex_2cde:75": + tau: 75.e-6 + "alex_2cde:150": + tau: 150.e-6 fret_2cde: - tau: 50.e-6 + tau: 45.e-6 + with_columns: + exprs: + E_app: "n_DD / (n_DD + n_DA)" + S_app: "(n_DD + n_DA) / (n_DD + n_DA + n_AA)" + timestamps_length: "timestamps_max - timestamps_min" + # settings related to dont-fret's web interface web: @@ -53,6 +79,4 @@ web: protect_filebrowser: false # true to prevent navigation above default_dir burst_filters: # default filters to apply to burst search filters - name: n_photons - min: 150 - - name: alex_2cde - max: 100 + min: 150 \ No newline at end of file diff --git a/dont_fret/__main__.py b/dont_fret/__main__.py index 39c81f7..8719746 100644 --- a/dont_fret/__main__.py +++ b/dont_fret/__main__.py @@ -6,7 +6,7 @@ import yaml from solara.__main__ import run -from dont_fret.config import CONFIG_HOME, cfg +from dont_fret.config import CONFIG_DEFAULT_DIR, CONFIG_HOME_DIR, cfg, update_config_from_yaml from dont_fret.process import batch_search_and_save, search_and_save ROOT = Path(__file__).parent @@ -19,14 +19,33 @@ def cli(): pass +def find_config_file(config_path: Path) -> Optional[Path]: + if config_path.exists(): + return config_path + elif (pth := CONFIG_HOME_DIR / config_path).exists(): + return pth + elif (pth := CONFIG_DEFAULT_DIR / config_path).exists(): + return pth + + +def load_config(config_path: Path) -> None: + resolved_cfg_path = find_config_file(Path(config_path)) + if not resolved_cfg_path: + raise click.BadParameter(f"Configuration file '{config_path}' not found") + + update_config_from_yaml(resolved_cfg_path) + click.echo("Loading config file at: " + str(resolved_cfg_path)) + + @cli.command(context_settings=dict(ignore_unknown_options=True, allow_extra_args=True)) @click.option("--config", default=None, help="Configuration file to use") @click.argument("solara_args", nargs=-1, type=click.UNPROCESSED) def serve(config: Optional[str] = None, solara_args=None): """Run the don't fret web application.""" if config is not None: - data = yaml.safe_load(Path(config).read_text()) - cfg.update(data) + load_config(Path(config)) + else: + update_config_from_yaml(CONFIG_DEFAULT_DIR / "default_web.yaml") solara_args = solara_args or tuple() args = [str(APP_PATH), *solara_args] @@ -34,37 +53,9 @@ def serve(config: Optional[str] = None, solara_args=None): run(args) -@cli.command() -@click.option( - "--global", "is_global", is_flag=True, help="Create config file in user's home directory" -) -def config(is_global: bool): - """Create a local or global default configuration file.""" - src = ROOT / "config" / "default.yaml" - if is_global: - (CONFIG_HOME / "dont_fret").mkdir(exist_ok=True, parents=True) - output = CONFIG_HOME / "dont_fret" / "dont_fret.yaml" - else: - output = Path.cwd() / "dont_fret.yaml" - - if output.exists(): - click.echo(f"Configuration file already exists at '{str(output)}'") - return - - else: - output.write_text(src.read_text()) - - click.echo(f"Configuration file created at '{str(output)}'") - - -SUPPORTED_SUFFIXES = { - ".ptu", -} - - @cli.command() @click.argument("input_path", type=click.Path(exists=True)) -@click.option("--burst-colors", default=None, multiple=True, help="Burst colors to process") +@click.option("--config", default=None, help="Configuration file to use") @click.option( "--write-photons/--no-write-photons", default=False, help="Whether to write photon data" ) @@ -74,15 +65,14 @@ def config(is_global: bool): @click.option("--max-workers", type=int, default=None, help="Maximum number of worker threads") def process( input_path: str, - burst_colors: Optional[list[str]], - write_photons: bool, - output_type: Literal[".pq", ".csv"], - max_workers: Optional[int], + config: Optional[str] = None, + write_photons: bool = False, + output_type: Literal[".pq", ".csv"] = ".pq", + max_workers: Optional[int] = None, ): """Process photon file(s) and perform burst search.""" pth = Path(input_path) - if pth.is_file(): files = [pth] elif pth.is_dir(): @@ -96,14 +86,17 @@ def process( click.echo(f"Found {len(files)} file(s) to process.") + if config is not None: + load_config(Path(config)) + else: + update_config_from_yaml(CONFIG_DEFAULT_DIR / "default.yaml") + # Convert burst_colors to the expected format - burst_colors_param = list(burst_colors) if burst_colors else None if len(files) == 1: click.echo(f"Processing file: {files[0]}") search_and_save( files[0], - burst_colors=burst_colors_param, write_photons=write_photons, output_type=output_type, ) @@ -111,7 +104,6 @@ def process( click.echo("Processing files in batch mode.") batch_search_and_save( files, - burst_colors=burst_colors_param, write_photons=write_photons, output_type=output_type, max_workers=max_workers, @@ -122,3 +114,29 @@ def process( if __name__ == "__main__": cli() + + +@cli.command() +@click.option("--user", "user", is_flag=True, help="Create config file in user's home directory") +def config(user: bool): + """Create a local or global default configuration file.""" + src = ROOT / "config" / "default.yaml" + if user: + (CONFIG_HOME_DIR).mkdir(exist_ok=True, parents=True) + output = CONFIG_HOME_DIR / "dont_fret.yaml" + else: + output = Path.cwd() / "dont_fret.yaml" + + if output.exists(): + click.echo(f"Configuration file already exists at '{str(output)}'") + return + + else: + output.write_text(src.read_text()) + + click.echo(f"Configuration file created at '{str(output)}'") + + +SUPPORTED_SUFFIXES = { + ".ptu", +} diff --git a/dont_fret/aggregation.py b/dont_fret/aggregation.py new file mode 100644 index 0000000..a61e73e --- /dev/null +++ b/dont_fret/aggregation.py @@ -0,0 +1,85 @@ +from functools import wraps +from typing import Callable, Iterable, Optional + +import polars as pl + +from dont_fret.expr import parse_expression +from dont_fret.models import Bursts, PhotonData +from dont_fret.utils import suffice + +# photon aggregation hooks + + +# Global registry of transforms +aggregation_registry: dict[str, Callable] = {} + + +def aggregate(_func=None, *, name: Optional[str] = None) -> Callable: + """ + Decorator to register a aggregate function. + + Can be used as @aggregate or @aggregate(name="custom_name") + """ + + def decorator(func: Callable) -> Callable: + agg_name = name or func.__name__ + aggregation_registry[agg_name] = func + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + if _func is None: + return decorator + return decorator(_func) + + +@aggregate +def length(name="n_photons", suffix: str = "") -> list[pl.Expr]: + return [pl.len().alias(suffice(name, suffix))] + + +@aggregate +def stream_length(streams: Iterable[str], suffix: str = "") -> list[pl.Expr]: + """length of each stream (number of elements)""" + return [ + (pl.col("stream") == stream).sum().alias(suffice(f"n_{stream}", suffix)) + for stream in streams + ] + + +@aggregate +def stream_mean(streams: Iterable[str], column: str, suffix: str = "") -> list[pl.Expr]: + return [ + pl.col(column) + .filter(pl.col("stream") == stream) + .mean() + .alias(suffice(f"{column}_{stream}", suffix)) + for stream in streams + ] + + +@aggregate +def column_stats( + column: str, stat_funcs: list[str] = ["mean", "min", "max"], suffix: str = "" +) -> list[pl.Expr]: + return [ + getattr(pl.col(column), d)().alias(suffice(f"{column}_{d}", suffix)) for d in stat_funcs + ] + + +@aggregate +def parse_expr(exprs: dict[str, str], suffix: str = "") -> list[pl.Expr]: + return [parse_expression(v).alias(suffice(k, suffix)) for k, v in exprs.items()] + + +@aggregate +def stream_asymmetry( + lhs_streams: list[str], rhs_streams: list[str], column: str, suffix="" +) -> list[pl.Expr]: + value_lhs = pl.col(column).filter(pl.col("stream").is_in(lhs_streams)).mean() + value_rhs = pl.col(column).filter(pl.col("stream").is_in(rhs_streams)).mean() + + return [(value_lhs - value_rhs).alias(suffice("asymmetry", suffix))] diff --git a/dont_fret/config/__init__.py b/dont_fret/config/__init__.py index 93a8f16..b7e4422 100644 --- a/dont_fret/config/__init__.py +++ b/dont_fret/config/__init__.py @@ -1,3 +1,19 @@ -from .config import CONFIG_HOME, cfg +from .config import ( + CONFIG_DEFAULT, + CONFIG_DEFAULT_DIR, + CONFIG_HOME_DIR, + BurstColor, + DontFRETConfig, + cfg, + update_config_from_yaml, +) -__all__ = ["cfg", "CONFIG_HOME"] +__all__ = [ + "CONFIG_DEFAULT", + "CONFIG_DEFAULT_DIR", + "CONFIG_HOME_DIR", + "BurstColor", + "DontFRETConfig", + "cfg", + "update_config_from_yaml", +] diff --git a/dont_fret/config/config.py b/dont_fret/config/config.py index fbbca5c..faf64ab 100644 --- a/dont_fret/config/config.py +++ b/dont_fret/config/config.py @@ -12,7 +12,8 @@ from dont_fret.utils import clean_types -CONFIG_HOME = Path(os.getenv("XDG_CONFIG_HOME", Path.home() / ".config")) +CONFIG_HOME_DIR = Path(os.getenv("XDG_CONFIG_HOME", Path.home() / ".config")) / "dont-fret" +CONFIG_DEFAULT_DIR = Path(__file__).parent @dataclass @@ -66,7 +67,8 @@ class DontFRETConfig: channels: dict[str, Channel] streams: dict[str, list[str]] burst_search: dict[str, list[BurstColor]] - hooks: dict[str, dict[str, Any]] = field(default_factory=dict) + aggregations: dict[str, dict[str, Any]] = field(default_factory=dict) + transforms: dict[str, dict[str, Any]] = field(default_factory=dict) web: Web = field(default_factory=Web) @classmethod @@ -90,12 +92,24 @@ def update(self, data: Data): new_cfg = DontFRETConfig.from_dict(new_data) vars(self).update(vars(new_cfg)) + def copy(self) -> DontFRETConfig: + return DontFRETConfig.from_dict(asdict(self)) + + +def update_config_from_yaml(config_path: Path) -> None: + """Updates the global configuration object with settings from a YAML file.""" + + data = yaml.safe_load(config_path.read_text()) + cfg.update(data) + cfg_file_paths = [ - CONFIG_HOME / "dont_fret" / "dont_fret.yaml", - Path(__file__).parent / "default.yaml", + CONFIG_HOME_DIR / "dont_fret.yaml", + CONFIG_DEFAULT_DIR / "default.yaml", ] +CONFIG_DEFAULT = DontFRETConfig.from_yaml(CONFIG_DEFAULT_DIR / "default.yaml") + # take the first one which exists cfg_fpath = next((p for p in cfg_file_paths if p.exists()), None) assert cfg_fpath diff --git a/dont_fret/config/default.yaml b/dont_fret/config/default.yaml index 39ca05e..58e5766 100644 --- a/dont_fret/config/default.yaml +++ b/dont_fret/config/default.yaml @@ -34,20 +34,44 @@ burst_search: M: 100 T: 500.e-6 -# post-burst search hooks to apply -# hooks are of the form my_hook(burst_data, photon_data, **kwargs) -# kwargs are as specified here -hooks: - alex_2cde: - tau: 150.e-6 # make sure to format as float +# dict of aggregations to apply +aggregations: + parse_expr: + exprs: + n_photons: len() + # same as parse expr above + # length: + # name: n_photons + stream_length: {} + stream_mean: + column: nanotimes + column_stats: # get timestamps min/max/mean + column: timestamps + stream_asymmetry: + lhs_streams: [DD, DA] + rhs_streams: [AA] + column: timestamps + +# dict of transforms to apply in order +transforms: + "alex_2cde:75": + tau: 75.e-6 + "alex_2cde:150": + tau: 150.e-6 fret_2cde: - tau: 50.e-6 + tau: 45.e-6 + with_columns: + exprs: + E_app: "n_DA / (n_DD + n_DA)" + S_app: "(n_DD + n_DA) / (n_DD + n_DA + n_AA)" + timestamps_length: "timestamps_max - timestamps_min" + # settings related to dont-fret's web interface web: - password: null # set to null to disable password protection - default_dir: "~" # default directory show in the file browser - protect_filebrowser: true # true to prevent navigation above default_dir + password: null + default_dir: tests\test_data\input\ds2 # default directory show in the file browser + protect_filebrowser: false # true to prevent navigation above default_dir burst_filters: # default filters to apply to burst search filters - name: n_photons min: 150 diff --git a/dont_fret/config/default_web.yaml b/dont_fret/config/default_web.yaml new file mode 100644 index 0000000..fe4bc90 --- /dev/null +++ b/dont_fret/config/default_web.yaml @@ -0,0 +1,79 @@ +channels: # refactor channels in code to channel_identifiers + laser_D: + target: nanotimes + value: [ 0, 1000 ] + laser_A: + target: nanotimes + value: [ 1000, 2000 ] # intervals are inclusive, exclusive + det_D: + target: detectors + value: 1 + det_A: + target: detectors + value: 0 + +streams: + DD: [laser_D, det_D] + DA: [laser_D, det_A] + AA: [laser_A, det_A] + AD: [laser_A, det_D] + +burst_search: + DCBS: # name of the burst search + - streams: [DD, DA] # photons streams to use + L: 50 + M: 35 + T: 500.e-6 + - streams: [AA] + L: 50 + M: 35 + T: 500.e-6 + APBS: + - streams: [DD, DA, AA] + L: 50 + M: 100 + T: 500.e-6 + +# dict of aggregations to apply +aggregations: + parse_expr: + exprs: + n_photons: len() + # same as parse expr above + # length: + # name: n_photons + stream_length: {} + stream_mean: + column: nanotimes + column_stats: # get timestamps min/max/mean + column: timestamps + stream_asymmetry: + lhs_streams: [DD, DA] + rhs_streams: [AA] + column: timestamps + +# dict of transforms to apply in order +transforms: + "alex_2cde:75": + tau: 75.e-6 + "alex_2cde:150": + tau: 150.e-6 + fret_2cde: + tau: 45.e-6 + with_columns: + exprs: + E_app: "n_DA / (n_DD + n_DA)" + S_app: "(n_DD + n_DA) / (n_DD + n_DA + n_AA)" + timestamps_length: "timestamps_max - timestamps_min" + convert_timestamps: {} # converts timestamps columns to seconds + convert_nanotimes: {} # convert nanotime columns to nanoseconds + +# settings related to dont-fret's web interface +web: + password: null + default_dir: tests\test_data\input\ds2 # default directory show in the file browser + protect_filebrowser: false # true to prevent navigation above default_dir + burst_filters: # default filters to apply to burst search filters + - name: n_photons + min: 150 + diff --git a/dont_fret/config/hooks.py b/dont_fret/config/hooks.py deleted file mode 100644 index ad61f35..0000000 --- a/dont_fret/config/hooks.py +++ /dev/null @@ -1,17 +0,0 @@ -from dont_fret.models import Bursts, PhotonData - - -def fret_2cde( - bursts: Bursts, - photon_data: PhotonData, - tau: float = 50e-6, - dem_stream: str = "DD", - aem_stream: str = "DA", -) -> Bursts: - return bursts.fret_2cde(photon_data, tau=tau, dem_stream=dem_stream, aem_stream=aem_stream) - - -def alex_2cde( - bursts: Bursts, photon_data: PhotonData, tau: float = 50e-6, dex_streams=None, aex_streams=None -) -> Bursts: - return bursts.alex_2cde(photon_data, tau=tau, dex_streams=dex_streams, aex_streams=aex_streams) diff --git a/dont_fret/expr.py b/dont_fret/expr.py index a175424..a2e27e3 100644 --- a/dont_fret/expr.py +++ b/dont_fret/expr.py @@ -58,6 +58,13 @@ def evaluate_node(node): return pl.col(node.id) elif isinstance(node, ast.Constant): return pl.lit(node.n) + # args/kwargs not supported + elif isinstance(node, ast.Call): # example: len() => pl.len() + try: + func = getattr(pl, node.func.id) + except AttributeError: + raise ValueError(f"Unsupported function: {node.func.id}") + return func() elif isinstance(node, ast.BinOp): left = evaluate_node(node.left) right = evaluate_node(node.right) diff --git a/dont_fret/models.py b/dont_fret/models.py index 89a2a63..c19be98 100644 --- a/dont_fret/models.py +++ b/dont_fret/models.py @@ -23,14 +23,6 @@ class PhotonData: """Base object for timestamp data - - does not have identified channels - and access the structured array with properties - - timestamps: ndarray int timestamps (global resolution) - detectors: ndarray int - nanotimes, optional: ndarray int - metadata: dict; whatever contents """ @@ -201,7 +193,7 @@ def to_file(self): """write to photon-hdf5 file""" ... - def burst_search(self, colors: Union[str, list[BurstColor]]) -> Bursts: + def burst_search(self, colors: Union[str, list[BurstColor]]) -> pl.DataFrame: """ Search for bursts in the photon data. @@ -260,14 +252,14 @@ def burst_search(self, colors: Union[str, list[BurstColor]]) -> Bursts: if len(final_times) == 0: # No overlap found burst_photons = pl.DataFrame({k: [] for k in self.data.columns + ["burst_index"]}) - indices = pl.DataFrame({"imin": [], "imax": []}) + # indices = pl.DataFrame({"imin": [], "imax": []}) else: tmin, tmax = np.array(final_times).T # Convert back to indices imin = np.searchsorted(self.timestamps, tmin) imax = np.searchsorted(self.timestamps, tmax) - indices = pl.DataFrame({"imin": imin, "imax": imax}) + # indices = pl.DataFrame({"imin": imin, "imax": imax}) # take all photons (up to and including? edges need to be checked!) b_num = int(2 ** np.ceil(np.log2((np.log2(len(imin)))))) index_dtype = getattr(pl, f"UInt{b_num}", pl.Int32) @@ -279,9 +271,7 @@ def burst_search(self, colors: Union[str, list[BurstColor]]) -> Bursts: ] burst_photons = pl.concat(bursts) - bs = Bursts.from_photons(burst_photons, metadata=self.metadata) - - return bs + return burst_photons class BinnedPhotonData: @@ -376,81 +366,6 @@ class Bursts: metadata: Optional[dict] = None cfg: Optional[DontFRETConfig] = None - @classmethod - def from_photons( - cls, - photon_data: pl.DataFrame, - metadata: Optional[dict] = None, - cfg: DontFRETConfig = global_cfg, - ) -> Bursts: - # implement as hooks - - # number of photons per stream per burst - agg = [(pl.col("stream") == stream).sum().alias(f"n_{stream}") for stream in cfg.streams] - - # mean nanotimes per stream per burst - agg.extend( - [ - pl.when(pl.col("stream") == k) - .then(pl.col("nanotimes")) - .mean() - .alias(f"nanotimes_{k}") - for k in cfg.streams - ] - ) - - agg.append(pl.col("timestamps").mean().alias("timestamps_mean")) - agg.append(pl.col("timestamps").min().alias("timestamps_min")) - agg.append(pl.col("timestamps").max().alias("timestamps_max")) - - t_unit = metadata.get("timestamps_unit", None) - if t_unit is not None: - acceptor_streams = ["DA", "AA"] - donor_streams = ["DD"] - - ts_acceptor = ( - pl.col("timestamps").filter(pl.col("stream").is_in(acceptor_streams)).mean() - ) - ts_donor = pl.col("timestamps").filter(pl.col("stream").is_in(donor_streams)).mean() - asymmetry = ((ts_acceptor - ts_donor) * t_unit).alias("asymmetry") - - agg.append(asymmetry) - - # TODO configure via hooks - columns = [ - (pl.col("n_DA") / (pl.col("n_DD") + pl.col("n_DA"))).alias("E_app"), - ( - (pl.col("n_DA") + pl.col("n_DD")) - / (pl.col("n_DD") + pl.col("n_DA") + pl.col("n_AA")) - ).alias("S_app"), - photon_data["burst_index"].unique_counts().alias("n_photons"), - ] - - # yaml config via eval? - if t_unit is not None: - columns.extend( - [ - (pl.col("timestamps_mean") * t_unit).alias("time_mean"), - ((pl.col("timestamps_max") - pl.col("timestamps_min")) * t_unit).alias( - "time_length" - ), - ] - ) - nanotimes_unit = metadata.get("nanotimes_unit", None) - if nanotimes_unit is not None: - columns.extend( - [ - pl.col(f"nanotimes_{stream}").mul(nanotimes_unit).alias(f"tau_{stream}") - for stream in cfg.streams - ] - ) - - burst_data = ( - photon_data.group_by("burst_index", maintain_order=True).agg(agg).with_columns(columns) - ) - - return Bursts(burst_data, photon_data, metadata=metadata, cfg=cfg) - @classmethod def load(cls, directory: Path) -> Bursts: burst_data = pl.read_parquet(directory / "burst_data.pq") @@ -480,6 +395,7 @@ def fret_2cde( tau: float = 50e-6, dem_stream: str = "DD", aem_stream: str = "DA", + alias="fret_2cde", ) -> Bursts: if self.burst_data.is_empty(): burst_data = self.burst_data.with_columns(pl.lit(None).alias("fret_2cde")) @@ -499,7 +415,7 @@ def fret_2cde( ) fret_2cde = compute_fret_2cde(self.photon_data, kde_data) - burst_data = self.burst_data.with_columns(pl.lit(fret_2cde).alias("fret_2cde")) + burst_data = self.burst_data.with_columns(pl.lit(fret_2cde).alias(alias)) return Bursts(burst_data, self.photon_data, self.metadata, self.cfg) @@ -509,6 +425,7 @@ def alex_2cde( tau: float = 50e-6, dex_streams: Optional[list[str]] = None, aex_streams: Optional[list[str]] = None, + alias="alex_2cde", ) -> Bursts: if self.burst_data.is_empty(): burst_data = self.burst_data.with_columns(pl.lit(None).alias("alex_2cde")) @@ -532,10 +449,23 @@ def alex_2cde( ) alex_2cde = compute_alex_2cde(self.photon_data, kde_data) - burst_data = self.burst_data.with_columns(pl.lit(alex_2cde).alias("alex_2cde")) + burst_data = self.burst_data.with_columns(pl.lit(alex_2cde).alias(alias)) return Bursts(burst_data, self.photon_data, self.metadata, self.cfg) + def with_columns(self, columns: list[pl.Expr]) -> Bursts: + return Bursts( + self.burst_data.with_columns(columns), self.photon_data, self.metadata, self.cfg + ) + + def drop(self, columns: list[str]) -> Bursts: + return Bursts( + self.burst_data.drop(columns), + self.photon_data, + self.metadata, + self.cfg, + ) + def __len__(self) -> int: """Number of bursts""" return len(self.burst_data) diff --git a/dont_fret/process.py b/dont_fret/process.py index 2d2ed1a..291f667 100644 --- a/dont_fret/process.py +++ b/dont_fret/process.py @@ -2,59 +2,50 @@ import importlib from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import asdict +from inspect import signature from pathlib import Path from typing import Literal, Optional import polars as pl from tqdm.auto import tqdm -from dont_fret.config import cfg -from dont_fret.config.config import BurstColor +from dont_fret.aggregation import aggregation_registry +from dont_fret.config import CONFIG_DEFAULT, CONFIG_DEFAULT_DIR, BurstColor, DontFRETConfig, cfg from dont_fret.fileIO import PhotonFile from dont_fret.models import Bursts, PhotonData +from dont_fret.transform import transform_registry + +# TODO pass cfg; need also to pass to `from_file` +# not sure if we should go down this road def search_and_save( file: Path, - burst_colors: str | list[str] | None = None, write_photons: bool = True, output_type: Literal[".pq", ".csv"] = ".pq", ) -> None: """ Performs burst search on the supplied file and saves burst search output to disk. + Uses global cfg object """ + photons = PhotonData.from_file(PhotonFile(file)) - if burst_colors is None: - colors = cfg.burst_search.keys() - elif isinstance(burst_colors, str): - colors = [burst_colors] - elif isinstance(burst_colors, list): - colors = burst_colors output_dir = file.parent - for color in colors: - bursts = photons.burst_search(color) + for bs_name, burst_colors in cfg.burst_search.items(): + bursts = process_photon_data(photons, burst_colors) + if write_photons: write_dataframe( - bursts.photon_data, output_dir / f"{file.stem}_{color}_photon_data{output_type}" + bursts.photon_data, output_dir / f"{file.stem}_{bs_name}_photon_data{output_type}" ) write_dataframe( - bursts.burst_data, output_dir / f"{file.stem}_{color}_burst_data{output_type}" + bursts.burst_data, output_dir / f"{file.stem}_{bs_name}_burst_data{output_type}" ) -def write_dataframe(df: pl.DataFrame, path: Path) -> None: - """Write a dataframe to disk. Writer used depends on path suffix.""" - if path.suffix == ".pq": - df.write_parquet(path) - elif path.suffix == ".csv": - df.write_csv(path) - else: - raise ValueError(f"Unsupported output type: {path.suffix}") - - def batch_search_and_save( files: list[Path], - burst_colors: str | list[str] | None = None, write_photons: bool = True, output_type: Literal[".pq", ".csv"] = ".pq", max_workers: Optional[int] = None, @@ -66,28 +57,96 @@ def batch_search_and_save( futures = [] with ThreadPoolExecutor(max_workers=max_workers) as executor: for f in files: - fut = executor.submit(search_and_save, f, burst_colors, write_photons, output_type) + fut = executor.submit(search_and_save, f, write_photons, output_type) futures.append(fut) for f in tqdm(as_completed(futures), total=len(futures)): f.result() +def write_dataframe(df: pl.DataFrame, path: Path) -> None: + """Write a dataframe to disk. Writer used depends on path suffix.""" + if path.suffix == ".pq": + df.write_parquet(path) + elif path.suffix == ".csv": + df.write_csv(path) + else: + raise ValueError(f"Unsupported output type: {path.suffix}") + + +def apply_aggregations( + burst_photons: pl.DataFrame, + aggregations: dict = cfg.aggregations, +) -> pl.DataFrame: + agg_fixtures = {"streams": cfg.streams} + + agg = [] + for agg_name, agg_params in aggregations.items(): + split = agg_name.split(":", 1) + base_name = split[0] + suffix = split[1] if len(split) > 1 else "" + + agg_fn = aggregation_registry[base_name] + kwargs = agg_params + + sig = signature(agg_fn) + matching_fixtures = {k: agg_fixtures[k] for k in sig.parameters if k in agg_fixtures} + kwargs = {**agg_params, **matching_fixtures} + + agg_expr = agg_fn(**kwargs) + if isinstance(agg_expr, list): + agg.extend(agg_expr) + else: + agg.append(agg_expr) + + burst_data = burst_photons.group_by("burst_index", maintain_order=True).agg(agg) + return burst_data + + +def apply_transformations( + bursts: Bursts, + photon_data: PhotonData, + transforms: dict = cfg.transforms, +) -> Bursts: + # one global dict for fixtures? + trs_fixtures = {"photon_data": photon_data} + for trs_name, trs_params in transforms.items(): + split = trs_name.split(":", 1) + base_name = split[0] + + suffix = split[1] if len(split) > 1 else "" + transform_fn = transform_registry[base_name] + sig = signature(transform_fn) + matching_fixtures = {k: trs_fixtures[k] for k in sig.parameters if k in trs_fixtures} + kwargs = {**trs_params, **matching_fixtures} + + bursts = transform_fn(bursts, suffix=suffix, **kwargs) + + return bursts + + def process_photon_data( - photon_data: PhotonData, burst_colors: list[BurstColor], hooks: dict = cfg.hooks + photon_data: PhotonData, + burst_colors: list[BurstColor], + aggregations: Optional[dict] = None, + transforms: Optional[dict] = None, ) -> Bursts: - """search and apply hooks""" - bursts = photon_data.burst_search(burst_colors) - - for hook_name, hook_params in hooks.items(): - try: - hook = getattr(importlib.import_module("dont_fret.config.hooks"), hook_name) - except AttributeError: - try: - hook = getattr(importlib.import_module("hooks"), hook_name) - except (ImportError, AttributeError): - raise ValueError(f"Hook '{hook_name}' not found") - - bursts = hook(bursts, photon_data, **hook_params) + """search and apply agg/trfms to burst data + by default uses config from photon_data object + + """ + local_cfg_data = asdict(photon_data.cfg) + if aggregations is not None: + local_cfg_data["aggregations"] = aggregations + if transforms is not None: + local_cfg_data["transforms"] = transforms + local_cfg = DontFRETConfig.from_dict(local_cfg_data) + burst_photons = photon_data.burst_search(burst_colors) + + # evaluate aggregations + burst_data = apply_aggregations(burst_photons, local_cfg.aggregations) + bursts = Bursts(burst_data, burst_photons, photon_data.metadata, local_cfg) + + bursts = apply_transformations(bursts, photon_data, local_cfg.transforms) return bursts diff --git a/dont_fret/transform.py b/dont_fret/transform.py new file mode 100644 index 0000000..0eb85f8 --- /dev/null +++ b/dont_fret/transform.py @@ -0,0 +1,141 @@ +# transforms.py +from functools import wraps +from typing import Any, Callable, Concatenate, Optional, ParamSpec, Protocol, Tuple + +import polars as pl + +from dont_fret.config.config import DontFRETConfig +from dont_fret.expr import parse_expression +from dont_fret.models import Bursts, PhotonData +from dont_fret.utils import suffice + +# Global registry of transforms +transform_registry: dict[str, Callable] = {} + + +TIME_UNITS = { + "ps": 1e12, + "ns": 1e9, + "us": 1e6, + "ms": 1e3, + "s": 1, + "min": 60, + "h": 3600, +} + + +def transform(_func=None, *, name: Optional[str] = None) -> Callable: + """ + Decorator to register a transform function. + + Can be used as @transform or @transform(name="custom_name") + """ + + def decorator(func: Callable) -> Callable: + transform_name = name or func.__name__ + transform_registry[transform_name] = func + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + if _func is None: + return decorator + return decorator(_func) + + +@transform +def fret_2cde( + bursts: Bursts, + photon_data: PhotonData, + tau: float = 45e-6, + dem_stream: str = "DD", + aem_stream: str = "DA", + suffix="", +) -> Bursts: + alias = suffice("fret_2cde", suffix) + return bursts.fret_2cde( + photon_data, tau=tau, dem_stream=dem_stream, aem_stream=aem_stream, alias=alias + ) + + +@transform +def alex_2cde( + bursts: Bursts, + photon_data: PhotonData, + tau: float = 75e-6, + dex_streams=None, + aex_streams=None, + suffix="", +) -> Bursts: + alias = suffice("alex_2cde", suffix) + + return bursts.alex_2cde( + photon_data, + tau=tau, + dex_streams=dex_streams, + aex_streams=aex_streams, + alias=alias, + ) + + +@transform +def convert_nanotimes( + bursts: Bursts, time_unit: str = "ns", keep_columns: bool = False, suffix: str = "" +) -> Bursts: + assert bursts.metadata + timestamps_unit = bursts.metadata["nanotimes_unit"] + factor = timestamps_unit * TIME_UNITS[time_unit] + + burst_data = bursts.burst_data + timestamps_columns = [col for col in burst_data.columns if col.startswith("nanotimes")] + conversion = [ + (pl.col(col) * factor).alias(suffice(col.replace("nanotimes", "tau"), suffix)) + for col in timestamps_columns + ] + burst_data = burst_data.with_columns(conversion) + if not keep_columns: + burst_data = burst_data.drop(timestamps_columns) + + return Bursts(burst_data, bursts.photon_data, bursts.metadata, bursts.cfg) + + +@transform +def convert_timestamps( + bursts: Bursts, time_unit: str = "s", keep_columns: bool = False, suffix: str = "" +) -> Bursts: + assert bursts.metadata + timestamps_unit = bursts.metadata["timestamps_unit"] + factor = timestamps_unit * TIME_UNITS[time_unit] + + burst_data = bursts.burst_data + timestamps_columns = [col for col in burst_data.columns if col.startswith("timestamps")] + conversion = [ + (pl.col(col) * factor).alias(suffice(col.replace("timestamps", "time"), suffix)) + for col in timestamps_columns + ] + burst_data = burst_data.with_columns(conversion) + if not keep_columns: + burst_data = burst_data.drop(timestamps_columns) + + return Bursts(burst_data, bursts.photon_data, bursts.metadata, bursts.cfg) + + +@transform +def with_columns(bursts: Bursts, exprs: dict[str, pl.Expr | str], suffix="") -> Bursts: + parsed_exprs = [] + for k, v in exprs.items(): + alias = suffice(k, suffix) + if isinstance(v, str): + parsed_exprs.append(parse_expression(v).alias(alias)) + else: + parsed_exprs.append(v.alias(alias)) + + return bursts.with_columns(parsed_exprs) + + +@transform +def drop(bursts: Bursts, columns: list[str], suffix="") -> Bursts: + return bursts.drop(columns) diff --git a/dont_fret/utils.py b/dont_fret/utils.py index 132e8f8..b6a0713 100644 --- a/dont_fret/utils.py +++ b/dont_fret/utils.py @@ -1,9 +1,17 @@ from __future__ import annotations + from collections import OrderedDict from pathlib import Path +from typing import Any import numpy as np -from typing import Any + + +def suffice(name: str, suffix: str = "") -> str: + if suffix: + return f"{name}_{suffix}" + else: + return name def clean_types(d: Any) -> Any: diff --git a/dont_fret/web/bursts/components.py b/dont_fret/web/bursts/components.py index a8bcb09..46913c8 100644 --- a/dont_fret/web/bursts/components.py +++ b/dont_fret/web/bursts/components.py @@ -269,7 +269,9 @@ def PlotSettingsEditDialog( duration: Optional[float] = None, ): copy = solara.use_reactive(plot_settings.value) - img, set_img = solara.use_state(cast(Optional[BinnedImage], None)) + img, set_img = solara.use_state( + cast(Optional[BinnedImage], None) + ) # BinnedImage.from_settings(df, copy.value) ? items = order_columns(df.columns) drop_cols = ["filename", "burst_index"] diff --git a/dont_fret/web/datamanager.py b/dont_fret/web/datamanager.py index cfd90a8..33e8f36 100644 --- a/dont_fret/web/datamanager.py +++ b/dont_fret/web/datamanager.py @@ -79,7 +79,11 @@ async def get_bursts( try: photon_data = await self.get_photons(photon_node) bursts = await self.run( - process_photon_data, photon_data, burst_colors, self.cfg.hooks + process_photon_data, + photon_data, + burst_colors, + self.cfg.aggregations, + self.cfg.transforms, ) future.set_result(bursts) @@ -90,12 +94,6 @@ async def get_bursts( return await self.burst_cache[key] - async def search(self, node: PhotonNode, colors: list[BurstColor]) -> Bursts: - photon_data = await self.get_photons(node) - bursts = photon_data.burst_search(colors) - - return bursts - async def get_bursts_batch( self, photon_nodes: list[PhotonNode], diff --git a/dont_fret/web/methods.py b/dont_fret/web/methods.py index 0cb95c2..581ec74 100644 --- a/dont_fret/web/methods.py +++ b/dont_fret/web/methods.py @@ -46,24 +46,23 @@ def make_burst_dataframe( return concat -# hooks? def make_burst_nodes( photon_nodes: list[PhotonNode], burst_settings: dict[str, list[BurstColor]], - hooks: Optional[dict[str, dict[str, Any]]] = None, + aggregations: Optional[dict] = None, + transforms: Optional[dict] = None, ) -> list[BurstNode]: photons = [PhotonData.from_file(PhotonFile(node.file_path)) for node in photon_nodes] burst_nodes = [] # todo tqdm? - hooks = hooks or {} for name, burst_colors in burst_settings.items(): - bursts = [process_photon_data(photon_data, burst_colors, hooks) for photon_data in photons] - # bursts = [photons.burst_search(burst_colors) for photons in photons] - # if alex_2cde: - # bursts = [b.alex_2cde(photons) for b, photons in zip(bursts, photons)] - # if fret_2cde: - # bursts = [b.fret_2cde(photons) for b, photons in zip(bursts, photons)] + bursts = [ + process_photon_data( + photon_data, burst_colors, aggregations=aggregations, transforms=transforms + ) + for photon_data in photons + ] infos = [get_info(photons) for photons in photons] duration = get_duration(infos) diff --git a/templates/01_load_datafile_1.py b/templates/01_load_datafile_1.py index 60b291b..75ade36 100644 --- a/templates/01_load_datafile_1.py +++ b/templates/01_load_datafile_1.py @@ -1,26 +1,28 @@ +# %% + +# %load_ext autoreload +# %autoreload 2 + # %% from pathlib import Path +# %% +from dont_fret.config import CONFIG_DEFAULT_DIR, cfg, update_config_from_yaml from dont_fret.fileIO import PhotonFile from dont_fret.models import PhotonData +from dont_fret.process import process_photon_data # %% cwd = Path(__file__).parent test_data_dir = cwd.parent / "tests" / "test_data" / "input" / "ds1" +output_data_dir = cwd.parent / "tests" / "test_data" / "output" ptu_file = "datafile_1.ptu" # %% -photons = PhotonData.from_file(PhotonFile(test_data_dir / ptu_file)) - # %% +# select a config +update_config_from_yaml(CONFIG_DEFAULT_DIR / "default_web.yaml") -bursts = photons.burst_search("DCBS") -bursts.burst_data - -# %% - -bursts = bursts.alex_2cde(photons).fret_2cde(photons) - -bursts.burst_data - -# %% +# process photon data to get DCBS bursts +photons = PhotonData.from_file(PhotonFile(test_data_dir / ptu_file)) +bursts = process_photon_data(photons, cfg.burst_search["DCBS"], cfg.aggregations, cfg.transforms) diff --git a/tests/test_models.py b/tests/test_models.py index f951e9e..0d0f9c8 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,8 +3,11 @@ import numpy as np import polars as pl +import polars.testing as pl_test import pytest +import yaml +from dont_fret.config import cfg from dont_fret.config.config import BurstColor from dont_fret.fileIO import PhotonFile from dont_fret.models import BinnedPhotonData, Bursts, PhotonData @@ -32,12 +35,14 @@ def ph_ds1() -> PhotonData: @pytest.fixture def dcbs_bursts(ph_ds1: PhotonData) -> Bursts: - return ph_ds1.burst_search("DCBS") + bursts = process_photon_data(ph_ds1, cfg.burst_search["DCBS"]) + return bursts @pytest.fixture def apbs_bursts(ph_ds1: PhotonData) -> Bursts: - return ph_ds1.burst_search(APBS_TEST) + bursts = process_photon_data(ph_ds1, APBS_TEST) + return bursts @pytest.fixture @@ -88,39 +93,60 @@ def test_load_save_bursts(dcbs_bursts: Bursts, tmp_path: Path): def test_burst_search(ph_ds1: PhotonData): - search_args = ["DCBS", APBS_TEST] + search_args = [cfg.burst_search["DCBS"], APBS_TEST] reference_files = ["dcbs_bursts.csv", "apbs_bursts.csv"] + # only calculate E, S, time; no cde's + transforms = { + "with_columns": { + "exprs": { + "E_app": "n_DA / (n_DD + n_DA)", + "S_app": "(n_DD + n_DA) / (n_DD + n_DA + n_AA)", + "timestamps_length": "timestamps_max - timestamps_min", + } + } + } + for bs_arg, ref_file in zip(search_args, reference_files): - bs = ph_ds1.burst_search(bs_arg) + bursts = process_photon_data(ph_ds1, bs_arg, transforms=transforms) pth = output_data_dir / "ds1" / ref_file - # bs = ph_ds1.burst_search("DCBS") - # pth = output_data_dir / "ds1" / "dcbs_bursts.csv" - df_ref = pl.read_csv(pth) - df_test = bs.burst_data.filter(pl.col("n_photons") > 50) + df_test = bursts.burst_data.filter(pl.col("n_photons") > 50) for k in ["n_photons", "E_app", "S_app"]: - assert (df_ref[k] == df_test[k]).all() + pl_test.assert_series_equal( + df_test[k], df_ref[k], check_dtypes=False, check_names=False + ) - time_length = ( - df_test["timestamps_max"] - df_test["timestamps_min"] - ) * ph_ds1.timestamps_unit - assert (df_ref["time_length"] == time_length).all() - assert (df_ref["time_min"] == df_test["timestamps_min"] * ph_ds1.timestamps_unit).all() - assert (df_ref["time_max"] == df_test["timestamps_max"] * ph_ds1.timestamps_unit).all() + time_length = df_test["timestamps_length"] * ph_ds1.timestamps_unit + time_min = df_test["timestamps_min"] * ph_ds1.timestamps_unit + time_max = df_test["timestamps_max"] * ph_ds1.timestamps_unit + pl_test.assert_series_equal(time_length, df_ref["time_length"], check_names=False) + pl_test.assert_series_equal(time_min, df_ref["time_min"], check_names=False) + pl_test.assert_series_equal(time_max, df_ref["time_max"], check_names=False) def test_process_photon_data(ph_ds1: PhotonData): - hooks = { - "alex_2cde": {}, - "fret_2cde": {}, - } - - bursts = process_photon_data(ph_ds1, APBS_TEST, hooks=hooks) - assert "alex_2cde" in bursts.burst_data.columns - assert "fret_2cde" in bursts.burst_data.columns + s = """ + "alex_2cde:75": + tau: 75.e-6 + "alex_2cde:150": + tau: 150.e-6 + fret_2cde: + tau: 45.e-6 + with_columns: + exprs: + E_app: "n_DA / (n_DD + n_DA)" + S_app: "(n_DD + n_DA) / (n_DD + n_DA + n_AA)" + timestamps_length: "timestamps_max - timestamps_min" + """ + + transforms = yaml.safe_load(s) + + bursts = process_photon_data(ph_ds1, APBS_TEST, transforms=transforms) + assert "alex_2cde_75" in bursts.burst_data.columns + assert "alex_2cde_150" in bursts.burst_data.columns def test_binning(ph_ds1): diff --git a/tests/test_web.py b/tests/test_web.py index 44abba9..38f7eaa 100644 --- a/tests/test_web.py +++ b/tests/test_web.py @@ -72,7 +72,7 @@ async def test_burst_search(): assert new_node.bursts burst_item = new_node.bursts[0] assert burst_item.name == "DCBS" - assert burst_item.df.shape == (72, 24) + assert burst_item.df.shape == (72, 21) assert burst_item.df["filename"].unique()[0] == "datafile_1.ptu" await asyncio.sleep(0)