From 2a1d81ab2c31fb6d3b8e2d63f6f5716bf758db7d Mon Sep 17 00:00:00 2001
From: Leo Meyerovich <leo@graphistry.com>
Date: Thu, 18 Jul 2024 23:33:17 -0400
Subject: [PATCH 1/4] refactor(lazy import): centralize, optimize, CPU fallback
 when broken GPU envs

---
 CHANGELOG.md                             |   9 ++
 graphistry/Engine.py                     |  29 ++--
 graphistry/compute/cluster.py            |  38 +----
 graphistry/dgl_utils.py                  |  28 +---
 graphistry/embed_utils.py                |  31 ++--
 graphistry/feature_utils.py              |  67 ++-------
 graphistry/networks.py                   |  42 ++----
 graphistry/tests/test_compute_cluster.py |  10 +-
 graphistry/tests/test_dgl_utils.py       |   4 +-
 graphistry/tests/test_embed_utils.py     |   5 +-
 graphistry/tests/test_feature_utils.py   |   8 +-
 graphistry/tests/test_text_utils.py      |   9 +-
 graphistry/tests/test_umap_utils.py      |  14 +-
 graphistry/umap_utils.py                 |  57 ++------
 graphistry/utils/lazy_import.py          | 179 +++++++++++++++++++++++
 15 files changed, 286 insertions(+), 244 deletions(-)
 create mode 100644 graphistry/utils/lazy_import.py

diff --git a/CHANGELOG.md b/CHANGELOG.md
index fc4664493f..03e9505bb3 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,15 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
 
 ## [Development]
 
+### Fixed
+
+* Graceful CPU fallbacks: When lazy GPU dependency imports throw `ImportError`, commonly seen due to broken CUDA environments or having CUDA libraries but no GPU, warn and fall back to CPU.
+
+### Changed
+
+* Centralize lazy imports into `graphistry.utils.lazy_import`
+* Lazy imports distinguish `ModuleNotFound` (=> `False`) from `ImportError` (warn + `False`)
+
 ## [0.34.0 - 2024-07-17]
 
 ### Infra
diff --git a/graphistry/Engine.py b/graphistry/Engine.py
index 8bc2bc2b1d..e514a69195 100644
--- a/graphistry/Engine.py
+++ b/graphistry/Engine.py
@@ -1,6 +1,8 @@
+from inspect import getmodule
 import pandas as pd
 from typing import Any, Optional, Union
 from enum import Enum
+from graphistry.utils.lazy_import import lazy_cudf_import
 
 
 class Engine(Enum):
@@ -21,18 +23,6 @@ class EngineAbstract(Enum):
 DataframeLocalLike = Any  # pdf, cudf
 GraphistryLke = Any
 
-#TODO use new importer when it lands (this is copied from umap_utils)
-def lazy_cudf_import_has_dependancy():
-    try:
-        import warnings
-
-        warnings.filterwarnings("ignore")
-        import cudf  # type: ignore
-
-        return True, "ok", cudf
-    except ModuleNotFoundError as e:
-        return False, e, None
-
 def resolve_engine(
     engine: Union[EngineAbstract, str],
     g_or_df: Optional[Any] = None,
@@ -58,14 +48,15 @@ def resolve_engine(
         if isinstance(g_or_df, pd.DataFrame):
             return Engine.PANDAS
 
-        has_cudf_dependancy_, _, _ = lazy_cudf_import_has_dependancy()
-        if has_cudf_dependancy_:
-            import cudf
-            if isinstance(g_or_df, cudf.DataFrame):
-                return Engine.CUDF
-            raise ValueError(f'Expected cudf dataframe, got: {type(g_or_df)}')
+        if 'cudf.core.dataframe' in str(getmodule(g_or_df)):
+            has_cudf_dependancy_, _, _ = lazy_cudf_import()
+            if has_cudf_dependancy_:
+                import cudf
+                if isinstance(g_or_df, cudf.DataFrame):
+                    return Engine.CUDF
+                raise ValueError(f'Expected cudf dataframe, got: {type(g_or_df)}')
     
-    has_cudf_dependancy_, _, _ = lazy_cudf_import_has_dependancy()
+    has_cudf_dependancy_, _, _ = lazy_cudf_import()
     if has_cudf_dependancy_:
         return Engine.CUDF
     return Engine.PANDAS
diff --git a/graphistry/compute/cluster.py b/graphistry/compute/cluster.py
index 585b17acd8..2d742b422b 100644
--- a/graphistry/compute/cluster.py
+++ b/graphistry/compute/cluster.py
@@ -10,6 +10,7 @@
 from graphistry.constants import CUML, UMAP_LEARN, DBSCAN  # noqa type: ignore
 from graphistry.features import ModelDict
 from graphistry.feature_utils import get_matrix_by_column_parts
+from graphistry.utils.lazy_import import lazy_cudf_import, lazy_dbscan_import
 
 logger = logging.getLogger("compute.cluster")
 
@@ -22,37 +23,6 @@
 DBSCANEngine = Literal[DBSCANEngineConcrete, "auto"]
 
 
-def lazy_dbscan_import_has_dependency():
-    has_min_dependency = True
-    DBSCAN = None
-    try:
-        from sklearn.cluster import DBSCAN
-    except ImportError:
-        has_min_dependency = False
-        logger.info("Please install sklearn for CPU DBSCAN")
-
-    has_cuml_dependency = True
-    cuDBSCAN = None
-    try:
-        from cuml import DBSCAN as cuDBSCAN
-    except ImportError:
-        has_cuml_dependency = False
-        logger.info("Please install cuml for GPU DBSCAN")
-
-    return has_min_dependency, DBSCAN, has_cuml_dependency, cuDBSCAN
-
-def lazy_cudf_import_has_dependancy():
-    try:
-        import warnings
-
-        warnings.filterwarnings("ignore")
-        import cudf  # type: ignore
-
-        return True, "ok", cudf
-    except ModuleNotFoundError as e:
-        return False, e, None
-
-
 def resolve_cpu_gpu_engine(
     engine: DBSCANEngine,
 ) -> DBSCANEngineConcrete:  # noqa
@@ -64,7 +34,7 @@ def resolve_cpu_gpu_engine(
             _,
             has_cuml_dependency,
             _,
-        ) = lazy_dbscan_import_has_dependency()
+        ) = lazy_dbscan_import()
         if has_cuml_dependency:
             return "cuml"
         if has_min_dependency:
@@ -90,7 +60,7 @@ def safe_cudf(X, y):
                 new_kwargs[key] = value
         return new_kwargs['X'], new_kwargs['y']
 
-    has_cudf_dependancy_, _, cudf = lazy_cudf_import_has_dependancy()
+    has_cudf_dependancy_, _, cudf = lazy_cudf_import()
     if has_cudf_dependancy_:
         # print('DBSCAN CUML Matrices')
         return safe_cudf(X, y)
@@ -209,7 +179,7 @@ def _cluster_dbscan(
     ):
         """DBSCAN clustering on cpu or gpu infered by .engine flag
         """
-        _, DBSCAN, _, cuDBSCAN = lazy_dbscan_import_has_dependency()
+        _, DBSCAN, _, cuDBSCAN = lazy_dbscan_import()
 
         if engine_dbscan in [CUML]:
             print('`g.transform_dbscan(..)` not supported for engine=cuml, will return `g.transform_umap(..)` instead')
diff --git a/graphistry/dgl_utils.py b/graphistry/dgl_utils.py
index 56b5670f33..b7225c0fdc 100644
--- a/graphistry/dgl_utils.py
+++ b/graphistry/dgl_utils.py
@@ -5,6 +5,10 @@
 import numpy as np
 import pandas as pd
 
+from graphistry.utils.lazy_import import (
+    lazy_dgl_import,
+    lazy_torch_import_has_dependency
+)
 from . import constants as config
 from .feature_utils import (
     FeatureEngine,
@@ -34,26 +38,6 @@
     MIXIN_BASE = object
 
 
-def lazy_dgl_import_has_dependency():
-    try:
-        import warnings
-        warnings.filterwarnings('ignore')
-        import dgl  # noqa: F811
-        return True, 'ok', dgl
-    except ModuleNotFoundError as e:
-        return False, e, None
-
-
-def lazy_torch_import_has_dependency():
-    try:
-        import warnings
-        warnings.filterwarnings('ignore')
-        import torch  # noqa: F811
-        return True, 'ok', torch
-    except ModuleNotFoundError as e:
-        return False, e, None
-
-
 logger = setup_logger(name=__name__)
 
 
@@ -181,7 +165,7 @@ def pandas_to_dgl_graph(
         sp_mat: sparse scipy matrix
         ordered_nodes_dict: dict ordered from most common src and dst nodes
     """
-    _, _, dgl = lazy_dgl_import_has_dependency()  # noqa: F811
+    _, _, dgl = lazy_dgl_import()  # noqa: F811
     sp_mat, ordered_nodes_dict = pandas_to_sparse_adjacency(df, src, dst, weight_col)
     g = dgl.from_scipy(sp_mat, device=device)  # there are other ways too
     logger.info(f"Graph Type: {type(g)}") 
@@ -225,7 +209,7 @@ def dgl_lazy_init(self, train_split: float = 0.8, device: str = "cpu"):
         """
 
         if not self.dgl_initialized:
-            lazy_dgl_import_has_dependency()
+            lazy_dgl_import()
             lazy_torch_import_has_dependency()
             self.train_split = train_split
             self.device = device
diff --git a/graphistry/embed_utils.py b/graphistry/embed_utils.py
index 81fc45fe8d..1803f9d1ca 100644
--- a/graphistry/embed_utils.py
+++ b/graphistry/embed_utils.py
@@ -3,24 +3,11 @@
 import pandas as pd
 from typing import Optional, Union, Callable, List, TYPE_CHECKING, Any, Tuple
 
+from graphistry.utils.lazy_import import lazy_embed_import
 from .PlotterBase import Plottable
 from .compute.ComputeMixin import ComputeMixin
 
 
-def lazy_embed_import_dep():
-    try:
-        import torch
-        import torch.nn as nn
-        import dgl
-        from dgl.dataloading import GraphDataLoader
-        import torch.nn.functional as F
-        from .networks import HeteroEmbed
-        from tqdm import trange
-        return True, torch, nn, dgl, GraphDataLoader, HeteroEmbed, F, trange
-
-    except:
-        return False, None, None, None, None, None, None, None
-
 def check_cudf():
     try:
         import cudf
@@ -30,7 +17,7 @@ def check_cudf():
         
 
 if TYPE_CHECKING:
-    _, torch, _, _, _, _, _, _ = lazy_embed_import_dep()
+    _, torch, _, _, _, _, _, _ = lazy_embed_import()
     TT = torch.Tensor
     MIXIN_BASE = ComputeMixin
 else:
@@ -147,7 +134,7 @@ def _preprocess_embedding_data(self, res, train_split:Union[float, int] = 0.8) -
         return res
 
     def _build_graph(self, res) -> Plottable:
-        _, _, _, dgl, _, _, _, _ = lazy_embed_import_dep()
+        _, _, _, dgl, _, _, _, _ = lazy_embed_import()
         s, r, t = res._triplets.T
 
         if res._train_idx is not None:
@@ -169,7 +156,7 @@ def _build_graph(self, res) -> Plottable:
 
 
     def _init_model(self, res, batch_size:int, sample_size:int, num_steps:int, device):
-        _, _, _, _, GraphDataLoader, HeteroEmbed, _, _ = lazy_embed_import_dep()
+        _, _, _, _, GraphDataLoader, HeteroEmbed, _, _ = lazy_embed_import()
         g_iter = SubgraphIterator(res._kg_dgl, sample_size, num_steps)
         g_dataloader = GraphDataLoader(
             g_iter, batch_size=batch_size, collate_fn=lambda x: x[0]
@@ -188,7 +175,7 @@ def _init_model(self, res, batch_size:int, sample_size:int, num_steps:int, devic
         return model, g_dataloader
 
     def _train_embedding(self, res, epochs:int, batch_size:int, lr:float, sample_size:int, num_steps:int, device) -> Plottable:
-        _, torch, nn, _, _, _, _, trange = lazy_embed_import_dep()
+        _, torch, nn, _, _, _, _, trange = lazy_embed_import()
         log('Training embedding')
         model, g_dataloader = res._init_model(res, batch_size, sample_size, num_steps, device)
         if hasattr(res, "_embed_model") and not res._build_new_embedding_model:
@@ -232,7 +219,7 @@ def _train_embedding(self, res, epochs:int, batch_size:int, lr:float, sample_siz
 
     @property
     def _gcn_node_embeddings(self):
-        _, torch, _, _, _, _, _, _ = lazy_embed_import_dep()
+        _, torch, _, _, _, _, _, _ = lazy_embed_import()
         g_dgl = self._kg_dgl.to(self._device)
         em = self._embed_model(g_dgl).detach()
         torch.cuda.empty_cache()
@@ -540,7 +527,7 @@ def fetch_triplets_for_inference(x_r):
         
 
     def _score(self, triplets: Union[np.ndarray, TT]) -> TT:  # type: ignore
-        _, torch, _, _, _, _, _, _ = lazy_embed_import_dep()
+        _, torch, _, _, _, _, _, _ = lazy_embed_import()
         emb = self._kg_embeddings.clone().detach()
         if not isinstance(triplets, torch.Tensor):
             triplets = torch.tensor(triplets)
@@ -571,7 +558,7 @@ def __len__(self) -> int:
         return self.num_steps
 
     def __getitem__(self, i:int):
-        _, torch, nn, dgl, GraphDataLoader, _, F, _ = lazy_embed_import_dep()
+        _, torch, nn, dgl, GraphDataLoader, _, F, _ = lazy_embed_import()
         eids = torch.from_numpy(np.random.choice(self.eids, self.sample_size))
 
         src, dst = self.g.find_edges(eids)
@@ -593,7 +580,7 @@ def __getitem__(self, i:int):
 
     @staticmethod
     def _sample_neg(triplets:np.ndarray, num_nodes:int) -> Tuple[TT, TT]:  # type: ignore
-        _, torch, _, _, _, _, _, _ = lazy_embed_import_dep()
+        _, torch, _, _, _, _, _, _ = lazy_embed_import()
         triplets = torch.tensor(triplets)
         h, r, t = triplets.T
         h_o_t = torch.randint(high=2, size=h.size())
diff --git a/graphistry/feature_utils.py b/graphistry/feature_utils.py
index 84d9c8a817..c96daa9768 100644
--- a/graphistry/feature_utils.py
+++ b/graphistry/feature_utils.py
@@ -20,6 +20,13 @@
 
 from graphistry.compute.ComputeMixin import ComputeMixin
 from graphistry.config import config as graphistry_config
+from graphistry.utils.lazy_import import (
+    lazy_sentence_transformers_import,
+    lazy_import_has_min_dependancy,
+    lazy_dirty_cat_import,
+    assert_imported_text,
+    assert_imported
+)
 from . import constants as config
 from .PlotterBase import WeakValueDictionary, Plottable
 from .util import setup_logger, check_set_memoize
@@ -67,56 +74,6 @@
     TransformerMixin = Any
 
 
-#@check_set_memoize
-def lazy_import_has_dependancy_text():
-    import warnings
-    warnings.filterwarnings("ignore")
-    try:
-        from sentence_transformers import SentenceTransformer
-        return True, 'ok', SentenceTransformer
-    except ModuleNotFoundError as e:
-        return False, e, None
-
-def lazy_import_has_min_dependancy():
-    import warnings
-    warnings.filterwarnings("ignore")
-    try:
-        import scipy.sparse  # noqa
-        from scipy import __version__ as scipy_version
-        from sklearn import __version__ as sklearn_version
-        logger.debug(f"SCIPY VERSION: {scipy_version}")
-        logger.debug(f"sklearn VERSION: {sklearn_version}")
-        return True, 'ok'
-    except ModuleNotFoundError as e:
-        return False, e
-
-def lazy_import_has_dirty_cat():
-    import warnings
-    warnings.filterwarnings("ignore")
-    try:
-        import dirty_cat 
-        return True, 'ok', dirty_cat
-    except ModuleNotFoundError as e:
-        return False, e, None
-
-def assert_imported_text():
-    has_dependancy_text_, import_text_exn, _ = lazy_import_has_dependancy_text()
-    if not has_dependancy_text_:
-        logger.error(  # noqa
-            "AI Package sentence_transformers not found,"
-            "trying running `pip install graphistry[ai]`"
-        )
-        raise import_text_exn
-
-
-def assert_imported():
-    has_min_dependancy_, import_min_exn = lazy_import_has_min_dependancy()
-    if not has_min_dependancy_:
-        logger.error(  # noqa
-                     "AI Packages not found, trying running"  # noqa
-                     "`pip install graphistry[ai]`"  # noqa
-        )
-        raise import_min_exn
 
 
 # ############################################################################
@@ -154,7 +111,7 @@ def resolve_feature_engine(
         return feature_engine  # type: ignore
 
     if feature_engine == "auto":
-        has_dependancy_text_, _, _ = lazy_import_has_dependancy_text()
+        has_dependancy_text_, _, _ = lazy_sentence_transformers_import()
         if has_dependancy_text_:
             return "torch"
         has_min_dependancy_, _ = lazy_import_has_min_dependancy()
@@ -708,7 +665,7 @@ def encode_textual(
     max_df: float = 0.2,
     min_df: int = 3,
 ) -> Tuple[pd.DataFrame, List, Any]:
-    _, _, SentenceTransformer = lazy_import_has_dependancy_text()
+    _, _, SentenceTransformer = lazy_sentence_transformers_import()
 
     t = time()
     text_cols = get_textual_columns(
@@ -886,7 +843,7 @@ def process_dirty_dataframes(
     :return: Encoded data matrix and target (if not None),
             the data encoder, and the label encoder.
     """
-    has_dirty_cat, _, dirty_cat = lazy_import_has_dirty_cat()
+    has_dirty_cat, _, dirty_cat = lazy_dirty_cat_import()
     if has_dirty_cat:
         from dirty_cat import SuperVectorizer, GapEncoder, SimilarityEncoder
     from sklearn.preprocessing import FunctionTransformer
@@ -1126,7 +1083,7 @@ def process_nodes_dataframes(
     text_cols: List[str] = []
     text_model: Any = None
     text_enc = pd.DataFrame([])
-    has_deps_text, import_text_exn, _ = lazy_import_has_dependancy_text()
+    has_deps_text, import_text_exn, _ = lazy_sentence_transformers_import()
     if has_deps_text and (feature_engine in ["torch", "auto"]):
         text_enc, text_cols, text_model = encode_textual(
             df,
@@ -1497,7 +1454,7 @@ def transform_text(
     text_cols: Union[List, str],
 ) -> pd.DataFrame:
     from sklearn.pipeline import Pipeline
-    _, _, SentenceTransformer = lazy_import_has_dependancy_text()
+    _, _, SentenceTransformer = lazy_sentence_transformers_import()
 
     logger.debug("Transforming text using:")
     if isinstance(text_model, Pipeline):
diff --git a/graphistry/networks.py b/graphistry/networks.py
index 8d59263efd..b5fb04b438 100644
--- a/graphistry/networks.py
+++ b/graphistry/networks.py
@@ -1,25 +1,15 @@
 from typing import TYPE_CHECKING, Any
+from graphistry.utils.lazy_import import lazy_networks_import
+
 from . import constants as config
 
 import logging
 
 logger = logging.getLogger(__name__)
 
-def lazy_import_networks():  # noqa
-    try:
-        import dgl
-        import dgl.nn as dglnn
-        import dgl.function as fn
-        import torch
-        import torch.nn as nn
-        import torch.nn.functional as F
-        Module = nn.Module
-        return nn, dgl, dglnn, fn, torch, F, Module
-    except:
-        return Any, Any, Any, Any, Any, Any, object
 
 if TYPE_CHECKING:  # noqa
-    _, dgl, dglnn, fn, torch, F, Module = lazy_import_networks()
+    _, dgl, dglnn, fn, torch, F, Module = lazy_networks_import()
 else:
     nn = Any 
     dgl = Any
@@ -40,12 +30,12 @@ def lazy_import_networks():  # noqa
 class GCN(Module):  # type: ignore
     def __init__(self, in_feats, h_feats, num_classes):
         super(GCN, self).__init__()
-        _, _, dglnn, _, _, _, _ = lazy_import_networks()
+        _, _, dglnn, _, _, _, _ = lazy_networks_import()
         self.conv1 = dglnn.GraphConv(in_feats, h_feats)
         self.conv2 = dglnn.GraphConv(h_feats, num_classes)
 
     def forward(self, g, in_feat):
-        _, _, _, _, _, F, _ = lazy_import_networks()
+        _, _, _, _, _, F, _ = lazy_networks_import()
         h = self.conv1(g, in_feat)
         h = F.relu(h)
         h = self.conv2(g, h)
@@ -65,7 +55,7 @@ class RGCN(Module):  # type: ignore
 
     def __init__(self, in_feats, hid_feats, out_feats, rel_names):
         super().__init__()
-        _, _, dglnn, _, _, _, _ = lazy_import_networks()        
+        _, _, dglnn, _, _, _, _ = lazy_networks_import()        
         
         self.conv1 = dglnn.HeteroGraphConv(
             {rel: dglnn.GraphConv(in_feats, hid_feats) for rel in rel_names},
@@ -88,7 +78,7 @@ def forward(self, graph, inputs):
 class HeteroClassifier(Module):  # type: ignore
     def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
         super().__init__()
-        nn, _, _, _, _, _, _ = lazy_import_networks()
+        nn, _, _, _, _, _, _ = lazy_networks_import()
         self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
         self.classify = nn.Linear(hidden_dim, n_classes)
 
@@ -111,7 +101,7 @@ class MLPPredictor(Module):  # type: ignore
 
     def __init__(self, in_features, out_classes):
         super().__init__()
-        nn, _, _, _, _, _, _ = lazy_import_networks()
+        nn, _, _, _, _, _, _ = lazy_networks_import()
         self.W = nn.Linear(in_features * 2, out_classes)
 
     def apply_edges(self, edges):
@@ -133,7 +123,7 @@ def forward(self, graph, h):
 class SAGE(Module):  # type: ignore
     def __init__(self, in_feats, hid_feats, out_feats):
         super().__init__()
-        _, _, dglnn, _, _, _, _ = lazy_import_networks()
+        _, _, dglnn, _, _, _, _ = lazy_networks_import()
         self.conv1 = dglnn.SAGEConv(
             in_feats=in_feats, out_feats=hid_feats, aggregator_type="mean"
         )
@@ -152,7 +142,7 @@ def forward(self, graph, inputs):
 
 class DotProductPredictor(Module):  # type: ignore
     def forward(self, graph, h):
-        _, _, _, fn, _, _, _ = lazy_import_networks()
+        _, _, _, fn, _, _, _ = lazy_networks_import()
 
         # h contains the node representations computed from the GNN defined
         # in the node classification section (Section 5.1).
@@ -176,7 +166,7 @@ def forward(self, g, x):
 
 class LinkPredModelMultiOutput(Module):  # type: ignore
     def __init__(self, in_features, hidden_features, out_features, out_classes):
-        _, _, dglnn, _, _, _, _ = lazy_import_networks()
+        _, _, dglnn, _, _, _, _ = lazy_networks_import()
         super().__init__()
         self.sage = SAGE(in_features, hidden_features, out_features)
         self.pred = MLPPredictor(out_features, out_classes)
@@ -197,7 +187,7 @@ class RGCNEmbed(Module):  # type: ignore
     def __init__(self, d, num_nodes, num_rels, hidden=None, device='cpu'):
         super().__init__()
 
-        nn, _, dglnn, _, torch, _, _ = lazy_import_networks()
+        nn, _, dglnn, _, torch, _, _ = lazy_networks_import()
         self.node_ids = torch.tensor(range(num_nodes))
         
         self.node_ids = self.node_ids.to(device)
@@ -212,7 +202,7 @@ def __init__(self, d, num_nodes, num_rels, hidden=None, device='cpu'):
 
     def forward(self, g, node_features=None):
 
-        _, dgl, _, _, torch, F, _ = lazy_import_networks()
+        _, dgl, _, _, torch, F, _ = lazy_networks_import()
 
         x = self.emb(self.node_ids)
         x = self.rgc1(g, x, g.edata[dgl.ETYPE], g.edata['norm'])
@@ -236,7 +226,7 @@ def __init__(
         reg = 0.01
     ):
         super().__init__()
-        nn, _, _, _, torch, _, _ = lazy_import_networks()
+        nn, _, _, _, torch, _, _ = lazy_networks_import()
         self.reg = reg
         self.proto = proto
         self.node_features = node_features
@@ -267,7 +257,7 @@ def score(self, node_embedding, triplets):
         return score
 
     def loss(self, node_embedding, triplets, labels):
-        _, _, _, _, torch, F, _ = lazy_import_networks()
+        _, _, _, _, torch, F, _ = lazy_networks_import()
         score = self.score(node_embedding, triplets)
 
         # binary crossentropy loss
@@ -290,7 +280,7 @@ def loss(self, node_embedding, triplets, labels):
 #ACC = metrics.accuracy_score
    
 def train_link_pred(model, G, epochs=100, use_cross_entropy_loss = False):
-    _, _, _, _, torch, F, _ = lazy_import_networks()
+    _, _, _, _, torch, F, _ = lazy_networks_import()
     # take the node features out
     node_features = G.ndata["feature"].float()
     # we are predicting edges
diff --git a/graphistry/tests/test_compute_cluster.py b/graphistry/tests/test_compute_cluster.py
index 0afe003fe7..b9bcc77844 100644
--- a/graphistry/tests/test_compute_cluster.py
+++ b/graphistry/tests/test_compute_cluster.py
@@ -4,11 +4,13 @@
 import graphistry
 from graphistry.constants import DBSCAN
 from graphistry.util import ModelDict
-from graphistry.compute.cluster import lazy_dbscan_import_has_dependency
-from graphistry.umap_utils import lazy_umap_import_has_dependancy
+from graphistry.utils.lazy_import import (
+    lazy_dbscan_import,
+    lazy_umap_import
+)
 
-has_dbscan, _, has_gpu_dbscan, _ = lazy_dbscan_import_has_dependency()
-has_umap, _, _ = lazy_umap_import_has_dependancy()
+has_dbscan, _, has_gpu_dbscan, _ = lazy_dbscan_import()
+has_umap, _, _ = lazy_umap_import()
 
 
 ndf = edf = pd.DataFrame({'src': [1, 2, 1, 4], 'dst': [4, 5, 6, 1], 'label': ['a', 'b', 'b', 'c']})
diff --git a/graphistry/tests/test_dgl_utils.py b/graphistry/tests/test_dgl_utils.py
index 760045eee6..cf8f24bd91 100644
--- a/graphistry/tests/test_dgl_utils.py
+++ b/graphistry/tests/test_dgl_utils.py
@@ -3,10 +3,10 @@
 import graphistry
 import pandas as pd
 from graphistry.util import setup_logger
+from graphistry.utils.lazy_import import lazy_dgl_import
 
-from graphistry.dgl_utils import lazy_dgl_import_has_dependency
 
-has_dgl, _, dgl = lazy_dgl_import_has_dependency()
+has_dgl, _, dgl = lazy_dgl_import()
 
 if has_dgl:
     import torch
diff --git a/graphistry/tests/test_embed_utils.py b/graphistry/tests/test_embed_utils.py
index 307bdd0266..5bb92a49aa 100644
--- a/graphistry/tests/test_embed_utils.py
+++ b/graphistry/tests/test_embed_utils.py
@@ -5,12 +5,13 @@
 import graphistry
 import numpy as np
 
-from graphistry.embed_utils import lazy_embed_import_dep, check_cudf
+from graphistry.embed_utils import check_cudf
+from graphistry.utils.lazy_import import lazy_embed_import
 
 import logging
 logger = logging.getLogger(__name__)
 
-dep_flag, _, _, _, _, _, _, _ = lazy_embed_import_dep()
+dep_flag, _, _, _, _, _, _, _ = lazy_embed_import()
 has_cudf, cudf = check_cudf()
 
 # enable tests if has cudf and env didn't explicitly disable
diff --git a/graphistry/tests/test_feature_utils.py b/graphistry/tests/test_feature_utils.py
index fa4333737a..5bf2ae5d58 100644
--- a/graphistry/tests/test_feature_utils.py
+++ b/graphistry/tests/test_feature_utils.py
@@ -14,18 +14,20 @@
     process_dirty_dataframes,
     process_nodes_dataframes,
     resolve_feature_engine,
-    lazy_import_has_min_dependancy,
-    lazy_import_has_dependancy_text,
     FastEncoder
 )
 
 from graphistry.features import topic_model, ngrams_model
 from graphistry.constants import SCALERS
+from graphistry.utils.lazy_import import (
+    lazy_import_has_min_dependancy,
+    lazy_sentence_transformers_import
+)
 
 np.random.seed(137)
 
 has_min_dependancy, _ = lazy_import_has_min_dependancy()
-has_min_dependancy_text, _, _ = lazy_import_has_dependancy_text()
+has_min_dependancy_text, _, _ = lazy_sentence_transformers_import()
 
 logger = logging.getLogger(__name__)
 warnings.filterwarnings("ignore")
diff --git a/graphistry/tests/test_text_utils.py b/graphistry/tests/test_text_utils.py
index 649d74f89f..5b930f553f 100644
--- a/graphistry/tests/test_text_utils.py
+++ b/graphistry/tests/test_text_utils.py
@@ -10,13 +10,14 @@
 from graphistry.tests.test_feature_utils import (
     ndf_reddit,
     edge_df,
-    lazy_import_has_min_dependancy,
 )
-
-from graphistry.umap_utils import lazy_umap_import_has_dependancy
+from graphistry.utils.lazy_import import (
+    lazy_umap_import,
+    lazy_import_has_min_dependancy
+)
 
 has_dependancy, _ = lazy_import_has_min_dependancy()
-has_umap, _, _ = lazy_umap_import_has_dependancy()
+has_umap, _, _ = lazy_umap_import()
 
 logger = logging.getLogger(__name__)
 
diff --git a/graphistry/tests/test_umap_utils.py b/graphistry/tests/test_umap_utils.py
index 86b37e304e..0624a2a11d 100644
--- a/graphistry/tests/test_umap_utils.py
+++ b/graphistry/tests/test_umap_utils.py
@@ -25,17 +25,17 @@
     lazy_import_has_min_dependancy,
     check_allclose_fit_transform_on_same_data,
 )
-from graphistry.umap_utils import (
-    lazy_umap_import_has_dependancy,
-    lazy_cuml_import_has_dependancy,
-    lazy_cudf_import_has_dependancy,
+from graphistry.utils.lazy_import import (
+    lazy_cudf_import,
+    lazy_cuml_import,
+    lazy_umap_import,
 )
 from graphistry.util import cache_coercion_helper
 
 has_dependancy, _ = lazy_import_has_min_dependancy()
-has_cuml, _, _ = lazy_cuml_import_has_dependancy()
-has_umap, _, _ = lazy_umap_import_has_dependancy()
-has_cudf, _, cudf = lazy_cudf_import_has_dependancy()
+has_cuml, _, _ = lazy_cuml_import()
+has_umap, _, _ = lazy_umap_import()
+has_cudf, _, cudf = lazy_cudf_import()
 
 # print('has_dependancy', has_dependancy)
 # print('has_cuml', has_cuml)
diff --git a/graphistry/umap_utils.py b/graphistry/umap_utils.py
index d2561739df..8292710f7a 100644
--- a/graphistry/umap_utils.py
+++ b/graphistry/umap_utils.py
@@ -5,6 +5,11 @@
 
 import pandas as pd
 
+from graphistry.utils.lazy_import import (
+    lazy_cudf_import,
+    lazy_umap_import,
+    lazy_cuml_import,
+)
 from . import constants as config
 from .constants import CUML, UMAP_LEARN
 from .feature_utils import (FeatureMixin, Literal, XSymbolic, YSymbolic,
@@ -26,51 +31,15 @@
 ###############################################################################
 
 
-def lazy_umap_import_has_dependancy():
-    try:
-        import warnings
-
-        warnings.filterwarnings("ignore")
-        import umap  # noqa
-
-        return True, "ok", umap
-    except ModuleNotFoundError as e:
-        return False, e, None
-
-
-def lazy_cuml_import_has_dependancy():
-    try:
-        import warnings
-
-        warnings.filterwarnings("ignore")
-        with warnings.catch_warnings():
-            warnings.filterwarnings("ignore")
-            import cuml  # type: ignore
-
-        return True, "ok", cuml
-    except ModuleNotFoundError as e:
-        return False, e, None
-
-def lazy_cudf_import_has_dependancy():
-    try:
-        import warnings
-
-        warnings.filterwarnings("ignore")
-        import cudf  # type: ignore
-
-        return True, "ok", cudf
-    except ModuleNotFoundError as e:
-        return False, e, None
-
 def assert_imported():
-    has_dependancy_, import_exn, _ = lazy_umap_import_has_dependancy()
+    has_dependancy_, import_exn, _ = lazy_umap_import()
     if not has_dependancy_:
         logger.error("UMAP not found, trying running " "`pip install graphistry[ai]`")
         raise import_exn
 
 
 def assert_imported_cuml():
-    has_cuml_dependancy_, import_cuml_exn, _ = lazy_cuml_import_has_dependancy()
+    has_cuml_dependancy_, import_cuml_exn, _ = lazy_cuml_import()
     if not has_cuml_dependancy_:
         logger.warning("cuML not found, trying running " "`pip install cuml`")
         raise import_cuml_exn
@@ -99,10 +68,10 @@ def resolve_umap_engine(
     if engine in [CUML, UMAP_LEARN]:
         return engine  # type: ignore
     if engine in ["auto"]:
-        has_cuml_dependancy_, _, _ = lazy_cuml_import_has_dependancy()
+        has_cuml_dependancy_, _, _ = lazy_cuml_import()
         if has_cuml_dependancy_:
             return 'cuml'
-        has_umap_dependancy_, _, _ = lazy_umap_import_has_dependancy()
+        has_umap_dependancy_, _, _ = lazy_umap_import()
         if has_umap_dependancy_:
             return 'umap_learn'
 
@@ -134,7 +103,7 @@ def safe_cudf(X, y):
                 new_kwargs[key] = value
         return new_kwargs['X'], new_kwargs['y']
 
-    has_cudf_dependancy_, _, cudf = lazy_cudf_import_has_dependancy()
+    has_cudf_dependancy_, _, cudf = lazy_cudf_import()
     if has_cudf_dependancy_:
         return safe_cudf(X, y)
     else:
@@ -203,9 +172,9 @@ def umap_lazy_init(
         engine_resolved = resolve_umap_engine(engine)
         # FIXME remove as set_new_kwargs will always replace?
         if engine_resolved == UMAP_LEARN:
-            _, _, umap_engine = lazy_umap_import_has_dependancy()
+            _, _, umap_engine = lazy_umap_import()
         elif engine_resolved == CUML:
-            _, _, umap_engine = lazy_cuml_import_has_dependancy()
+            _, _, umap_engine = lazy_cuml_import()
         else:
             raise ValueError(
                 "No umap engine, ensure 'auto', 'umap_learn', or 'cuml', and the library is installed"
@@ -554,7 +523,7 @@ def umap(
         logger.debug("umap_kwargs: %s", umap_kwargs)
 
         # temporary until we have full cudf support in feature_utils.py
-        has_cudf, _, cudf = lazy_cudf_import_has_dependancy()
+        has_cudf, _, cudf = lazy_cudf_import()
 
         if has_cudf:
             flag_nodes_cudf = isinstance(self._nodes, cudf.DataFrame)
diff --git a/graphistry/utils/lazy_import.py b/graphistry/utils/lazy_import.py
new file mode 100644
index 0000000000..f7de35bdbf
--- /dev/null
+++ b/graphistry/utils/lazy_import.py
@@ -0,0 +1,179 @@
+from typing import Any
+import warnings
+from graphistry .util import setup_logger, check_set_memoize
+logger = setup_logger(__name__)
+
+
+#TODO use new importer when it lands (this is copied from umap_utils)
+def lazy_cudf_import():
+    try:
+        warnings.filterwarnings("ignore")
+        import cudf  # type: ignore
+
+        return True, "ok", cudf
+    except ModuleNotFoundError as e:
+        return False, e, None
+    except Exception as e:
+        logger.warn("Unexpected exn during lazy import", exc_info=e)
+        return False, e, None
+
+def lazy_cuml_import():
+    try:
+        warnings.filterwarnings("ignore")
+        with warnings.catch_warnings():
+            warnings.filterwarnings("ignore")
+            import cuml  # type: ignore
+
+        return True, "ok", cuml
+    except ModuleNotFoundError as e:
+        return False, e, None
+    except Exception as e:
+        logger.warn("Unexpected exn during lazy import", exc_info=e)
+        return False, e, None
+
+def lazy_dbscan_import():
+    has_min_dependency = True
+    DBSCAN = None
+    try:
+        from sklearn.cluster import DBSCAN
+    except ModuleNotFoundError:
+        has_min_dependency = False
+        logger.info("Please install sklearn for CPU DBSCAN")
+    except Exception as e:
+        logger.warn("Unexpected exn during lazy import", exc_info=e)
+        return False, None, False, None
+
+    has_cuml_dependency = True
+    cuDBSCAN = None
+    try:
+        from cuml import DBSCAN as cuDBSCAN
+    except ModuleNotFoundError:
+        has_cuml_dependency = False
+        logger.info("Please install cuml for GPU DBSCAN")
+    except Exception as e:
+        has_cuml_dependency = False
+        logger.warn("Unexpected exn during lazy import", exc_info=e)
+
+    return has_min_dependency, DBSCAN, has_cuml_dependency, cuDBSCAN
+
+def lazy_dgl_import():
+    try:
+        warnings.filterwarnings('ignore')
+        import dgl  # noqa: F811
+        return True, 'ok', dgl
+    except ModuleNotFoundError as e:
+        return False, e, None
+    except Exception as e:
+        logger.warn("Unexpected exn during lazy import", exc_info=e)
+        return False, e, None
+
+def lazy_dirty_cat_import():
+    warnings.filterwarnings("ignore")
+    try:
+        import dirty_cat 
+        return True, 'ok', dirty_cat
+    except ModuleNotFoundError as e:
+        return False, e, None
+    except Exception as e:
+        logger.warn('Unexpected exn during lazy import', exc_info=e)
+        return False, e, None
+
+def lazy_embed_import():
+    try:
+        import torch
+        import torch.nn as nn
+        import dgl
+        from dgl.dataloading import GraphDataLoader
+        import torch.nn.functional as F
+        from graphistry.networks import HeteroEmbed
+        from tqdm import trange
+        return True, torch, nn, dgl, GraphDataLoader, HeteroEmbed, F, trange
+    except ModuleNotFoundError:
+        return False, None, None, None, None, None, None, None
+    except Exception as e:
+        logger.warn('Unexpected exn during lazy import', exc_info=e)
+        return False, None, None, None, None, None, None, None
+
+def lazy_networks_import():  # noqa
+    try:
+        import dgl
+        import dgl.nn as dglnn
+        import dgl.function as fn
+        import torch
+        import torch.nn as nn
+        import torch.nn.functional as F
+        Module = nn.Module
+        return nn, dgl, dglnn, fn, torch, F, Module
+    except ModuleNotFoundError:
+        return None, None, None, None, None, None, None
+    except Exception as e:
+        logger.warn('Unexpected exn during lazy import', exc_info=e)
+        return None, None, None, None, None, None, None
+
+def lazy_torch_import_has_dependency():
+    try:
+        warnings.filterwarnings('ignore')
+        import torch  # noqa: F811
+        return True, 'ok', torch
+    except ModuleNotFoundError as e:
+        return False, e, None
+    except Exception as e:
+        logger.warn('Unexpected exn during lazy import', exc_info=e)
+        return False, e, None
+
+def lazy_umap_import():
+    try:
+        warnings.filterwarnings("ignore")
+        import umap  # noqa
+
+        return True, "ok", umap
+    except ModuleNotFoundError as e:
+        return False, e, None
+    except Exception as e:
+        logger.warn('Unexpected exn during lazy import', exc_info=e)
+        return False, e, None
+
+#@check_set_memoize
+def lazy_sentence_transformers_import():
+    warnings.filterwarnings("ignore")
+    try:
+        from sentence_transformers import SentenceTransformer
+        return True, 'ok', SentenceTransformer
+    except ModuleNotFoundError as e:
+        return False, e, None
+    except Exception as e:
+        logger.warn('Unexpected exn during lazy import', exc_info=e)
+        return False, e, None
+
+def lazy_import_has_min_dependancy():
+    warnings.filterwarnings("ignore")
+    try:
+        import scipy.sparse  # noqa
+        from scipy import __version__ as scipy_version
+        from sklearn import __version__ as sklearn_version
+        logger.debug(f"SCIPY VERSION: {scipy_version}")
+        logger.debug(f"sklearn VERSION: {sklearn_version}")
+        return True, 'ok'
+    except ModuleNotFoundError as e:
+        return False, e
+    except Exception as e:
+        logger.warn('Unexpected exn during lazy import', exc_info=e)
+        return False, e, None
+
+def assert_imported_text():
+    has_dependancy_text_, import_text_exn, _ = lazy_sentence_transformers_import()
+    if not has_dependancy_text_:
+        logger.error(  # noqa
+            "AI Package sentence_transformers not found,"
+            "trying running `pip install graphistry[ai]`"
+        )
+        raise import_text_exn
+
+def assert_imported():
+    has_min_dependancy_, import_min_exn = lazy_import_has_min_dependancy()
+    if not has_min_dependancy_:
+        logger.error(  # noqa
+                     "AI Packages not found, trying running"  # noqa
+                     "`pip install graphistry[ai]`"  # noqa
+        )
+        raise import_min_exn

From c8e674fbd314cf956ae4288026043170e6723c21 Mon Sep 17 00:00:00 2001
From: Leo Meyerovich <leo@graphistry.com>
Date: Fri, 19 Jul 2024 01:14:33 -0400
Subject: [PATCH 2/4] fix(ring layouts): handle filtered DFs

---
 CHANGELOG.md                          | 2 ++
 graphistry/layout/ring/categorical.py | 2 ++
 graphistry/layout/ring/continuous.py  | 2 ++
 graphistry/layout/ring/time.py        | 2 ++
 4 files changed, 8 insertions(+)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 03e9505bb3..1d29b058af 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -11,6 +11,8 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
 
 * Graceful CPU fallbacks: When lazy GPU dependency imports throw `ImportError`, commonly seen due to broken CUDA environments or having CUDA libraries but no GPU, warn and fall back to CPU.
 
+* Ring layouts now support filtered inputs, giving expected positions
+
 ### Changed
 
 * Centralize lazy imports into `graphistry.utils.lazy_import`
diff --git a/graphistry/layout/ring/categorical.py b/graphistry/layout/ring/categorical.py
index 61cbedd2ed..f491f797e2 100644
--- a/graphistry/layout/ring/categorical.py
+++ b/graphistry/layout/ring/categorical.py
@@ -170,6 +170,8 @@ def ring_categorical(
     if g._nodes is None:
         raise ValueError('Missing nodes')
 
+    g = g.nodes(g._nodes.reset_index(drop=True))
+
     engine_concrete = resolve_engine(engine, g._nodes)
 
     if ring_col is None or not isinstance(ring_col, str):
diff --git a/graphistry/layout/ring/continuous.py b/graphistry/layout/ring/continuous.py
index d0fe0d2074..e9bc5fafb0 100644
--- a/graphistry/layout/ring/continuous.py
+++ b/graphistry/layout/ring/continuous.py
@@ -182,6 +182,8 @@ def ring_continuous(
     if g._nodes is None:
         raise ValueError('Missing nodes')
 
+    g = g.nodes(g._nodes.reset_index(drop=True))
+
     engine_concrete = resolve_engine(engine, g._nodes)
 
     if ring_col is None:
diff --git a/graphistry/layout/ring/time.py b/graphistry/layout/ring/time.py
index f358b06cd3..cf10197d08 100644
--- a/graphistry/layout/ring/time.py
+++ b/graphistry/layout/ring/time.py
@@ -318,6 +318,8 @@ def time_ring(
 
     if g._nodes is None:
         raise ValueError('Expected nodes table')
+    
+    g = g.nodes(g._nodes.reset_index(drop=True))
 
     engine_concrete = resolve_engine(engine, g._nodes)
 

From 856839d7fa6b21bec4924fe8d09b422bc8c7f9b4 Mon Sep 17 00:00:00 2001
From: Leo Meyerovich <leo@graphistry.com>
Date: Fri, 19 Jul 2024 01:35:07 -0400
Subject: [PATCH 3/4] fix(encode_axis): functional updates

---
 CHANGELOG.md              |  2 ++
 graphistry/PlotterBase.py | 10 ++++------
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 1d29b058af..d74f16957a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -13,6 +13,8 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm
 
 * Ring layouts now support filtered inputs, giving expected positions
 
+* `encode_axis()` updates are now functional, not inplace
+
 ### Changed
 
 * Centralize lazy imports into `graphistry.utils.lazy_import`
diff --git a/graphistry/PlotterBase.py b/graphistry/PlotterBase.py
index 8e2ed75b3b..34c947e5cc 100644
--- a/graphistry/PlotterBase.py
+++ b/graphistry/PlotterBase.py
@@ -373,14 +373,12 @@ def encode_axis(self, rows: List[Dict] = []) -> Plottable:
 
         """
 
-        complex_encodings = self._complex_encodings or {}
+        complex_encodings = {**self._complex_encodings} if self._complex_encodings else {}
         if 'node_encodings' not in complex_encodings:
             complex_encodings['node_encodings'] = {}
-        node_encodings = complex_encodings['node_encodings']
-        if 'current' not in node_encodings:
-            node_encodings['current'] = {}
-        if 'default' not in node_encodings:
-            node_encodings['default'] = {}
+        node_encodings = {**complex_encodings['node_encodings']}
+        node_encodings['current'] = {**node_encodings['current']} if 'current' in node_encodings else {}
+        node_encodings['default'] = {**node_encodings['default']} if 'default' in node_encodings else {}
         node_encodings['default']["pointAxisEncoding"] = {
             "graphType": "point",
             "encodingType": "axis",

From ac11e56a033b7440d630efa261abf3c35e362430 Mon Sep 17 00:00:00 2001
From: Leo Meyerovich <leo@graphistry.com>
Date: Fri, 19 Jul 2024 01:40:17 -0400
Subject: [PATCH 4/4] fix(encode_axis): functional updates

---
 graphistry/PlotterBase.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/graphistry/PlotterBase.py b/graphistry/PlotterBase.py
index 34c947e5cc..ac39d37517 100644
--- a/graphistry/PlotterBase.py
+++ b/graphistry/PlotterBase.py
@@ -374,9 +374,8 @@ def encode_axis(self, rows: List[Dict] = []) -> Plottable:
         """
 
         complex_encodings = {**self._complex_encodings} if self._complex_encodings else {}
-        if 'node_encodings' not in complex_encodings:
-            complex_encodings['node_encodings'] = {}
-        node_encodings = {**complex_encodings['node_encodings']}
+        node_encodings = {**complex_encodings['node_encodings']} if 'node_encodings' not in complex_encodings else {}
+        complex_encodings['node_encodings'] = node_encodings
         node_encodings['current'] = {**node_encodings['current']} if 'current' in node_encodings else {}
         node_encodings['default'] = {**node_encodings['default']} if 'default' in node_encodings else {}
         node_encodings['default']["pointAxisEncoding"] = {