Skip to content

Commit

Permalink
fix(types)
Browse files Browse the repository at this point in the history
  • Loading branch information
lmeyerov committed Nov 28, 2024
1 parent 904fa93 commit 2096127
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 14 deletions.
1 change: 1 addition & 0 deletions graphistry/Plottable.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@


RenderModesConcrete = Literal["g", "url", "ipython", "databricks", "browser"]
RENDER_MODE_CONCRETE_VALUES: Set[RenderModesConcrete] = set(["g", "url", "ipython", "databricks", "browser"])
RenderModes = Union[Literal["auto"], RenderModesConcrete]
RENDER_MODE_VALUES: Set[RenderModes] = set(["auto", "g", "url", "ipython", "databricks", "browser"])

Expand Down
5 changes: 3 additions & 2 deletions graphistry/PlotterBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ def description(self, description):
return res


def edges(self, edges: Union[Callable, Any], source=None, destination=None, edge=None, *args, **kwargs) -> Plottable:
def edges(self: Plottable, edges: Union[Callable, Any], source=None, destination=None, edge=None, *args, **kwargs) -> Plottable:
"""Specify edge list data and associated edge attribute values.
If a callable, will be called with current Plotter and whatever positional+named arguments
Expand Down Expand Up @@ -1514,6 +1514,7 @@ def plot(
logger.debug("4. @PloatterBase plot: PyGraphistry.org_name(): {}".format(PyGraphistry.org_name()))

dataset = self._plot_dispatch_arrow(g, n, name, description, self._style, memoize)
assert dataset is not None
if skip_upload:
return dataset
dataset.token = PyGraphistry.api_token()
Expand Down Expand Up @@ -2269,7 +2270,7 @@ def infer_labels(self):
raise ValueError('Could not find a label-like node column and no g._node id fallback set')


def cypher(self, query: str, params: Dict[str, Any] = {}) -> 'PlotterBase':
def cypher(self, query: str, params: Dict[str, Any] = {}) -> Plottable:
"""
Execute a Cypher query against a Neo4j, Memgraph, or Amazon Neptune database and retrieve the results.
Expand Down
16 changes: 10 additions & 6 deletions graphistry/compute/chain_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ def chain_remote_generic(
}

if node_col_subset is not None:
request_body["node_col_subset"] = node_col_subset
request_body["node_col_subset"] = node_col_subset # type: ignore
if edge_col_subset is not None:
request_body["edge_col_subset"] = edge_col_subset
request_body["edge_col_subset"] = edge_col_subset # type: ignore
if df_export_args is not None:
request_body["df_export_args"] = df_export_args
if engine is not None:
request_body["engine"] = engine
request_body["engine"] = engine # type: ignore

base_url = ""
url = f"{base_url}/api/v2/etl/datasets/{dataset_id}/gfql/"
Expand Down Expand Up @@ -107,7 +107,7 @@ def chain_remote_generic(

if output_type == "shape":
if format == "json":
return pd.DataFrame(response.json)
return pd.DataFrame(response.json())
elif format == "csv":
return read_csv(BytesIO(response.content))
elif format == "parquet":
Expand Down Expand Up @@ -184,7 +184,7 @@ def chain_remote_shape(
print(shape_df)
"""

return chain_remote_generic(
out_df = chain_remote_generic(
self,
chain,
api_token,
Expand All @@ -197,6 +197,8 @@ def chain_remote_shape(
engine,
validate
)
assert isinstance(out_df, pd.DataFrame)
return out_df

def chain_remote(
self: Plottable,
Expand Down Expand Up @@ -281,7 +283,7 @@ def chain_remote(

assert output_type != "shape", 'Method chain_remote() does not support output_type="shape", call instead chain_remote_shape()'

return chain_remote_generic(
g = chain_remote_generic(
self,
chain,
api_token,
Expand All @@ -294,3 +296,5 @@ def chain_remote(
engine,
validate
)
assert isinstance(g, Plottable)
return g
10 changes: 4 additions & 6 deletions graphistry/render/resolve_render_mode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional, Union

from graphistry.Plottable import RENDER_MODE_VALUES, Plottable, RenderModes, RenderModesConcrete
from graphistry.Plottable import RENDER_MODE_CONCRETE_VALUES, Plottable, RenderModes, RenderModesConcrete
from graphistry.util import in_databricks, in_ipython


Expand All @@ -15,14 +15,12 @@ def resolve_render_mode(

# => RenderMode
if isinstance(render, bool):
render = "auto" if render else "url"

if render not in RENDER_MODE_VALUES:
raise ValueError(f'Invalid render mode: {render}, expected one of {RENDER_MODE_VALUES}')
render = "auto" if render else "url"

# => RenderModeConcrete
if render != "auto":
return render
assert render in RENDER_MODE_CONCRETE_VALUES
return render # type: ignore

if in_ipython():
return "ipython"
Expand Down

0 comments on commit 2096127

Please sign in to comment.