From 3390bfbf98c4ea4324ebfc16bd04e84e66daf73f Mon Sep 17 00:00:00 2001 From: Jens Scheffler <95105677+jscheffl@users.noreply.github.com> Date: Tue, 24 Sep 2024 22:49:36 +0200 Subject: [PATCH] AIP-69: Add CLI to Edge Provider (#42050) * Add CLI to Edge Provider * Review feedback --- airflow/providers/edge/cli/__init__.py | 16 + airflow/providers/edge/cli/edge_command.py | 313 ++++++++++++++++++ tests/providers/edge/cli/__init__.py | 17 + tests/providers/edge/cli/test_edge_command.py | 259 +++++++++++++++ .../providers/edge/models/test_edge_worker.py | 29 ++ 5 files changed, 634 insertions(+) create mode 100644 airflow/providers/edge/cli/__init__.py create mode 100644 airflow/providers/edge/cli/edge_command.py create mode 100644 tests/providers/edge/cli/__init__.py create mode 100644 tests/providers/edge/cli/test_edge_command.py diff --git a/airflow/providers/edge/cli/__init__.py b/airflow/providers/edge/cli/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/airflow/providers/edge/cli/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/airflow/providers/edge/cli/edge_command.py b/airflow/providers/edge/cli/edge_command.py new file mode 100644 index 0000000000000..09998ffe80281 --- /dev/null +++ b/airflow/providers/edge/cli/edge_command.py @@ -0,0 +1,313 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging +import os +import platform +import signal +import sys +from dataclasses import dataclass +from datetime import datetime +from pathlib import Path +from subprocess import Popen +from time import sleep + +import psutil +from lockfile.pidlockfile import read_pid_from_pidfile, remove_existing_pidfile, write_pid_to_pidfile + +from airflow import __version__ as airflow_version, settings +from airflow.api_internal.internal_api_call import InternalApiConfig +from airflow.cli.cli_config import ARG_PID, ARG_VERBOSE, ActionCommand, Arg +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.providers.edge import __version__ as edge_provider_version +from airflow.providers.edge.models.edge_job import EdgeJob +from airflow.providers.edge.models.edge_logs import EdgeLogs +from airflow.providers.edge.models.edge_worker import EdgeWorker, EdgeWorkerState +from airflow.utils import cli as cli_utils +from airflow.utils.platform import IS_WINDOWS +from airflow.utils.providers_configuration_loader import providers_configuration_loaded +from airflow.utils.state import TaskInstanceState + +logger = logging.getLogger(__name__) +EDGE_WORKER_PROCESS_NAME = "edge-worker" +EDGE_WORKER_HEADER = "\n".join( + [ + r" ____ __ _ __ __", + r" / __/__/ /__ ____ | | /| / /__ ____/ /_____ ____", + r" / _// _ / _ `/ -_) | |/ |/ / _ \/ __/ '_/ -_) __/", + r"/___/\_,_/\_, /\__/ |__/|__/\___/_/ /_/\_\\__/_/", + r" /___/", + r"", + ] +) + + +@providers_configuration_loaded +def force_use_internal_api_on_edge_worker(): + """ + Ensure that the environment is configured for the internal API without needing to declare it outside. + + This is only required for an Edge worker and must to be done before the Click CLI wrapper is initiated. + That is because the CLI wrapper will attempt to establish a DB connection, which will fail before the + function call can take effect. In an Edge worker, we need to "patch" the environment before starting. + """ + if "airflow" in sys.argv[0] and sys.argv[1:3] == ["edge", "worker"]: + api_url = conf.get("edge", "api_url") + if not api_url: + raise SystemExit("Error: API URL is not configured, please correct configuration.") + logger.info("Starting worker with API endpoint %s", api_url) + # export Edge API to be used for internal API + os.environ["AIRFLOW_ENABLE_AIP_44"] = "True" + os.environ["AIRFLOW__CORE__INTERNAL_API_URL"] = api_url + InternalApiConfig.set_use_internal_api("edge-worker") + # Disable mini-scheduler post task execution and leave next task schedule to core scheduler + os.environ["AIRFLOW__SCHEDULER__SCHEDULE_AFTER_TASK_EXECUTION"] = "False" + + +force_use_internal_api_on_edge_worker() + + +def _hostname() -> str: + if IS_WINDOWS: + return platform.uname().node + else: + return os.uname()[1] + + +def _get_sysinfo() -> dict: + """Produce the sysinfo from worker to post to central site.""" + return { + "airflow_version": airflow_version, + "edge_provider_version": edge_provider_version, + } + + +def _pid_file_path(pid_file: str | None) -> str: + return cli_utils.setup_locations(process=EDGE_WORKER_PROCESS_NAME, pid=pid_file)[0] + + +@dataclass +class _Job: + """Holds all information for a task/job to be executed as bundle.""" + + edge_job: EdgeJob + process: Popen + logfile: Path + logsize: int + """Last size of log file, point of last chunk push.""" + + +class _EdgeWorkerCli: + """Runner instance which executes the Edge Worker.""" + + jobs: list[_Job] = [] + """List of jobs that the worker is running currently.""" + last_hb: datetime | None = None + """Timestamp of last heart beat sent to server.""" + drain: bool = False + """Flag if job processing should be completed and no new jobs fetched for a graceful stop/shutdown.""" + + def __init__( + self, + pid_file_path: Path, + hostname: str, + queues: list[str] | None, + concurrency: int, + job_poll_interval: int, + heartbeat_interval: int, + ): + self.pid_file_path = pid_file_path + self.job_poll_interval = job_poll_interval + self.hb_interval = heartbeat_interval + self.hostname = hostname + self.queues = queues + self.concurrency = concurrency + + @staticmethod + def signal_handler(sig, frame): + logger.info("Request to show down Edge Worker received, waiting for jobs to complete.") + _EdgeWorkerCli.drain = True + + def start(self): + """Start the execution in a loop until terminated.""" + try: + self.last_hb = EdgeWorker.register_worker( + self.hostname, EdgeWorkerState.STARTING, self.queues, _get_sysinfo() + ).last_update + except AirflowException as e: + if "404:NOT FOUND" in str(e): + raise SystemExit("Error: API endpoint is not ready, please set [edge] api_enabled=True.") + raise SystemExit(str(e)) + write_pid_to_pidfile(self.pid_file_path) + signal.signal(signal.SIGINT, _EdgeWorkerCli.signal_handler) + try: + while not _EdgeWorkerCli.drain or self.jobs: + self.loop() + + logger.info("Quitting worker, signal being offline.") + EdgeWorker.set_state(self.hostname, EdgeWorkerState.OFFLINE, 0, _get_sysinfo()) + finally: + remove_existing_pidfile(self.pid_file_path) + + def loop(self): + """Run a loop of scheduling and monitoring tasks.""" + new_job = False + if not _EdgeWorkerCli.drain and len(self.jobs) < self.concurrency: + new_job = self.fetch_job() + self.check_running_jobs() + + if _EdgeWorkerCli.drain or datetime.now().timestamp() - self.last_hb.timestamp() > self.hb_interval: + self.heartbeat() + self.last_hb = datetime.now() + + if not new_job: + self.interruptible_sleep() + + def fetch_job(self) -> bool: + """Fetch and start a new job from central site.""" + logger.debug("Attempting to fetch a new job...") + edge_job = EdgeJob.reserve_task(self.hostname, self.queues) + if edge_job: + logger.info("Received job: %s", edge_job) + env = os.environ.copy() + env["AIRFLOW__CORE__DATABASE_ACCESS_ISOLATION"] = "True" + env["AIRFLOW__CORE__INTERNAL_API_URL"] = conf.get("edge", "api_url") + env["_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK"] = "1" + process = Popen(edge_job.command, close_fds=True, env=env) + logfile = EdgeLogs.logfile_path(edge_job.key) + self.jobs.append(_Job(edge_job, process, logfile, 0)) + EdgeJob.set_state(edge_job.key, TaskInstanceState.RUNNING) + return True + + logger.info("No new job to process%s", f", {len(self.jobs)} still running" if self.jobs else "") + return False + + def check_running_jobs(self) -> None: + """Check which of the running tasks/jobs are completed and report back.""" + for i in range(len(self.jobs) - 1, -1, -1): + job = self.jobs[i] + job.process.poll() + if job.process.returncode is not None: + self.jobs.remove(job) + if job.process.returncode == 0: + logger.info("Job completed: %s", job.edge_job) + EdgeJob.set_state(job.edge_job.key, TaskInstanceState.SUCCESS) + else: + logger.error("Job failed: %s", job.edge_job) + EdgeJob.set_state(job.edge_job.key, TaskInstanceState.FAILED) + if job.logfile.exists() and job.logfile.stat().st_size > job.logsize: + with job.logfile.open("r") as logfile: + logfile.seek(job.logsize, os.SEEK_SET) + logdata = logfile.read() + EdgeLogs.push_logs( + task=job.edge_job.key, + log_chunk_time=datetime.now(), + log_chunk_data=logdata, + ) + job.logsize += len(logdata) + + def heartbeat(self) -> None: + """Report liveness state of worker to central site with stats.""" + state = ( + (EdgeWorkerState.TERMINATING if _EdgeWorkerCli.drain else EdgeWorkerState.RUNNING) + if self.jobs + else EdgeWorkerState.IDLE + ) + sysinfo = _get_sysinfo() + EdgeWorker.set_state(self.hostname, state, len(self.jobs), sysinfo) + + def interruptible_sleep(self): + """Sleeps but stops sleeping if drain is made.""" + drain_before_sleep = _EdgeWorkerCli.drain + for _ in range(0, self.job_poll_interval * 10): + sleep(0.1) + if drain_before_sleep != _EdgeWorkerCli.drain: + return + + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def worker(args): + """Start Airflow Edge Worker.""" + print(settings.HEADER) + print(EDGE_WORKER_HEADER) + + edge_worker = _EdgeWorkerCli( + pid_file_path=_pid_file_path(args.pid), + hostname=args.edge_hostname or _hostname(), + queues=args.queues.split(",") if args.queues else None, + concurrency=args.concurrency, + job_poll_interval=conf.getint("edge", "job_poll_interval"), + heartbeat_interval=conf.getint("edge", "heartbeat_interval"), + ) + edge_worker.start() + + +@cli_utils.action_cli(check_db=False) +@providers_configuration_loaded +def stop(args): + """Stop a running Airflow Edge Worker.""" + pid = read_pid_from_pidfile(_pid_file_path(args.pid)) + # Send SIGINT + if pid: + logger.warning("Sending SIGINT to worker pid %i.", pid) + worker_process = psutil.Process(pid) + worker_process.send_signal(signal.SIGINT) + else: + logger.warning("Could not find PID of worker.") + + +ARG_CONCURRENCY = Arg( + ("-c", "--concurrency"), + type=int, + help="The number of worker processes", + default=conf.getint("edge", "worker_concurrency", fallback=8), +) +ARG_QUEUES = Arg( + ("-q", "--queues"), + help="Comma delimited list of queues to serve, serve all queues if not provided.", +) +ARG_EDGE_HOSTNAME = Arg( + ("-H", "--edge-hostname"), + help="Set the hostname of worker if you have multiple workers on a single machine", +) +EDGE_COMMANDS: list[ActionCommand] = [ + ActionCommand( + name=worker.__name__, + help=worker.__doc__, + func=worker, + args=( + ARG_CONCURRENCY, + ARG_QUEUES, + ARG_EDGE_HOSTNAME, + ARG_PID, + ARG_VERBOSE, + ), + ), + ActionCommand( + name=stop.__name__, + help=stop.__doc__, + func=stop, + args=( + ARG_PID, + ARG_VERBOSE, + ), + ), +] diff --git a/tests/providers/edge/cli/__init__.py b/tests/providers/edge/cli/__init__.py new file mode 100644 index 0000000000000..217e5db960782 --- /dev/null +++ b/tests/providers/edge/cli/__init__.py @@ -0,0 +1,17 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/providers/edge/cli/test_edge_command.py b/tests/providers/edge/cli/test_edge_command.py new file mode 100644 index 0000000000000..398c221db02f9 --- /dev/null +++ b/tests/providers/edge/cli/test_edge_command.py @@ -0,0 +1,259 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime +from pathlib import Path +from subprocess import Popen +from unittest.mock import patch + +import pytest +import time_machine + +from airflow.exceptions import AirflowException +from airflow.providers.edge.cli.edge_command import ( + _EdgeWorkerCli, + _get_sysinfo, + _Job, +) +from airflow.providers.edge.models.edge_job import EdgeJob +from airflow.providers.edge.models.edge_worker import EdgeWorker, EdgeWorkerState +from airflow.utils.state import TaskInstanceState +from tests.test_utils.config import conf_vars + +pytest.importorskip("pydantic", minversion="2.0.0") + +# Ignore the following error for mocking +# mypy: disable-error-code="attr-defined" + + +def test_get_sysinfo(): + sysinfo = _get_sysinfo() + assert "airflow_version" in sysinfo + assert "edge_provider_version" in sysinfo + + +class TestEdgeWorkerCli: + @pytest.fixture + def dummy_joblist(self, tmp_path: Path) -> list[_Job]: + logfile = tmp_path / "file.log" + logfile.touch() + + class MockPopen(Popen): + generated_returncode = None + + def __init__(self): + pass + + def poll(self): + pass + + @property + def returncode(self): + return self.generated_returncode + + return [ + _Job( + edge_job=EdgeJob( + dag_id="test", + task_id="test1", + run_id="test", + map_index=-1, + try_number=1, + state=TaskInstanceState.RUNNING, + queue="test", + command=["test", "command"], + queued_dttm=datetime.now(), + edge_worker=None, + last_update=None, + ), + process=MockPopen(), + logfile=logfile, + logsize=0, + ), + ] + + @pytest.fixture + def worker_with_job(self, tmp_path: Path, dummy_joblist: list[_Job]) -> _EdgeWorkerCli: + test_worker = _EdgeWorkerCli(tmp_path / "dummy.pid", "dummy", None, 8, 5, 5) + test_worker.jobs = dummy_joblist + return test_worker + + @pytest.mark.parametrize( + "reserve_result, fetch_result, expected_calls", + [ + pytest.param(None, False, (0, 0), id="no_job"), + pytest.param( + EdgeJob( + dag_id="test", + task_id="test", + run_id="test", + map_index=-1, + try_number=1, + state=TaskInstanceState.QUEUED, + queue="test", + command=["test", "command"], + queued_dttm=datetime.now(), + edge_worker=None, + last_update=None, + ), + True, + (1, 1), + id="new_job", + ), + ], + ) + @patch("airflow.providers.edge.models.edge_job.EdgeJob.reserve_task") + @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.logfile_path") + @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state") + @patch("subprocess.Popen") + def test_fetch_job( + self, + mock_popen, + mock_set_state, + mock_logfile_path, + mock_reserve_task, + reserve_result, + fetch_result, + expected_calls, + worker_with_job: _EdgeWorkerCli, + ): + logfile_path_call_count, set_state_call_count = expected_calls + mock_reserve_task.side_effect = [reserve_result] + mock_popen.side_effect = ["dummy"] + with conf_vars({("edge", "api_url"): "https://mock.server"}): + got_job = worker_with_job.fetch_job() + mock_reserve_task.assert_called_once() + assert got_job == fetch_result + assert mock_logfile_path.call_count == logfile_path_call_count + assert mock_set_state.call_count == set_state_call_count + + def test_check_running_jobs_running(self, worker_with_job: _EdgeWorkerCli): + worker_with_job.jobs[0].process.generated_returncode = None + with conf_vars({("edge", "api_url"): "https://mock.server"}): + worker_with_job.check_running_jobs() + assert len(worker_with_job.jobs) == 1 + + @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state") + def test_check_running_jobs_success(self, mock_set_state, worker_with_job: _EdgeWorkerCli): + job = worker_with_job.jobs[0] + job.process.generated_returncode = 0 + with conf_vars({("edge", "api_url"): "https://mock.server"}): + worker_with_job.check_running_jobs() + assert len(worker_with_job.jobs) == 0 + mock_set_state.assert_called_once_with(job.edge_job.key, TaskInstanceState.SUCCESS) + + @patch("airflow.providers.edge.models.edge_job.EdgeJob.set_state") + def test_check_running_jobs_failed(self, mock_set_state, worker_with_job: _EdgeWorkerCli): + job = worker_with_job.jobs[0] + job.process.generated_returncode = 42 + with conf_vars({("edge", "api_url"): "https://mock.server"}): + worker_with_job.check_running_jobs() + assert len(worker_with_job.jobs) == 0 + mock_set_state.assert_called_once_with(job.edge_job.key, TaskInstanceState.FAILED) + + @time_machine.travel(datetime.now(), tick=False) + @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs") + def test_check_running_jobs_log_push(self, mock_push_logs, worker_with_job: _EdgeWorkerCli): + job = worker_with_job.jobs[0] + job.process.generated_returncode = None + job.logfile.write_text("some log content") + with conf_vars({("edge", "api_url"): "https://mock.server"}): + worker_with_job.check_running_jobs() + assert len(worker_with_job.jobs) == 1 + mock_push_logs.assert_called_once_with( + task=job.edge_job.key, log_chunk_time=datetime.now(), log_chunk_data="some log content" + ) + + @time_machine.travel(datetime.now(), tick=False) + @patch("airflow.providers.edge.models.edge_logs.EdgeLogs.push_logs") + def test_check_running_jobs_log_push_increment(self, mock_push_logs, worker_with_job: _EdgeWorkerCli): + job = worker_with_job.jobs[0] + job.process.generated_returncode = None + job.logfile.write_text("hello ") + job.logsize = job.logfile.stat().st_size + job.logfile.write_text("hello world") + with conf_vars({("edge", "api_url"): "https://mock.server"}): + worker_with_job.check_running_jobs() + assert len(worker_with_job.jobs) == 1 + mock_push_logs.assert_called_once_with( + task=job.edge_job.key, log_chunk_time=datetime.now(), log_chunk_data="world" + ) + + @pytest.mark.parametrize( + "drain, jobs, expected_state", + [ + pytest.param(False, True, EdgeWorkerState.RUNNING, id="running_jobs"), + pytest.param(True, True, EdgeWorkerState.TERMINATING, id="shutting_down"), + pytest.param(False, False, EdgeWorkerState.IDLE, id="idle"), + ], + ) + @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state") + def test_heartbeat(self, mock_set_state, drain, jobs, expected_state, worker_with_job: _EdgeWorkerCli): + if not jobs: + worker_with_job.jobs = [] + _EdgeWorkerCli.drain = drain + with conf_vars({("edge", "api_url"): "https://mock.server"}): + worker_with_job.heartbeat() + assert mock_set_state.call_args.args[1] == expected_state + + @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker") + def test_start_missing_apiserver(self, mock_register_worker, worker_with_job: _EdgeWorkerCli): + mock_register_worker.side_effect = AirflowException( + "Something with 404:NOT FOUND means API is not active" + ) + with pytest.raises(SystemExit, match=r"API endpoint is not ready"): + worker_with_job.start() + + @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker") + def test_start_server_error(self, mock_register_worker, worker_with_job: _EdgeWorkerCli): + mock_register_worker.side_effect = AirflowException("Something other error not FourhundretFour") + with pytest.raises(SystemExit, match=r"Something other"): + worker_with_job.start() + + @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.register_worker") + @patch("airflow.providers.edge.cli.edge_command._EdgeWorkerCli.loop") + @patch("airflow.providers.edge.models.edge_worker.EdgeWorker.set_state") + def test_start_and_run_one( + self, mock_set_state, mock_loop, mock_register_worker, worker_with_job: _EdgeWorkerCli + ): + mock_register_worker.side_effect = [ + EdgeWorker( + worker_name="test", + state=EdgeWorkerState.STARTING, + queues=None, + first_online=datetime.now(), + last_update=datetime.now(), + jobs_active=0, + jobs_taken=0, + jobs_success=0, + jobs_failed=0, + sysinfo="", + ) + ] + + def stop_running(): + _EdgeWorkerCli.drain = True + worker_with_job.jobs = [] + + mock_loop.side_effect = stop_running + + worker_with_job.start() + + mock_register_worker.assert_called_once() + mock_loop.assert_called_once() + mock_set_state.assert_called_once() diff --git a/tests/providers/edge/models/test_edge_worker.py b/tests/providers/edge/models/test_edge_worker.py index 9eca293bafe3f..f0e0ac9dfa056 100644 --- a/tests/providers/edge/models/test_edge_worker.py +++ b/tests/providers/edge/models/test_edge_worker.py @@ -20,11 +20,14 @@ import pytest +from airflow.providers.edge.cli.edge_command import _get_sysinfo from airflow.providers.edge.models.edge_worker import ( EdgeWorker, EdgeWorkerModel, + EdgeWorkerState, EdgeWorkerVersionException, ) +from airflow.utils import timezone if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -63,3 +66,29 @@ def test_assert_version(self): EdgeWorker.assert_version( {"airflow_version": airflow_version, "edge_provider_version": edge_provider_version} ) + + def test_register_worker(self, session: Session): + EdgeWorker.register_worker( + "test_worker", EdgeWorkerState.STARTING, queues=None, sysinfo=_get_sysinfo() + ) + + worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all() + assert len(worker) == 1 + assert worker[0].worker_name == "test_worker" + + def test_set_state(self, session: Session): + rwm = EdgeWorkerModel( + worker_name="test2_worker", + state=EdgeWorkerState.IDLE, + queues=["default"], + first_online=timezone.utcnow(), + ) + session.add(rwm) + session.commit() + + EdgeWorker.set_state("test2_worker", EdgeWorkerState.RUNNING, 1, _get_sysinfo()) + + worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all() + assert len(worker) == 1 + assert worker[0].worker_name == "test2_worker" + assert worker[0].state == EdgeWorkerState.RUNNING