diff --git a/pyproject.toml b/pyproject.toml index 8ba642b9..5179cd55 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,8 +39,8 @@ license = { file = "LICENSE" } requires-python = ">=3.7" dependencies = [ "numpy ~= 1.24", - "scipy ~= 1.10", - "mfem ~= 4.6.1.0", + "scipy ~= 1.11.3", + "mfem ~= 4.7.0.1", "scikit-learn ~= 1.4", "tqdm ~= 4.66", "pyvista ~= 0.43", @@ -48,6 +48,7 @@ dependencies = [ "nibabel ~= 5.1.0", "numba ~= 0.59.1", "scikit-image ~= 0.24.0", + "meshio ~= 2.3.5", # "dolfin", # "meshio", # "snakemake", diff --git a/src/kesi/fem_utils/pyvista_resampling.py b/src/kesi/fem_utils/pyvista_resampling.py index 7a478ad3..6a21fd8d 100644 --- a/src/kesi/fem_utils/pyvista_resampling.py +++ b/src/kesi/fem_utils/pyvista_resampling.py @@ -1,6 +1,8 @@ import argparse import glob import os +import tempfile +from io import StringIO import nibabel import numpy as np @@ -11,6 +13,70 @@ from tqdm import tqdm from kesi.fem_utils.grid_utils import load_or_create_grid +import mfem.ser as mfem + +from kesi.mfem_solver.mfem_piecewise_solver import prepare_fespace + + + +def convert_mfem_to_pyvista(mesh, solutions, names): + """ + mesh - mfem mesh object + solutions - list gridfunctions + names - list of gridfunction names to use + """ + + assert len(solutions) == len(names) + with tempfile.NamedTemporaryFile(suffix='.vtk') as fp: + output = StringIO() + print("saving output mesh") + mesh.PrintVTK(output, 0) + with open(fp.name, 'a') as vtk_file: + vtk_file.write(output.getvalue()) + del output + + fespace = prepare_fespace(mesh) + x = mfem.GridFunction(fespace) + # setting initial values in all points, boundary elements will enforce this value + for name, sol in zip(names, solutions): + output = StringIO() + + x.Assign(sol.GetDataArray()) + x.SaveVTK(output, name, 0) # this crashes when using passed solutions, missing fespace??? + + with open(fp.name, 'a') as vtk_file: + vtk_file.write(output.getvalue()) + del output + + pyvista_mesh = pyvista.read(fp.name) + return pyvista_mesh + + +def pyvista_sample_points(pyvista_mesh, points): + point_cloud = pyvista.PolyData(points) + sampled = point_cloud.sample(pyvista_mesh, progress_bar=True, snap_to_closest_point=True) + return sampled + + +def pyvista_sample_grid(pyvista_mesh, grid): + """ + grid - meshgrid [x, y, z, 3] + """ + dimensions = grid.shape[0:4] + spacing = [ + grid[1, 0, 0, 0] - grid[0, 0, 0, 0], + grid[0, 1, 0, 1] - grid[0, 0, 0, 1], + grid[0, 0, 1, 2] - grid[0, 0, 0, 2], + + ] + origin = [grid[0, 0, 0, 0], grid[0, 0, 0, 1], grid[0, 0, 0, 2]] + + str_grid = pyvista.ImageData(dimensions=dimensions, + spacing=spacing, + origin=origin + ) + sampled = str_grid.sample(pyvista_mesh, progress_bar=True, snap_to_closest_point=True) + return sampled def main(): diff --git a/src/kesi/mfem_solver/forward_solver.py b/src/kesi/mfem_solver/forward_solver.py index 78c8bbf6..8f8a1d4a 100644 --- a/src/kesi/mfem_solver/forward_solver.py +++ b/src/kesi/mfem_solver/forward_solver.py @@ -4,26 +4,10 @@ from scipy.interpolate import LinearNDInterpolator from scipy.spatial import Delaunay +from kesi.fem_utils.pyvista_resampling import convert_mfem_to_pyvista, pyvista_sample_points, pyvista_sample_grid from kesi.mfem_solver.mfem_piecewise_solver import mfem_solve_mesh, csd_distribution_coefficient, prepare_mesh -@lru_cache -def _cachedDelaunay(verts): - """verts as tuple of tuples""" - verts = np.array(verts) - verts_triangulation = Delaunay(verts) - return verts_triangulation - - -def cachedDelaunay(verts): - """ - Calculates delaunay triangulation of given vertices, with caching of the result in RAM - - verts - numpy array of verts (N, 3) - """ - return _cachedDelaunay(tuple(map(tuple, verts))) - - class CSDForwardSolver: def __init__(self, meshfile, conductivities, boundary_value=0, additional_refinement=False, sampling_points=None, @@ -53,11 +37,7 @@ def solve_coeff(self, coeff): conductivities=self.conductivities) self.solution = solution - verts = np.array(self.mesh.GetVertexArray()) - sol = self.solution.GetDataArray() - verts_triangulation = cachedDelaunay(verts) - - self.solution_interpolated = self.interpolator(verts_triangulation, sol) + self.pyvista_mesh_solution = convert_mfem_to_pyvista(self.mesh, [solution], ["pot"]) return solution def solve(self, xyz, csd): @@ -67,12 +47,21 @@ def solve(self, xyz, csd): coeff = csd_distribution_coefficient(xyz, csd) return self.solve_coeff(coeff) - def sample_solution_probe(self, x, y, z): - return self.solution_interpolated([x, y, z])[0] + points = np.array([[x, y, z]]) + cloud = self.sample_solution(points) + return cloud.get_array("pot") def sample_solution(self, positions): - """positions - array (N, 3)""" - assert self.solution_interpolated is not None - pos_arr = np.array(positions) - return self.solution_interpolated(pos_arr) + """positions - array (N, 3) or meshgrid stack [X, Y, Z, 3]""" + + cloud = pyvista_sample_points(self.pyvista_mesh_solution, positions) + data = cloud.get_array("pot") + return data + + def sample_grid(self, grid): + sampled = pyvista_sample_grid(self.pyvista_mesh_solution, grid) + data = sampled.get_array("pot") + sampled_grid = data.reshape(grid.shape[0:3], order="F") + return sampled_grid + diff --git a/src/kesi/mfem_solver/mfem_piecewise_solver.py b/src/kesi/mfem_solver/mfem_piecewise_solver.py index b655a4d8..49c8c6cd 100644 --- a/src/kesi/mfem_solver/mfem_piecewise_solver.py +++ b/src/kesi/mfem_solver/mfem_piecewise_solver.py @@ -1,6 +1,6 @@ import argparse import os -from functools import partial +from functools import partial, lru_cache import mfem.ser as mfem import numpy as np import pandas as pd @@ -37,8 +37,9 @@ def refine_around_electrodes(mesh, electrode_positions): return mesh +@lru_cache def prepare_mesh(meshfile, refinement, electrode_positions=None): - "if electrode positions are given, perform additional refinement around electrodes positions, array (N, 3)" + "if electrode positions are given, perform additional refinement around electrodes positions, tuple of tuples of length 3" # to create run # gmsh -3 -format msh22 four_spheres_in_air_with_plane.geo print("Loading mesh...") @@ -53,7 +54,6 @@ def prepare_mesh(meshfile, refinement, electrode_positions=None): if electrode_positions is not None: mesh = refine_around_electrodes(mesh, electrode_positions) - return mesh @@ -122,6 +122,7 @@ def mfem_solve_mesh(csd_coefficient, mesh, boundary_potential, conductivities): b = mfem.LinearForm(fespace) conductivities_vector = mfem.Vector(list(1.0 / (4 * np.pi * conductivities))) + # conductivities_vector = mfem.Vector(list(conductivities)) conductivities_coeff = mfem.PWConstCoefficient(conductivities_vector)