Skip to content

Commit

Permalink
Refactor consolidate import from io in providers (apache#34378)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro authored Oct 6, 2023
1 parent 25cd12d commit 8e26865
Show file tree
Hide file tree
Showing 17 changed files with 70 additions and 74 deletions.
12 changes: 4 additions & 8 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import asyncio
import fnmatch
import gzip as gz
import io
import logging
import os
import re
Expand Down Expand Up @@ -1120,10 +1119,8 @@ def load_string(
if compression == "gzip":
bytes_data = gz.compress(bytes_data)

file_obj = io.BytesIO(bytes_data)

self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy)
file_obj.close()
with BytesIO(bytes_data) as f:
self._upload_file_obj(f, key, bucket_name, replace, encrypt, acl_policy)

@unify_bucket_name_and_key
@provide_bucket_name
Expand Down Expand Up @@ -1155,9 +1152,8 @@ def load_bytes(
:param acl_policy: The string to specify the canned ACL policy for the
object to be uploaded
"""
file_obj = io.BytesIO(bytes_data)
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy)
file_obj.close()
with BytesIO(bytes_data) as f:
self._upload_file_obj(f, key, bucket_name, replace, encrypt, acl_policy)

@unify_bucket_name_and_key
@provide_bucket_name
Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/docker/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,11 +453,11 @@ def copy_from_docker(container_id, src):
# 0 byte file, it can't be anything else than None
return None
# no need to port to a file since we intend to deserialize
file_standin = BytesIO(b"".join(archived_result))
tar = tarfile.open(fileobj=file_standin)
file = tar.extractfile(stat["name"])
lib = getattr(self, "pickling_library", pickle)
return lib.loads(file.read())
with BytesIO(b"".join(archived_result)) as f:
tar = tarfile.open(fileobj=f)
file = tar.extractfile(stat["name"])
lib = getattr(self, "pickling_library", pickle)
return lib.load(file)

try:
return copy_from_docker(self.container["Id"], self.retrieve_output_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from __future__ import annotations

import errno
import io
import os
import shutil
import sys
from io import BytesIO
from tempfile import mkdtemp
from unittest import mock

Expand All @@ -39,7 +39,7 @@ def setup_method(self):
self.bucket = "bucket"
self.input_key = "foo"
self.output_key = "bar"
self.bio = io.BytesIO(self.content)
self.bio = BytesIO(self.content)
self.tmp_dir = mkdtemp(prefix="test_tmpS3FileTransform_")
self.transform_script = os.path.join(self.tmp_dir, "transform.py")
os.mknod(self.transform_script)
Expand Down
18 changes: 9 additions & 9 deletions tests/providers/amazon/aws/operators/test_s3_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

import io
from io import BytesIO
from unittest import mock

import boto3
Expand Down Expand Up @@ -49,7 +49,7 @@ def test_s3_copy_object_arg_combination_1(self):
conn = boto3.client("s3")
conn.create_bucket(Bucket=self.source_bucket)
conn.create_bucket(Bucket=self.dest_bucket)
conn.upload_fileobj(Bucket=self.source_bucket, Key=self.source_key, Fileobj=io.BytesIO(b"input"))
conn.upload_fileobj(Bucket=self.source_bucket, Key=self.source_key, Fileobj=BytesIO(b"input"))

# there should be nothing found before S3CopyObjectOperator is executed
assert "Contents" not in conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key)
Expand All @@ -74,7 +74,7 @@ def test_s3_copy_object_arg_combination_2(self):
conn = boto3.client("s3")
conn.create_bucket(Bucket=self.source_bucket)
conn.create_bucket(Bucket=self.dest_bucket)
conn.upload_fileobj(Bucket=self.source_bucket, Key=self.source_key, Fileobj=io.BytesIO(b"input"))
conn.upload_fileobj(Bucket=self.source_bucket, Key=self.source_key, Fileobj=BytesIO(b"input"))

# there should be nothing found before S3CopyObjectOperator is executed
assert "Contents" not in conn.list_objects(Bucket=self.dest_bucket, Prefix=self.dest_key)
Expand Down Expand Up @@ -103,7 +103,7 @@ def test_s3_delete_single_object(self):

conn = boto3.client("s3")
conn.create_bucket(Bucket=bucket)
conn.upload_fileobj(Bucket=bucket, Key=key, Fileobj=io.BytesIO(b"input"))
conn.upload_fileobj(Bucket=bucket, Key=key, Fileobj=BytesIO(b"input"))

# The object should be detected before the DELETE action is taken
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key)
Expand All @@ -125,7 +125,7 @@ def test_s3_delete_multiple_objects(self):
conn = boto3.client("s3")
conn.create_bucket(Bucket=bucket)
for k in keys:
conn.upload_fileobj(Bucket=bucket, Key=k, Fileobj=io.BytesIO(b"input"))
conn.upload_fileobj(Bucket=bucket, Key=k, Fileobj=BytesIO(b"input"))

# The objects should be detected before the DELETE action is taken
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key_pattern)
Expand All @@ -147,7 +147,7 @@ def test_s3_delete_prefix(self):
conn = boto3.client("s3")
conn.create_bucket(Bucket=bucket)
for k in keys:
conn.upload_fileobj(Bucket=bucket, Key=k, Fileobj=io.BytesIO(b"input"))
conn.upload_fileobj(Bucket=bucket, Key=k, Fileobj=BytesIO(b"input"))

# The objects should be detected before the DELETE action is taken
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key_pattern)
Expand All @@ -167,7 +167,7 @@ def test_s3_delete_empty_list(self):

conn = boto3.client("s3")
conn.create_bucket(Bucket=bucket)
conn.upload_fileobj(Bucket=bucket, Key=key_of_test, Fileobj=io.BytesIO(b"input"))
conn.upload_fileobj(Bucket=bucket, Key=key_of_test, Fileobj=BytesIO(b"input"))

# The object should be detected before the DELETE action is tested
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key_of_test)
Expand All @@ -189,7 +189,7 @@ def test_s3_delete_empty_string(self):

conn = boto3.client("s3")
conn.create_bucket(Bucket=bucket)
conn.upload_fileobj(Bucket=bucket, Key=key_of_test, Fileobj=io.BytesIO(b"input"))
conn.upload_fileobj(Bucket=bucket, Key=key_of_test, Fileobj=BytesIO(b"input"))

# The object should be detected before the DELETE action is tested
objects_in_dest_bucket = conn.list_objects(Bucket=bucket, Prefix=key_of_test)
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_validate_keys_and_prefix_in_execute(self, keys, prefix):

conn = boto3.client("s3")
conn.create_bucket(Bucket=bucket)
conn.upload_fileobj(Bucket=bucket, Key=key_of_test, Fileobj=io.BytesIO(b"input"))
conn.upload_fileobj(Bucket=bucket, Key=key_of_test, Fileobj=BytesIO(b"input"))

# Set valid values for constructor, and change them later for emulate rendering template
op = S3DeleteObjectsOperator(
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/system/utils/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
"""
from __future__ import annotations

import io
import os
import sys
from io import StringIO
from unittest.mock import ANY, patch

import pytest
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_fetch_variable_no_value_found_raises_exception(self):
@pytest.mark.parametrize("env_id, is_valid", ENV_ID_TEST_CASES)
def test_validate_env_id_success(self, env_id, is_valid):
if is_valid:
captured_output = io.StringIO()
captured_output = StringIO()
sys.stdout = captured_output

result = _validate_env_id(env_id)
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/utils/test_eks_get_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from __future__ import annotations

import contextlib
import io
import json
import os
import runpy
from io import StringIO
from unittest import mock
from unittest.mock import ANY

Expand Down Expand Up @@ -72,7 +72,7 @@ def test_run(self, mock_eks_hook, args, expected_aws_conn_id, expected_region_na
mock_eks_hook.return_value.fetch_access_token_for_cluster.return_value
) = "k8s-aws-v1.aHR0cDovL2V4YW1wbGUuY29t"

with mock.patch("sys.argv", args), contextlib.redirect_stdout(io.StringIO()) as temp_stdout:
with mock.patch("sys.argv", args), contextlib.redirect_stdout(StringIO()) as temp_stdout:
os.chdir(AIRFLOW_MAIN_FOLDER)
# We are not using run_module because of https://github.com/pytest-dev/pytest/issues/9007
runpy.run_path("airflow/providers/amazon/aws/utils/eks_get_token.py", run_name="__main__")
Expand Down
8 changes: 4 additions & 4 deletions tests/providers/apache/pinot/hooks/test_pinot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
# under the License.
from __future__ import annotations

import io
import os
import subprocess
from io import BytesIO
from unittest import mock

import pytest
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_upload_segment(self, mock_run_cli):
def test_run_cli_success(self, mock_popen):
mock_proc = mock.MagicMock()
mock_proc.returncode = 0
mock_proc.stdout = io.BytesIO(b"")
mock_proc.stdout = BytesIO(b"")
mock_popen.return_value.__enter__.return_value = mock_proc

params = ["foo", "bar", "baz"]
Expand All @@ -173,7 +173,7 @@ def test_run_cli_failure_error_message(self, mock_popen):
msg = b"Exception caught"
mock_proc = mock.MagicMock()
mock_proc.returncode = 0
mock_proc.stdout = io.BytesIO(msg)
mock_proc.stdout = BytesIO(msg)
mock_popen.return_value.__enter__.return_value = mock_proc
params = ["foo", "bar", "baz"]
with pytest.raises(AirflowException):
Expand All @@ -187,7 +187,7 @@ def test_run_cli_failure_error_message(self, mock_popen):
def test_run_cli_failure_status_code(self, mock_popen):
mock_proc = mock.MagicMock()
mock_proc.returncode = 1
mock_proc.stdout = io.BytesIO(b"")
mock_proc.stdout = BytesIO(b"")
mock_popen.return_value.__enter__.return_value = mock_proc

self.db_hook.pinot_admin_system_exit = True
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/apache/spark/hooks/test_spark_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# under the License.
from __future__ import annotations

import io
import itertools
from io import StringIO
from unittest.mock import call, patch

import pytest
Expand Down Expand Up @@ -85,8 +85,8 @@ def test_build_command(self):
@patch("airflow.providers.apache.spark.hooks.spark_sql.subprocess.Popen")
def test_spark_process_runcmd(self, mock_popen):
# Given
mock_popen.return_value.stdout = io.StringIO("Spark-sql communicates using stdout")
mock_popen.return_value.stderr = io.StringIO("stderr")
mock_popen.return_value.stdout = StringIO("Spark-sql communicates using stdout")
mock_popen.return_value.stderr = StringIO("stderr")
mock_popen.return_value.wait.return_value = 0

# When
Expand Down
14 changes: 7 additions & 7 deletions tests/providers/apache/spark/hooks/test_spark_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
# under the License.
from __future__ import annotations

import io
import os
from io import StringIO
from unittest.mock import call, patch

import pytest
Expand Down Expand Up @@ -241,8 +241,8 @@ def test_build_track_driver_status_command(self):
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
def test_spark_process_runcmd(self, mock_popen):
# Given
mock_popen.return_value.stdout = io.StringIO("stdout")
mock_popen.return_value.stderr = io.StringIO("stderr")
mock_popen.return_value.stdout = StringIO("stdout")
mock_popen.return_value.stderr = StringIO("stderr")
mock_popen.return_value.wait.return_value = 0

# When
Expand Down Expand Up @@ -694,8 +694,8 @@ def test_process_spark_driver_status_log_bad_response(self):
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
def test_yarn_process_on_kill(self, mock_popen, mock_renew_from_kt):
# Given
mock_popen.return_value.stdout = io.StringIO("stdout")
mock_popen.return_value.stderr = io.StringIO("stderr")
mock_popen.return_value.stdout = StringIO("stdout")
mock_popen.return_value.stderr = StringIO("stderr")
mock_popen.return_value.poll.return_value = None
mock_popen.return_value.wait.return_value = 0
log_lines = [
Expand Down Expand Up @@ -776,8 +776,8 @@ def test_standalone_cluster_process_on_kill(self):
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
def test_k8s_process_on_kill(self, mock_popen, mock_client_method):
# Given
mock_popen.return_value.stdout = io.StringIO("stdout")
mock_popen.return_value.stderr = io.StringIO("stderr")
mock_popen.return_value.stdout = StringIO("stdout")
mock_popen.return_value.stderr = StringIO("stderr")
mock_popen.return_value.poll.return_value = None
mock_popen.return_value.wait.return_value = 0
client = mock_client_method.return_value
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/cncf/kubernetes/operators/test_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@

import re
from contextlib import contextmanager, nullcontext
from io import BytesIO
from unittest import mock
from unittest.mock import MagicMock, patch

import pendulum
import pytest
from kubernetes.client import ApiClient, V1PodSecurityContext, V1PodStatus, models as k8s
from urllib3 import HTTPResponse
from urllib3.packages.six import BytesIO

from airflow.exceptions import AirflowException, AirflowSkipException, TaskDeferred
from airflow.models import DAG, DagModel, DagRun, TaskInstance
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/elasticsearch/log/test_es_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
# under the License.
from __future__ import annotations

import io
import json
import logging
import os
import re
import shutil
from io import StringIO
from pathlib import Path
from unittest import mock
from urllib.parse import quote
Expand Down Expand Up @@ -602,7 +602,7 @@ def test_supports_external_link(self, frontend, expected):
self.es_task_handler.frontend = frontend
assert self.es_task_handler.supports_external_link == expected

@mock.patch("sys.__stdout__", new_callable=io.StringIO)
@mock.patch("sys.__stdout__", new_callable=StringIO)
def test_dynamic_offset(self, stdout_mock, ti, time_machine):
# arrange
handler = ElasticsearchTaskHandler(
Expand Down
6 changes: 3 additions & 3 deletions tests/providers/ftp/hooks/test_ftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

import io
from io import StringIO
from unittest import mock

from airflow.providers.ftp.hooks import ftp as fh
Expand Down Expand Up @@ -107,14 +107,14 @@ def test_get_size(self):
self.conn_mock.size.assert_called_once_with(path)

def test_retrieve_file(self):
_buffer = io.StringIO("buffer")
_buffer = StringIO("buffer")
with fh.FTPHook() as ftp_hook:
ftp_hook.retrieve_file(self.path, _buffer)
self.conn_mock.retrbinary.assert_called_once_with("RETR path", _buffer.write, 8192)

def test_retrieve_file_with_callback(self):
func = mock.Mock()
_buffer = io.StringIO("buffer")
_buffer = StringIO("buffer")
with fh.FTPHook() as ftp_hook:
ftp_hook.retrieve_file(self.path, _buffer, callback=func)
self.conn_mock.retrbinary.assert_called_once_with("RETR path", func, 8192)
Expand Down
8 changes: 4 additions & 4 deletions tests/providers/google/cloud/hooks/test_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
from __future__ import annotations

import copy
import io
import logging
import os
import re
from datetime import datetime, timedelta
from io import BytesIO
from unittest import mock

import dateutil
Expand Down Expand Up @@ -699,7 +699,7 @@ def test_compose_without_destination_object(self, mock_service):
def test_download_as_bytes(self, mock_service):
test_bucket = "test_bucket"
test_object = "test_object"
test_object_bytes = io.BytesIO(b"input")
test_object_bytes = BytesIO(b"input")

download_method = mock_service.return_value.bucket.return_value.blob.return_value.download_as_bytes
download_method.return_value = test_object_bytes
Expand All @@ -713,7 +713,7 @@ def test_download_as_bytes(self, mock_service):
def test_download_to_file(self, mock_service):
test_bucket = "test_bucket"
test_object = "test_object"
test_object_bytes = io.BytesIO(b"input")
test_object_bytes = BytesIO(b"input")
test_file = "test_file"

download_filename_method = (
Expand All @@ -737,7 +737,7 @@ def test_download_to_file(self, mock_service):
def test_provide_file(self, mock_service, mock_temp_file):
test_bucket = "test_bucket"
test_object = "test_object"
test_object_bytes = io.BytesIO(b"input")
test_object_bytes = BytesIO(b"input")
test_file = "test_file"

download_filename_method = (
Expand Down
Loading

0 comments on commit 8e26865

Please sign in to comment.