Skip to content

Commit

Permalink
WIP: openpmd wavefront output
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Aug 22, 2024
1 parent 3a89b43 commit c05fc3f
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 37 deletions.
89 changes: 89 additions & 0 deletions pmd_beamphysics/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

import datetime
import platform
import getpass
import dataclasses

from typing import Dict, Sequence, Union
from typing_extensions import Literal

from . import tools


@dataclasses.dataclass
class BaseMetadata:
data_index: int = 0
object_index: int = 0

# Base pmd wavefront file attrs:
spec_version: str = "2.0.0"

author: str = dataclasses.field(default_factory=getpass.getuser)
machine: str = dataclasses.field(default_factory=platform.node)
comment: str = dataclasses.field(default="")

software: str = dataclasses.field(default="pmd_beamphysics")
software_version: str = dataclasses.field(
default_factory=tools.get_version, metadata={"pmd_key": "softwareVersion"}
)
software_dependencies: str = dataclasses.field(
default="", metadata={"pmd_key": "softwareDependencies"}
)
iteration_encoding: Literal["fileBased", "groupBased"] = dataclasses.field(
default="groupBased", metadata={"pmd_key": "iterationEncoding"}
)
iteration_format: str = dataclasses.field(
default="/data/%T/", metadata={"pmd_key": "iterationFormat"}
)
date: datetime.datetime = dataclasses.field(
default_factory=tools.current_date_with_tzinfo,
metadata={"pmd_key": "date"},
)

# Per iteration
iteration_time: float = dataclasses.field(default=0.0, metadata={"pmd_key": "time"})
iteration_dt: float = dataclasses.field(default=0.0, metadata={"pmd_key": "dt"})
iteration_time_unit_si: float = dataclasses.field(
default=1.0, metadata={"pmd_key": "timeUnitSI"}
)

def _get_pmd_dict(self, attrs: Sequence[str]) -> Dict[str, Union[str, float, None]]:
attr_to_field = {
fld.name: fld.metadata.get("pmd_key", fld.name)
for fld in dataclasses.fields(self)
}
return {
attr_to_field[attr]: getattr(self, attr)
for attr in attrs
if getattr(self, attr) is not None
}

@property
def iteration_attrs(self):
return self._get_pmd_dict(
[
"iteration_time",
"iteration_dt",
"iteration_time_unit_si",
]
)

@property
def base_attrs(self):
res = self._get_pmd_dict(
[
"author",
"machine",
"comment",
"software",
"software_version",
"software_dependencies",
"iteration_format",
"iteration_encoding",
]
)
return {
**res,
"date": tools.pmd_format_date(self.date),
}
63 changes: 40 additions & 23 deletions pmd_beamphysics/tools.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,96 @@
import datetime
import numpy as np


def fstr(s):
"""
Makes a fixed string for h5 files
"""
return np.bytes_(s)



def data_are_equal(d1, d2):
"""
Simple utility to compare data in dicts
Returns True only if all keys are the same, and all np.all data are the same
"""

if set(d1) != set(d2):
return False

for k in d1:
if not np.all(d1[k]==d2[k]):
if not np.all(d1[k] == d2[k]):
return False

return True


return True


#-----------------------------------------
# -----------------------------------------
# HDF5 utilities



def decode_attr(a):
"""
Decodes:
ASCII strings and arrays of them to str and arrays of str
single-length arrays to scalar (Bmad writes this)
"""
if isinstance(a, bytes):
return a.decode('utf-8')
return a.decode("utf-8")

if isinstance(a, np.ndarray):
if a.dtype.type is np.bytes_:
a = a.astype(str)
if len(a) == 1:
return a[0]

return a

return a


def decode_attrs(attrs):
return {k:decode_attr(v) for k,v in attrs.items()}
return {k: decode_attr(v) for k, v in attrs.items()}


def encode_attr(a):
"""
Encodes attribute
See the inverse function:
decode_attr
"""

if isinstance(a, str):
a = fstr(a)

if isinstance(a, list) or isinstance(a, tuple):
a = np.array(a)

if isinstance(a, np.ndarray):
if a.dtype.type is np.str_:
a = a.astype(np.bytes_)

return a


def encode_attrs(attrs):
return {k:encode_attr(v) for k,v in attrs.items()}
return {k: encode_attr(v) for k, v in attrs.items()}


def get_version() -> str:
"""Get the installed pmd-beamphysics version."""
from . import __version__

return __version__


def current_date_with_tzinfo() -> datetime.datetime:
from dateutil.tz import tzlocal

return datetime.datetime.now(tzlocal())


def pmd_format_date(dt: datetime.datetime) -> str:
return dt.strftime("%Y-%m-%d %H:%M:%S %z")
Loading

0 comments on commit c05fc3f

Please sign in to comment.