From 67f7964cdd60ffe15965b2ac89d310b233178c16 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 25 Nov 2024 12:35:29 -0700 Subject: [PATCH] Add zstd compression to set_task_result in compute service - Update env files to include zstandard - Update set_task_result in compute api and client to handle base64 encoded data. Rather than JSON serialize the ProtocolDAGResult (PDR) and use this is a the intermediate format, instead: 1) create a keyed chain representation of the PDR 2) JSON serialize this representation 3) compress the utf-8 encoded bytes with zstandard 4) encode with base64 - Use the above base64 encoded data as the intermediate format and reverse the operations above to recover the PDR. --- alchemiscale/compute/api.py | 14 +++++++++++--- alchemiscale/compute/client.py | 19 +++++++++++++++---- devtools/conda-envs/alchemiscale-client.yml | 1 + devtools/conda-envs/alchemiscale-compute.yml | 1 + devtools/conda-envs/alchemiscale-server.yml | 1 + devtools/conda-envs/test.yml | 1 + 6 files changed, 30 insertions(+), 7 deletions(-) diff --git a/alchemiscale/compute/api.py b/alchemiscale/compute/api.py index 9337055b..a75854a2 100644 --- a/alchemiscale/compute/api.py +++ b/alchemiscale/compute/api.py @@ -9,10 +9,12 @@ import json from datetime import datetime, timedelta import random +import base64 from fastapi import FastAPI, APIRouter, Body, Depends from fastapi.middleware.gzip import GZipMiddleware -from gufe.tokenization import GufeTokenizable, JSON_HANDLER +from gufe.tokenization import GufeTokenizable, JSON_HANDLER, KeyedChain +import zstandard as zstd from ..base.api import ( QueryGUFEHandler, @@ -328,8 +330,14 @@ def set_task_result( task_sk = ScopedKey.from_str(task_scoped_key) validate_scopes(task_sk.scope, token) - pdr = json.loads(protocoldagresult, cls=JSON_HANDLER.decoder) - pdr = GufeTokenizable.from_dict(pdr) + # decode b64 and decompress the zstd bytes back into json + protocoldagresult = base64.b64decode(protocoldagresult) + decompressor = zstd.ZstdDecompressor() + protocoldagresult = decompressor.decompress(protocoldagresult) + + pdr_keyed_chain_rep = json.loads(protocoldagresult, cls=JSON_HANDLER.decoder) + pdr_keyed_chain = KeyedChain.from_keyed_chain_rep(pdr_keyed_chain_rep) + pdr = pdr_keyed_chain.to_gufe() tf_sk, _ = n4js.get_task_transformation( task=task_scoped_key, diff --git a/alchemiscale/compute/client.py b/alchemiscale/compute/client.py index 901a7516..b703459b 100644 --- a/alchemiscale/compute/client.py +++ b/alchemiscale/compute/client.py @@ -9,11 +9,14 @@ import json from urllib.parse import urljoin from functools import wraps +import base64 import requests from requests.auth import HTTPBasicAuth -from gufe.tokenization import GufeTokenizable, JSON_HANDLER +import zstandard as zstd + +from gufe.tokenization import GufeTokenizable, JSON_HANDLER, KeyedChain from gufe import Transformation from gufe.protocols import ProtocolDAGResult @@ -128,10 +131,18 @@ def set_task_result( protocoldagresult: ProtocolDAGResult, compute_service_id=Optional[ComputeServiceID], ) -> ScopedKey: + + keyed_chain_rep = KeyedChain.from_gufe(protocoldagresult).to_keyed_chain_rep() + json_rep = json.dumps(keyed_chain_rep, cls=JSON_HANDLER.encoder) + json_bytes = json_rep.encode("utf-8") + + compressor = zstd.ZstdCompressor() + compressed = compressor.compress(json_bytes) + + base64_encoded = base64.b64encode(compressed).decode("utf-8") + data = dict( - protocoldagresult=json.dumps( - protocoldagresult.to_dict(), cls=JSON_HANDLER.encoder - ), + protocoldagresult=base64_encoded, compute_service_id=str(compute_service_id), ) diff --git a/devtools/conda-envs/alchemiscale-client.yml b/devtools/conda-envs/alchemiscale-client.yml index 6f2ae9be..81cfd63f 100644 --- a/devtools/conda-envs/alchemiscale-client.yml +++ b/devtools/conda-envs/alchemiscale-client.yml @@ -15,6 +15,7 @@ dependencies: - httpx - pydantic<2.0 - async-lru + - zstandard ## user client - rich diff --git a/devtools/conda-envs/alchemiscale-compute.yml b/devtools/conda-envs/alchemiscale-compute.yml index f93cd1f3..cd39cce2 100644 --- a/devtools/conda-envs/alchemiscale-compute.yml +++ b/devtools/conda-envs/alchemiscale-compute.yml @@ -15,6 +15,7 @@ dependencies: - httpx - pydantic<2.0 - async-lru + - zstandard # openmm protocols - feflow=0.1.0 diff --git a/devtools/conda-envs/alchemiscale-server.yml b/devtools/conda-envs/alchemiscale-server.yml index ae871cca..00102ab5 100644 --- a/devtools/conda-envs/alchemiscale-server.yml +++ b/devtools/conda-envs/alchemiscale-server.yml @@ -10,6 +10,7 @@ dependencies: # alchemiscale dependencies - gufe=1.1.0 - openfe=1.2.0 + - zstandard - requests - click diff --git a/devtools/conda-envs/test.yml b/devtools/conda-envs/test.yml index fc447422..d55ac617 100644 --- a/devtools/conda-envs/test.yml +++ b/devtools/conda-envs/test.yml @@ -11,6 +11,7 @@ dependencies: - openfe>=1.2.0 - pydantic<2.0 - async-lru + - zstandard ## state store - neo4j-python-driver