Skip to content

Commit

Permalink
[Model Monitoring] Fix credentials exposure in YAML (mlrun#7203)
Browse files Browse the repository at this point in the history
  • Loading branch information
alxtkr77 authored Feb 4, 2025
1 parent de71511 commit e22ea2f
Show file tree
Hide file tree
Showing 12 changed files with 54 additions and 91 deletions.
29 changes: 0 additions & 29 deletions mlrun/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,35 +1366,6 @@ def is_ce_mode(self) -> bool:
ver in mlrun.mlconf.ce.mode for ver in ["lite", "full"]
)

def get_s3_storage_options(self) -> dict[str, typing.Any]:
"""
Generate storage options dictionary as required for handling S3 path in fsspec. The model monitoring stream
graph uses this method for generating the storage options for S3 parquet target path.
:return: A storage options dictionary in which each key-value pair represents a particular configuration,
such as endpoint_url or aws access key.
"""
key = mlrun.get_secret_or_env("AWS_ACCESS_KEY_ID")
secret = mlrun.get_secret_or_env("AWS_SECRET_ACCESS_KEY")

force_non_anonymous = mlrun.get_secret_or_env("S3_NON_ANONYMOUS")
profile = mlrun.get_secret_or_env("AWS_PROFILE")

storage_options = dict(
anon=not (force_non_anonymous or (key and secret)),
key=key,
secret=secret,
)

endpoint_url = mlrun.get_secret_or_env("S3_ENDPOINT_URL")
if endpoint_url:
client_kwargs = {"endpoint_url": endpoint_url}
storage_options["client_kwargs"] = client_kwargs

if profile:
storage_options["profile"] = profile

return storage_options

def is_explicit_ack_enabled(self) -> bool:
return self.httpdb.nuclio.explicit_ack == "enabled" and (
not self.nuclio_version
Expand Down
23 changes: 20 additions & 3 deletions mlrun/datastore/storeytargets.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,21 @@ def get_url_and_storage_options(path, external_storage_options=None):


class TDEngineStoreyTarget(storey.TDEngineTarget):
def __init__(self, *args, **kwargs):
kwargs["url"] = mlrun.model_monitoring.helpers.get_tsdb_connection_string()
super().__init__(*args, **kwargs)
def __init__(self, *args, url: str, **kwargs):
if url.startswith("ds://"):
datastore_profile = (
mlrun.datastore.datastore_profile.datastore_profile_read(url)
)
if not isinstance(
datastore_profile,
mlrun.datastore.datastore_profile.TDEngineDatastoreProfile,
):
raise ValueError(
f"Unexpected datastore profile type:{datastore_profile.type}."
"Only TDEngineDatastoreProfile is supported"
)
url = datastore_profile.dsn()
super().__init__(*args, url=url, **kwargs)


class StoreyTargetUtils:
Expand All @@ -69,7 +81,12 @@ def process_args_and_kwargs(args, kwargs):

class ParquetStoreyTarget(storey.ParquetTarget):
def __init__(self, *args, **kwargs):
alt_key_name = kwargs.pop("alternative_v3io_access_key", None)
args, kwargs = StoreyTargetUtils.process_args_and_kwargs(args, kwargs)
storage_options = kwargs.get("storage_options", {})
if storage_options and storage_options.get("v3io_access_key") and alt_key_name:
if alt_key := mlrun.get_secret_or_env(alt_key_name):
storage_options["v3io_access_key"] = alt_key
super().__init__(*args, **kwargs)


Expand Down
7 changes: 4 additions & 3 deletions mlrun/model_monitoring/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,10 @@ def __init__(self) -> None:

self.model_monitoring_access_key = self._get_model_monitoring_access_key()
self.v3io_access_key = mlrun.mlconf.get_v3io_access_key()
self.storage_options = None
if mlrun.mlconf.artifact_path.startswith("s3://"):
self.storage_options = mlrun.mlconf.get_s3_storage_options()
store, _, _ = mlrun.store_manager.get_or_create_store(
mlrun.mlconf.artifact_path
)
self.storage_options = store.get_storage_options()

@staticmethod
def _get_model_monitoring_access_key() -> Optional[str]:
Expand Down
14 changes: 9 additions & 5 deletions mlrun/model_monitoring/db/tsdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import mlrun.datastore.datastore_profile
import mlrun.errors
import mlrun.model_monitoring.helpers
from mlrun.datastore.datastore_profile import DatastoreProfile

from .base import TSDBConnector

Expand All @@ -29,10 +30,13 @@ class ObjectTSDBFactory(enum.Enum):
v3io_tsdb = "v3io-tsdb"
tdengine = "tdengine"

def to_tsdb_connector(self, project: str, **kwargs) -> TSDBConnector:
def to_tsdb_connector(
self, project: str, profile: DatastoreProfile, **kwargs
) -> TSDBConnector:
"""
Return a TSDBConnector object based on the provided enum value.
:param project: The name of the project.
:param profile: Datastore profile containing DSN and credentials for TSDB connection
:return: `TSDBConnector` object.
"""

Expand All @@ -51,7 +55,7 @@ def to_tsdb_connector(self, project: str, **kwargs) -> TSDBConnector:

from .tdengine.tdengine_connector import TDEngineConnector

return TDEngineConnector(project=project, **kwargs)
return TDEngineConnector(project=project, profile=profile, **kwargs)

@classmethod
def _missing_(cls, value: typing.Any):
Expand Down Expand Up @@ -87,12 +91,10 @@ def get_tsdb_connector(
kwargs = {}
if isinstance(profile, mlrun.datastore.datastore_profile.DatastoreProfileV3io):
tsdb_connector_type = mlrun.common.schemas.model_monitoring.TSDBTarget.V3IO_TSDB
kwargs["v3io_access_key"] = profile.v3io_access_key
elif isinstance(
profile, mlrun.datastore.datastore_profile.TDEngineDatastoreProfile
):
tsdb_connector_type = mlrun.common.schemas.model_monitoring.TSDBTarget.TDEngine
kwargs["connection_string"] = profile.dsn()
else:
extra_message = (
""
Expand All @@ -109,4 +111,6 @@ def get_tsdb_connector(
tsdb_connector_factory = ObjectTSDBFactory(tsdb_connector_type)

# Convert into TSDB connector object
return tsdb_connector_factory.to_tsdb_connector(project=project, **kwargs)
return tsdb_connector_factory.to_tsdb_connector(
project=project, profile=profile, **kwargs
)
18 changes: 8 additions & 10 deletions mlrun/model_monitoring/db/tsdb/tdengine/tdengine_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import mlrun.common.schemas.model_monitoring as mm_schemas
import mlrun.model_monitoring.db.tsdb.tdengine.schemas as tdengine_schemas
import mlrun.model_monitoring.db.tsdb.tdengine.stream_graph_steps
from mlrun.datastore.datastore_profile import DatastoreProfile
from mlrun.model_monitoring.db import TSDBConnector
from mlrun.model_monitoring.helpers import get_invocations_fqn
from mlrun.utils import logger
Expand All @@ -40,16 +41,13 @@ class TDEngineConnector(TSDBConnector):
def __init__(
self,
project: str,
profile: DatastoreProfile,
database: typing.Optional[str] = None,
**kwargs,
):
super().__init__(project=project)
if "connection_string" not in kwargs:
raise mlrun.errors.MLRunInvalidArgumentError(
"connection_string is a required parameter for TDEngineConnector."
)

self._tdengine_connection_string = kwargs.get("connection_string")
self._tdengine_connection_profile = profile
self.database = (
database
or f"{tdengine_schemas._MODEL_MONITORING_DATABASE}_{mlrun.mlconf.system_id}"
Expand All @@ -70,7 +68,7 @@ def connection(self) -> TDEngineConnection:
def _create_connection(self) -> TDEngineConnection:
"""Establish a connection to the TSDB server."""
logger.debug("Creating a new connection to TDEngine", project=self.project)
conn = TDEngineConnection(self._tdengine_connection_string)
conn = TDEngineConnection(self._tdengine_connection_profile.dsn())
conn.run(
statements=f"CREATE DATABASE IF NOT EXISTS {self.database}",
timeout=self._timeout,
Expand Down Expand Up @@ -200,10 +198,10 @@ def apply_process_before_tsdb():

def apply_tdengine_target(name, after):
graph.add_step(
"storey.TDEngineTarget",
"mlrun.datastore.storeytargets.TDEngineStoreyTarget",
name=name,
after=after,
url=self._tdengine_connection_string,
url=f"ds://{self._tdengine_connection_profile.name}",
supertable=self.tables[
mm_schemas.TDEngineSuperTables.PREDICTIONS
].super_table,
Expand Down Expand Up @@ -242,10 +240,10 @@ def handle_model_error(
after="ForwardError",
)
graph.add_step(
"storey.TDEngineTarget",
"mlrun.datastore.storeytargets.TDEngineStoreyTarget",
name="tsdb_error",
after="error_extractor",
url=self._tdengine_connection_string,
url=f"ds://{self._tdengine_connection_profile.name}",
supertable=self.tables[mm_schemas.TDEngineSuperTables.ERRORS].super_table,
table_col=mm_schemas.EventFieldType.TABLE_COLUMN,
time_col=mm_schemas.EventFieldType.TIME,
Expand Down
15 changes: 0 additions & 15 deletions mlrun/model_monitoring/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,21 +246,6 @@ def get_monitoring_drift_measures_data(project: str, endpoint_id: str) -> "DataI
)


def get_tsdb_connection_string(
secret_provider: Optional[Callable[[str], str]] = None,
) -> str:
"""Get TSDB connection string from the project secret. If wasn't set, take it from the system
configurations.
:param secret_provider: An optional secret provider to get the connection string secret.
:return: Valid TSDB connection string.
"""

return mlrun.get_secret_or_env(
key=mm_constants.ProjectSecretKeys.TSDB_CONNECTION,
secret_provider=secret_provider,
)


def _get_profile(
project: str,
secret_provider: Optional[Callable[[str], str]],
Expand Down
10 changes: 2 additions & 8 deletions mlrun/model_monitoring/stream_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,11 @@ def __init__(
parquet_batching_max_events=self.parquet_batching_max_events,
)

self.storage_options = None
self.tsdb_configurations = {}
if not mlrun.mlconf.is_ce_mode():
self._initialize_v3io_configurations(
model_monitoring_access_key=model_monitoring_access_key
)
elif self.parquet_path.startswith("s3://"):
self.storage_options = mlrun.mlconf.get_s3_storage_options()

def _initialize_v3io_configurations(
self,
Expand All @@ -95,9 +92,6 @@ def _initialize_v3io_configurations(
or os.environ.get(ProjectSecretKeys.ACCESS_KEY)
or self.v3io_access_key
)
self.storage_options = dict(
v3io_access_key=self.model_monitoring_access_key, v3io_api=self.v3io_api
)

# TSDB path and configurations
tsdb_path = mlrun.mlconf.get_model_monitoring_file_target_path(
Expand Down Expand Up @@ -248,12 +242,12 @@ def apply_process_before_parquet():
# Write the Parquet target file, partitioned by key (endpoint_id) and time.
def apply_parquet_target():
graph.add_step(
"storey.ParquetTarget",
"mlrun.datastore.storeytargets.ParquetStoreyTarget",
alternative_v3io_access_key=mlrun.common.schemas.model_monitoring.ProjectSecretKeys.ACCESS_KEY,
name="ParquetTarget",
after="ProcessBeforeParquet",
graph_shape="cylinder",
path=self.parquet_path,
storage_options=self.storage_options,
max_events=self.parquet_batching_max_events,
flush_after_seconds=self.parquet_batching_timeout_secs,
attributes={"infer_columns_from_data": True},
Expand Down
8 changes: 1 addition & 7 deletions mlrun/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -3671,7 +3671,6 @@ def export(self, filepath=None, include_files: Optional[str] = None):

def set_model_monitoring_credentials(
self,
access_key: Optional[str] = None,
stream_path: Optional[str] = None, # Deprecated
tsdb_connection: Optional[str] = None, # Deprecated
replace_creds: bool = False,
Expand All @@ -3684,11 +3683,6 @@ def set_model_monitoring_credentials(
infrastructure functions. Important to note that you have to set the credentials before deploying any
model monitoring or serving function.
:param access_key: Model monitoring access key for managing user permissions.
* None - will be set from the system configuration.
* v3io - for v3io endpoint store, pass `v3io` and the system will generate the
exact path.
:param stream_path: (Deprecated) This argument is deprecated. Use ``stream_profile_name`` instead.
Path to the model monitoring stream. By default, None. Options:
Expand Down Expand Up @@ -3791,7 +3785,7 @@ def set_model_monitoring_credentials(
db.set_model_monitoring_credentials(
project=self.name,
credentials={
"access_key": access_key,
"access_key": None,
"tsdb_profile_name": tsdb_profile_name,
"stream_profile_name": stream_profile_name,
},
Expand Down
1 change: 1 addition & 0 deletions server/py/services/api/crud/model_monitoring/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def deploy_model_monitoring_stream_processing(
db_session=self.db_session, project=self.project
)
)

fn = self._initial_model_monitoring_stream_processing_function(
stream_image=stream_image, parquet_target=parquet_target
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,6 @@ def test_credentials(
)

secrets = monitoring_deployment._get_monitoring_mandatory_project_secrets()
assert (
secrets[
mlrun.common.schemas.model_monitoring.ProjectSecretKeys.TSDB_CONNECTION
]
== "v3io"
)

monitoring_deployment.set_credentials(
tsdb_connection="v3io",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ModelEndpointMonitoringMetric,
ModelEndpointMonitoringMetricType,
)
from mlrun.datastore.datastore_profile import TDEngineDatastoreProfile
from mlrun.model_monitoring.db.tsdb.tdengine import TDEngineConnector

project = "test-tdengine-connector"
Expand All @@ -41,11 +42,12 @@ def is_tdengine_defined() -> bool:

@pytest.fixture
def connector() -> Iterator[TDEngineConnector]:
connection = taosws.connect()
connection = taosws.connect(connection_string)
drop_database(connection)
conn = TDEngineConnector(
project, connection_string=connection_string, database=database
profile = TDEngineDatastoreProfile.from_dsn(
profile_name="mm-profile", dsn=connection_string
)
conn = TDEngineConnector(project, profile=profile, database=database)
try:
yield conn
finally:
Expand Down
6 changes: 4 additions & 2 deletions tests/model_monitoring/test_tdengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dateutil import parser

import mlrun.common.schemas
from mlrun.datastore.datastore_profile import TDEngineDatastoreProfile
from mlrun.model_monitoring.db.tsdb.tdengine import TDEngineConnector
from mlrun.model_monitoring.db.tsdb.tdengine.schemas import (
_MODEL_MONITORING_DATABASE,
Expand Down Expand Up @@ -493,9 +494,10 @@ def test_get_records_with_interval_query(
class TestTDEngineConnector:
@pytest.fixture
def connector(self):
return TDEngineConnector(
project="test-project", connection_string="taosws://localhost:6041"
profile = TDEngineDatastoreProfile(
name="mm-profile", host="localhost", port=6041, user="root"
)
return TDEngineConnector(project="test-project", profile=profile)

def test_get_last_request(self, connector):
df = pd.DataFrame(
Expand Down

0 comments on commit e22ea2f

Please sign in to comment.