Skip to content

Commit

Permalink
Add zstd compression to set_task_result in compute service
Browse files Browse the repository at this point in the history
- Update env files to include zstandard

- Update set_task_result in compute api and client to handle base64
encoded data. Rather than JSON serialize the ProtocolDAGResult (PDR)
and use this is a the intermediate format, instead:

1) create a keyed chain representation of the PDR

2) JSON serialize this representation

3) compress the utf-8 encoded bytes with zstandard

4) encode with base64

- Use the above base64 encoded data as the intermediate format and
reverse the operations above to recover the PDR.
  • Loading branch information
ianmkenney committed Nov 25, 2024
1 parent cec1538 commit 67f7964
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 7 deletions.
14 changes: 11 additions & 3 deletions alchemiscale/compute/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import json
from datetime import datetime, timedelta
import random
import base64

from fastapi import FastAPI, APIRouter, Body, Depends
from fastapi.middleware.gzip import GZipMiddleware
from gufe.tokenization import GufeTokenizable, JSON_HANDLER
from gufe.tokenization import GufeTokenizable, JSON_HANDLER, KeyedChain
import zstandard as zstd

from ..base.api import (
QueryGUFEHandler,
Expand Down Expand Up @@ -328,8 +330,14 @@ def set_task_result(
task_sk = ScopedKey.from_str(task_scoped_key)
validate_scopes(task_sk.scope, token)

pdr = json.loads(protocoldagresult, cls=JSON_HANDLER.decoder)
pdr = GufeTokenizable.from_dict(pdr)
# decode b64 and decompress the zstd bytes back into json
protocoldagresult = base64.b64decode(protocoldagresult)
decompressor = zstd.ZstdDecompressor()
protocoldagresult = decompressor.decompress(protocoldagresult)

Check warning on line 336 in alchemiscale/compute/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/compute/api.py#L334-L336

Added lines #L334 - L336 were not covered by tests

pdr_keyed_chain_rep = json.loads(protocoldagresult, cls=JSON_HANDLER.decoder)
pdr_keyed_chain = KeyedChain.from_keyed_chain_rep(pdr_keyed_chain_rep)
pdr = pdr_keyed_chain.to_gufe()

Check warning on line 340 in alchemiscale/compute/api.py

View check run for this annotation

Codecov / codecov/patch

alchemiscale/compute/api.py#L338-L340

Added lines #L338 - L340 were not covered by tests

tf_sk, _ = n4js.get_task_transformation(
task=task_scoped_key,
Expand Down
19 changes: 15 additions & 4 deletions alchemiscale/compute/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
import json
from urllib.parse import urljoin
from functools import wraps
import base64

import requests
from requests.auth import HTTPBasicAuth

from gufe.tokenization import GufeTokenizable, JSON_HANDLER
import zstandard as zstd

from gufe.tokenization import GufeTokenizable, JSON_HANDLER, KeyedChain
from gufe import Transformation
from gufe.protocols import ProtocolDAGResult

Expand Down Expand Up @@ -128,10 +131,18 @@ def set_task_result(
protocoldagresult: ProtocolDAGResult,
compute_service_id=Optional[ComputeServiceID],
) -> ScopedKey:

keyed_chain_rep = KeyedChain.from_gufe(protocoldagresult).to_keyed_chain_rep()
json_rep = json.dumps(keyed_chain_rep, cls=JSON_HANDLER.encoder)
json_bytes = json_rep.encode("utf-8")

compressor = zstd.ZstdCompressor()
compressed = compressor.compress(json_bytes)

base64_encoded = base64.b64encode(compressed).decode("utf-8")

data = dict(
protocoldagresult=json.dumps(
protocoldagresult.to_dict(), cls=JSON_HANDLER.encoder
),
protocoldagresult=base64_encoded,
compute_service_id=str(compute_service_id),
)

Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/alchemiscale-client.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- httpx
- pydantic<2.0
- async-lru
- zstandard

## user client
- rich
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/alchemiscale-compute.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- httpx
- pydantic<2.0
- async-lru
- zstandard

# openmm protocols
- feflow=0.1.0
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/alchemiscale-server.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies:
# alchemiscale dependencies
- gufe=1.1.0
- openfe=1.2.0
- zstandard

- requests
- click
Expand Down
1 change: 1 addition & 0 deletions devtools/conda-envs/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- openfe>=1.2.0
- pydantic<2.0
- async-lru
- zstandard

## state store
- neo4j-python-driver
Expand Down

0 comments on commit 67f7964

Please sign in to comment.