From 5e1c7f4f702574c30c0cfc3b1a27f10f5e196115 Mon Sep 17 00:00:00 2001 From: Wenjun Si Date: Mon, 26 Aug 2024 17:07:08 +0800 Subject: [PATCH] Releases v0.11.6.5 (#248) --- odps/_version.py | 2 +- odps/accounts.py | 212 +++++++++++++++++-------- odps/conftest.py | 8 - odps/core.py | 27 ++-- odps/df/tests/test_delay.py | 2 +- odps/errors.py | 12 +- odps/models/instance.py | 4 +- odps/models/table.py | 3 +- odps/models/tests/test_schemas.py | 40 +++-- odps/models/tests/test_tasks.py | 3 +- odps/rest.py | 7 + odps/tests/core.py | 19 ++- odps/tests/test_accounts.py | 209 +++++++++++++++++++----- odps/tunnel/tabletunnel.py | 20 ++- odps/tunnel/tests/test_tabletunnel.py | 3 +- odps/tunnel/tests/test_volumetunnel.py | 42 +++-- 16 files changed, 436 insertions(+), 177 deletions(-) diff --git a/odps/_version.py b/odps/_version.py index 7a59602..8c5a80c 100644 --- a/odps/_version.py +++ b/odps/_version.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -version_info = (0, 11, 6, 4) +version_info = (0, 11, 6, 5) _num_index = max(idx if isinstance(v, int) else 0 for idx, v in enumerate(version_info)) __version__ = '.'.join(map(str, version_info[:_num_index + 1])) + \ diff --git a/odps/accounts.py b/odps/accounts.py index ffc8a78..af9d5b3 100644 --- a/odps/accounts.py +++ b/odps/accounts.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""A couple of authentication types in ODPS. -""" +"""A couple of authentication types in ODPS.""" import base64 -import hmac +import calendar import hashlib +import hmac +import json import logging import os import threading @@ -27,13 +28,12 @@ import requests -from .compat import six, cgi, urlparse, unquote, parse_qsl from . import options, utils - +from .compat import cgi, parse_qsl, six, unquote, urlparse logger = logging.getLogger(__name__) -DEFAULT_BEARER_TOKEN_HOURS = 5 +DEFAULT_TEMP_ACCOUNT_HOURS = 5 class BaseAccount(object): @@ -138,20 +138,6 @@ def sign_request(self, req, endpoint, region_name=None): logger.debug('headers after signing: %r', req.headers) -class StsAccount(AliyunAccount): - """ - Account of sts - """ - def __init__(self, access_id, secret_access_key, sts_token): - super(StsAccount, self).__init__(access_id, secret_access_key) - self.sts_token = sts_token - - def sign_request(self, req, endpoint, region_name=None): - super(StsAccount, self).sign_request(req, endpoint, region_name=region_name) - if self.sts_token: - req.headers['authorization-sts-token'] = self.sts_token - - class AppAccount(BaseAccount): """ Account for applications. @@ -352,77 +338,152 @@ def sign_request(self, req, endpoint, region_name=None): ) -class BearerTokenAccount(BaseAccount): +class TempAccountMixin(object): + def __init__(self, expired_hours=DEFAULT_TEMP_ACCOUNT_HOURS): + self._last_modified_time = datetime.now() + if expired_hours is not None: + self._expired_time = timedelta(hours=expired_hours) + else: + self._expired_time = None + self.reload() + + def _is_account_valid(self): + raise NotImplementedError + + def _reload_account(self): + raise NotImplementedError + + def reload(self, force=False): + t = datetime.now() + if ( + force + or not self._is_account_valid() + or ( + self._last_modified_time is not None + and self._expired_time is not None + and (t - self._last_modified_time) > self._expired_time + ) + ): + self._last_modified_time = self._reload_account() or datetime.now() + + +class StsAccount(TempAccountMixin, AliyunAccount): + """ + Account of sts + """ + def __init__( - self, token=None, expired_hours=DEFAULT_BEARER_TOKEN_HOURS, get_bearer_token_fun=None + self, + access_id, + secret_access_key, + sts_token, + expired_hours=DEFAULT_TEMP_ACCOUNT_HOURS, ): - self._get_bearer_token = get_bearer_token_fun or self.get_default_bearer_token - self._token = token or self._get_bearer_token() - self._reload_bearer_token_time() + self.sts_token = sts_token + AliyunAccount.__init__(self, access_id, secret_access_key) + TempAccountMixin.__init__(self, expired_hours=expired_hours) - self._expired_time = timedelta(hours=expired_hours) + @classmethod + def from_environments(cls): + expired_hours = int( + os.getenv("ODPS_STS_TOKEN_HOURS", str(DEFAULT_TEMP_ACCOUNT_HOURS)) + ) + if "ODPS_STS_ACCOUNT_FILE" in os.environ or "ODPS_STS_TOKEN" in os.environ: + if "ODPS_STS_ACCOUNT_FILE" not in os.environ: + expired_hours = None + return cls(None, None, None, expired_hours=expired_hours) + return None + + def sign_request(self, req, endpoint, region_name=None): + self.reload() + super(StsAccount, self).sign_request(req, endpoint, region_name=region_name) + if self.sts_token: + req.headers["authorization-sts-token"] = self.sts_token + + def _is_account_valid(self): + return self.sts_token is not None + + def _resolve_expiration(self, exp_data): + if exp_data is None or self._expired_time is None: + return None + try: + ts = calendar.timegm(time.strptime(exp_data, "%Y-%m-%dT%H:%M:%SZ")) + return ts - self._expired_time.total_seconds() + except: + return None + + def _reload_account(self): + ts = None + if "ODPS_STS_ACCOUNT_FILE" in os.environ: + token_file_name = os.getenv("ODPS_STS_ACCOUNT_FILE") + if token_file_name and os.path.exists(token_file_name): + with open(token_file_name, "r") as token_file: + token_json = json.load(token_file) + self.access_id = token_json["accessKeyId"] + self.secret_access_key = token_json["accessKeySecret"] + self.sts_token = token_json["securityToken"] + ts = self._resolve_expiration(token_json.get("expiration")) + elif "ODPS_STS_ACCESS_KEY_ID" in os.environ: + self.access_id = os.getenv("ODPS_STS_ACCESS_KEY_ID") + self.secret_access_key = os.getenv("ODPS_STS_ACCESS_KEY_SECRET") + self.sts_token = os.getenv("ODPS_STS_TOKEN") + + return datetime.fromtimestamp(ts) if ts is not None else None + + +class BearerTokenAccount(TempAccountMixin, BaseAccount): + def __init__( + self, token=None, expired_hours=DEFAULT_TEMP_ACCOUNT_HOURS, get_bearer_token_fun=None + ): + self.token = token + self._custom_bearer_token_func = get_bearer_token_fun + TempAccountMixin.__init__(self, expired_hours=expired_hours) @classmethod def from_environments(cls): - expired_hours = int(os.getenv('ODPS_BEARER_TOKEN_HOURS', str(DEFAULT_BEARER_TOKEN_HOURS))) + expired_hours = int(os.getenv('ODPS_BEARER_TOKEN_HOURS', str(DEFAULT_TEMP_ACCOUNT_HOURS))) kwargs = {"expired_hours": expired_hours} - if 'ODPS_BEARER_TOKEN' in os.environ: - return cls(os.environ['ODPS_BEARER_TOKEN'], **kwargs) - elif 'ODPS_BEARER_TOKEN_FILE' in os.environ: + if "ODPS_BEARER_TOKEN_FILE" in os.environ: return cls(**kwargs) + elif "ODPS_BEARER_TOKEN" in os.environ: + kwargs["expired_hours"] = None + return cls(os.environ["ODPS_BEARER_TOKEN"], **kwargs) return None - @staticmethod - def get_default_bearer_token(): + def _get_bearer_token(self): + if self._custom_bearer_token_func is not None: + return self._custom_bearer_token_func() + token_file_name = os.getenv("ODPS_BEARER_TOKEN_FILE") if token_file_name and os.path.exists(token_file_name): with open(token_file_name, "r") as token_file: return token_file.read().strip() + else: # pragma: no cover + from cupid.runtime import context, RuntimeContext - from cupid.runtime import context, RuntimeContext - - if not RuntimeContext.is_context_ready(): - return - cupid_context = context() - return cupid_context.get_bearer_token() - - def get_bearer_token_and_timestamp(self): - self._check_bearer_token() - return self._token, self._last_modified_time.timestamp() - - def _reload_bearer_token_time(self): - if "ODPS_BEARER_TOKEN_TIMESTAMP_FILE" in os.environ: - with open(os.getenv("ODPS_BEARER_TOKEN_TIMESTAMP_FILE"), "r") as ts_file: - self._last_modified_time = datetime.fromtimestamp(float(ts_file.read())) - else: - self._last_modified_time = datetime.now() - - def _check_bearer_token(self): - t = datetime.now() - if self._last_modified_time is None: - token = self._get_bearer_token() - if token is None: - return - if token != self._token: - self._token = token - self._reload_bearer_token_time() - elif (t - self._last_modified_time) > self._expired_time: - token = self._get_bearer_token() - if token is None: + if not RuntimeContext.is_context_ready(): return - self._token = token - self._reload_bearer_token_time() + cupid_context = context() + return cupid_context.get_bearer_token() - @property - def token(self): - return self._token + def _is_account_valid(self): + return self.token is not None + + def _reload_account(self): + token = self._get_bearer_token() + self.token = token + try: + resolved_token_parts = base64.b64decode(token).decode().split(",") + return datetime.fromtimestamp(int(resolved_token_parts[2])) + except: + return None def sign_request(self, req, endpoint, region_name=None): - self._check_bearer_token() + self.reload() url = req.url[len(endpoint):] url_components = urlparse(unquote(url), allow_fragments=False) self._build_canonical_str(url_components, req) - req.headers['x-odps-bearer-token'] = self._token + req.headers['x-odps-bearer-token'] = self.token logger.debug('headers after signing: %r', req.headers) @@ -432,7 +493,10 @@ def __init__(self, credential_provider): super(CredentialProviderAccount, self).__init__(None, None, None) def sign_request(self, req, endpoint, region_name=None): - credential = self.provider.get_credentials() + get_cred_method = getattr(self.provider, "get_credential", None) or getattr( + self.provider, "get_credentials" + ) + credential = get_cred_method() self.access_id = credential.get_access_key_id() self.secret_access_key = credential.get_access_key_secret() @@ -440,3 +504,11 @@ def sign_request(self, req, endpoint, region_name=None): return super(CredentialProviderAccount, self).sign_request( req, endpoint, region_name=region_name ) + + +def from_environments(): + for account_cls in (StsAccount, BearerTokenAccount): + account = account_cls.from_environments() + if account is not None: + break + return account diff --git a/odps/conftest.py b/odps/conftest.py index cc8dec9..c10962a 100644 --- a/odps/conftest.py +++ b/odps/conftest.py @@ -33,14 +33,6 @@ def odps_with_schema(): pytest.skip("ODPS project with schema not defined") -@pytest.fixture(scope="session") -def odps_with_schema_tenant(): - try: - return get_config().odps_with_schema_tenant - except AttributeError: - pytest.skip("ODPS project with schema configured on tenants not defined") - - @pytest.fixture(scope="session") def odps_with_tunnel_quota(): try: diff --git a/odps/core.py b/odps/core.py index b8f43d7..ba8116e 100644 --- a/odps/core.py +++ b/odps/core.py @@ -130,12 +130,14 @@ def _init( if account is None: if access_id is not None: self.account = self._build_account(access_id, secret_access_key) - elif accounts.BearerTokenAccount.from_environments(): - self.account = accounts.BearerTokenAccount.from_environments() elif options.account is not None: self.account = options.account else: - raise TypeError('`access_id` and `secret_access_key` should be provided.') + self.account = accounts.from_environments() + if self.account is None: + raise TypeError( + "`access_id` and `secret_access_key` should be provided." + ) else: self.account = account self.endpoint = ( @@ -2544,14 +2546,17 @@ def from_global(cls): @classmethod def from_environments(cls): try: - account = accounts.BearerTokenAccount.from_environments() - if not account: - raise KeyError('ODPS_BEARER_TOKEN') - project = os.getenv('ODPS_PROJECT_NAME') - endpoint = os.environ['ODPS_ENDPOINT'] - tunnel_endpoint = os.getenv('ODPS_TUNNEL_ENDPOINT') - return cls(None, None, account=account, project=project, - endpoint=endpoint, tunnel_endpoint=tunnel_endpoint) + project = os.getenv("ODPS_PROJECT_NAME") + endpoint = os.environ["ODPS_ENDPOINT"] + tunnel_endpoint = os.getenv("ODPS_TUNNEL_ENDPOINT") + return cls( + None, + None, + account=accounts.from_environments(), + project=project, + endpoint=endpoint, + tunnel_endpoint=tunnel_endpoint, + ) except KeyError: return None diff --git a/odps/df/tests/test_delay.py b/odps/df/tests/test_delay.py index ebfaf22..734df84 100644 --- a/odps/df/tests/test_delay.py +++ b/odps/df/tests/test_delay.py @@ -60,7 +60,7 @@ def test_async_execute(setup): def make_filter(df, cnt): def waiter(val, c): import time - time.sleep(5 * c) + time.sleep(30 * c) return val f_df = df[df.value == cnt] diff --git a/odps/errors.py b/odps/errors.py index 0f1f2ae..fd0ee78 100644 --- a/odps/errors.py +++ b/odps/errors.py @@ -225,6 +225,10 @@ class NoSuchObject(ServerDefinedException): pass +class NoSuchProject(NoSuchObject): + pass + + class NoSuchPartition(NoSuchObject): pass @@ -241,6 +245,10 @@ class InvalidArgument(ServerDefinedException): pass +class AuthenticationRequestExpired(ServerDefinedException): + pass + + class AuthorizationRequired(ServerDefinedException): pass @@ -363,10 +371,6 @@ class SecurityQueryError(ODPSError): pass -class NoSuchProject(ODPSError): - pass - - class OSSSignUrlError(ODPSError): def __init__(self, err): if isinstance(err, six.string_types): diff --git a/odps/models/instance.py b/odps/models/instance.py index 0ed9bf6..eaa3274 100644 --- a/odps/models/instance.py +++ b/odps/models/instance.py @@ -1014,7 +1014,9 @@ def open_reader(self, *args, **kwargs): timeout = timeout if timeout is not None else options.tunnel.legacy_fallback_timeout kwargs["timeout"] = timeout - result_fallback_errors = (errors.InvalidProjectTable, errors.InvalidArgument) + result_fallback_errors = ( + errors.InvalidProjectTable, errors.InvalidArgument, errors.NoSuchProject + ) if use_tunnel: # for compatibility if 'limit_enabled' in kwargs: diff --git a/odps/models/table.py b/odps/models/table.py index 62970f5..17cf3a7 100644 --- a/odps/models/table.py +++ b/odps/models/table.py @@ -207,8 +207,9 @@ class Table(LazyLoad): class Type(Enum): MANAGED_TABLE = "MANAGED_TABLE" - VIRTUAL_VIEW = "VIRTUAL_VIEW" EXTERNAL_TABLE = "EXTERNAL_TABLE" + OBJECT_TABLE = "OBJECT_TABLE" + VIRTUAL_VIEW = "VIRTUAL_VIEW" MATERIALIZED_VIEW = "MATERIALIZED_VIEW" name = serializers.XMLNodeField('Name') diff --git a/odps/models/tests/test_schemas.py b/odps/models/tests/test_schemas.py index 2b6fac8..d61ade4 100644 --- a/odps/models/tests/test_schemas.py +++ b/odps/models/tests/test_schemas.py @@ -1,4 +1,4 @@ -# Copyright 1999-2022 Alibaba Group Holding Ltd. +# Copyright 1999-2024 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,12 +15,13 @@ import os import time +import mock import pytest from ... import ODPS, options from ...compat import BytesIO from ...errors import NoSuchObject -from ...tests.core import tn, force_drop_schema +from ...tests.core import force_drop_schema, tn try: import pyarrow as pa @@ -110,12 +111,14 @@ def test_schemas(odps_with_schema, legacy): assert any(s.name == schema_name for s in odps_with_schema.list_schemas()) for idx in range(5): - odps_with_schema.delete_schema(schema) try: + odps_with_schema.delete_schema(schema) _assert_schema_deleted(odps_with_schema, schema_name) except AssertionError: if idx >= 5: raise + except NoSuchObject: + break else: break @@ -123,12 +126,14 @@ def test_schemas(odps_with_schema, legacy): assert odps_with_schema.exist_schema(schema_name2) for idx in range(5): - schema.drop() try: + schema.drop() _assert_schema_deleted(odps_with_schema, schema_name2) except AssertionError: if idx >= 5: raise + except NoSuchObject: + break else: break @@ -245,22 +250,27 @@ def test_get_table_with_schema_opt(odps_with_schema): options.always_enable_schema = False -def test_table_tenant_config(odps_with_schema_tenant): - odps = odps_with_schema_tenant +def test_table_tenant_config(odps_with_schema): + odps = odps_with_schema test_table_name = tn("pyodps_test_table_with_schema3") - assert odps.is_schema_namespace_enabled() + def _new_get_parameter(self, key, default=None): + assert key == "odps.namespace.schema" + return "true" - odps.delete_table(test_table_name, if_exists=True) - tb = odps.create_table(test_table_name, "col1 string", lifecycle=1) - assert tb.get_schema().name == "default" + with mock.patch("odps.models.tenant.Tenant.get_parameter", new=_new_get_parameter): + assert odps.is_schema_namespace_enabled() - tb = odps.get_table("default." + test_table_name) - tb.reload() - assert tb.name == test_table_name - assert tb.get_schema().name == "default" + odps.delete_table(test_table_name, if_exists=True) + tb = odps.create_table(test_table_name, "col1 string", lifecycle=1) + assert tb.get_schema().name == "default" - tb.drop() + tb = odps.get_table("default." + test_table_name) + tb.reload() + assert tb.name == test_table_name + assert tb.get_schema().name == "default" + + tb.drop() def test_file_resource_with_schema(odps_with_schema): diff --git a/odps/models/tests/test_tasks.py b/odps/models/tests/test_tasks.py index af3bf26..67c979e 100644 --- a/odps/models/tests/test_tasks.py +++ b/odps/models/tests/test_tasks.py @@ -325,5 +325,4 @@ def test_ray_cluster_init(odps): assert to_text(to_xml) == to_text(right_xml) task = Task.parse(None, to_xml) assert isinstance(task, MaxFrameTask) - assert task.command == MaxFrameTask.CommandType.RAY_CLUSTER_INIT - + assert task.command == MaxFrameTask.CommandType.RAY_CLUSTER_INIT \ No newline at end of file diff --git a/odps/rest.py b/odps/rest.py index 32f7212..70771be 100644 --- a/odps/rest.py +++ b/odps/rest.py @@ -192,6 +192,8 @@ def request(self, url, method, stream=False, **kwargs): if self._endpoint in self._endpoints_without_v4_sign or not options.enable_v4_sign: sign_region_name = None + auth_expire_retried = False + while True: kwargs["region_name"] = sign_region_name try: @@ -215,6 +217,11 @@ def request(self, url, method, stream=False, **kwargs): raise self._endpoints_without_v4_sign.add(self._endpoint) sign_region_name = None + except errors.AuthenticationRequestExpired: + if not hasattr(self.account, "reload") or auth_expire_retried: + raise + self.account.reload(True) + auth_expire_retried = True def _request(self, url, method, stream=False, **kwargs): self.upload_survey_log() diff --git a/odps/tests/core.py b/odps/tests/core.py index 6ca5981..e6f2d07 100644 --- a/odps/tests/core.py +++ b/odps/tests/core.py @@ -24,6 +24,7 @@ import types import pytest + try: from flaky import flaky as _raw_flaky except ImportError: @@ -78,10 +79,21 @@ def _load_config_odps(config, section_name, overwrite_global=True): except ConfigParser.NoSectionError: return - access_id = config.get(section_name, "access_id") - secret_access_key = config.get(section_name, "secret_access_key") project = config.get(section_name, "project") - endpoint = config.get(section_name, "endpoint") + + try: + access_id = config.get(section_name, "access_id") + except ConfigParser.NoOptionError: + access_id = config.get("odps", "access_id") + try: + secret_access_key = config.get(section_name, "secret_access_key") + except ConfigParser.NoOptionError: + secret_access_key = config.get("odps", "secret_access_key") + try: + endpoint = config.get(section_name, "endpoint") + except ConfigParser.NoOptionError: + endpoint = config.get("odps", "endpoint") + try: seahawks_url = config.get(section_name, "seahawks_url") except (ConfigParser.NoSectionError, ConfigParser.NoOptionError): @@ -127,7 +139,6 @@ def get_config(): _load_config_odps(config, "odps_daily", overwrite_global=False) _load_config_odps(config, "odps_with_storage_tier", overwrite_global=False) _load_config_odps(config, "odps_with_schema", overwrite_global=False) - _load_config_odps(config, "odps_with_schema_tenant", overwrite_global=False) _load_config_odps(config, "odps_with_tunnel_quota", overwrite_global=False) # make sure main config overrides other configs _load_config_odps(config, "odps") diff --git a/odps/tests/test_accounts.py b/odps/tests/test_accounts.py index d047471..422f935 100644 --- a/odps/tests/test_accounts.py +++ b/odps/tests/test_accounts.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# Copyright 1999-2022 Alibaba Group Holding Ltd. +# Copyright 1999-2024 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import datetime +import json import os import shutil import tempfile @@ -23,14 +25,18 @@ import mock import pytest +import requests from .. import ODPS, errors, options from ..accounts import ( + DEFAULT_TEMP_ACCOUNT_HOURS, BearerTokenAccount, CredentialProviderAccount, SignServer, SignServerAccount, SignServerError, + StsAccount, + from_environments, ) from ..rest import RestClient from .core import tn @@ -83,18 +89,64 @@ def test_tokenized_sign_server_account(odps): server.stop() +def test_sts_account(odps): + tmp_path = tempfile.mkdtemp(prefix="tmp_pyodps_") + req = requests.Request(method="GET", url=odps.get_project().resource()) + try: + token_account = StsAccount( + odps.account.access_id, odps.account.secret_access_key, "token" + ) + cp_req = copy.deepcopy(req) + token_account.sign_request(cp_req, odps.endpoint) + assert "token" == cp_req.headers["authorization-sts-token"] + + os.environ["ODPS_STS_ACCESS_KEY_ID"] = odps.account.access_id + os.environ["ODPS_STS_ACCESS_KEY_SECRET"] = odps.account.secret_access_key + os.environ["ODPS_STS_TOKEN"] = "token" + account = from_environments() + assert isinstance(account, StsAccount) + cp_req = copy.deepcopy(req) + token_account.sign_request(cp_req, odps.endpoint) + assert "token" == cp_req.headers["authorization-sts-token"] + + os.environ.pop("ODPS_STS_ACCESS_KEY_ID", None) + os.environ.pop("ODPS_STS_ACCESS_KEY_SECRET", None) + os.environ.pop("ODPS_STS_TOKEN", None) + + sts_file_name = os.path.join(tmp_path, "sts_file") + os.environ["ODPS_STS_ACCOUNT_FILE"] = sts_file_name + exp_time = int(time.time() + 3 * 3600) + account_data = { + "accessKeyId": odps.account.access_id, + "accessKeySecret": odps.account.secret_access_key, + "securityToken": "token", + "expiration": datetime.datetime.utcfromtimestamp(exp_time).strftime( + "%Y-%m-%dT%H:%M:%SZ" + ), + } + with open(sts_file_name, "w") as out_file: + out_file.write(json.dumps(account_data)) + account = from_environments() + assert isinstance(account, StsAccount) + assert account._last_modified_time == datetime.datetime.fromtimestamp( + exp_time + ) - datetime.timedelta(hours=DEFAULT_TEMP_ACCOUNT_HOURS) + + cp_req = copy.deepcopy(req) + token_account.sign_request(cp_req, odps.endpoint) + assert "token" == cp_req.headers["authorization-sts-token"] + finally: + shutil.rmtree(tmp_path) + os.environ.pop("ODPS_STS_ACCESS_KEY_ID", None) + os.environ.pop("ODPS_STS_ACCESS_KEY_SECRET", None) + os.environ.pop("ODPS_STS_TOKEN", None) + os.environ.pop("ODPS_STS_ACCOUNT_FILE", None) + + @pytest.mark.skipif(cupid_context is None, reason="cannot import cupid context") def test_bearer_token_account(odps): - odps.delete_table(tn('test_bearer_token_account_table'), if_exists=True) - t = odps.create_table(tn('test_bearer_token_account_table'), 'col string', lifecycle=1) - with t.open_writer() as writer: - records = [['val1'], ['val2'], ['val3']] - writer.write(records) - - inst = odps.execute_sql( - 'select count(*) from {0}'.format(tn('test_bearer_token_account_table')), async_=True - ) - inst.wait_for_success() + inst = odps.run_sql("select count(*) from dual") + inst.wait_for_completion() task_name = inst.get_task_names()[0] logview_address = inst.get_logview_address() @@ -112,41 +164,67 @@ def test_bearer_token_account(odps): bearer_token_odps.create_table(tn('test_bearer_token_account_table_test1'), 'col string', lifecycle=1) + +def test_fake_bearer_token(odps): fake_token_account = BearerTokenAccount(token='fake-token') bearer_token_odps = ODPS( - None, None, odps.project, odps.endpoint, account=fake_token_account + None, + None, + odps.project, + odps.endpoint, + account=fake_token_account, + overwrite_global=False, ) with pytest.raises(errors.ODPSError): bearer_token_odps.create_table(tn('test_bearer_token_account_table_test2'), 'col string', lifecycle=1) + +def test_bearer_token_load_and_update(odps): + token = "fake-token" tmp_path = tempfile.mkdtemp(prefix="tmp_pyodps_") + os.environ["ODPS_BEARER_TOKEN_HOURS"] = "0" try: token_file_name = os.path.join(tmp_path, "token_file") with open(token_file_name, "w") as token_file: token_file.write(token) os.environ["ODPS_BEARER_TOKEN_FILE"] = token_file_name - token_ts_file_name = os.path.join(tmp_path, "token_ts_file") create_timestamp = int(time.time()) - with open(token_ts_file_name, "w") as token_ts_file: - token_ts_file.write(str(create_timestamp)) - os.environ["ODPS_BEARER_TOKEN_TIMESTAMP_FILE"] = token_ts_file_name + options.account = None env_odps = ODPS(project=odps.project, endpoint=odps.endpoint) assert isinstance(env_odps.account, BearerTokenAccount) assert env_odps.account.token == token - assert env_odps.account._last_modified_time == datetime.datetime.fromtimestamp(create_timestamp) + assert env_odps.account._last_modified_time > datetime.datetime.fromtimestamp(create_timestamp) + + last_timestamp = env_odps.account._last_modified_time + env_odps.account.reload() + assert env_odps.account._last_modified_time > last_timestamp + + inst = odps.run_sql("select count(*) from dual") + logview_address = inst.get_logview_address() + token = logview_address[logview_address.find("token=") + len("token=") :] + with open(token_file_name, "w") as token_file: + token_file.write(token) + + last_timestamp = env_odps.account._last_modified_time + env_odps.account.reload() + assert env_odps.account._last_modified_time != last_timestamp + + last_timestamp = env_odps.account._last_modified_time + env_odps.account.reload() + assert env_odps.account._last_modified_time == last_timestamp finally: shutil.rmtree(tmp_path) - os.environ.pop("ODPS_BEARER_TOKEN_FILE") - os.environ.pop("ODPS_BEARER_TOKEN_TIMESTAMP_FILE") + os.environ.pop("ODPS_BEARER_TOKEN_HOURS", None) + os.environ.pop("ODPS_BEARER_TOKEN_FILE", None) + os.environ.pop("ODPS_BEARER_TOKEN_TIMESTAMP_FILE", None) -def test_v4_signature_fallback(odps_daily): - odps = odps_daily - odps.delete_table(tn('test_sign_account_table'), if_exists=True) +def test_v4_signature_fallback(odps): + odps.delete_table(tn("test_sign_account_table"), if_exists=True) assert odps.endpoint not in RestClient._endpoints_without_v4_sign def _new_is_ok(self, resp): @@ -158,7 +236,7 @@ def _new_is_ok2(self, resp): if odps.endpoint not in self._endpoints_without_v4_sign: raise errors.InternalServerError( "ODPS-0010000:System internal error - Error occurred while getting access key for " - "'%s', AliyunV4 request need ak v3 support" % odps_daily.account.access_id + "'%s', AliyunV4 request need ak v3 support" % odps.account.access_id ) return resp.ok @@ -189,26 +267,79 @@ def _new_is_ok3(self, resp): options.enable_v4_sign = old_enable_v4_sign -def test_credential_provider_account(odps): - class MockCredentials(object): - @classmethod - def get_access_key_id(cls): - return odps.account.access_id +def test_auth_expire_reload(odps): + inst = odps.run_sql("select count(*) from dual") + inst.wait_for_completion() + + tmp_path = tempfile.mkdtemp(prefix="tmp_pyodps_") + try: + logview_address = inst.get_logview_address() + token = logview_address[logview_address.find("token=") + len("token=") :] + + token_file = os.path.join(tmp_path, "token_ts_file") + os.environ["ODPS_BEARER_TOKEN_FILE"] = token_file + with open(token_file, "w") as token_file_obj: + token_file_obj.write("invalid_token") + + token_odps = ODPS( + account=BearerTokenAccount(), project=odps.project, endpoint=odps.endpoint + ) + + retrial_counts = [0] + + def _new_is_ok(self, resp): + if not retrial_counts[0]: + with open(token_file, "w") as token_file_obj: + token_file_obj.write(token) + retrial_counts[0] += 1 + raise errors.AuthenticationRequestExpired("mock auth expired") + return resp.ok + + with mock.patch("odps.rest.RestClient.is_ok", new=_new_is_ok): + token_inst = token_odps.get_instance(inst.id) + token_inst.reload() + assert retrial_counts[0] == 1 + assert token_odps.account.token is not None + finally: + shutil.rmtree(tmp_path) + os.environ.pop("ODPS_BEARER_TOKEN_FILE", None) - @classmethod - def get_access_key_secret(cls): - return odps.account.secret_access_key - @classmethod - def get_security_token(cls): - return None # kept empty to skip sts token check +class MockCredentials(object): + def __init__(self, odps): + self._odps = odps - class MockCredentialProvider(object): - @classmethod - def get_credentials(cls): - return MockCredentials() + def get_access_key_id(self): + return self._odps.account.access_id - account = CredentialProviderAccount(MockCredentialProvider()) + def get_access_key_secret(self): + return self._odps.account.secret_access_key + + def get_security_token(cls): + return None # kept empty to skip sts token check + + +class MockCredentialProvider(object): + def __init__(self, odps): + self._odps = odps + + def get_credentials(self): + return MockCredentials(self._odps) + + +class MockCredentialProvider2(object): + def __init__(self, odps): + self._odps = odps + + def get_credential(self): + return MockCredentials(self._odps) + + +@pytest.mark.parametrize( + "provider_cls", [MockCredentialProvider, MockCredentialProvider2] +) +def test_credential_provider_account(odps, provider_cls): + account = CredentialProviderAccount(provider_cls(odps)) cred_odps = ODPS( account, None, odps.project, odps.endpoint ) diff --git a/odps/tunnel/tabletunnel.py b/odps/tunnel/tabletunnel.py index 4cc7fbf..b812282 100644 --- a/odps/tunnel/tabletunnel.py +++ b/odps/tunnel/tabletunnel.py @@ -76,6 +76,7 @@ class BaseTableTunnelSession(serializers.JSONSerializableModel): def get_common_headers(content_length=None, chunked=False): header = { "odps-tunnel-date-transform": TUNNEL_DATA_TRANSFORM_VERSION, + "odps-tunnel-sdk-support-schema-evolution": "true", "x-odps-tunnel-version": TUNNEL_VERSION, } if content_length is not None: @@ -586,9 +587,16 @@ def __iter__(self): slots = serializers.JSONNodeField( 'slots', parse_callback=lambda val: TableStreamUploadSession.Slots(val)) quota_name = serializers.JSONNodeField('QuotaName') + schema_version = serializers.JSONNodeField("schema_version") def __init__( - self, client, table, partition_spec, compress_option=None, quota_name=None + self, + client, + table, + partition_spec, + compress_option=None, + quota_name=None, + schema_version=None, ): super(TableStreamUploadSession, self).__init__() @@ -597,6 +605,7 @@ def __init__( self._partition_spec = self.normalize_partition_spec(partition_spec) self._quota_name = quota_name + self.schema_version = schema_version self._init() self._compress_option = compress_option @@ -620,11 +629,15 @@ def _init(self): params = self.get_common_params() headers = self.get_common_headers(content_length=0) + if self.schema_version is not None: + params["schema_version"] = str(self.schema_version) + url = self._get_resource() resp = self._client.post(url, {}, params=params, headers=headers) self.check_tunnel_response(resp) self.parse(resp, obj=self) + self._quota_name = self.quota_name if self.schema is not None: self.schema.build_snapshot() @@ -640,6 +653,7 @@ def reload(self): self.check_tunnel_response(resp) self.parse(resp, obj=self) + self._quota_name = self.quota_name if self.schema is not None: self.schema.build_snapshot() @@ -919,7 +933,7 @@ def _build_compress_option(compress_algo=None, level=None, strategy=None): compress_algo=compress_algo, level=level, strategy=strategy ) - def create_download_session(self, table, async_mode=False, partition_spec=None, + def create_download_session(self, table, async_mode=True, partition_spec=None, download_id=None, compress_option=None, compress_algo=None, compress_level=None, compress_strategy=None, schema=None, timeout=None, **kw): @@ -978,6 +992,7 @@ def create_stream_upload_session( compress_level=None, compress_strategy=None, schema=None, + schema_version=None, ): table = self._get_tunnel_table(table, schema) compress_option = compress_option or self._build_compress_option( @@ -989,6 +1004,7 @@ def create_stream_upload_session( partition_spec, compress_option=compress_option, quota_name=self._quota_name, + schema_version=schema_version, ) def create_upsert_session( diff --git a/odps/tunnel/tests/test_tabletunnel.py b/odps/tunnel/tests/test_tabletunnel.py index 15d31d1..b650865 100644 --- a/odps/tunnel/tests/test_tabletunnel.py +++ b/odps/tunnel/tests/test_tabletunnel.py @@ -384,8 +384,9 @@ def test_upload_and_download_by_raw_tunnel(setup): @py_and_c_deco -def test_stream_upload_and_download_tunnel(setup): +def test_stream_upload_and_download_tunnel(odps, setup): test_table_name = tn('pyodps_test_stream_upload_' + get_code_mode()) + odps.delete_table(test_table_name, if_exists=True) setup.create_table(test_table_name) data = setup.gen_data() diff --git a/odps/tunnel/tests/test_volumetunnel.py b/odps/tunnel/tests/test_volumetunnel.py index 2d3e1f2..5a6bdc6 100644 --- a/odps/tunnel/tests/test_volumetunnel.py +++ b/odps/tunnel/tests/test_volumetunnel.py @@ -19,12 +19,11 @@ import pytest from ...compat import irange, six -from ...tests.core import tn +from ...errors import NoSuchObject +from ...tests.core import get_test_unique_name, tn from .. import CompressOption from ..volumetunnel import VolumeTunnel -TEST_PARTED_VOLUME_NAME = tn('pyodps_test_p_volume') -TEST_FS_VOLUME_NAME = tn('pyodps_test_fs_volume') TEST_PARTITION_NAME = 'pyodps_test_partition' TEST_FILE_NAME = 'test_output_file' @@ -34,20 +33,23 @@ @pytest.fixture def setup(odps): + test_parted_vol_name = tn("pyodps_test_p_volume_" + get_test_unique_name(5)) + test_fs_vol_name = tn("pyodps_test_fs_volume" + get_test_unique_name(5)) + def gen_byte_block(): return bytes(bytearray([iid % TEST_MODULUS for iid in irange(TEST_BLOCK_SIZE)])) def get_test_partition(): - if odps.exist_volume(TEST_PARTED_VOLUME_NAME): - odps.delete_volume(TEST_PARTED_VOLUME_NAME) - odps.create_parted_volume(TEST_PARTED_VOLUME_NAME) - return odps.get_volume_partition(TEST_PARTED_VOLUME_NAME, TEST_PARTITION_NAME) + if odps.exist_volume(test_parted_vol_name): + odps.delete_volume(test_parted_vol_name) + odps.create_parted_volume(test_parted_vol_name) + return odps.get_volume_partition(test_parted_vol_name, TEST_PARTITION_NAME) def get_test_fs(): - if odps.exist_volume(TEST_FS_VOLUME_NAME): - odps.delete_volume(TEST_FS_VOLUME_NAME) - odps.create_fs_volume(TEST_FS_VOLUME_NAME) - return odps.get_volume(TEST_FS_VOLUME_NAME) + if odps.exist_volume(test_fs_vol_name): + odps.delete_volume(test_fs_vol_name) + odps.create_fs_volume(test_fs_vol_name) + return odps.get_volume(test_fs_vol_name) def wrap_fun(func): @six.wraps(func) @@ -66,14 +68,20 @@ def wrapped(self, *args, **kwargs): ) VolumeTunnel.create_upload_session = wrap_fun(_old_create_upload_session) - tn = namedtuple("TN", "gen_byte_block, get_test_partition, get_test_fs") + test_funcs = namedtuple( + "TestFuncs", "gen_byte_block, get_test_partition, get_test_fs" + ) try: - yield tn(gen_byte_block, get_test_partition, get_test_fs) + yield test_funcs(gen_byte_block, get_test_partition, get_test_fs) finally: - if odps.exist_volume(TEST_PARTED_VOLUME_NAME): - odps.delete_volume(TEST_PARTED_VOLUME_NAME) - if odps.exist_volume(TEST_FS_VOLUME_NAME): - odps.delete_volume(TEST_FS_VOLUME_NAME) + try: + odps.delete_volume(test_parted_vol_name) + except NoSuchObject: + pass + try: + odps.delete_volume(test_fs_vol_name) + except NoSuchObject: + pass VolumeTunnel.create_download_session = _old_create_download_session VolumeTunnel.create_upload_session = _old_create_upload_session