Skip to content

Commit

Permalink
Add set_ee_info function to JobQueue
Browse files Browse the repository at this point in the history
  • Loading branch information
berland committed Nov 14, 2023
1 parent 41868ba commit 0c76269
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 54 deletions.
20 changes: 9 additions & 11 deletions src/ert/ensemble_evaluator/_builder/_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,23 +194,21 @@ async def _evaluate_inner( # pylint: disable=too-many-branches
)
]

# Tell queue to pass info to the jobs-file
# NOTE: This touches files on disk...
sema = threading.BoundedSemaphore(value=CONCURRENT_INTERNALIZATION)
self._job_queue.add_dispatch_information_to_jobs_file(
self._job_queue.set_ee_info(
ee_uri=self._config.dispatch_uri,
ens_id=self.id_,
dispatch_url=self._config.dispatch_uri,
cert=self._config.cert,
token=self._config.token,
ee_cert=self._config.cert,
ee_token=self._config.token,
)

# Tell queue to pass info to the jobs-file
# NOTE: This touches files on disk...
self._job_queue.add_dispatch_information_to_jobs_file()

sema = threading.BoundedSemaphore(value=CONCURRENT_INTERNALIZATION)
result: str = await self._job_queue.execute(
self._config.dispatch_uri,
self.id_,
sema,
queue_evaluators,
ee_cert=self._config.cert,
ee_token=self._config.token,
)

except Exception:
Expand Down
60 changes: 30 additions & 30 deletions src/ert/job_queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, queue_config: QueueConfig):
self._max_submit = queue_config.max_submit
self._pool_sema = BoundedSemaphore(value=CONCURRENT_INTERNALIZATION)

self.ens_id: Optional[str] = None
self._ens_id: Optional[str] = None
self._ee_uri: Optional[str] = None
self._ee_cert: Optional[Union[str, bytes]] = None
self._ee_token: Optional[str] = None
Expand Down Expand Up @@ -214,6 +214,27 @@ def launch_jobs(self, pool_sema: Semaphore) -> None:
max_submit=self.max_submit,
)

def set_ee_info(
self,
ee_uri: str,
ens_id: str,
ee_cert: Optional[Union[str, bytes]] = None,
ee_token: Optional[str] = None,
verify_context: bool = True,
) -> None:
self._ens_id = ens_id
self._ee_token = ee_token

self._ee_uri = ee_uri
if ee_cert is not None:
self._ee_cert = ee_cert
self._ee_token = ee_token
self._ee_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
if verify_context:
self._ee_ssl_context.load_verify_locations(cadata=ee_cert)
else:
self._ee_ssl_context = True if ee_uri.startswith("wss") else None

@staticmethod
def _translate_change_to_cloudevent(
ens_id: str, real_id: int, status: str
Expand All @@ -232,10 +253,10 @@ def _translate_change_to_cloudevent(
async def _publish_changes(
self, changes: Dict[int, str], ee_connection: WebSocketClientProtocol
) -> None:
assert self.ens_id is not None # mypy
assert self._ens_id is not None # mypy
events = deque(
[
JobQueue._translate_change_to_cloudevent(self.ens_id, real_id, status)
JobQueue._translate_change_to_cloudevent(self._ens_id, real_id, status)
for real_id, status in changes.items()
]
)
Expand Down Expand Up @@ -274,30 +295,13 @@ async def _jobqueue_publisher(self) -> None:

async def execute(
self,
ee_uri: Optional[str] = None,
ens_id: Optional[str] = None,
pool_sema: Optional[threading.BoundedSemaphore] = None,
evaluators: Optional[Iterable[Callable[..., Any]]] = None,
ee_cert: Optional[Union[str, bytes]] = None,
ee_token: Optional[str] = None,
) -> str:
self.ens_id = ens_id
self._ee_token = ee_token

if pool_sema is not None:
self._pool_sema = pool_sema
if evaluators is None:
evaluators = []
if ee_uri is not None:
self._ee_uri = ee_uri
if ee_cert is not None:
self._ee_cert = ee_cert
self._ee_token = ee_token
self._ee_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self._ee_ssl_context.load_verify_locations(cadata=ee_cert)
else:
self._ee_ssl_context = True if ee_uri.startswith("wss") else None

if self._ee_uri is not None:
self._changes_to_publish = asyncio.Queue()
asyncio.create_task(self._jobqueue_publisher())
Expand Down Expand Up @@ -425,27 +429,23 @@ def changes_without_transition(self) -> Tuple[Dict[int, str], List[JobStatus]]:

def add_dispatch_information_to_jobs_file(
self,
ens_id: str,
dispatch_url: str,
cert: Optional[str],
token: Optional[str],
experiment_id: Optional[str] = None,
) -> None:
for q_index, q_node in enumerate(self.job_list):
cert_path = f"{q_node.run_path}/{CERT_FILE}"
if cert is not None:
if self._ee_cert is not None:
with open(cert_path, "w", encoding="utf-8") as cert_file:
cert_file.write(cert)
cert_file.write(self._ee_cert)

Check failure on line 438 in src/ert/job_queue/queue.py

View workflow job for this annotation

GitHub Actions / type-checking (3.11)

Argument 1 to "write" of "TextIOBase" has incompatible type "str | bytes"; expected "str"
with open(
f"{q_node.run_path}/{JOBS_FILE}", "r+", encoding="utf-8"
) as jobs_file:
data = json.load(jobs_file)

data["ens_id"] = ens_id
data["ens_id"] = self._ens_id
data["real_id"] = self._differ.qindex_to_iens(q_index)
data["dispatch_url"] = dispatch_url
data["ee_token"] = token
data["ee_cert_path"] = cert_path if cert is not None else None
data["dispatch_url"] = self._ee_uri
data["ee_token"] = self._ee_token
data["ee_cert_path"] = cert_path if self._ee_cert is not None else None
data["experiment_id"] = experiment_id

jobs_file.seek(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ async def test_happy_path(
for real in ensemble.reals:
queue.add_realization(real, callback_timeout=None)

await queue.execute(url, "ee_0", threading.BoundedSemaphore(value=10), None)
queue.set_ee_info(ee_uri=url, ens_id="ee_0")
await queue.execute(pool_sema=threading.BoundedSemaphore(value=10))
done.set_result(None)

await mock_ws_task
Expand Down
3 changes: 0 additions & 3 deletions tests/unit_tests/ensemble_evaluator/test_ensemble_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def test_run_and_cancel_legacy_ensemble(tmpdir, make_ensemble_builder, monkeypat
if cancel:
mon.signal_cancel()
cancel = False
print("tracking..")

assert evaluator._ensemble.status == state.ENSEMBLE_STATE_CANCELLED

Expand Down Expand Up @@ -102,9 +101,7 @@ def test_run_legacy_ensemble_exception(tmpdir, make_ensemble_builder, monkeypatc
state.ENSEMBLE_STATE_FAILED,
state.ENSEMBLE_STATE_STOPPED,
]:
print("FAILED OR STOPPED")
monitor.signal_done()
print("retracking")
assert evaluator._ensemble.status == state.ENSEMBLE_STATE_FAILED

# realisations should not finish, thus not creating a status-file
Expand Down
19 changes: 10 additions & 9 deletions tests/unit_tests/job_queue/test_job_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,14 @@ def test_add_dispatch_info(tmpdir, monkeypatch, simple_script):
runpaths = [Path(DUMMY_CONFIG["run_path"].format(iens)) for iens in range(10)]
for runpath in runpaths:
(runpath / "jobs.json").write_text(json.dumps({}), encoding="utf-8")
job_queue.add_dispatch_information_to_jobs_file(
job_queue.set_ee_info(
ee_uri=dispatch_url,
ens_id=ens_id,
dispatch_url=dispatch_url,
cert=cert,
token=token,
ee_cert=cert,
ee_token=token,
verify_context=False,
)
job_queue.add_dispatch_information_to_jobs_file(
experiment_id="experiment_id",
)

Expand All @@ -262,12 +265,10 @@ def test_add_dispatch_info_cert_none(tmpdir, monkeypatch, simple_script):
runpaths = [Path(DUMMY_CONFIG["run_path"].format(iens)) for iens in range(10)]
for runpath in runpaths:
(runpath / "jobs.json").write_text(json.dumps({}), encoding="utf-8")
job_queue.add_dispatch_information_to_jobs_file(
ens_id=ens_id,
dispatch_url=dispatch_url,
cert=cert,
token=token,
job_queue.set_ee_info(
ee_uri=dispatch_url, ens_id=ens_id, ee_cert=cert, ee_token=token
)
job_queue.add_dispatch_information_to_jobs_file()

for runpath in runpaths:
job_file_path = runpath / "jobs.json"
Expand Down

0 comments on commit 0c76269

Please sign in to comment.