Skip to content

Commit

Permalink
Adding telemetry reporting to client side libraries.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 349468416
  • Loading branch information
SinaChavoshi authored and Tensorflow Cloud maintainers committed Dec 29, 2020
1 parent 94f3785 commit 8644935
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
"""Tests for optimizer_client."""

from googleapiclient import errors
from googleapiclient import http as googleapiclient_http
import httplib2
import mock
import tensorflow as tf
from tensorflow_cloud import version
from tensorflow_cloud.tuner import optimizer_client
from tensorflow_cloud.utils import google_api_client


class OptimizerClientTest(tf.test.TestCase):
Expand Down Expand Up @@ -511,18 +508,6 @@ def test_delete_study_with_404_raises_ValueError(self):
.format(self._trial_parent)):
self._client.delete_study()

def test_cloud_tuner_request_header(self):
http_request = google_api_client.TFCloudHttpRequest(
googleapiclient_http.HttpMockSequence([({"status": "200"}, "{}")]),
object(),
"fake_uri",
)
self.assertIsInstance(http_request, googleapiclient_http.HttpRequest)
self.assertEqual(
{"user-agent": "tf-cloud/" + version.__version__},
http_request.headers
)


if __name__ == "__main__":
tf.test.main()
147 changes: 144 additions & 3 deletions src/python/tensorflow_cloud/utils/google_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,42 @@
# limitations under the License.
"""Utilities for Google API client."""

import enum
import json
import os
import sys
import time
from typing import Text
from typing import Dict, Text
from .. import version
from absl import logging
from googleapiclient import discovery
from googleapiclient import errors
from googleapiclient import http as googleapiclient_http

_USER_AGENT_FOR_TF_CLOUD_TRACKING = "tf-cloud/" + version.__version__
_TF_CLOUD_USER_AGENT_HEADER = "tf-cloud/" + version.__version__
_POLL_INTERVAL_IN_SECONDS = 30
_LOCAL_CONFIG_PATH = os.path.expanduser(
"~/.config/tf_cloud/tf_cloud_config.json")
_PRIVACY_NOTICE = """
This application reports technical and operational details of your usage of
Cloud Services in accordance with Google privacy policy, for more information
please refer to https://cloud.google.com/terms/cloud-privacy-notice. If you wish
to opt-out, you may do so by running
tensorflow_cloud.utils.google_api_client.optout_metrics_reporting().
"""

_TELEMETRY_REJECTED_CONFIG = "telemetry_rejected"
_TELEMETRY_VERSION_CONFIG = "notification_version"


class ClientEnvironment(enum.Enum):
"""Types of client environment for telemetry reporting."""
UNKNOWN = 0
KAGGLE_NOTEBOOK = 1
HOSTED_NOTEBOOK = 2
DLVM = 3
DL_CONTAINER = 4
COLAB = 5


class TFCloudHttpRequest(googleapiclient_http.HttpRequest):
Expand All @@ -31,6 +57,9 @@ class TFCloudHttpRequest(googleapiclient_http.HttpRequest):
This is used to track the usage of the TF Cloud.
"""

# Class property for passing additional telemetry fields to constructor.
_telemetry_dict = {}

def __init__(self, *args, **kwargs):
"""Construct a HttpRequest.
Expand All @@ -39,9 +68,121 @@ def __init__(self, *args, **kwargs):
**kwargs: Keyword arguments to pass to the base class constructor.
"""
headers = kwargs.setdefault("headers", {})
headers["user-agent"] = _USER_AGENT_FOR_TF_CLOUD_TRACKING

comments = {}
if get_or_set_consent_status():
comments = self._telemetry_dict

# Add the local environment to the user agent header comment field.
comments["client_environment"] = get_client_environment_name()

# construct comment string using comments dict
user_agent_text = f"{_TF_CLOUD_USER_AGENT_HEADER} ("
for key, value in comments.items():
user_agent_text = f"{user_agent_text}{key}:{value};"
user_agent_text = f"{user_agent_text})"

headers["user-agent"] = user_agent_text
super(TFCloudHttpRequest, self).__init__(*args, **kwargs)

# @classmethod @property chain is only supported in python 3.9+, see
# https://docs.python.org/3/howto/descriptor.html#id27. Using class
# getter and setter instead.
@classmethod
def get_telemetry_dict(cls):
telemetry_dict = cls._telemetry_dict.copy()
return telemetry_dict

@classmethod
def set_telemetry_dict(cls, telemetry_dict: Dict[Text, Text]):
cls._telemetry_dict = telemetry_dict.copy()


# TODO(b/176097105) Use get_client_environment_name in tfc.run and cloud_fit
def get_client_environment_name() -> Text:
"""Identifies the local environment where tensorflow_cloud is running.
Returns:
ClientEnvironment Enum representing the environment type.
"""
if os.getenv("KAGGLE_CONTAINER_NAME"):
logging.info("Kaggle client environment detected.")
return ClientEnvironment.KAGGLE_NOTEBOOK.name

if "google.colab" in sys.modules:
logging.info("Detected running in COLAB environment.")
return ClientEnvironment.COLAB.name

if os.getenv("DL_PATH"):
# TODO(b/171720710) Update logic based resolution of the issue.
if os.getenv("USER") == "jupyter":
logging.info("Detected running in HOSTED_NOTEBOOK environment.")
return ClientEnvironment.HOSTED_NOTEBOOK.name

# TODO(b/175815580) Update logic based resolution of the issue.
logging.info("Detected running in DLVM environment.")
return ClientEnvironment.DLVM.name

# TODO(b/175815580) Update logic based resolution of the issue.
if "google" in sys.modules:
logging.info("Detected running in DL_CONTAINER environment.")
return ClientEnvironment.DL_CONTAINER.name

logging.info("Detected running in UNKNOWN environment.")
return ClientEnvironment.UNKNOWN.name


def get_or_set_consent_status()-> bool:
"""Gets or sets the user consent status for telemetry collection.
Returns:
If the user has rejected client side telemetry collection returns
False, otherwise it returns true, if a consent flag is not found the
user is notified of telemetry collection and a flag is set.
"""
# Verify if user consent exists and if it is valid for current version of
# tensorflow_cloud
if os.path.exists(_LOCAL_CONFIG_PATH):
with open(_LOCAL_CONFIG_PATH) as config_json:
config_data = json.load(config_json)
if config_data.get(_TELEMETRY_REJECTED_CONFIG):
logging.info("User has opt-out of telemetry reporting.")
return False
if config_data.get(
_TELEMETRY_VERSION_CONFIG) == version.__version__:
return True

# Either user has not been notified of telemetry collection or a different
# version of the tensorflow_cloud has been installed since the last
# notification. Notify the user and update the configuration.
logging.info(_PRIVACY_NOTICE)
print(_PRIVACY_NOTICE)

config_data = {}
config_data[_TELEMETRY_VERSION_CONFIG] = version.__version__

# Create the config path if it does not already exist
os.makedirs(os.path.dirname(_LOCAL_CONFIG_PATH), exist_ok=True)

with open(_LOCAL_CONFIG_PATH, "w") as config_json:
json.dump(config_data, config_json)
return True


def optout_metrics_reporting():
"""Set configuration to opt-out of client side metric reporting."""

config_data = {}
config_data["telemetry_rejected"] = True

# Create the config path if it does not already exist
os.makedirs(os.path.dirname(_LOCAL_CONFIG_PATH), exist_ok=True)

with open(_LOCAL_CONFIG_PATH, "w") as config_json:
json.dump(config_data, config_json)

logging.info("Client side metrics reporting has been disabled.")


# TODO(b/170436896) change wait_for_api_.. to wait_for_aip_..
def wait_for_api_training_job_completion(job_id: Text, project_id: Text)->bool:
Expand Down

0 comments on commit 8644935

Please sign in to comment.