Skip to content

Commit

Permalink
AtRes.input_data
Browse files Browse the repository at this point in the history
  • Loading branch information
loriab committed Nov 12, 2024
1 parent b9aebcb commit 9dde67b
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 39 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ New Features

Enhancements
++++++++++++
- (536b) ``v2.AtomicResult`` no longer inherits from ``v2.AtomicInput``. It gained a ``input_data`` field for the corresponding ``AtomicInput`` and independent ``id`` and ``molecule`` fields (the latter being equivalvent to ``v1.AtomicResult.molecule`` with the frame of the results; ``v2.AtomicResult.input_data.molecule`` is new, preserving the input frame). Gained independent ``extras``
- (536b) Both v1/v2 ``AtomicResult.convert_v()`` learned to handle the new ``input_data`` layout.
- (:pr:`357`, :issue:`536`) ``v2.AtomicResult``, ``v2.OptimizationResult``, and ``v2.TorsionDriveResult`` have the ``success`` field enforced to ``True``. Previously it could be set T/F. Now validation errors if not T. Likewise ``v2.FailedOperation.success`` is enforced to ``False``.
- (:pr:`357`, :issue:`536`) ``v2.AtomicResult``, ``v2.OptimizationResult``, and ``v2.TorsionDriveResult`` have the ``error`` field removed. This isn't used now that ``success=True`` and failure should be routed to ``FailedOperation``.
- (:pr:`357`) ``v1.Molecule`` had its schema_version changed to a Literal[2] (remember Mol is one-ahead of general numbering scheme) so new instances will be 2 even if another value is passed in. Ditto ``v2.BasisSet.schema_version=2``. Ditto ``v1.BasisSet.schema_version=1`` Ditto ``v1.QCInputSpecification.schema_version=1`` and ``v1.OptimizationSpecification.schema_version=1``.
Expand Down
4 changes: 3 additions & 1 deletion qcelemental/models/v1/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,10 @@ def convert_v(
if dself.pop("error", None):
pass

dself["input_specification"].pop("schema_version", None)
dself["optimization_spec"].pop("schema_version", None)
dself["optimization_history"] = {
(k, [opthist_class(**res).convert_v(version) for res in lst])
k: [opthist_class(**res).convert_v(version) for res in lst]
for k, lst in dself["optimization_history"].items()
}

Expand Down
12 changes: 12 additions & 0 deletions qcelemental/models/v1/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,18 @@ def convert_v(
if dself.pop("error", None):
pass

input_data = {
k: dself.pop(k) for k in list(dself.keys()) if k in ["driver", "keywords", "model", "protocols"]
}
input_data["molecule"] = dself["molecule"] # duplicate since input mol has been overwritten
# any input provenance has been overwritten
if dself["id"]:
input_data["id"] = dself["id"] # in/out should likely match
input_data["extras"] = {
k: dself["extras"].pop(k) for k in list(dself["extras"].keys()) if k in []
} # sep any merged extras
dself["input_data"] = input_data

self_vN = qcel.models.v2.AtomicResult(**dself)

return self_vN
Expand Down
12 changes: 12 additions & 0 deletions qcelemental/models/v2/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,12 @@ def convert_v(
if check_convertible_version(version, error="OptimizationResult") == "self":
return self

trajectory_class = self.trajectory[0].__class__
dself = self.model_dump()
if version == 1:
dself["trajectory"] = [trajectory_class(**atres).convert_v(version) for atres in dself["trajectory"]]
dself["input_specification"].pop("schema_version", None)

self_vN = qcel.models.v1.OptimizationResult(**dself)

return self_vN
Expand Down Expand Up @@ -297,6 +300,9 @@ def convert_v(

dself = self.model_dump()
if version == 1:
if dself["optimization_spec"].pop("extras", None):
pass

self_vN = qcel.models.v1.TorsionDriveInput(**dself)

return self_vN
Expand Down Expand Up @@ -348,11 +354,17 @@ def convert_v(
if check_convertible_version(version, error="TorsionDriveResult") == "self":
return self

opthist_class = next(iter(self.optimization_history.values()))[0].__class__
dself = self.model_dump()
if version == 1:
if dself["optimization_spec"].pop("extras", None):
pass

dself["optimization_history"] = {
k: [opthist_class(**res).convert_v(version) for res in lst]
for k, lst in dself["optimization_history"].items()
}

self_vN = qcel.models.v1.TorsionDriveResult(**dself)

return self_vN
44 changes: 33 additions & 11 deletions qcelemental/models/v2/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def convert_v(
return self_vN


class AtomicResult(AtomicInput):
class AtomicResult(ProtoModel):
r"""Results from a CMS program execution."""

schema_name: constr(strip_whitespace=True, pattern=r"^(qc\_?schema_output)$") = Field( # type: ignore
Expand All @@ -736,6 +736,9 @@ class AtomicResult(AtomicInput):
2,
description="The version number of :attr:`~qcelemental.models.AtomicResult.schema_name` to which this model conforms.",
)
id: Optional[str] = Field(None, description="The optional ID for the computation.")
input_data: AtomicInput = Field(..., description=str(AtomicInput.__doc__))
molecule: Molecule = Field(..., description="The molecule with frame and orientation of the results.")
properties: AtomicResultProperties = Field(..., description=str(AtomicResultProperties.__doc__))
wavefunction: Optional[WavefunctionProperties] = Field(None, description=str(WavefunctionProperties.__doc__))

Expand All @@ -755,6 +758,10 @@ class AtomicResult(AtomicInput):
True, description="The success of program execution. If False, other fields may be blank."
)
provenance: Provenance = Field(..., description=str(Provenance.__doc__))
extras: Dict[str, Any] = Field(
{},
description="Additional information to bundle with the computation. Use for schema development and scratch space.",
)

@field_validator("schema_name", mode="before")
@classmethod
Expand All @@ -774,12 +781,17 @@ def _version_stamp(cls, v):
@field_validator("return_result")
@classmethod
def _validate_return_result(cls, v, info):
if info.data["driver"] == "energy":
print(info)
# Do not propagate validation errors
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")
driver = info.data["input_data"].driver
if driver == "energy":
if isinstance(v, np.ndarray) and v.size == 1:
v = v.item(0)
elif info.data["driver"] == "gradient":
elif driver == "gradient":
v = np.asarray(v).reshape(-1, 3)
elif info.data["driver"] == "hessian":
elif driver == "hessian":
v = np.asarray(v)
nsq = int(v.size**0.5)
v.shape = (nsq, nsq)
Expand All @@ -800,8 +812,8 @@ def _wavefunction_protocol(cls, value, info):
raise ValueError("wavefunction must be None, a dict, or a WavefunctionProperties object.")

# Do not propagate validation errors
if "protocols" not in info.data:
raise ValueError("Protocols was not properly formed.")
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")

# Handle restricted
restricted = wfn.get("restricted", None)
Expand All @@ -814,7 +826,7 @@ def _wavefunction_protocol(cls, value, info):
wfn.pop(k)

# Handle protocols
wfnp = info.data["protocols"].wavefunction
wfnp = info.data["input_data"].protocols.wavefunction
return_keep = None
if wfnp == "all":
pass
Expand Down Expand Up @@ -861,10 +873,10 @@ def _wavefunction_protocol(cls, value, info):
@classmethod
def _stdout_protocol(cls, value, info):
# Do not propagate validation errors
if "protocols" not in info.data:
raise ValueError("Protocols was not properly formed.")
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")

outp = info.data["protocols"].stdout
outp = info.data["input_data"].protocols.stdout
if outp is True:
return value
elif outp is False:
Expand All @@ -875,7 +887,11 @@ def _stdout_protocol(cls, value, info):
@field_validator("native_files")
@classmethod
def _native_file_protocol(cls, value, info):
ancp = info.data["protocols"].native_files
# Do not propagate validation errors
if "input_data" not in info.data:
raise ValueError("Input_data was not properly formed.")

ancp = info.data["input_data"].protocols.native_files
if ancp == "all":
return value
elif ancp == "none":
Expand Down Expand Up @@ -905,6 +921,12 @@ def convert_v(

dself = self.model_dump()
if version == 1:
# input_data = self.input_data.convert_v(1) # TODO probably later
input_data = dself.pop("input_data")
input_data.pop("molecule", None) # discard
input_data.pop("provenance", None) # discard
dself["extras"] = {**input_data.pop("extras", {}), **dself.pop("extras", {})} # merge
dself = {**input_data, **dself}
self_vN = qcel.models.v1.AtomicResult(**dself)

return self_vN
Loading

0 comments on commit 9dde67b

Please sign in to comment.