diff --git a/dont_fret/web/datamanager.py b/dont_fret/web/datamanager.py index b1bbb46..c63e552 100644 --- a/dont_fret/web/datamanager.py +++ b/dont_fret/web/datamanager.py @@ -3,16 +3,12 @@ import asyncio import dataclasses import json -import threading import uuid -from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor from typing import ( Callable, - Generic, + Dict, Optional, - ParamSpec, - TypeVar, ) import numpy as np @@ -24,219 +20,13 @@ from dont_fret.web.methods import get_duration, get_info from dont_fret.web.models import BurstNode, PhotonNode -T = TypeVar("T") -R = TypeVar("R") - - -class DaskDataManager: - pass - - -K = TypeVar("K") -V = TypeVar("V") - - -# TODO make async, lock per key -class Cache(Generic[K, V]): - """superclass for caches""" - - def __init__(self): - self._cache: OrderedDict[K, V] = OrderedDict() - self.lock = threading.Lock() - - def __contains__(self, key: K) -> bool: - with self.lock: - return key in self._cache - - def __getitem__(self, key: K) -> V: - with self.lock: - return self._cache[key] - - def get(self, key: K) -> Optional[V]: - with self.lock: - if key in self._cache: - self._cache.move_to_end(key) - return self._cache[key] - return None - - def set(self, key: K, value: V) -> None: - with self.lock: - self._cache[key] = value - self._cache.move_to_end(key) # in case it already existed - - def clear(self) -> None: - with self.lock: - self._cache.clear() - - -# altnernative for threadpool -# however, we do need control over the number of workers - -# Current implementation with ThreadPoolExecutor -# import asyncio -# from concurrent.futures import ThreadPoolExecutor - -# class ThreadedDataManager: -# def __init__(self, max_workers: int = 4) -> None: -# self.executor = ThreadPoolExecutor(max_workers=max_workers) -# # ... other initializations ... - -# def run(self, func, *args): -# return self.loop.run_in_executor(self.executor, func, *args) - -# async def get_photons(self, node: PhotonNode) -> PhotonData: -# # ... (lock handling code) ... -# photons = await self.run(PhotonData.from_file, PhotonFile(node.file_path)) -# # ... (cache update code) ... -# return photons - -# # Equivalent implementation using threading.Thread -# import threading - -# class ThreadBasedDataManager: -# def __init__(self) -> None: -# self.active_threads = set() -# # ... other initializations ... - -# def run(self, func, *args): -# future = asyncio.Future() - -# def thread_func(): -# try: -# result = func(*args) -# self.loop.call_soon_threadsafe(future.set_result, result) -# except Exception as e: -# self.loop.call_soon_threadsafe(future.set_exception, e) -# finally: -# self.active_threads.remove(thread) - -# thread = threading.Thread(target=thread_func) -# self.active_threads.add(thread) -# thread.start() -# return future - -# async def get_photons(self, node: PhotonNode) -> PhotonData: -# # ... (lock handling code) ... -# photons = await self.run(PhotonData.from_file, PhotonFile(node.file_path)) -# # ... (cache update code) ... -# return photons - -# def __del__(self): -# for thread in self.active_threads: -# thread.join() - -# per-key lock will also solve the problem of two threads asking for the same photon item object -# claude suggestion for cache: -# class AsyncCache(Generic[K, V]): -# def __init__(self): -# self._data: Dict[K, V] = {} -# self._locks: Dict[K, asyncio.Lock] = {} - -# async def get(self, key: K) -> V | None: -# if key not in self._locks: -# self._locks[key] = asyncio.Lock() - -# async with self._locks[key]: -# return self._data.get(key) - -# async def set(self, key: K, value: V) -> None: -# if key not in self._locks: -# self._locks[key] = asyncio.Lock() - -# async with self._locks[key]: -# self._data[key] = value - -# claude suggestiosn for locks on photon items: -# import asyncio -# import uuid -# from concurrent.futures import ThreadPoolExecutor -# from typing import Dict - -# from your_module import Cache, PhotonData, PhotonNode, PhotonFile - -# class ThreadedDataManager: -# def __init__(self, max_workers: int = 4) -> None: -# self.photon_cache = Cache[uuid.UUID, PhotonData]() -# self.burst_cache = Cache[tuple[uuid.UUID, str], Bursts]() -# self.executor = ThreadPoolExecutor(max_workers=max_workers) -# self._loop = None -# self.photon_locks: Dict[uuid.UUID, asyncio.Lock] = {} - -# @property -# def loop(self) -> asyncio.AbstractEventLoop: -# if self._loop is None: -# self._loop = asyncio.get_event_loop() -# return self._loop - -# def run(self, func, *args): -# return self.loop.run_in_executor(self.executor, func, *args) - -# async def get_photons(self, node: PhotonNode) -> PhotonData: -# if node.id not in self.photon_locks: -# self.photon_locks[node.id] = asyncio.Lock() - -# async with self.photon_locks[node.id]: -# photons = self.photon_cache.get(node.id) -# if photons is None: -# photons = await self.run(PhotonData.from_file, PhotonFile(node.file_path)) -# self.photon_cache.set(node.id, photons) - -# return photons - -# ... (rest of the class remains the same) - -# testing: -# import asyncio - -# class SyncDataManager: -# def __init__(self): -# self.async_manager = AsyncDataManager() -# self.loop = asyncio.new_event_loop() -# asyncio.set_event_loop(self.loop) - -# def __getattr__(self, name): -# async_attr = getattr(self.async_manager, name) -# if callable(async_attr): -# def sync_wrapper(*args, **kwargs): -# return self.loop.run_until_complete(async_attr(*args, **kwargs)) -# return sync_wrapper -# return async_attr - -# def __del__(self): -# self.loop.close() - -# # Usage in tests: -# def test_get_photons(): -# manager = SyncDataManager() -# node = PhotonNode(id=uuid.uuid4(), file_path="test_file.photon") -# photons = manager.get_photons(node) -# assert isinstance(photons, PhotonData) -# # ... more assertions - -# move functional approach: -# import asyncio -# import uuid -# from typing import Dict, Callable, Any - -# # Type aliases for clarity -# PhotonCache = Dict[uuid.UUID, PhotonData] -# BurstCache = Dict[tuple[uuid.UUID, str], Bursts] - -# async def get_photons(photon_cache: PhotonCache, node: PhotonNode) -> PhotonData: -# if node.id not in photon_cache: -# photons = await PhotonData.from_file_async(PhotonFile(node.file_path)) -# photon_cache[node.id] = photons -# return photon_cache[node.id] - -T = TypeVar("T") -P = ParamSpec("P") - class ThreadedDataManager: def __init__(self) -> None: - self.photon_cache = Cache[uuid.UUID, PhotonData]() - self.burst_cache = Cache[tuple[uuid.UUID, str], Bursts]() + self.photon_cache: Dict[uuid.UUID, asyncio.Future[PhotonData]] = {} + self.burst_cache: Dict[tuple[uuid.UUID, str], asyncio.Future[Bursts]] = {} self.executor = ThreadPoolExecutor(max_workers=4) # todo config wokers + self.running_jobs = {} # todo allow passing loop to init @property @@ -244,18 +34,28 @@ def loop(self) -> asyncio.AbstractEventLoop: return asyncio.get_running_loop() def run(self, func, *args): - # TODO typing / implement + # TODO typing return self.loop.run_in_executor(self.executor, func, *args) async def get_photons(self, node: PhotonNode) -> PhotonData: - # TODO if another thread is already generating these photons it should wait for - # that task to finish - photons = self.photon_cache.get(node.id) - if photons is None: - photons = await self.run(PhotonData.from_file, PhotonFile(node.file_path)) - self.photon_cache.set(node.id, photons) - - return photons + # Check if we have a future in the cache + if node.id not in self.photon_cache: + # Create a future for this job + future = self.loop.create_future() + self.photon_cache[node.id] = future + + try: + # Run the actual data loading in a thread + photons = await self.run(PhotonData.from_file, PhotonFile(node.file_path)) + future.set_result(photons) + except Exception as e: + # If there's an error, remove the future from cache and propagate the error + self.photon_cache.pop(node.id) + future.set_exception(e) + raise + + # Wait for and return the result + return await self.photon_cache[node.id] async def get_info(self, node: PhotonNode) -> dict: if node.info is not None: @@ -277,10 +77,24 @@ async def get_bursts( burst_colors: list[BurstColor], ) -> Bursts: key = self.burst_key(photon_node, burst_colors) - bursts = self.burst_cache.get(key) - if bursts is None: - bursts = await self.search(photon_node, burst_colors) - self.burst_cache.set(key, bursts) + + if key not in self.burst_cache: + future = self.loop.create_future() + self.burst_cache[key] = future + + try: + bursts = await self.search(photon_node, burst_colors) + future.set_result(bursts) + except Exception as e: + self.burst_cache.pop(key) + future.set_exception(e) + raise + + return await self.burst_cache[key] + + async def search(self, node: PhotonNode, colors: list[BurstColor]) -> Bursts: + photon_data = await self.get_photons(node) + bursts = photon_data.burst_search(colors) return bursts @@ -306,15 +120,6 @@ async def get_bursts_batch( return results - async def search(self, node: PhotonNode, colors: list[BurstColor]) -> Bursts: - """performs burst search and stores the result in the burst cache""" - photon_data = await self.get_photons(node) - bursts = photon_data.burst_search(colors) - - self.burst_cache.set(self.burst_key(node, colors), bursts) - return bursts - - # then get the results and concatenate async def get_dataframe( self, photon_nodes: list[PhotonNode], @@ -324,28 +129,16 @@ async def get_dataframe( on_progress = on_progress or (lambda _: None) on_progress(True) - # check for missing bursts, do search for those - # TODO we dont need to check for missing keys, we can just use get_bursts which will try - # to find it in the cache - all_keys = [self.burst_key(ph_node, burst_colors) for ph_node in photon_nodes] - todo_keys = [k for k in all_keys if k not in self.burst_cache] - todo_nodes = [photon_nodes[all_keys.index(k)] for k in todo_keys] - - if todo_nodes: - await self.get_bursts_batch(todo_nodes, burst_colors, on_progress) + results = await self.get_bursts_batch(photon_nodes, burst_colors, on_progress) on_progress(True) - # gather all bursts, combine into a dataframe - all_bursts = [self.burst_cache[key] for key in all_keys] - assert not any(burst is None for burst in all_bursts), "some burst data is missing" - names = [ph_node.name for ph_node in photon_nodes] - lens = [len(burst) for burst in all_bursts] + lens = [len(burst) for burst in results] dtype = pl.Enum(categories=names) filenames = pl.Series(name="filename", values=np.repeat(names, lens), dtype=dtype) - df = pl.concat([b.burst_data for b in all_bursts], how="vertical_relaxed").with_columns( + df = pl.concat([b.burst_data for b in results], how="vertical_relaxed").with_columns( filenames ) @@ -374,93 +167,3 @@ async def get_burst_node( ) return node - - -import asyncio -import threading -from functools import wraps -from typing import Callable, Optional - - -class SyncDataManager: - def __init__(self): - self._async_manager = ThreadedDataManager() - self._loop = None - self._loop_thread = None - - def _ensure_loop(self): - if self._loop is None: - self._loop = asyncio.new_event_loop() - self._loop_thread = threading.Thread(target=self._run_loop, daemon=True) - self._loop_thread.start() - - def _run_loop(self): - asyncio.set_event_loop(self._loop) - self._loop.run_forever() - - def _run_async(self, coro): - self._ensure_loop() - future = asyncio.run_coroutine_threadsafe(coro, self._loop) - return future.result() - - @staticmethod - def _sync_wrapper(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - return self._run_async(getattr(self._async_manager, func.__name__)(*args, **kwargs)) - - return wrapper - - @_sync_wrapper - def get_photons(self, node: PhotonNode) -> PhotonData: - pass - - @_sync_wrapper - def get_info(self, node: PhotonNode) -> dict: - pass - - @_sync_wrapper - def get_bursts(self, photon_node: PhotonNode, burst_colors: list[BurstColor]) -> Bursts: - pass - - @_sync_wrapper - def get_bursts_batch( - self, - photon_nodes: list[PhotonNode], - burst_colors: list[BurstColor], - on_progress: Optional[Callable[[float | bool], None]] = None, - ) -> list[Bursts]: - pass - - @_sync_wrapper - def search(self, node: PhotonNode, colors: list[BurstColor]) -> Bursts: - pass - - @_sync_wrapper - def get_dataframe( - self, - photon_nodes: list[PhotonNode], - burst_colors: list[BurstColor], - on_progress: Optional[Callable[[float | bool], None]] = None, - ) -> "pl.DataFrame": - pass - - @_sync_wrapper - def get_burst_node( - self, - photon_nodes: list[PhotonNode], - burst_colors: list[BurstColor], - name: str = "", - on_progress: Optional[Callable[[float | bool], None]] = None, - ) -> "BurstNode": - pass - - def burst_key(self, node: PhotonNode, burst_colors: list[BurstColor]) -> tuple[uuid.UUID, str]: - return self._async_manager.burst_key(node, burst_colors) - - def __del__(self): - if self._loop and not self._loop.is_closed(): - self._loop.call_soon_threadsafe(self._loop.stop) - if self._loop_thread: - self._loop_thread.join() - self._loop.close() diff --git a/dont_fret/web/home/methods.py b/dont_fret/web/home/methods.py index 4a22eca..7624805 100644 --- a/dont_fret/web/home/methods.py +++ b/dont_fret/web/home/methods.py @@ -7,6 +7,7 @@ from dont_fret.web.utils import has_bursts +# todo pass callable to add item on done # todo pass callable to add item on done @solara.lab.task(prefer_threaded=False) # type: ignore async def task_burst_search(name: str, photon_nodes: list[PhotonNode], burst_store) -> None: @@ -16,22 +17,19 @@ def on_progress(progress: float | bool) -> None: task_burst_search.progress = progress burst_colors = list(state.burst_settings.value[name]) - df = await state.data_manager.get_dataframe(photon_nodes, burst_colors, on_progress) - if len(df) == 0: - state.snackbar.warning("No bursts found", timeout=0) - else: - # getting info should be fast / non_blocking since photons are cached - info_list = [await state.data_manager.get_info(ph_node) for ph_node in photon_nodes] - duration = get_duration(info_list) - burst_node = BurstNode( - name=name, df=df, colors=burst_colors, photon_nodes=photon_nodes, duration=duration + try: + burst_node = await state.data_manager.get_burst_node( + photon_nodes, burst_colors, name, on_progress ) - burst_store.append(burst_node) - state.disable_burst_page.set(not has_bursts(state.fret_nodes.items)) + except ValueError: + state.snackbar.warning("No bursts found", timeout=0) + task_burst_search.progress = False + return - state.snackbar.success( - f"Burst search completed, found {len(burst_node.df)} bursts", timeout=0 - ) + burst_store.append(burst_node) + state.disable_burst_page.set(not has_bursts(state.fret_nodes.items)) + + state.snackbar.success(f"Burst search completed, found {len(burst_node.df)} bursts", timeout=0) task_burst_search.progress = False diff --git a/dont_fret/web/methods.py b/dont_fret/web/methods.py index 54893d7..ec791cc 100644 --- a/dont_fret/web/methods.py +++ b/dont_fret/web/methods.py @@ -12,7 +12,7 @@ from dont_fret.web.models import BurstFilterItem if TYPE_CHECKING: - from dont_fret.web.new_models import FRETNode + from dont_fret.web.datamanager import FRETNode def chain_filters(filters: list[BurstFilterItem]) -> Union[pl.Expr, Literal[True]]: diff --git a/templates/05_web_models.py b/templates/05_web_models.py index 54f6440..4f5bca0 100644 --- a/templates/05_web_models.py +++ b/templates/05_web_models.py @@ -6,57 +6,40 @@ from __future__ import annotations -import dataclasses -import threading -import time -import uuid +import asyncio from pathlib import Path -from typing import Literal, Optional, TypeVar -import numpy as np -import polars as pl import solara -import dont_fret.web.state as state from dont_fret.config import cfg -from dont_fret.config.config import BurstColor from dont_fret.web.bursts.components import BurstFigure -from dont_fret.web.methods import batch_burst_search, get_duration -from dont_fret.web.models import BurstNode, PhotonNode -from dont_fret.web.new_models import ( - FRETNode, - FRETStore, - ListStore, - SelectorNode, - SyncDataManager, - ThreadedDataManager, -) -from dont_fret.web.utils import find_object +from dont_fret.web.datamanager import ThreadedDataManager + +# SyncDataManager, +from dont_fret.web.models import FRETNode, ListStore, PhotonNode # %% ROOT = Path(__file__).parent.parent pth = ROOT / "tests" / "test_data" / "input" / "ds2" photon_nodes = [PhotonNode(file_path=ptu_pth) for ptu_pth in pth.glob("*.ptu")] -data_manager = ThreadedDataManager() -sync_manager = SyncDataManager() - -fret_store = FRETStore([]) +dm = ThreadedDataManager() # %% -node_1 = FRETNode( - name=solara.Reactive("my_node"), -) + +# %% burst_settings = ["DCBS", "APBS"] -nodes = [ - sync_manager.get_burst_node(photon_nodes, cfg.burst_search[name], name=name) +burst_nodes = [ + await dm.get_burst_node(photon_nodes, cfg.burst_search[name], name=name) for name in burst_settings ] -node_1.bursts.extend(nodes) -node_1.photons.extend(photon_nodes) -fret_store.append(node_1) +node_1 = FRETNode( + name=solara.Reactive("my_node"), + photons=ListStore(photon_nodes), + bursts=ListStore(burst_nodes), +) # %% @@ -65,160 +48,14 @@ name=solara.Reactive("my_node_2"), ) node_2.photons.extend(photon_nodes[2:]) -fret_store.append(node_2) -# %% - - -fret_node = node_1 - -# %% - - -df = fret_node.bursts.items[0].df - -# %% - -source = df - -import altair as alt - -chart = ( - alt.Chart(df) - .mark_bar() - .encode( - alt.X( - "n_photons:Q", - bin=alt.Bin(step=22.3, nice=False), - axis=alt.Axis(), # This removes the bin-aligned ticks - ), - y="count()", - ) -) -spec = chart.to_dict() -spec["encoding"] - - -# %% -def fd_bin_width(data): - """ - Calculate bin width using the Freedman-Diaconis rule: - bin width = 2 * IQR * n^(-1/3) - where IQR is the interquartile range and n is the number of observations - """ - q75, q25 = np.percentile(data, [75, 25]) - iqr = q75 - q25 - n = len(data) - return 2 * iqr * (n ** (-1 / 3)) - - -# %% -field = "nanotimes_AA" -fd_bin_width(df[field].drop_nans()) - -# %% - -filtered_df = df.filter(pl.col(field).is_null()) -filtered_df +fret_nodes = [node_1, node_2] -# %% - - -# %% - -q75, q25 = np.percentile(data, [75, 25]) -iqr = q75 - q25 -iqr - -# %% -np.any(np.isnan(np.asarray(data))) -# %% - - -# %% - -f_item = state.filters.items[0] -f_item.as_expr() +fret_nodes +node_1.bursts # %% -df_f = df.filter(f_item.as_expr()) -len(df), len(df_f) +BurstFigure(fret_nodes) -# %% -field = "nanotimes_AA" -selection = alt.selection_interval(name="range", encodings=["x"]) - - -def make_chart(df, color="blue", opacity=1.0): - chart = ( - alt.Chart(df.select(pl.col(field))) - .mark_rect(opacity=opacity) - .transform_bin(as_=["x", "x2"], field=field, bin=alt.Bin(step=fd_bin_width(df[field]))) - .encode( - x=alt.X( - "x:Q", - scale={"zero": False}, - title=field, - ), - x2="x2:Q", - y=alt.Y("count():Q", title="count"), - tooltip=[ - alt.Tooltip("x:Q", title="Center", format=".2f"), - alt.Tooltip("x2:Q", title="Start", format=".2f"), - alt.Tooltip("bin_center:Q", title="End", format=".2f"), - alt.Tooltip("count():Q", title="Count", format=","), - ], - color=alt.value(color), - ) - .transform_calculate( - bin_center="(datum.x + datum.x2) / 2" # Calculate bin center - ) - # .add_params(selection) - ) - return chart - - -chart1 = make_chart(df, color="#1f77b4", opacity=0.5) -chart1 - -chart2 = make_chart(df_f, color="#1f77b4").add_params(selection) - -chart = chart1 + chart2 -chart - -chart -# %% -type(df.select(pl.col(field))) - -type(df[field]) # %% - -jchart = alt.JupyterChart(chart, embed_options={"actions": False}) -jchart - -# %% - - -def on_select(value): - print(value) - - -jchart.selections.observe(on_select, "range") - - -# %% -None.value - -# %% - -spec.keys() -r_spec = spec.copy() -r_spec.pop("datasets", None) - -r_spec -# %% -jchart = alt.JupyterChart(chart, spec=spec, embed_options={"actions": False}) -jchart - -jchart