Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typing State for the class-based API does not work #400

Open
mdrideout opened this issue Oct 16, 2024 · 2 comments
Open

Typing State for the class-based API does not work #400

mdrideout opened this issue Oct 16, 2024 · 2 comments
Assignees
Labels
enhancement New feature or request triage

Comments

@mdrideout
Copy link
Contributor

Following docs for action level typing for class-based actions does not work.

ref: #386

Current behavior

Example first action:

class SetInitialPromptAction(Action):
    @property
    def reads(self) -> list[str]:
        return []

    def run(self, state: ApplicationState, prompt: str) -> dict:
        return {"initial_prompt": prompt}

    @property
    def writes(self) -> list[str]:
        return ["initial_prompt"]

    def update(self, result: dict, state: ApplicationState) -> ApplicationState:
        prompt = result["initial_prompt"]
        logger.info(f"Saving prompt to state: {prompt}")
        state.initial_prompt = prompt
        return state

    @property
    def inputs(self) -> list[str]:
        return ["prompt"]

Example second action:

class ExtractSetAction(Action):
    @property
    def reads(self) -> list[str]:
        return ["initial_prompt"]

    def run(self, state: ApplicationState) -> dict:
        logger.info(f"ApplicationState: {state}")

        # Read prompt from state
        prompt = state.initial_prompt
        ...

Logs: ApplicationState: {'initial_prompt': None}

Stack Traces

api | ********************************************************************************
api | -------------------------------------------------------------------
api | Oh no an error! Need help with Burr?
api | Join our discord and ask for help! https://discord.gg/4FxBMyzW5n
api | -------------------------------------------------------------------
api | > Action: extract_set encountered an error!<
api | > State (at time of action):
api | {'__PRIOR_STEP': 'set_prompt',
api | '__SEQUENCE_ID': 1,
api | 'initial_prompt': None,
api | 'set_from_prompt': None}
api | > Inputs (at time of action):
api | {'prompt': 'bicep curls with 22 pound dumbells for 21 reps'}
api | ********************************************************************************
api | Traceback (most recent call last):
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step
api | result = _run_function(
api | ^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function
api | result = function.run(state_to_use, **inputs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/ai/actions/extract_set/action.py", line 93, in run
api | prompt = state.initial_prompt
api | ^^^^^^^^^^^^^^^^^^^^
api | AttributeError: 'State' object has no attribute 'initial_prompt'
api | INFO: 192.168.65.1:35000 - "GET /api/extract-set?prompt=bicep%20curls%20with%2022%20pound%20dumbells%20for%2021%20reps HTTP/1.1" 500 Internal Server Error
api | ERROR: Exception in ASGI application
api | + Exception Group Traceback (most recent call last):
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 77, in collapse_excgroups
api | | yield
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 186, in call
api | | async with anyio.create_task_group() as task_group:
api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 736, in aexit
api | | raise BaseExceptionGroup(
api | | ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
api | +-+---------------- 1 ----------------
api | | Traceback (most recent call last):
api | | File "/usr/local/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi
api | | result = await app( # type: ignore[func-returns-value]
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in call
api | | return await self.app(scope, receive, send)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in call
api | | await super().call(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/applications.py", line 113, in call
api | | await self.middleware_stack(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 187, in call
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 165, in call
api | | await self.app(scope, receive, _send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 185, in call
api | | with collapse_excgroups():
api | | File "/usr/local/lib/python3.11/contextlib.py", line 158, in exit
api | | self.gen.throw(typ, value, traceback)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 83, in collapse_excgroups
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 187, in call
api | | response = await self.dispatch_func(request, call_next)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/app/app/main.py", line 36, in log_requests
api | | response = await call_next(request)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 163, in call_next
api | | raise app_exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 149, in coro
api | | await self.app(scope, receive_or_disconnect, send_no_error)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 62, in call
api | | await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | | await app(scope, receive, sender)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 715, in call
api | | await self.middleware_stack(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 735, in app
api | | await route.handle(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 288, in handle
api | | await self.app(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 76, in app
api | | await wrap_app_handling_exceptions(app, request)(scope, receive, send)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | | raise exc
api | | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | | await app(scope, receive, sender)
api | | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 73, in app
api | | response = await f(request)
api | | ^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 301, in app
api | | raw_response = await run_endpoint_function(
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 214, in run_endpoint_function
api | | return await run_in_threadpool(dependant.call, **values)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/starlette/concurrency.py", line 39, in run_in_threadpool
api | | return await anyio.to_thread.run_sync(func, *args)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
api | | return await get_async_backend().run_sync_in_worker_thread(
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2405, in run_sync_in_worker_thread
api | | return await future
api | | ^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 914, in run
api | | result = context.run(func, *args)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/app/app/api/routes.py", line 53, in extract_set
api | | action, result, state = application.run(
api | | ^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/telemetry.py", line 276, in wrapped_fn
api | | return call_fn(*args, **kwargs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | | return fn(app_self, *args, **kwargs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1168, in run
api | | next(gen)
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1111, in iterate
api | | prior_action, result, state = self.step(inputs=inputs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | | return fn(app_self, *args, **kwargs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 773, in step
api | | out = self._step(inputs=inputs, _run_hooks=True)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 826, in _step
api | | raise e
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step
api | | result = _run_function(
api | | ^^^^^^^^^^^^^^
api | | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function
api | | result = function.run(state_to_use, **inputs)
api | | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | | File "/app/app/ai/actions/extract_set/action.py", line 93, in run
api | | prompt = state.initial_prompt
api | | ^^^^^^^^^^^^^^^^^^^^
api | | AttributeError: 'State' object has no attribute 'initial_prompt'
api | +------------------------------------
api |
api | During handling of the above exception, another exception occurred:
api |
api | Traceback (most recent call last):
api | File "/usr/local/lib/python3.11/site-packages/uvicorn/protocols/http/httptools_impl.py", line 401, in run_asgi
api | result = await app( # type: ignore[func-returns-value]
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/uvicorn/middleware/proxy_headers.py", line 60, in call
api | return await self.app(scope, receive, send)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/fastapi/applications.py", line 1054, in call
api | await super().call(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/applications.py", line 113, in call
api | await self.middleware_stack(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 187, in call
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/errors.py", line 165, in call
api | await self.app(scope, receive, _send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 185, in call
api | with collapse_excgroups():
api | File "/usr/local/lib/python3.11/contextlib.py", line 158, in exit
api | self.gen.throw(typ, value, traceback)
api | File "/usr/local/lib/python3.11/site-packages/starlette/_utils.py", line 83, in collapse_excgroups
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 187, in call
api | response = await self.dispatch_func(request, call_next)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/main.py", line 36, in log_requests
api | response = await call_next(request)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 163, in call_next
api | raise app_exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/base.py", line 149, in coro
api | await self.app(scope, receive_or_disconnect, send_no_error)
api | File "/usr/local/lib/python3.11/site-packages/starlette/middleware/exceptions.py", line 62, in call
api | await wrap_app_handling_exceptions(self.app, conn)(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | await app(scope, receive, sender)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 715, in call
api | await self.middleware_stack(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 735, in app
api | await route.handle(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 288, in handle
api | await self.app(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 76, in app
api | await wrap_app_handling_exceptions(app, request)(scope, receive, send)
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 62, in wrapped_app
api | raise exc
api | File "/usr/local/lib/python3.11/site-packages/starlette/_exception_handler.py", line 51, in wrapped_app
api | await app(scope, receive, sender)
api | File "/usr/local/lib/python3.11/site-packages/starlette/routing.py", line 73, in app
api | response = await f(request)
api | ^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 301, in app
api | raw_response = await run_endpoint_function(
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/fastapi/routing.py", line 214, in run_endpoint_function
api | return await run_in_threadpool(dependant.call, **values)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/starlette/concurrency.py", line 39, in run_in_threadpool
api | return await anyio.to_thread.run_sync(func, *args)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
api | return await get_async_backend().run_sync_in_worker_thread(
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 2405, in run_sync_in_worker_thread
api | return await future
api | ^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/anyio/_backends/_asyncio.py", line 914, in run
api | result = context.run(func, *args)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/api/routes.py", line 53, in extract_set
api | action, result, state = application.run(
api | ^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/telemetry.py", line 276, in wrapped_fn
api | return call_fn(*args, **kwargs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | return fn(app_self, *args, **kwargs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1168, in run
api | next(gen)
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 1111, in iterate
api | prior_action, result, state = self.step(inputs=inputs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 616, in wrapper_sync
api | return fn(app_self, *args, **kwargs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 773, in step
api | out = self._step(inputs=inputs, _run_hooks=True)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 826, in _step
api | raise e
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 816, in _step
api | result = _run_function(
api | ^^^^^^^^^^^^^^
api | File "/usr/local/lib/python3.11/site-packages/burr/core/application.py", line 123, in _run_function
api | result = function.run(state_to_use, **inputs)
api | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
api | File "/app/app/ai/actions/extract_set/action.py", line 93, in run
api | prompt = state.initial_prompt
api | ^^^^^^^^^^^^^^^^^^^^
api | AttributeError: 'State' object has no attribute 'initial_prompt'

Screenshots

(If applicable)

Steps to replicate behavior

Library & System Information

E.g. python version, burr library version, linux, etc.

  • Python 3.11
  • Debian bookworm slim
  • Burr library:
burr = { extras = [
  "graphviz",
  "hamilton",
  "streamlit",
  "tracking-client",
  "tracking-server",
], version = "^0.31.1" }

Expected behavior

To work the same as function-based actions

Additional context

Add any other context about the problem here.

@shun-liang
Copy link
Contributor

shun-liang commented Dec 1, 2024

I would like to follow this issue too. I would love to see typed state being treated as first class citizen in burr and making it work with class-based action is important imho.

@skrawcz skrawcz added the enhancement New feature or request label Dec 2, 2024
@elijahbenizzy
Copy link
Contributor

Will scope out -- I think this is high value. That said @shun-liang -- you can always use centralized state -- allowing you to define the state model centrally with the application rather than decentrally with the class. https://burr.dagworks.io/concepts/state-typing/#application-level-typing

Will take a bit to scope but I think we can build this out reasonably fast.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request triage
Projects
None yet
Development

No branches or pull requests

4 participants