Skip to content

Commit

Permalink
Experiment module complete
Browse files Browse the repository at this point in the history
  • Loading branch information
ashuping committed Oct 16, 2024
1 parent d2d8110 commit c881b79
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 32 deletions.
64 changes: 43 additions & 21 deletions exseos/experiment/Experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
objects.
"""

import functools
from typing import Any, Callable
from exseos.experiment.Constant import (
BasicExperimentConstant,
ExperimentConstant,
LambdaExperimentConstant,
ConstantResolutionError,
)
from exseos.experiment.optimizer.Optimizer import Optimizer, OptimizerIteration
from exseos.types.Option import Nothing, Option, Some
from exseos.types.Result import Okay, Result
from exseos.types.Result import Result, Okay, Fail
from exseos.types.Variable import (
BoundVariable,
Variable,
Expand Down Expand Up @@ -69,7 +69,11 @@ def optimizer(self) -> Optimizer:

@property
def constants(self) -> tuple[ExperimentConstant]:
return self.constants
return self.__constants

@property
def ui(self) -> UIManager:
return self.__ui

def copy(self, **delta) -> "Experiment":
params = {
Expand All @@ -83,12 +87,13 @@ def copy(self, **delta) -> "Experiment":

def _resolve_constants(
self, opt_inputs: tuple[Variable]
) -> Option[tuple[Variable]]:
) -> Result[Exception, Exception, tuple[Variable]]:
unresolved_constants = self.constants

vars = opt_inputs

while unresolved_constants:
print(vars)
newly_resolved = tuple(
[
BoundVariable(c.name, c.resolve(VariableSet(vars)))
Expand All @@ -102,30 +107,49 @@ def _resolve_constants(
)

if len(newly_resolved) == 0:
return Nothing() # Can't resolve all constants.
return Fail(
[ConstantResolutionError(unresolved_constants, vars)]
) # Can't resolve all constants.

vars += newly_resolved

return Some(vars)
return Okay(vars)

async def run(self) -> "Result[Exception, Exception, ExperimentResult]":
# iteration: int = 0
# history: tuple[OptimizerIteration] = {}
# while True:
# next_iteration_inputs = self.optimizer.next(iteration, 1, history)
iteration: int = 0
history: tuple[OptimizerIteration] = ()
status = Okay(None)

# if len(next_iteration_inputs) == 0:
# break
while True:
next_iteration_inputs = self.optimizer.next(iteration, 1, history)

# run_inputs = self._resolve_constants(next_iteration_inputs)
if len(next_iteration_inputs) == 0:
break

# if not run_inputs.has_val:
run_inputs = self._resolve_constants(
tuple(next_iteration_inputs[0].vars.values())
)

# run_res = await self.workflow.run(
status <<= run_inputs
if status.is_fail:
return status

run_res = await self.workflow.run(run_inputs.val, self.ui)

status <<= run_res
if status.is_fail:
return status

history += (
OptimizerIteration(
VariableSet(run_inputs.val),
run_res.val,
),
)

# )
iteration += 1

return Okay(None)
return status >> Okay(ExperimentResult(self.optimizer, history))


class ExperimentResult:
Expand Down Expand Up @@ -174,7 +198,7 @@ def optimizer(self) -> Optimizer:

@property
def constants(self) -> tuple[ExperimentConstant]:
return self.constants
return self.__constants

@property
def ui(self) -> UIManager:
Expand Down Expand Up @@ -228,9 +252,7 @@ def _inner(
[
LambdaExperimentConstant(
key,
functools.partial(
_inner, fn=val[0], dependencies=ensure_from_name_arr(val[1])
),
lambda vs: _inner(val[0], ensure_from_name_arr(val[1]), vs),
)
for key, val in calcs.items()
]
Expand Down
15 changes: 9 additions & 6 deletions exseos/workflow/Workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,11 @@ def ui(self) -> UIManager:
return self.__ui

async def run(
self, inputs: tuple[Variable]
self, inputs: tuple[Variable], ui: UIManager = None
) -> Result[Exception, Exception, VariableSet]:
if not ui:
ui = self.ui

if not self.is_runnable:
return MalformedWorkflowError(self, self.status)

Expand All @@ -187,21 +190,21 @@ async def run(

if run_result.is_fail:
log.error("Encountered an error while retrieving stage inputs!")
await self.ui.display(ResultMessage(run_result))
await ui.display(ResultMessage(run_result))
return run_result

try:
stage_res = await stage.run(stage_input_res.val, self.ui)
stage_res = await stage.run(stage_input_res.val, ui)
except Exception as e:
log.error(f"User code threw an exception: {type(e).__name__}!")
await self.ui.display(ResultMessage(run_result))
await ui.display(ResultMessage(run_result))
return run_result << Fail([e])

run_result <<= stage_res

if run_result.is_fail:
log.error("User code returned a failure result!")
await self.ui.display(ResultMessage(run_result))
await ui.display(ResultMessage(run_result))
return run_result

run_result <<= stage_res
Expand All @@ -213,7 +216,7 @@ async def run(

if run_result.is_fail:
log.error("Encountered an error while retrieving final outputs!")
await self.ui.display(ResultMessage(run_result))
await ui.display(ResultMessage(run_result))

return run_result.map(lambda _: final_outputs.val)

Expand Down
15 changes: 10 additions & 5 deletions test/experiment/test_Experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from exseos.experiment.optimizer.OptimizerParameter import ContinuousOptimizerParameter
from exseos.experiment.optimizer.OptimizerTarget import TargetMaximize
from exseos.types.Result import Fail, Result, Okay
from exseos.types.Variable import UnboundVariable, VariableSet
from exseos.types.Variable import UnboundVariable, VariableSet, Variable
from exseos.workflow.stage.Stage import Stage
from exseos.workflow.Workflow import MakeWorkflow

Expand All @@ -34,21 +34,27 @@ class MustBe2X(Stage):

output_vars = ()

async def run(self, inputs: VariableSet, _) -> Result:
async def run(
self, inputs: VariableSet, _
) -> Result[Exception, Exception, tuple[Variable]]:
res = inputs.check_all()
if res.is_fail:
return res

if abs(2 * inputs.x - inputs.must_be_2x) > 0.00001:
return Fail([ValueError("it wasn't 2x!!!")])

return Okay(())


class MysteryEquation(Stage):
input_vars = (UnboundVariable("x", float, "input to the calculation"),)

output_vars = UnboundVariable("y", float, "result of the calculation")
output_vars = (UnboundVariable("y", float, "result of the calculation"),)

async def run(self, inputs: VariableSet, _) -> Result:
async def run(
self, inputs: VariableSet, _
) -> Result[Exception, Exception, tuple[Variable]]:
res = inputs.check_all()
if res.is_fail:
return res
Expand All @@ -60,7 +66,6 @@ async def run(self, inputs: VariableSet, _) -> Result:

@pytest.mark.integration
@pytest.mark.asyncio
@pytest.mark.xfail
async def test_experiment_integration():
workflow = (
MakeWorkflow("The Mystery Equation")
Expand Down

0 comments on commit c881b79

Please sign in to comment.