From 08bc0f44904fe0d8bc8779e0e892e4d42def3983 Mon Sep 17 00:00:00 2001 From: Andreas Poehlmann Date: Tue, 20 Feb 2024 10:53:49 +0100 Subject: [PATCH] Update ObjectStoragePath for universal_pathlib>=v0.2.1 (#37524) This updates ObjectStoragePath to be compatible with universal_pathlib >= 0.2.1 which in turn makes it compatible with Python 3.12+. --- airflow/io/path.py | 169 ++++++------------ airflow/providers/common/io/xcom/backend.py | 4 +- pyproject.toml | 9 +- tests/io/test_path.py | 135 +++++++++----- .../providers/common/io/xcom/test_backend.py | 2 +- 5 files changed, 154 insertions(+), 165 deletions(-) diff --git a/airflow/io/path.py b/airflow/io/path.py index d65d837e7e5db..cb4c48c47618b 100644 --- a/airflow/io/path.py +++ b/airflow/io/path.py @@ -17,24 +17,20 @@ from __future__ import annotations import contextlib -import functools import os import shutil import typing -from pathlib import PurePath +from typing import Any, Mapping from urllib.parse import urlsplit -from fsspec.core import split_protocol from fsspec.utils import stringify_path -from upath.implementations.cloud import CloudPath, _CloudAccessor +from upath.implementations.cloud import CloudPath from upath.registry import get_upath_class from airflow.io.store import attach from airflow.io.utils.stat import stat_result if typing.TYPE_CHECKING: - from urllib.parse import SplitResult - from fsspec import AbstractFileSystem @@ -43,124 +39,68 @@ default = "file" -class _AirflowCloudAccessor(_CloudAccessor): - __slots__ = ("_store",) - - def __init__( - self, - parsed_url: SplitResult | None, - conn_id: str | None = None, - **kwargs: typing.Any, - ) -> None: - # warning: we are not calling super().__init__ here - # as it will try to create a new fs from a different - # set if registered filesystems - if parsed_url and parsed_url.scheme: - self._store = attach(parsed_url.scheme, conn_id) - else: - self._store = attach("file", conn_id) - - @property - def _fs(self) -> AbstractFileSystem: - return self._store.fs - - def __eq__(self, other): - return isinstance(other, _AirflowCloudAccessor) and self._store == other._store - - class ObjectStoragePath(CloudPath): """A path-like object for object storage.""" - _accessor: _AirflowCloudAccessor - __version__: typing.ClassVar[int] = 1 - _default_accessor = _AirflowCloudAccessor + _protocol_dispatch = False sep: typing.ClassVar[str] = "/" root_marker: typing.ClassVar[str] = "/" - _bucket: str - _key: str - _protocol: str - _hash: int | None - - __slots__ = ( - "_bucket", - "_key", - "_conn_id", - "_protocol", - "_hash", - ) - - def __new__( - cls: type[PT], - *args: str | os.PathLike, - scheme: str | None = None, - conn_id: str | None = None, - **kwargs: typing.Any, - ) -> PT: - args_list = list(args) - - if args_list: - other = args_list.pop(0) or "." - else: - other = "." - - if isinstance(other, PurePath): - _cls: typing.Any = type(other) - drv, root, parts = _cls._parse_args(args_list) - drv, root, parts = _cls._flavour.join_parsed_parts( - other._drv, # type: ignore[attr-defined] - other._root, # type: ignore[attr-defined] - other._parts, # type: ignore[attr-defined] - drv, - root, - parts, # type: ignore - ) - - _kwargs = getattr(other, "_kwargs", {}) - _url = getattr(other, "_url", None) - other_kwargs = _kwargs.copy() - if _url and _url.scheme: - other_kwargs["url"] = _url - new_kwargs = _kwargs.copy() - new_kwargs.update(kwargs) - - return _cls(_cls._format_parsed_parts(drv, root, parts, **other_kwargs), **new_kwargs) - - url = stringify_path(other) - parsed_url: SplitResult = urlsplit(url) - - if scheme: # allow override of protocol - parsed_url = parsed_url._replace(scheme=scheme) - - if not parsed_url.path: # ensure path has root - parsed_url = parsed_url._replace(path="/") - - if not parsed_url.scheme and not split_protocol(url)[0]: - args_list.insert(0, url) - else: - args_list.insert(0, parsed_url.path) + __slots__ = ("_hash_cached",) + + @classmethod + def _transform_init_args( + cls, + args: tuple[str | os.PathLike, ...], + protocol: str, + storage_options: dict[str, Any], + ) -> tuple[tuple[str | os.PathLike, ...], str, dict[str, Any]]: + """Extract conn_id from the URL and set it as a storage option.""" + if args: + arg0 = args[0] + parsed_url = urlsplit(stringify_path(arg0)) + userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@") + if have_info: + storage_options.setdefault("conn_id", userinfo or None) + parsed_url = parsed_url._replace(netloc=hostinfo) + args = (parsed_url.geturl(),) + args[1:] + protocol = protocol or parsed_url.scheme + return args, protocol, storage_options - # This matches the parsing logic in urllib.parse; see: - # https://github.com/python/cpython/blob/46adf6b701c440e047abf925df9a75a/Lib/urllib/parse.py#L194-L203 - userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@") - if have_info: - conn_id = conn_id or userinfo or None - parsed_url = parsed_url._replace(netloc=hostinfo) + @classmethod + def _parse_storage_options( + cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any] + ) -> dict[str, Any]: + fs = attach(protocol or "file", conn_id=storage_options.get("conn_id")).fs + pth_storage_options = type(fs)._get_kwargs_from_urls(urlpath) + return {**pth_storage_options, **storage_options} - return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, **kwargs) # type: ignore + @classmethod + def _fs_factory( + cls, urlpath: str, protocol: str, storage_options: Mapping[str, Any] + ) -> AbstractFileSystem: + return attach(protocol or "file", storage_options.get("conn_id")).fs - @functools.lru_cache def __hash__(self) -> int: - return hash(str(self)) + self._hash_cached: int + try: + return self._hash_cached + except AttributeError: + self._hash_cached = hash(str(self)) + return self._hash_cached def __eq__(self, other: typing.Any) -> bool: return self.samestore(other) and str(self) == str(other) def samestore(self, other: typing.Any) -> bool: - return isinstance(other, ObjectStoragePath) and self._accessor == other._accessor + return ( + isinstance(other, ObjectStoragePath) + and self.protocol == other.protocol + and self.storage_options.get("conn_id") == other.storage_options.get("conn_id") + ) @property def container(self) -> str: @@ -186,12 +126,17 @@ def key(self) -> str: def namespace(self) -> str: return f"{self.protocol}://{self.bucket}" if self.bucket else self.protocol + def open(self, mode="r", **kwargs): + """Open the file pointed to by this path.""" + kwargs.setdefault("block_size", kwargs.pop("buffering", None)) + return self.fs.open(self.path, mode=mode, **kwargs) + def stat(self) -> stat_result: # type: ignore[override] """Call ``stat`` and return the result.""" return stat_result( - self._accessor.stat(self), + self.fs.stat(self.path), protocol=self.protocol, - conn_id=self._accessor._store.conn_id, + conn_id=self.storage_options.get("conn_id"), ) def samefile(self, other_path: typing.Any) -> bool: @@ -368,7 +313,11 @@ def copy(self, dst: str | ObjectStoragePath, recursive: bool = False, **kwargs) if path == self.path: continue - src_obj = ObjectStoragePath(path, conn_id=self._accessor._store.conn_id) + src_obj = ObjectStoragePath( + path, + protocol=self.protocol, + conn_id=self.storage_options.get("conn_id"), + ) # skip directories, empty directories will not be created if src_obj.is_dir(): @@ -401,7 +350,7 @@ def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs) self.unlink() def serialize(self) -> dict[str, typing.Any]: - _kwargs = self._kwargs.copy() + _kwargs = {**self.storage_options} conn_id = _kwargs.pop("conn_id", None) return { diff --git a/airflow/providers/common/io/xcom/backend.py b/airflow/providers/common/io/xcom/backend.py index 6e995c30e1a24..3028a49be20a9 100644 --- a/airflow/providers/common/io/xcom/backend.py +++ b/airflow/providers/common/io/xcom/backend.py @@ -132,7 +132,7 @@ def serialize_value( if not p.parent.exists(): p.parent.mkdir(parents=True, exist_ok=True) - with p.open("wb", compression=compression) as f: + with p.open(mode="wb", compression=compression) as f: f.write(s_val) return BaseXCom.serialize_value(str(p)) @@ -152,7 +152,7 @@ def deserialize_value( try: p = ObjectStoragePath(path) / XComObjectStoreBackend._get_key(data) - return json.load(p.open("rb", compression="infer"), cls=XComDecoder) + return json.load(p.open(mode="rb", compression="infer"), cls=XComDecoder) except TypeError: return data except ValueError: diff --git a/pyproject.toml b/pyproject.toml index f53c1002a3a88..42265978a7878 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,14 +148,7 @@ dependencies = [ # We should also remove "licenses/LICENSE-unicodecsv.txt" file when we remove this dependency "unicodecsv>=0.14.1", # The Universal Pathlib provides Pathlib-like interface for FSSPEC - # In 0.1. *It was not very well defined for extension, so the way how we use it for 0.1.* - # so we used a lot of private methods and attributes that were not defined in the interface - # an they are broken with version 0.2.0 which is much better suited for extension and supports - # Python 3.12. We should limit it, unti we migrate to 0.2.0 - # See: https://github.com/fsspec/universal_pathlib/pull/173#issuecomment-1937090528 - # This is prerequistite to make Airflow compatible with Python 3.12 - # Tracked in https://github.com/apache/airflow/pull/36755 - "universal-pathlib>=0.1.4,<0.2.0", + "universal-pathlib>=0.2.1", # Werkzug 3 breaks Flask-Login 0.6.2, also connexion needs to be updated to >= 3.0 # we should remove this limitation when FAB supports Flask 2.3 and we migrate connexion to 3+ "werkzeug>=2.0,<3", diff --git a/tests/io/test_path.py b/tests/io/test_path.py index deb8d412cc700..e03b40e0e4cd8 100644 --- a/tests/io/test_path.py +++ b/tests/io/test_path.py @@ -20,11 +20,13 @@ import uuid from stat import S_ISDIR, S_ISREG from tempfile import NamedTemporaryFile +from typing import Any, ClassVar from unittest import mock import pytest from fsspec.implementations.local import LocalFileSystem -from fsspec.utils import stringify_path +from fsspec.implementations.memory import MemoryFileSystem +from fsspec.registry import _registry as _fsspec_registry, register_implementation from airflow.datasets import Dataset from airflow.io import _register_filesystems, get_fs @@ -38,19 +40,46 @@ BAR = FOO -class FakeRemoteFileSystem(LocalFileSystem): - id = "fakefs" - auto_mk_dir = True +class FakeLocalFileSystem(MemoryFileSystem): + protocol = ("file", "local") + root_marker = "/" + store: ClassVar[dict[str, Any]] = {} + pseudo_dirs = [""] - @property - def fsid(self): - return self.id + def __init__(self, *args, **kwargs): + self.conn_id = kwargs.pop("conn_id", None) + super().__init__(*args, **kwargs) @classmethod - def _strip_protocol(cls, path) -> str: - path = stringify_path(path) - i = path.find("://") - return path[i + 3 :] if i > 0 else path + def _strip_protocol(cls, path): + for protocol in cls.protocol: + if path.startswith(f"{protocol}://"): + return path[len(f"{protocol}://") :] + if "::" in path or "://" in path: + return path.rstrip("/") + path = path.lstrip("/").rstrip("/") + return "/" + path if path else "" + + +class FakeRemoteFileSystem(MemoryFileSystem): + protocol = ("s3", "fakefs", "ffs", "ffs2") + root_marker = "" + store: ClassVar[dict[str, Any]] = {} + pseudo_dirs = [""] + + def __init__(self, *args, **kwargs): + self.conn_id = kwargs.pop("conn_id", None) + super().__init__(*args, **kwargs) + + @classmethod + def _strip_protocol(cls, path): + for protocol in cls.protocol: + if path.startswith(f"{protocol}://"): + return path[len(f"{protocol}://") :] + if "::" in path or "://" in path: + return path.rstrip("/") + path = path.lstrip("/").rstrip("/") + return "/" + path if path else "" def get_fs_no_storage_options(_: str): @@ -60,10 +89,15 @@ def get_fs_no_storage_options(_: str): class TestFs: def setup_class(self): self._store_cache = _STORE_CACHE.copy() + self._fsspec_registry = _fsspec_registry.copy() + for protocol in FakeRemoteFileSystem.protocol: + register_implementation(protocol, FakeRemoteFileSystem, clobber=True) def teardown(self): _STORE_CACHE.clear() _STORE_CACHE.update(self._store_cache) + _fsspec_registry.clear() + _fsspec_registry.update(self._fsspec_registry) def test_alias(self): store = attach("file", alias="local") @@ -71,22 +105,24 @@ def test_alias(self): assert "local" in _STORE_CACHE def test_init_objectstoragepath(self): - path = ObjectStoragePath("file://bucket/key/part1/part2") + attach("s3", fs=FakeRemoteFileSystem()) + + path = ObjectStoragePath("s3://bucket/key/part1/part2") assert path.bucket == "bucket" assert path.key == "key/part1/part2" - assert path.protocol == "file" + assert path.protocol == "s3" assert path.path == "bucket/key/part1/part2" path2 = ObjectStoragePath(path / "part3") assert path2.bucket == "bucket" assert path2.key == "key/part1/part2/part3" - assert path2.protocol == "file" + assert path2.protocol == "s3" assert path2.path == "bucket/key/part1/part2/part3" path3 = ObjectStoragePath(path2 / "2023") assert path3.bucket == "bucket" assert path3.key == "key/part1/part2/part3/2023" - assert path3.protocol == "file" + assert path3.protocol == "s3" assert path3.path == "bucket/key/part1/part2/part3/2023" def test_read_write(self): @@ -116,49 +152,57 @@ def test_ls(self): assert not o.exists() - @pytest.fixture() - def fake_fs(self): - fs = mock.Mock() - fs._strip_protocol.return_value = "/" - fs.conn_id = "fake" - return fs - - def test_objectstoragepath_init_conn_id_in_uri(self, fake_fs): - fake_fs.stat.return_value = {"stat": "result"} - attach(protocol="fake", conn_id="fake", fs=fake_fs) + def test_objectstoragepath_init_conn_id_in_uri(self): + attach(protocol="fake", conn_id="fake", fs=FakeRemoteFileSystem(conn_id="fake")) p = ObjectStoragePath("fake://fake@bucket/path") - assert p.stat() == {"stat": "result", "conn_id": "fake", "protocol": "fake"} + p.touch() + fsspec_info = p.fs.info(p.path) + assert p.stat() == {**fsspec_info, "conn_id": "fake", "protocol": "fake"} + + @pytest.fixture + def fake_local_files(self): + obj = FakeLocalFileSystem() + obj.touch(FOO) + try: + yield + finally: + FakeLocalFileSystem.store.clear() + FakeLocalFileSystem.pseudo_dirs[:] = [""] @pytest.mark.parametrize( "fn, args, fn2, path, expected_args, expected_kwargs", [ - ("checksum", {}, "checksum", FOO, FakeRemoteFileSystem._strip_protocol(BAR), {}), - ("size", {}, "size", FOO, FakeRemoteFileSystem._strip_protocol(BAR), {}), + ("checksum", {}, "checksum", FOO, FakeLocalFileSystem._strip_protocol(BAR), {}), + ("size", {}, "size", FOO, FakeLocalFileSystem._strip_protocol(BAR), {}), ( "sign", {"expiration": 200, "extra": "xtra"}, "sign", FOO, - FakeRemoteFileSystem._strip_protocol(BAR), + FakeLocalFileSystem._strip_protocol(BAR), {"expiration": 200, "extra": "xtra"}, ), - ("ukey", {}, "ukey", FOO, FakeRemoteFileSystem._strip_protocol(BAR), {}), + ("ukey", {}, "ukey", FOO, FakeLocalFileSystem._strip_protocol(BAR), {}), ( "read_block", {"offset": 0, "length": 1}, "read_block", FOO, - FakeRemoteFileSystem._strip_protocol(BAR), + FakeLocalFileSystem._strip_protocol(BAR), {"delimiter": None, "length": 1, "offset": 0}, ), ], ) - def test_standard_extended_api(self, fake_fs, fn, args, fn2, path, expected_args, expected_kwargs): - store = attach(protocol="file", conn_id="fake", fs=fake_fs) - o = ObjectStoragePath(path, conn_id="fake") + def test_standard_extended_api( + self, fake_local_files, fn, args, fn2, path, expected_args, expected_kwargs + ): + fs = FakeLocalFileSystem() + with mock.patch.object(fs, fn2) as method: + attach(protocol="file", conn_id="fake", fs=fs) + o = ObjectStoragePath(path, conn_id="fake") - getattr(o, fn)(**args) - getattr(store.fs, fn2).assert_called_once_with(expected_args, **expected_kwargs) + getattr(o, fn)(**args) + method.assert_called_once_with(expected_args, **expected_kwargs) def test_stat(self): with NamedTemporaryFile() as f: @@ -168,6 +212,8 @@ def test_stat(self): assert S_ISDIR(o.parent.stat().st_mode) def test_bucket_key_protocol(self): + attach(protocol="s3", fs=FakeRemoteFileSystem()) + bucket = "bkt" key = "yek" protocol = "s3" @@ -227,24 +273,23 @@ def test_move_remote(self): _to.unlink() def test_copy_remote_remote(self): - # foo = xxx added to prevent same fs token - attach("ffs", fs=FakeRemoteFileSystem(auto_mkdir=True, foo="bar")) - attach("ffs2", fs=FakeRemoteFileSystem(auto_mkdir=True, foo="baz")) + attach("ffs", fs=FakeRemoteFileSystem(skip_instance_cache=True)) + attach("ffs2", fs=FakeRemoteFileSystem(skip_instance_cache=True)) - dir_src = f"/tmp/{str(uuid.uuid4())}" - dir_dst = f"/tmp/{str(uuid.uuid4())}" + dir_src = f"bucket1/{str(uuid.uuid4())}" + dir_dst = f"bucket2/{str(uuid.uuid4())}" key = "foo/bar/baz.txt" - # note we are dealing with object storage characteristics - # while working on a local filesystem, so it might feel not intuitive _from = ObjectStoragePath(f"ffs://{dir_src}") _from_file = _from / key _from_file.touch() + assert _from.bucket == "bucket1" assert _from_file.exists() _to = ObjectStoragePath(f"ffs2://{dir_dst}") _from.copy(_to) + assert _to.bucket == "bucket2" assert _to.exists() assert _to.is_dir() assert (_to / _from.key / key).exists() @@ -254,7 +299,7 @@ def test_copy_remote_remote(self): _to.rmdir(recursive=True) def test_serde_objectstoragepath(self): - path = "file://bucket/key/part1/part2" + path = "file:///bucket/key/part1/part2" o = ObjectStoragePath(path) s = o.serialize() @@ -312,6 +357,8 @@ def test_backwards_compat(self): _register_filesystems.cache_clear() def test_dataset(self): + attach("s3", fs=FakeRemoteFileSystem()) + p = "s3" f = "/tmp/foo" i = Dataset(uri=f"{p}://{f}", extra={"foo": "bar"}) diff --git a/tests/providers/common/io/xcom/test_backend.py b/tests/providers/common/io/xcom/test_backend.py index fce5ed985e21f..0641e18fe04a0 100644 --- a/tests/providers/common/io/xcom/test_backend.py +++ b/tests/providers/common/io/xcom/test_backend.py @@ -181,7 +181,7 @@ def test_value_storage(self, task_instance, session): run_id=task_instance.run_id, session=session, ) - assert self.path in qry.first().value + assert str(p) == qry.first().value @pytest.mark.db_test def test_clear(self, task_instance, session):