Skip to content

Commit

Permalink
Merge pull request #67 from scipp/supermirror-efficiency
Browse files Browse the repository at this point in the history
Add 2nd order polynomial supermirror efficiency function
  • Loading branch information
SimonHeybrock authored Jul 8, 2024
2 parents 611e4eb + de63535 commit 44aeafc
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 27 deletions.
3 changes: 3 additions & 0 deletions docs/api-reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
HalfPolarizedWorkflow
PolarizationAnalysisWorkflow
SupermirrorWorkflow
```

## Classes

Expand Down Expand Up @@ -44,6 +45,8 @@
PolarizationCorrectedData
Polarized
Polarizer
SecondDegreePolynomialEfficiency
SupermirrorEfficiencyFunction
Up
```

Expand Down
101 changes: 91 additions & 10 deletions docs/user-guide/zoom.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -258,26 +258,26 @@
"metadata": {},
"outputs": [],
"source": [
"pol_workflow = pol.he3.He3CellWorkflow(in_situ=False, incoming_polarized=True)\n",
"he3_workflow = pol.he3.He3CellWorkflow(in_situ=False, incoming_polarized=True)\n",
"# TODO Is plus correct here, this is period 0? Do we also have minus data?\n",
"pol_workflow[pol.he3.He3AnalyzerTransmissionFractionParallel] = transmission\n",
"he3_workflow[pol.he3.He3AnalyzerTransmissionFractionParallel] = transmission\n",
"# TODO Fake empty transmission for now, would need to load different period\n",
"pol_workflow[pol.he3.He3AnalyzerTransmissionFractionAntiParallel] = transmission[\n",
"he3_workflow[pol.he3.He3AnalyzerTransmissionFractionAntiParallel] = transmission[\n",
" 'time', 0:0\n",
"]\n",
"pol_workflow[\n",
"he3_workflow[\n",
" pol.he3.He3CellTransmissionFractionIncomingUnpolarized[\n",
" pol.Analyzer, pol.Depolarized\n",
" ]\n",
"] = transmission_depolarized\n",
"\n",
"# When in_situ=False, these params are used as starting guess for the fit\n",
"pol_workflow[pol.he3.He3CellLength[pol.Analyzer]] = 0.1 * sc.Unit('m')\n",
"pol_workflow[pol.he3.He3CellPressure[pol.Analyzer]] = 1.0 * sc.Unit('bar')\n",
"pol_workflow[pol.he3.He3CellTemperature[pol.Analyzer]] = 300.0 * sc.Unit('K')\n",
"he3_workflow[pol.he3.He3CellLength[pol.Analyzer]] = 0.1 * sc.Unit('m')\n",
"he3_workflow[pol.he3.He3CellPressure[pol.Analyzer]] = 1.0 * sc.Unit('bar')\n",
"he3_workflow[pol.he3.He3CellTemperature[pol.Analyzer]] = 300.0 * sc.Unit('K')\n",
"\n",
"pol_workflow[pol.he3.He3TransmissionEmptyGlass[pol.Analyzer]] = transmission_empty_glass\n",
"pol_workflow.visualize(\n",
"he3_workflow[pol.he3.He3TransmissionEmptyGlass[pol.Analyzer]] = transmission_empty_glass\n",
"he3_workflow.visualize(\n",
" pol.TransmissionFunction[pol.Analyzer], graph_attr={'rankdir': 'LR'}\n",
")"
]
Expand All @@ -297,7 +297,7 @@
"metadata": {},
"outputs": [],
"source": [
"func = pol_workflow.compute(pol.TransmissionFunction[pol.Analyzer])"
"func = he3_workflow.compute(pol.TransmissionFunction[pol.Analyzer])"
]
},
{
Expand Down Expand Up @@ -378,6 +378,87 @@
"source": [
"func.polarization_function.T1"
]
},
{
"cell_type": "markdown",
"id": "79f78366",
"metadata": {},
"source": [
"## Correction workflow\n",
"\n",
"In the previous section we have setup the workflow for the analyzer.\n",
"We also computed the transmission function there, but in production this will be done implicitly by running the entire workflow we will setup here.\n",
"We can combine this with the workflow for the polarizer to obtain the full correction workflow:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41bb77db",
"metadata": {},
"outputs": [],
"source": [
"supermirror_workflow = pol.SupermirrorWorkflow()\n",
"supermirror_workflow.visualize(pol.TransmissionFunction[pol.Polarizer])"
]
},
{
"cell_type": "markdown",
"id": "c9209c68",
"metadata": {},
"source": [
"We will use a second-order polynomial supermirror efficiency function:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b2ef8c40",
"metadata": {},
"outputs": [],
"source": [
"# Note that these coefficients are meaningless, please fill in correct values!\n",
"supermirror_workflow[pol.SupermirrorEfficiencyFunction[pol.Polarizer]] = (\n",
" pol.SecondDegreePolynomialEfficiency(\n",
" a=0.5 * sc.Unit('1/angstrom**2'),\n",
" b=0.4 * sc.Unit('1/angstrom'),\n",
" c=0.3 * sc.Unit('dimensionless'),\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c95094c7",
"metadata": {},
"outputs": [],
"source": [
"workflow = pol.PolarizationAnalysisWorkflow(\n",
" polarizer_workflow=supermirror_workflow,\n",
" analyzer_workflow=he3_workflow,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "d5e6b0b6",
"metadata": {},
"source": [
"For a single channel, the complete workflow looks as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "91a76593",
"metadata": {},
"outputs": [],
"source": [
"workflow.visualize(\n",
" pol.PolarizationCorrectedData[pol.Up, pol.Up], graph_attr={'rankdir': 'LR'}\n",
")"
]
}
],
"metadata": {
Expand Down
8 changes: 7 additions & 1 deletion src/ess/polarization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
He3TransmissionFunction,
Polarized,
)
from .supermirror import SupermirrorWorkflow
from .supermirror import (
SecondDegreePolynomialEfficiency,
SupermirrorEfficiencyFunction,
SupermirrorWorkflow,
)
from .types import (
Analyzer,
Down,
Expand Down Expand Up @@ -73,6 +77,8 @@
"Polarizer",
"PolarizingElement",
"ReducedSampleDataBySpinChannel",
"SecondDegreePolynomialEfficiency",
"SupermirrorEfficiencyFunction",
"SupermirrorWorkflow",
"TransmissionFunction",
"NoAnalyzer",
Expand Down
54 changes: 38 additions & 16 deletions src/ess/polarization/supermirror.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Generic

Expand All @@ -10,10 +11,44 @@
from .types import PlusMinus, PolarizingElement, TransmissionFunction


class SupermirrorEfficiencyFunction(Generic[PolarizingElement]):
class SupermirrorEfficiencyFunction(Generic[PolarizingElement], ABC):
"""Base class for supermirror efficiency functions"""

@abstractmethod
def __call__(self, *, wavelength: sc.Variable) -> sc.DataArray:
"""Return the efficiency of a supermirror for a given wavelength"""


@dataclass
class SecondDegreePolynomialEfficiency(
SupermirrorEfficiencyFunction[PolarizingElement]
):
"""
Efficiency of a supermirror as a second-degree polynomial
The efficiency is given by a * wavelength^2 + b * wavelength + c
Parameters
----------
a:
Coefficient of the quadratic term, with unit of 1/angstrom^2
b:
Coefficient of the linear term, with unit of 1/angstrom
c:
Constant term, dimensionless
"""

a: sc.Variable
b: sc.Variable
c: sc.Variable

def __call__(self, *, wavelength: sc.Variable) -> sc.DataArray:
"""Return the efficiency of a supermirror for a given wavelength"""
raise NotImplementedError
return (
(self.a * wavelength**2).to(unit='', copy=False)
+ (self.b * wavelength).to(unit='', copy=False)
+ self.c.to(unit='', copy=False)
)


@dataclass
Expand All @@ -37,13 +72,6 @@ def apply(self, data: sc.DataArray, plus_minus: PlusMinus) -> sc.DataArray:
return self(wavelength=data.coords['wavelength'], plus_minus=plus_minus)


def get_supermirror_efficiency_function() -> (
SupermirrorEfficiencyFunction[PolarizingElement]
):
# TODO This will need some input parameters
return SupermirrorEfficiencyFunction[PolarizingElement]()


def get_supermirror_transmission_function(
efficiency_function: SupermirrorEfficiencyFunction[PolarizingElement],
) -> TransmissionFunction[PolarizingElement]:
Expand All @@ -52,14 +80,8 @@ def get_supermirror_transmission_function(
)


supermirror_providers = (
get_supermirror_efficiency_function,
get_supermirror_transmission_function,
)


def SupermirrorWorkflow() -> sciline.Pipeline:
"""
Workflow for computing transmission functions for supermirror polarizing elements.
"""
return sciline.Pipeline(supermirror_providers)
return sciline.Pipeline((get_supermirror_transmission_function,))
61 changes: 61 additions & 0 deletions tests/supermirror_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
import pytest
import scipp as sc

import ess.polarization as pol


def test_SecondDegreePolynomialEfficiency_raises_if_units_incompatible():
wav = sc.scalar(1.0, unit='m')
with pytest.raises(sc.UnitError, match=" to `dimensionless` is not valid"):
eff = pol.SecondDegreePolynomialEfficiency(
a=sc.scalar(1.0, unit='1/angstrom'),
b=sc.scalar(1.0, unit='1/angstrom'),
c=sc.scalar(1.0),
)
eff(wavelength=wav)
with pytest.raises(sc.UnitError, match=" to `dimensionless` is not valid"):
eff = pol.SecondDegreePolynomialEfficiency(
a=sc.scalar(1.0, unit='1/angstrom**2'),
b=sc.scalar(1.0, unit='1/angstrom**2'),
c=sc.scalar(1.0),
)
eff(wavelength=wav)
with pytest.raises(sc.UnitError, match=" to `dimensionless` is not valid"):
eff = pol.SecondDegreePolynomialEfficiency(
a=sc.scalar(1.0, unit='1/angstrom**2'),
b=sc.scalar(1.0, unit='1/angstrom'),
c=sc.scalar(1.0, unit='1/angstrom'),
)
eff(wavelength=wav)
with pytest.raises(sc.UnitError, match=" to `dimensionless` is not valid"):
eff = pol.SecondDegreePolynomialEfficiency(
a=sc.scalar(1.0, unit='1/angstrom**2'),
b=sc.scalar(1.0, unit='1/angstrom'),
c=sc.scalar(1.0),
)
eff(wavelength=wav / sc.scalar(1.0, unit='s'))


def test_SecondDegreePolynomialEfficiency_produces_correct_values():
a = sc.scalar(1.0, unit='1/angstrom**2')
b = sc.scalar(2.0, unit='1/angstrom')
c = sc.scalar(3.0)
f = pol.SecondDegreePolynomialEfficiency(a=a, b=b, c=c)
assert f(wavelength=sc.scalar(0.0, unit='angstrom')) == 3.0
assert f(wavelength=sc.scalar(1.0, unit='angstrom')) == 6.0
assert f(wavelength=sc.scalar(2.0, unit='angstrom')) == 11.0


def test_SecondDegreePolynomialEfficiency_converts_units():
a = sc.scalar(1.0, unit='1/angstrom**2')
b = sc.scalar(20.0, unit='1/nm')
c = sc.scalar(3.0)
f = pol.SecondDegreePolynomialEfficiency(a=a, b=b, c=c)
assert f(wavelength=sc.scalar(0.0, unit='angstrom')) == 3.0
assert f(wavelength=sc.scalar(1.0, unit='angstrom')) == 6.0
assert f(wavelength=sc.scalar(2.0, unit='angstrom')) == 11.0
assert f(wavelength=sc.scalar(0.0, unit='nm')) == 3.0
assert f(wavelength=sc.scalar(0.1, unit='nm')) == 6.0
assert f(wavelength=sc.scalar(0.2, unit='nm')) == 11.0

0 comments on commit 44aeafc

Please sign in to comment.