Skip to content

Commit

Permalink
Update um.py
Browse files Browse the repository at this point in the history
  • Loading branch information
monabraeunig authored Nov 13, 2024
1 parent c66334d commit 1ff768f
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions umbridge/um.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,11 @@ async def gradient(request):
if len(sens) != output_sizes[out_wrt]:
return error_response("InvalidInput", f"Sensitivity vector sens has invalid length! Expected {output_sizes[out_wrt]} but got {len(sens)}.", 400)

output_future = model_executor.submit(model.gradient, out_wrt, in_wrt, parameters, sens, config)
output = await asyncio.wrap_future(output_future)
try:
output_future = model_executor.submit(model.gradient, out_wrt, in_wrt, parameters, sens, config)
output = await asyncio.wrap_future(output_future)
except Exception as e:
return error_response("GradientComputationError", str(e), 500)

# Check if output is a list
if not isinstance(output, list):
Expand Down Expand Up @@ -316,8 +319,11 @@ async def applyjacobian(request):
if len(vec) != input_sizes[in_wrt]:
return error_response("InvalidInput", f"Vector vec has invalid length! Expected {input_sizes[in_wrt]} but got {len(vec)}.", 400)

output_future = model_executor.submit(model.apply_jacobian, out_wrt, in_wrt, parameters, vec, config)
output = await asyncio.wrap_future(output_future)
try:
output_future = model_executor.submit(model.apply_jacobian, out_wrt, in_wrt, parameters, vec, config)
output = await asyncio.wrap_future(output_future)
except Exception as e:
return error_response("JacobianComputationError", str(e), 500)

# Check if output is a list
if not isinstance(output, list):
Expand Down Expand Up @@ -369,8 +375,11 @@ async def applyhessian(request):
if in_wrt2 < 0 or in_wrt2 >= len(input_sizes):
return error_response("InvalidInput", "Invalid inWrt2 index! Expected between 0 and number of inputs minus one, but got " + str(in_wrt2), 400)

output_future = model_executor.submit(model.apply_hessian, out_wrt, in_wrt1, in_wrt2, parameters, sens, vec, config)
output = await asyncio.wrap_future(output_future)
try:
output_future = model_executor.submit(model.apply_hessian, out_wrt, in_wrt1, in_wrt2, parameters, sens, vec, config)
output = await asyncio.wrap_future(output_future)
except Exception as e:
return error_response("HessianComputationError", str(e), 500)

# Check if output is a list
if not isinstance(output, list):
Expand Down

0 comments on commit 1ff768f

Please sign in to comment.