Skip to content

Commit

Permalink
feat: Store server-side state for rehydration (#338)
Browse files Browse the repository at this point in the history
- deephaven.ui widgets now require the client to set the initial render
state
  - On first widget open, an empty state is used
- Server exports the current current state along with the rendered
component
- Only primitive state values (`int`, `bool`, `str`, `float`) are
stored. To have something like a `Table` that is restored, the table
should be derived from a primitive state variable
- Client can then set the state again next time opening the widget to
resume where they left off
- Update deephaven.ui widget architecture diagram with new
initialization details

---------

Co-authored-by: Joe <[email protected]>
  • Loading branch information
mofojed and jnumainville authored Mar 14, 2024
1 parent 7724e55 commit bb28df3
Show file tree
Hide file tree
Showing 12 changed files with 427 additions and 124 deletions.
49 changes: 27 additions & 22 deletions plugins/ui/DESIGN.md
Original file line number Diff line number Diff line change
Expand Up @@ -1964,31 +1964,35 @@ The above examples are all in Python, and particularly take some advantage of la

##### Rendering

When you call a function decorated by `@ui.component`, it will return a `UiNode` object that has a reference to the function it is decorated; that is to say, the function does _not_ get run immediately. The function is only run when the `UiNode` is rendered by the client, and the result is sent back to the client. This allows the `@ui.component` decorator to execute the function with the appropriate rendering context, and also allows for memoization of the function (e.g. if the function is called multiple times with the same arguments, it will only be executed once - akin to a [memoized component](https://react.dev/reference/react/memo) or PureComponent in React).
When you call a function decorated by `@ui.component`, it will return an `Element` object that has a reference to the function it is decorated; that is to say, the function does _not_ get run immediately. The function is only run when the `Element` is rendered by the client, and the result is sent back to the client. This allows the `@ui.component` decorator to execute the function with the appropriate rendering context. The client must also set the initial state before rendering, allowing the client to persist the state and re-render in the future.

Let's say we execute the following, where a table is filtered based on the value of a text input:

```python
from deephaven import ui


@ui.component
def text_filter_table(source, column, initial_value=""):
value, set_value = use_state(initial_value)
ti = ui.text_field(value, on_change=set_value)
value, set_value = ui.use_state(initial_value)
ti = ui.text_field(value=value, on_change=set_value)
tt = source.where(f"{column}=`{value}`")
return [ti, tt]


# This will render two panels, one filtering the table by Sym, and the other by Exchange
@ui.component
def sym_exchange(source):
tft1 = text_filter_table(source, "Sym")
tft2 = text_filter_table(source, "Exchange")
return ui.flex(tft1, tft2, direction="row")
def double_text_filter_table(source):
tft1 = text_filter_table(source, "sym")
tft2 = text_filter_table(source, "exchange")
return ui.panel(tft1, title="Sym"), ui.panel(tft2, title="Exchange")


import deephaven.plot.express as dx

t = dx.data.stocks()
_stocks = dx.data.stocks()

tft = text_filter_table(t, "sym")
tft = double_text_filter_table(_stocks)
```

Which should result in a UI like this:
Expand All @@ -2013,21 +2017,21 @@ sequenceDiagram
W->>UIP: Open tft
UIP->>C: Export tft
C-->>UIP: tft (UiNode)
C-->>UIP: tft (Element)
Note over UIP: UI knows about object tft<br/>sym_exchange not executed yet
Note over UIP: UI knows about object tft<br/>double_text_filter_table not executed yet
UIP->>SP: Render tft
SP->>SP: Run sym_exchange
Note over SP: sym_exchange executes, running text_filter_table twice
SP-->>UIP: Result (document=flex([tft1, tft2]), exported_objects=[tft1, tft2])
UIP->>SP: Render tft (initialState)
SP->>SP: Run double_text_filter_table
Note over SP: double_text_filter_table executes, running text_filter_table twice
SP-->>UIP: Result (document=[panel(tft1), pane(tft2)], exported_objects=[tft1, tft2])
UIP-->>W: Display Result
U->>UIP: Change text input 1
UIP->>SP: Change state
SP->>SP: Run sym_exchange
Note over SP: sym_exchange executes, text_filter_table only <br/>runs once for the one changed input<br/>only exports the new table, as client already has previous tables
SP-->>UIP: Result (document=flex([tft1', tft2], exported_objects=[tft1']))
SP->>SP: Run double_text_filter_table
Note over SP: double_text_filter_table executes, text_filter_table only <br/>runs once for the one changed input<br/>only exports the new table, as client already has previous tables
SP-->>UIP: Result (document=[panel(tft1'), panel(tft2)], state={}, exported_objects=[tft1'])
UIP-->>W: Display Result
```

Expand All @@ -2040,14 +2044,15 @@ sequenceDiagram
participant UIP as UI Plugin
participant SP as Server Plugin
UIP->>SP: obj.getDataAsString()
Note over UIP, SP: Uses json-rpc
SP-->>UIP: documentUpdated(Document)
Note over UIP, SP: Uses JSON-RPC
UIP->>SP: setState(initialState)
SP-->>UIP: documentUpdated(Document, State)
loop Callback
UIP->>SP: foo(params)
SP-->>UIP: foo result
SP->>UIP: documentUpdated(Document)
SP->>UIP: documentUpdated(Document, State)
Note over UIP: Client can store State to restore the same state later
end
```

Expand Down
86 changes: 82 additions & 4 deletions plugins/ui/src/deephaven/ui/_internal/RenderContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from typing import (
Any,
Callable,
Dict,
Optional,
TypeVar,
Union,
Generator,
Generic,
cast,
Set,
)
from functools import partial
from deephaven import DHError
Expand Down Expand Up @@ -48,11 +48,16 @@
A function that takes the old value and returns the new value for a state.
"""

ContextKey = Union[str, int]
ContextKey = str
"""
The key for a child context.
"""

ChildrenContextDict = Dict[ContextKey, "RenderContext"]
"""
The child contexts for a RenderContext.
"""


@dataclass
class ValueWithLiveness(Generic[T]):
Expand All @@ -62,6 +67,17 @@ class ValueWithLiveness(Generic[T]):
liveness_scope: Union[LivenessScope, None]


ContextState = Dict[StateKey, ValueWithLiveness[Any]]
"""
The state for a context.
"""

ExportedRenderState = Dict[str, Any]
"""
The serializable state of a RenderContext. Used to serialize the state for the client.
"""


def _value_or_call(
value: T | None | Callable[[], T | None]
) -> ValueWithLiveness[T | None]:
Expand All @@ -83,6 +99,19 @@ def _value_or_call(
return ValueWithLiveness(value=value, liveness_scope=None)


def _should_retain_value(value: ValueWithLiveness[T | None]) -> bool:
"""
Determine if the given value should be retained by the current context.
Args:
value: The value to check.
Returns:
True if the value should be retained, False otherwise.
"""
return value.liveness_scope is None and isinstance(value.value, (str, int, float))


_local_data = threading.local()


Expand Down Expand Up @@ -133,12 +162,12 @@ class RenderContext:
Count of hooks used in the render. Should only be set after initial render.
"""

_state: dict[StateKey, ValueWithLiveness[Any]]
_state: ContextState
"""
The state for this context.
"""

_children_context: dict[ContextKey, "RenderContext"]
_children_context: ChildrenContextDict
"""
The child contexts for this context.
"""
Expand Down Expand Up @@ -354,3 +383,52 @@ def manage(self, liveness_scope: LivenessScope) -> None:
"""
assert self is get_context()
self._collected_scopes.add(cast(LivenessScope, liveness_scope.j_scope))

def export_state(self) -> ExportedRenderState:
"""
Export the state of this context. This is used to serialize the state for the client.
Returns:
The exported serializable state of this context.
"""
exported_state: ExportedRenderState = {}

# We need to iterate through all of our state and export anything that doesn't have a LivenessScope right now (anything serializable)
def retained_values(state: ContextState):
for key, value in state.items():
if _should_retain_value(value):
yield key, value.value

if len(state := dict(retained_values(self._state))) > 0:
exported_state["state"] = state

# Now iterate through all the children contexts, and only include them in the export if they're not empty
def retained_children(children: ChildrenContextDict):
for key, child in children.items():
if len(child_state := child.export_state()) > 0:
yield key, child_state

if len(children_state := dict(retained_children(self._children_context))) > 0:
exported_state["children"] = children_state

return exported_state

def import_state(self, state: dict[str, Any]) -> None:
"""
Import the state of this context. This is used to deserialize the state from the client.
Args:
state: The state to import.
"""
self._state.clear()
self._children_context.clear()
if "state" in state:
for key, value in state["state"].items():
# When python dict is converted to JSON, all keys are converted to strings. We convert them back to int here.
self._state[int(key)] = ValueWithLiveness(
value=value, liveness_scope=None
)
if "children" in state:
for key, child_state in state["children"].items():
self.get_child_context(key).import_state(child_state)
logger.debug("New state is %s", self._state)
1 change: 1 addition & 0 deletions plugins/ui/src/deephaven/ui/_internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
get_context,
NoContextException,
ValueWithLiveness,
ExportedRenderState,
)
from .utils import (
get_component_name,
Expand Down
93 changes: 60 additions & 33 deletions plugins/ui/src/deephaven/ui/object_types/ElementMessageStream.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .._internal import wrap_callable
from ..elements import Element
from ..renderer import NodeEncoder, Renderer, RenderedNode
from .._internal import RenderContext, StateUpdateCallable
from .._internal import RenderContext, StateUpdateCallable, ExportedRenderState

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -132,7 +132,7 @@ def __init__(self, element: Element, connection: MessageStream):
self._connection = connection
self._message_id = 0
self._manager = JSONRPCResponseManager()
self._dispatcher = Dispatcher()
self._dispatcher = self._make_dispatcher()
self._encoder = NodeEncoder(separators=(",", ":"))
self._context = RenderContext(self._queue_state_update, self._queue_callable)
self._renderer = Renderer(self._context)
Expand All @@ -155,40 +155,43 @@ def _render(self) -> None:

try:
node = self._renderer.render(self._element)
state = self._context.export_state()
self._send_document_update(node, state)
except Exception as e:
logger.exception("Error rendering %s", self._element.name)
raise e

self._send_document_update(node)

def _process_callable_queue(self) -> None:
"""
Process any queued callables, then re-renders the element if it is dirty.
"""
with self._exec_context:
with self._render_lock:
self._render_thread = threading.current_thread()
self._render_state = _RenderState.RENDERING

while not self._callable_queue.empty():
item = self._callable_queue.get()
with liveness_scope():
try:
item()
except Exception as e:
logger.exception(e)

if self._is_dirty:
self._render()

with self._render_lock:
self._render_thread = None
if not self._callable_queue.empty() or self._is_dirty:
# There are still callables to process, so queue up another render
self._render_state = _RenderState.QUEUED
submit_task("concurrent", self._process_callable_queue)
else:
self._render_state = _RenderState.IDLE
try:
with self._exec_context:
with self._render_lock:
self._render_thread = threading.current_thread()
self._render_state = _RenderState.RENDERING

while not self._callable_queue.empty():
item = self._callable_queue.get()
with liveness_scope():
try:
item()
except Exception as e:
logger.exception(e)

if self._is_dirty:
self._render()

with self._render_lock:
self._render_thread = None
if not self._callable_queue.empty() or self._is_dirty:
# There are still callables to process, so queue up another render
self._render_state = _RenderState.QUEUED
submit_task("concurrent", self._process_callable_queue)
else:
self._render_state = _RenderState.IDLE
except Exception as e:
logger.exception(e)

def _mark_dirty(self) -> None:
"""
Expand Down Expand Up @@ -232,9 +235,9 @@ def _queue_callable(self, callable: Callable[[], None]) -> None:

def start(self) -> None:
"""
Start the message stream. This will start the render loop and queue up the initial render.
Start the message stream. All we do is send a blank message to start. Client will respond with the initial state.
"""
self._mark_dirty()
self._connection.on_data(b"", [])

def on_close(self) -> None:
pass
Expand Down Expand Up @@ -302,12 +305,31 @@ def _make_request(self, method: str, *params: Any) -> dict[str, Any]:
"id": self._get_next_message_id(),
}

def _send_document_update(self, root: RenderedNode) -> None:
def _make_dispatcher(self) -> Dispatcher:
dispatcher = Dispatcher()
dispatcher["setState"] = self._set_state
return dispatcher

def _set_state(self, state: ExportedRenderState) -> None:
"""
Set the state of the element. This is called by the client on initial load.
Args:
state: The state to set
"""
logger.debug("Setting state: %s", state)
self._context.import_state(state)
self._mark_dirty()

def _send_document_update(
self, root: RenderedNode, state: ExportedRenderState
) -> None:
"""
Send a document update to the client. Currently just sends the entire document for each update.
Args:
root: The root node of the document to send
state: The state of the node to preserve
"""

# TODO(#67): Send a diff of the document instead of the entire document.
Expand All @@ -316,11 +338,16 @@ def _send_document_update(self, root: RenderedNode) -> None:
new_objects = encoder_result["new_objects"]
callable_id_dict = encoder_result["callable_id_dict"]

request = self._make_notification("documentUpdated", encoded_document)
logger.debug("Exported state: %s", state)
encoded_state = json.dumps(state)

request = self._make_notification(
"documentUpdated", encoded_document, encoded_state
)
payload = json.dumps(request)
logger.debug(f"Sending payload: {payload}")

dispatcher = Dispatcher()
dispatcher = self._make_dispatcher()
for callable, callable_id in callable_id_dict.items():
logger.debug("Registering callable %s", callable_id)
dispatcher[callable_id] = wrap_callable(callable)
Expand Down
Loading

0 comments on commit bb28df3

Please sign in to comment.