-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
11,952 additions
and
74 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import requests | ||
|
||
from enum import Enum | ||
class OS(Enum): | ||
WIN = 0 | ||
LINUX = 1 | ||
MAC = 2 | ||
UNKNOWN = 3 | ||
|
||
def determine_os(): | ||
""" | ||
Get the operating system of the current machine. | ||
""" | ||
import platform | ||
if platform.system() == "Windows": | ||
return OS.WIN | ||
elif platform.system() == "Linux": | ||
return OS.LINUX | ||
elif platform.system() == "Darwin": | ||
return OS.MAC | ||
return OS.UNKNOWN | ||
|
||
def get_ausaxs(): | ||
_os = determine_os() | ||
url = "https://github.com/SasView/AUSAXS/releases/latest/download/" | ||
libs = None | ||
if _os == OS.WIN: | ||
libs = ["libausaxs.dll"] | ||
elif _os == OS.LINUX: | ||
libs = ["libausaxs.so"] | ||
elif _os == OS.MAC: | ||
libs = ["libausaxs.dylib"] | ||
if libs is not None: | ||
# we have to use a relative path since the package is not installed yet | ||
base_loc = "src/sas/sascalc/calculator/ausaxs/lib/" | ||
for lib in libs: | ||
response = requests.get(url+lib) | ||
with open(base_loc+lib, "wb") as f: | ||
f.write(response.content) | ||
|
||
def fetch_external_dependencies(): | ||
#surround with try/except to avoid breaking the build if the download fails | ||
try: | ||
get_ausaxs() | ||
except Exception as e: | ||
print("Download of external dependencies failed.", e) | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from enum import Enum | ||
|
||
class Arch(Enum): | ||
NONE = 0 | ||
SSE4 = 1 | ||
AVX = 2 | ||
|
||
class OS(Enum): | ||
WIN = 0 | ||
LINUX = 1 | ||
MAC = 2 | ||
UNKNOWN = 3 | ||
|
||
def determine_cpu_support(): | ||
""" | ||
Get the highest level of CPU support for SIMD instructions. | ||
""" | ||
import cpufeature | ||
if cpufeature.CPUFeature["AVX"]: | ||
return Arch.AVX | ||
elif cpufeature.CPUFeature["SSE4"]: | ||
return Arch.SSE4 | ||
else: | ||
return Arch.NONE | ||
|
||
def determine_os(): | ||
""" | ||
Get the operating system of the current machine. | ||
""" | ||
import platform | ||
if platform.system() == "Windows": | ||
return OS.WIN | ||
elif platform.system() == "Linux": | ||
return OS.LINUX | ||
elif platform.system() == "Darwin": | ||
return OS.MAC | ||
return OS.UNKNOWN | ||
|
||
def get_shared_lib_extension(): | ||
""" | ||
Get the shared library extension for the current operating system, including the dot. | ||
If the operating system is unknown, return an empty string. | ||
""" | ||
if determine_os() == OS.WIN: | ||
return ".dll" | ||
elif determine_os() == OS.LINUX: | ||
return ".so" | ||
elif determine_os() == OS.MAC: | ||
return ".dylib" | ||
return "" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import ctypes as ct | ||
import numpy as np | ||
import logging | ||
from enum import Enum | ||
import importlib.resources as resources | ||
|
||
# we need to be able to differentiate between being uninitialized and failing to load | ||
class lib_state(Enum): | ||
UNINITIALIZED = 0 | ||
FAILED = 1 | ||
READY = 2 | ||
|
||
ausaxs_state = lib_state.UNINITIALIZED | ||
ausaxs = None | ||
|
||
def attach_hooks(): | ||
global ausaxs_state | ||
global ausaxs | ||
|
||
from sas.sascalc.calculator.ausaxs.architecture import OS, Arch, determine_os, determine_cpu_support | ||
sys = determine_os() | ||
arch = determine_cpu_support() | ||
|
||
# as_file extracts the dll if it is in a zip file and probably deletes it afterwards, | ||
# so we have to do all operations on the dll inside the with statement | ||
with resources.as_file(resources.files("sas.sascalc.calculator.ausaxs.lib")) as loc: | ||
if sys is OS.WIN: | ||
path = loc.joinpath("libausaxs.dll") | ||
elif sys is OS.LINUX: | ||
path = loc.joinpath("libausaxs.so") | ||
elif sys is OS.MAC: | ||
path = loc.joinpath("libausaxs.dylib") | ||
else: | ||
path = "" | ||
|
||
ausaxs_state = lib_state.READY | ||
try: | ||
# evaluate_sans_debye func | ||
ausaxs = ct.CDLL(str(path)) | ||
ausaxs.evaluate_sans_debye.argtypes = [ | ||
ct.POINTER(ct.c_double), # q vector | ||
ct.POINTER(ct.c_double), # x vector | ||
ct.POINTER(ct.c_double), # y vector | ||
ct.POINTER(ct.c_double), # z vector | ||
ct.POINTER(ct.c_double), # w vector | ||
ct.c_int, # nq (number of points in q) | ||
ct.c_int, # nc (number of points in x, y, z, w) | ||
ct.POINTER(ct.c_int), # status (0 = success, 1 = q range error, 2 = other error) | ||
ct.POINTER(ct.c_double) # Iq vector for return value | ||
] | ||
ausaxs.evaluate_sans_debye.restype = None # don't expect a return value | ||
ausaxs_state = lib_state.READY | ||
except Exception as e: | ||
ausaxs_state = lib_state.FAILED | ||
logging.warning("Failed to hook into AUSAXS library, using default Debye implementation") | ||
print(e) | ||
|
||
def ausaxs_available(): | ||
""" | ||
Check if the AUSAXS library is available. | ||
""" | ||
if ausaxs_state is lib_state.UNINITIALIZED: | ||
attach_hooks() | ||
return ausaxs_state is lib_state.READY | ||
|
||
def evaluate_sans_debye(q, coords, w): | ||
""" | ||
Compute I(q) for a set of points using Debye sums. | ||
This uses AUSAXS if available, otherwise it uses the default implementation. | ||
*q* is the q values for the calculation. | ||
*coords* are the sample points. | ||
*w* is the weight associated with each point. | ||
""" | ||
if ausaxs_state is lib_state.UNINITIALIZED: | ||
attach_hooks() | ||
if ausaxs_state is lib_state.FAILED: | ||
from sas.sascalc.calculator.ausaxs.sasview_sans_debye import sasview_sans_debye | ||
return sasview_sans_debye(q, coords, w) | ||
|
||
_Iq = (ct.c_double * len(q))() | ||
_nq = ct.c_int(len(q)) | ||
_nc = ct.c_int(len(w)) | ||
_q = q.ctypes.data_as(ct.POINTER(ct.c_double)) | ||
_x = coords[0:, :].ctypes.data_as(ct.POINTER(ct.c_double)) | ||
_y = coords[1:, :].ctypes.data_as(ct.POINTER(ct.c_double)) | ||
_z = coords[2:, :].ctypes.data_as(ct.POINTER(ct.c_double)) | ||
_w = w.ctypes.data_as(ct.POINTER(ct.c_double)) | ||
_status = ct.c_int() | ||
|
||
# do the call | ||
ausaxs.evaluate_sans_debye(_q, _x, _y, _z, _w, _nq, _nc, ct.byref(_status), _Iq) | ||
|
||
# check for errors | ||
if _status.value != 0: | ||
logging.error("AUSAXS calculator terminated unexpectedly. Using default Debye implementation instead.") | ||
from sas.sascalc.calculator.ausaxs.sasview_sans_debye import sasview_sans_debye | ||
return sasview_sans_debye(q, coords, w) | ||
|
||
return np.array(_Iq) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import numpy as np | ||
import logging | ||
|
||
def sasview_sans_debye(q, coords, weight, worksize=100000): | ||
""" | ||
Compute I(q) for a set of points using the full Debye formula. | ||
*q* is the q values for the calculation. | ||
*coords* are the sample points. | ||
*weight* is the weight associated with each point. | ||
""" | ||
Iq = np.zeros_like(q) | ||
q_pi = q/np.pi # Precompute q/pi since np.sinc = sin(pi x)/(pi x). | ||
batch_size = worksize // coords.shape[0] | ||
for batch in range(0, len(q), batch_size): | ||
_calc_Iq_batch(Iq[batch:batch+batch_size], q_pi[batch:batch+batch_size], | ||
coords, weight) | ||
return Iq | ||
|
||
def _calc_Iq_batch(Iq, q_pi, coords, weight): | ||
""" | ||
Helper function for _calc_Iq which operates on a batch of q values. | ||
*Iq* is accumulated within each batch, and should be initialized to zero. | ||
*q_pi* is q/pi, needed because np.sinc computes sin(pi x)/(pi x). | ||
*coords* are the sample points. | ||
*weight* is the weight associated with each point. | ||
""" | ||
for j in range(len(weight)): | ||
if j % 100 == 0: logging.info(f"\tprogress: {j/len(weight)*100:.0f}%") | ||
# Compute dx for one row of the upper triangle matrix. | ||
dx = coords[:, j:] - coords[:, j:j+1] | ||
# Find the length of each dx vector. | ||
r = np.sqrt(np.sum(dx**2, axis=0)) | ||
# Compute I_jk = rho_j rho_k j0(q ||x_j - x_k||) over all q in batch. | ||
bes = np.sinc(q_pi[:, None]*r[None, :]) | ||
I_jk = (weight[j:] * weight[j])[None, :] * bes | ||
# Accumulate terms I(j,j), I(j, k+1..n) and by symmetry I(k+1..n, j). | ||
# Don't double-count the diagonal. | ||
Iq += 2*np.sum(I_jk, axis=1) - I_jk[:, 0] |
Oops, something went wrong.