Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

getting rid of delaunay #32

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@ license = { file = "LICENSE" }
requires-python = ">=3.7"
dependencies = [
"numpy ~= 1.24",
"scipy ~= 1.10",
"mfem ~= 4.6.1.0",
"scipy ~= 1.11.3",
"mfem ~= 4.7.0.1",
"scikit-learn ~= 1.4",
"tqdm ~= 4.66",
"pyvista ~= 0.43",
"pandas ~= 2.0",
"nibabel ~= 5.1.0",
"numba ~= 0.59.1",
"scikit-image ~= 0.24.0",
"meshio ~= 2.3.5",
# "dolfin",
# "meshio",
# "snakemake",
Expand Down
66 changes: 66 additions & 0 deletions src/kesi/fem_utils/pyvista_resampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
import glob
import os
import tempfile
from io import StringIO

import nibabel
import numpy as np
Expand All @@ -11,6 +13,70 @@
from tqdm import tqdm

from kesi.fem_utils.grid_utils import load_or_create_grid
import mfem.ser as mfem

from kesi.mfem_solver.mfem_piecewise_solver import prepare_fespace



def convert_mfem_to_pyvista(mesh, solutions, names):
"""
mesh - mfem mesh object
solutions - list gridfunctions
names - list of gridfunction names to use
"""

assert len(solutions) == len(names)
with tempfile.NamedTemporaryFile(suffix='.vtk') as fp:
output = StringIO()
print("saving output mesh")
mesh.PrintVTK(output, 0)
with open(fp.name, 'a') as vtk_file:
vtk_file.write(output.getvalue())
del output

fespace = prepare_fespace(mesh)
x = mfem.GridFunction(fespace)
# setting initial values in all points, boundary elements will enforce this value
for name, sol in zip(names, solutions):
output = StringIO()

x.Assign(sol.GetDataArray())
x.SaveVTK(output, name, 0) # this crashes when using passed solutions, missing fespace???

with open(fp.name, 'a') as vtk_file:
vtk_file.write(output.getvalue())
del output

pyvista_mesh = pyvista.read(fp.name)
return pyvista_mesh


def pyvista_sample_points(pyvista_mesh, points):
point_cloud = pyvista.PolyData(points)
sampled = point_cloud.sample(pyvista_mesh, progress_bar=True, snap_to_closest_point=True)
return sampled


def pyvista_sample_grid(pyvista_mesh, grid):
"""
grid - meshgrid [x, y, z, 3]
"""
dimensions = grid.shape[0:4]
spacing = [
grid[1, 0, 0, 0] - grid[0, 0, 0, 0],
grid[0, 1, 0, 1] - grid[0, 0, 0, 1],
grid[0, 0, 1, 2] - grid[0, 0, 0, 2],

]
origin = [grid[0, 0, 0, 0], grid[0, 0, 0, 1], grid[0, 0, 0, 2]]

str_grid = pyvista.ImageData(dimensions=dimensions,
spacing=spacing,
origin=origin
)
sampled = str_grid.sample(pyvista_mesh, progress_bar=True, snap_to_closest_point=True)
return sampled


def main():
Expand Down
45 changes: 17 additions & 28 deletions src/kesi/mfem_solver/forward_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,10 @@
from scipy.interpolate import LinearNDInterpolator
from scipy.spatial import Delaunay

from kesi.fem_utils.pyvista_resampling import convert_mfem_to_pyvista, pyvista_sample_points, pyvista_sample_grid
from kesi.mfem_solver.mfem_piecewise_solver import mfem_solve_mesh, csd_distribution_coefficient, prepare_mesh


@lru_cache
def _cachedDelaunay(verts):
"""verts as tuple of tuples"""
verts = np.array(verts)
verts_triangulation = Delaunay(verts)
return verts_triangulation


def cachedDelaunay(verts):
"""
Calculates delaunay triangulation of given vertices, with caching of the result in RAM
verts - numpy array of verts (N, 3)
"""
return _cachedDelaunay(tuple(map(tuple, verts)))


class CSDForwardSolver:
def __init__(self, meshfile, conductivities, boundary_value=0, additional_refinement=False,
sampling_points=None,
Expand Down Expand Up @@ -53,11 +37,7 @@ def solve_coeff(self, coeff):
conductivities=self.conductivities)
self.solution = solution

verts = np.array(self.mesh.GetVertexArray())
sol = self.solution.GetDataArray()
verts_triangulation = cachedDelaunay(verts)

self.solution_interpolated = self.interpolator(verts_triangulation, sol)
self.pyvista_mesh_solution = convert_mfem_to_pyvista(self.mesh, [solution], ["pot"])
return solution

def solve(self, xyz, csd):
Expand All @@ -67,12 +47,21 @@ def solve(self, xyz, csd):
coeff = csd_distribution_coefficient(xyz, csd)
return self.solve_coeff(coeff)


def sample_solution_probe(self, x, y, z):
return self.solution_interpolated([x, y, z])[0]
points = np.array([[x, y, z]])
cloud = self.sample_solution(points)
return cloud.get_array("pot")

def sample_solution(self, positions):
"""positions - array (N, 3)"""
assert self.solution_interpolated is not None
pos_arr = np.array(positions)
return self.solution_interpolated(pos_arr)
"""positions - array (N, 3) or meshgrid stack [X, Y, Z, 3]"""

cloud = pyvista_sample_points(self.pyvista_mesh_solution, positions)
data = cloud.get_array("pot")
return data

def sample_grid(self, grid):
sampled = pyvista_sample_grid(self.pyvista_mesh_solution, grid)
data = sampled.get_array("pot")
sampled_grid = data.reshape(grid.shape[0:3], order="F")
return sampled_grid

7 changes: 4 additions & 3 deletions src/kesi/mfem_solver/mfem_piecewise_solver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import os
from functools import partial
from functools import partial, lru_cache
import mfem.ser as mfem
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -37,8 +37,9 @@ def refine_around_electrodes(mesh, electrode_positions):
return mesh


@lru_cache
def prepare_mesh(meshfile, refinement, electrode_positions=None):
"if electrode positions are given, perform additional refinement around electrodes positions, array (N, 3)"
"if electrode positions are given, perform additional refinement around electrodes positions, tuple of tuples of length 3"
# to create run
# gmsh -3 -format msh22 four_spheres_in_air_with_plane.geo
print("Loading mesh...")
Expand All @@ -53,7 +54,6 @@ def prepare_mesh(meshfile, refinement, electrode_positions=None):
if electrode_positions is not None:
mesh = refine_around_electrodes(mesh, electrode_positions)


return mesh


Expand Down Expand Up @@ -122,6 +122,7 @@ def mfem_solve_mesh(csd_coefficient, mesh, boundary_potential, conductivities):
b = mfem.LinearForm(fespace)

conductivities_vector = mfem.Vector(list(1.0 / (4 * np.pi * conductivities)))
# conductivities_vector = mfem.Vector(list(conductivities))

conductivities_coeff = mfem.PWConstCoefficient(conductivities_vector)

Expand Down