Skip to content

Commit

Permalink
Allow storage options to be passed (apache#35820)
Browse files Browse the repository at this point in the history
This allows storage options to be passed to the fsspec backend as part 
of the kwargs to ObjectStorage.

---------

Co-authored-by: Andrey Anshin <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
3 people authored Dec 8, 2023
1 parent a8333b7 commit aba58ad
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 55 deletions.
20 changes: 14 additions & 6 deletions airflow/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,24 @@
if TYPE_CHECKING:
from fsspec import AbstractFileSystem

from airflow.io.typedef import Properties


log = logging.getLogger(__name__)


def _file(_: str | None) -> LocalFileSystem:
return LocalFileSystem()
def _file(_: str | None, storage_options: Properties) -> LocalFileSystem:
return LocalFileSystem(**storage_options)


# builtin supported filesystems
_BUILTIN_SCHEME_TO_FS: dict[str, Callable[[str | None], AbstractFileSystem]] = {
_BUILTIN_SCHEME_TO_FS: dict[str, Callable[[str | None, Properties], AbstractFileSystem]] = {
"file": _file,
}


@cache
def _register_filesystems() -> dict[str, Callable[[str | None], AbstractFileSystem]]:
def _register_filesystems() -> dict[str, Callable[[str | None, Properties], AbstractFileSystem]]:
scheme_to_fs = _BUILTIN_SCHEME_TO_FS.copy()
with Stats.timer("airflow.io.load_filesystems") as timer:
manager = ProvidersManager()
Expand All @@ -65,20 +68,25 @@ def _register_filesystems() -> dict[str, Callable[[str | None], AbstractFileSyst
return scheme_to_fs


def get_fs(scheme: str, conn_id: str | None = None) -> AbstractFileSystem:
def get_fs(
scheme: str, conn_id: str | None = None, storage_options: Properties | None = None
) -> AbstractFileSystem:
"""
Get a filesystem by scheme.
:param scheme: the scheme to get the filesystem for
:return: the filesystem method
:param conn_id: the airflow connection id to use
:param storage_options: the storage options to pass to the filesystem
"""
filesystems = _register_filesystems()
try:
fs = filesystems[scheme]
except KeyError:
raise ValueError(f"No filesystem registered for scheme {scheme}") from None
return fs(conn_id)

options = storage_options or {}
return fs(conn_id, options)


def has_fs(scheme: str) -> bool:
Expand Down
22 changes: 17 additions & 5 deletions airflow/io/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@

PT = typing.TypeVar("PT", bound="ObjectStoragePath")

default = "file"


class _AirflowCloudAccessor(_CloudAccessor):
__slots__ = ("_store",)
Expand Down Expand Up @@ -70,7 +72,7 @@ class ObjectStoragePath(CloudPath):

__version__: typing.ClassVar[int] = 1

_default_accessor: type[_CloudAccessor] = _AirflowCloudAccessor
_default_accessor = _AirflowCloudAccessor

sep: typing.ClassVar[str] = "/"
root_marker: typing.ClassVar[str] = "/"
Expand Down Expand Up @@ -149,7 +151,10 @@ def __new__(

@functools.lru_cache
def __hash__(self) -> int:
return hash(self._bucket)
return hash(str(self))

def __eq__(self, other: typing.Any) -> bool:
return self.samestore(other) and str(self) == str(other)

def samestore(self, other: typing.Any) -> bool:
return isinstance(other, ObjectStoragePath) and self._accessor == other._accessor
Expand Down Expand Up @@ -386,16 +391,23 @@ def move(self, path: str | ObjectStoragePath, recursive: bool = False, **kwargs)
self.copy(path, recursive=recursive, **kwargs)
self.unlink()

def serialize(self) -> dict[str, str]:
def serialize(self) -> dict[str, typing.Any]:
_kwargs = self._kwargs.copy()
conn_id = _kwargs.pop("conn_id", None)

return {
"path": str(self),
**self._kwargs,
"conn_id": conn_id,
"kwargs": _kwargs,
}

@classmethod
def deserialize(cls, data: dict, version: int) -> ObjectStoragePath:
if version > cls.__version__:
raise ValueError(f"Cannot deserialize version {version} with version {cls.__version__}.")

_kwargs = data.pop("kwargs")
path = data.pop("path")
return ObjectStoragePath(path, **data)
conn_id = data.pop("conn_id", None)

return ObjectStoragePath(path, conn_id=conn_id, **_kwargs)
18 changes: 15 additions & 3 deletions airflow/io/store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
if TYPE_CHECKING:
from fsspec import AbstractFileSystem

from airflow.io.typedef import Properties


class ObjectStore:
"""Manages a filesystem or object storage."""
Expand All @@ -33,13 +35,21 @@ class ObjectStore:
method: str
conn_id: str | None
protocol: str
storage_options: Properties | None

_fs: AbstractFileSystem | None = None

def __init__(self, protocol: str, conn_id: str | None, fs: AbstractFileSystem | None = None):
def __init__(
self,
protocol: str,
conn_id: str | None,
fs: AbstractFileSystem | None = None,
storage_options: Properties | None = None,
):
self.conn_id = conn_id
self.protocol = protocol
self._fs = fs
self.storage_options = storage_options

def __str__(self):
return f"{self.protocol}-{self.conn_id}" if self.conn_id else self.protocol
Expand Down Expand Up @@ -69,6 +79,7 @@ def serialize(self):
"protocol": self.protocol,
"conn_id": self.conn_id,
"filesystem": qualname(self._fs) if self._fs else None,
"storage_options": self.storage_options,
}

@classmethod
Expand All @@ -90,7 +101,7 @@ def deserialize(cls, data: dict[str, str], version: int):
f"protocol {data['protocol']}. Please use attach() for this protocol and filesystem."
)

return attach(protocol=protocol, conn_id=conn_id)
return attach(protocol=protocol, conn_id=conn_id, storage_options=data["storage_options"])

def _connect(self) -> AbstractFileSystem:
if self._fs is None:
Expand All @@ -110,6 +121,7 @@ def attach(
alias: str | None = None,
encryption_type: str | None = "",
fs: AbstractFileSystem | None = None,
**kwargs,
) -> ObjectStore:
"""
Attach a filesystem or object storage.
Expand All @@ -134,6 +146,6 @@ def attach(
if store := _STORE_CACHE.get(alias):
return store

_STORE_CACHE[alias] = store = ObjectStore(protocol=protocol, conn_id=conn_id, fs=fs)
_STORE_CACHE[alias] = store = ObjectStore(protocol=protocol, conn_id=conn_id, fs=fs, **kwargs)

return store
21 changes: 21 additions & 0 deletions airflow/io/typedef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import Dict

Properties = Dict[str, str]
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/fs/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class SignError(Exception):
"""Raises when unable to sign a S3 request."""


def get_fs(conn_id: str | None) -> AbstractFileSystem:
def get_fs(conn_id: str | None, storage_options: dict[str, str] | None = None) -> AbstractFileSystem:
try:
from s3fs import S3FileSystem
except ImportError:
Expand All @@ -60,6 +60,8 @@ def get_fs(conn_id: str | None) -> AbstractFileSystem:
endpoint_url = s3_hook.conn_config.get_service_endpoint_url(service_name="s3")

config_kwargs: dict[str, Any] = s3_hook.conn_config.extra_config.get("config_kwargs", {})
config_kwargs.update(storage_options or {})

register_events: dict[str, Callable[[Properties], None]] = {}

s3_service_config = s3_hook.service_config
Expand Down
29 changes: 16 additions & 13 deletions airflow/providers/google/cloud/fs/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
schemes = ["gs", "gcs"]


def get_fs(conn_id: str | None) -> AbstractFileSystem:
def get_fs(conn_id: str | None, storage_options: dict[str, str] | None = None) -> AbstractFileSystem:
# https://gcsfs.readthedocs.io/en/latest/api.html#gcsfs.core.GCSFileSystem
from gcsfs import GCSFileSystem

Expand All @@ -49,15 +49,18 @@ def get_fs(conn_id: str | None) -> AbstractFileSystem:
g = GoogleBaseHook(gcp_conn_id=conn_id)
creds = g.get_credentials()

return GCSFileSystem(
project=g.project_id,
access=g.extras.get(GCS_ACCESS, "full_control"),
token=creds.token,
consistency=g.extras.get(GCS_CONSISTENCY, "none"),
cache_timeout=g.extras.get(GCS_CACHE_TIMEOUT),
requester_pays=g.extras.get(GCS_REQUESTER_PAYS, False),
session_kwargs=g.extras.get(GCS_SESSION_KWARGS, {}),
endpoint_url=g.extras.get(GCS_ENDPOINT),
default_location=g.extras.get(GCS_DEFAULT_LOCATION),
version_aware=g.extras.get(GCS_VERSION_AWARE, "false").lower() == "true",
)
options = {
"project": g.project_id,
"access": g.extras.get(GCS_ACCESS, "full_control"),
"token": creds.token,
"consistency": g.extras.get(GCS_CONSISTENCY, "none"),
"cache_timeout": g.extras.get(GCS_CACHE_TIMEOUT),
"requester_pays": g.extras.get(GCS_REQUESTER_PAYS, False),
"session_kwargs": g.extras.get(GCS_SESSION_KWARGS, {}),
"endpoint_url": g.extras.get(GCS_ENDPOINT),
"default_location": g.extras.get(GCS_DEFAULT_LOCATION),
"version_aware": g.extras.get(GCS_VERSION_AWARE, "false").lower() == "true",
}
options.update(storage_options or {})

return GCSFileSystem(**options)
35 changes: 12 additions & 23 deletions airflow/providers/microsoft/azure/fs/adls.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from airflow.hooks.base import BaseHook
from airflow.providers.microsoft.azure.utils import get_field
Expand All @@ -27,7 +27,7 @@
schemes = ["abfs", "abfss", "adl"]


def get_fs(conn_id: str | None) -> AbstractFileSystem:
def get_fs(conn_id: str | None, storage_options: dict[str, Any] | None = None) -> AbstractFileSystem:
from adlfs import AzureBlobFileSystem

if conn_id is None:
Expand All @@ -36,24 +36,13 @@ def get_fs(conn_id: str | None) -> AbstractFileSystem:
conn = BaseHook.get_connection(conn_id)
extras = conn.extra_dejson

connection_string = get_field(
conn_id=conn_id, conn_type="azure_data_lake", extras=extras, field_name="connection_string"
)
account_name = get_field(
conn_id=conn_id, conn_type="azure_data_lake", extras=extras, field_name="account_name"
)
account_key = get_field(
conn_id=conn_id, conn_type="azure_data_lake", extras=extras, field_name="account_key"
)
sas_token = get_field(conn_id=conn_id, conn_type="azure_data_lake", extras=extras, field_name="sas_token")
tenant = get_field(conn_id=conn_id, conn_type="azure_data_lake", extras=extras, field_name="tenant")

return AzureBlobFileSystem(
connection_string=connection_string,
account_name=account_name,
account_key=account_key,
sas_token=sas_token,
tenant_id=tenant,
client_id=conn.login,
client_secret=conn.password,
)
options = {}
fields = ["connection_string", "account_name", "account_key", "sas_token", "tenant"]
for field in fields:
options[field] = get_field(
conn_id=conn_id, conn_type="azure_data_lake", extras=extras, field_name=field
)

options.update(storage_options or {})

return AzureBlobFileSystem(**options)
9 changes: 7 additions & 2 deletions tests/io/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,17 @@ def test_serde_objectstoragepath(self):

o = ObjectStoragePath(path, my_setting="foo")
s = o.serialize()
assert s["my_setting"] == "foo"
assert "my_setting" in s["kwargs"]
d = ObjectStoragePath.deserialize(s, 1)
assert o == d

store = attach("filex", conn_id="mock")
o = ObjectStoragePath(path, store=store)
s = o.serialize()
assert s["store"] == store
assert s["kwargs"]["store"] == store

d = ObjectStoragePath.deserialize(s, 1)
assert o == d

def test_serde_store(self):
store = attach("file", conn_id="mock")
Expand Down
9 changes: 7 additions & 2 deletions tests/providers/amazon/aws/fs/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import os
from typing import cast
from unittest.mock import patch

import pytest
Expand All @@ -36,11 +37,15 @@

class TestFilesystem:
def test_get_s3fs(self):
import s3fs

from airflow.providers.amazon.aws.fs.s3 import get_fs

fs = get_fs(conn_id=TEST_CONN)
fs = get_fs(conn_id=TEST_CONN, storage_options={"key": "value"})
fs = cast(s3fs.S3FileSystem, fs)

assert "s3" in fs.protocol
assert fs.config_kwargs["key"] == "value"

@patch("s3fs.S3FileSystem", autospec=True)
def test_get_s3fs_anonymous(self, s3fs, monkeypatch):
Expand All @@ -51,7 +56,7 @@ def test_get_s3fs_anonymous(self, s3fs, monkeypatch):
if env_name.startswith("AWS"):
monkeypatch.delenv(env_name, raising=False)

get_fs(conn_id=None)
get_fs(conn_id=None, storage_options=None)

assert s3fs.call_args.kwargs["anon"] is True

Expand Down

0 comments on commit aba58ad

Please sign in to comment.