Skip to content

Commit

Permalink
router can directly serialise plotly Figures
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverlambson committed Aug 22, 2024
1 parent e6a05c3 commit c3ad52b
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 51 deletions.
13 changes: 6 additions & 7 deletions bored-charts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,18 @@ pip install bored-charts uvicorn
from pathlib import Path

import plotly.express as px
import plotly.graph_objects as go
from boredcharts import BCRouter, boredcharts
from boredcharts.jinja import to_html
from fastapi.responses import HTMLResponse

pages = Path(__file__).parent.absolute() / "pages"
figure_router = BCRouter()


@figure_router.chart("usa_population")
async def usa_population() -> HTMLResponse:
df = px.data.gapminder().query("country=='United States'")
@figure_router.chart("population")
async def population(country: str) -> go.Figure:
df = px.data.gapminder().query(f"country=='{country}'")
fig = px.bar(df, x="year", y="pop")
return HTMLResponse(to_html(fig))
return fig


app = boredcharts(
Expand All @@ -52,7 +51,7 @@ pages/populations.md:

USA's population has been growing linearly for the last 70 years:

{{ figure("usa_population") }}
{{ figure("population", country="United States") }}
```

### Run your app
Expand Down
2 changes: 1 addition & 1 deletion bored-charts/boredcharts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.0"
__version__ = "0.3.0"

from boredcharts.router import BCRouter
from boredcharts.webapp import boredcharts
Expand Down
60 changes: 59 additions & 1 deletion bored-charts/boredcharts/router.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,74 @@
from collections.abc import Callable
from typing import Any

import plotly.graph_objects as go
from fastapi import APIRouter
from fastapi.responses import HTMLResponse
from fastapi.types import DecoratedCallable
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema

from boredcharts.jinja import to_html


def validate_figure(fig: Any) -> go.Figure:
assert isinstance(fig, go.Figure)
return fig


class HTMLFigure(go.Figure): # type: ignore[misc]
"""A Plotly Figure that Pydantic can understand and serialize.
This allows us to return a Plotly Figure from a FastAPI route.
"""

@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.json_or_python_schema(
json_schema=core_schema.any_schema(),
python_schema=core_schema.union_schema(
[
core_schema.is_instance_schema(go.Figure),
core_schema.any_schema(),
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: to_html(instance)
),
)


class BCRouter(APIRouter):
"""A FastAPI APIRouter that is specifically designed for creating chart routes.
Usage:
```py
from boredcharts import BCRouter
import plotly.graph_objects as go
router = BCRouter()
@router.chart("my_chart")
async def my_chart() -> go.Figure:
return go.Figure()
```
"""

def chart(
self,
name: str,
) -> Callable[[DecoratedCallable], DecoratedCallable]:
"""
Creates a GET route for a chart, just a shorter form of the FastAPI get decorator,
your function still has to return a HTMLResponse
"""
path = f"/figure/{name}"
return self.api_route(
return self.get(
path=path,
name=name,
response_model=HTMLFigure,
response_class=HTMLResponse,
)
24 changes: 5 additions & 19 deletions examples/full/bcexample/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import numpy as np
import plotly.express as px
from boredcharts import BCRouter
from boredcharts.jinja import to_html
from fastapi.responses import HTMLResponse
from plotly.graph_objects import Figure

router = BCRouter()


@router.chart("population")
async def example(report_name: str, country: str) -> Figure:
df = px.data.gapminder().query(f"country=='{country}'")
fig = px.bar(df, x="year", y="pop")
Expand All @@ -30,17 +29,12 @@ async def example(report_name: str, country: str) -> Figure:
return fig


# TODO: pass functions into framework, auto generate these routes
@router.chart("example_simple_usa")
async def fig_example_simple(report_name: str) -> HTMLResponse:
return HTMLResponse(to_html(await example(report_name, "United States")))


@router.chart("example_params")
async def fig_example(report_name: str, country: str) -> HTMLResponse:
return HTMLResponse(to_html(await example(report_name, country)))
@router.chart("usa_population")
async def fig_example_simple(report_name: str) -> Figure:
return await example(report_name, "United States")


@router.chart("elasticity_vs_profit")
async def elasticity_vs_profit(
report_name: str, margin: float | None = None
) -> mplfig.Figure:
Expand Down Expand Up @@ -96,11 +90,3 @@ async def elasticity_vs_profit(
ax.grid(True)

return fig


# TODO: pass functions into framework, auto generate these routes
@router.chart("elasticity_vs_profit")
async def fig_elasticity_vs_profit(
report_name: str, margin: float | None = None
) -> HTMLResponse:
return HTMLResponse(to_html(await elasticity_vs_profit(report_name, margin)))
32 changes: 16 additions & 16 deletions examples/full/bcexample/pages/example.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,39 @@ The USA's population has been growing linearly:

<pre>
{%- raw %}
{{ figure("example_simple_usa") }}
{{ figure("usa_population") }}
{% endraw -%}
</pre>

{{ figure("example_simple_usa") }}
{{ figure("usa_population") }}

South Africa's growth is a bit weirder looking according to this chart:

<pre>
{%- raw %}
{{ figure("example_params", country="South Africa") }}
{{ figure("population", country="South Africa") }}
{% endraw -%}
</pre>

{{ figure("example_params", country="South Africa") }}
{{ figure("population", country="South Africa") }}

We can put two charts side by side:

<pre>
{%- raw %}
{{
row(
figure("example_params", country="United Kingdom"),
figure("example_params", country="France"),
figure("population", country="United Kingdom"),
figure("population", country="France"),
)
}}
{% endraw -%}
</pre>

{{
row(
figure("example_params", country="United Kingdom"),
figure("example_params", country="France"),
figure("population", country="United Kingdom"),
figure("population", country="France"),
)
}}

Expand All @@ -53,17 +53,17 @@ And we can add custom tailwind classes to the figures:
{%- raw %}
{{
row(
figure("example_params", country="Canada", class="h-[300px] min-w-[300px]"),
figure("example_params", country="Australia", class="h-[300px] min-w-[300px]"),
figure("population", country="Canada", class="h-[300px] min-w-[300px]"),
figure("population", country="Australia", class="h-[300px] min-w-[300px]"),
)
}}
{% endraw -%}
</pre>

{{
row(
figure("example_params", country="Canada", class="h-[300px] min-w-[300px]"),
figure("example_params", country="Australia", class="h-[300px] min-w-[300px]"),
figure("population", country="Canada", class="h-[300px] min-w-[300px]"),
figure("population", country="Australia", class="h-[300px] min-w-[300px]"),
)
}}

Expand All @@ -73,15 +73,15 @@ We can also dip into html when we need to
<pre>
{%- raw %}
&lt;div class="flex flex-wrap"&gt;
{{ figure("example_params", country="United Kingdom") }}
{{ figure("example_params", country="France") }}
{{ figure("population", country="United Kingdom") }}
{{ figure("population", country="France") }}
&lt/div&gt;
{% endraw -%}
</pre>

<div class="flex flex-wrap">
{{ figure("example_params", country="United Kingdom") }}
{{ figure("example_params", country="France") }}
{{ figure("population", country="United Kingdom") }}
{{ figure("population", country="France") }}
</div>

Or a matplotlib char
Expand Down
11 changes: 5 additions & 6 deletions examples/minimal/main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from pathlib import Path

import plotly.express as px
import plotly.graph_objects as go
from boredcharts import BCRouter, boredcharts
from boredcharts.jinja import to_html
from fastapi.responses import HTMLResponse

pages = Path(__file__).parent.absolute() / "pages"
figure_router = BCRouter()


@figure_router.chart("usa_population")
async def usa_population() -> HTMLResponse:
df = px.data.gapminder().query("country=='United States'")
@figure_router.chart("population")
async def population(country: str) -> go.Figure:
df = px.data.gapminder().query(f"country=='{country}'")
fig = px.bar(df, x="year", y="pop")
return HTMLResponse(to_html(fig))
return fig


app = boredcharts(
Expand Down
2 changes: 1 addition & 1 deletion examples/minimal/pages/populations.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

USA's population has been growing linearly for the last 70 years:

{{ figure("usa_population") }}
{{ figure("population", country="United States") }}

0 comments on commit c3ad52b

Please sign in to comment.