Skip to content

Commit

Permalink
getting rid of delaunay
Browse files Browse the repository at this point in the history
  • Loading branch information
Marian Dovgialo committed Sep 17, 2024
1 parent 7d5054e commit e55569e
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 33 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,16 @@ license = { file = "LICENSE" }
requires-python = ">=3.7"
dependencies = [
"numpy ~= 1.24",
"scipy ~= 1.10",
"mfem ~= 4.6.1.0",
"scipy ~= 1.11.3",
"mfem ~= 4.7.0.1",
"scikit-learn ~= 1.4",
"tqdm ~= 4.66",
"pyvista ~= 0.43",
"pandas ~= 2.0",
"nibabel ~= 5.1.0",
"numba ~= 0.59.1",
"scikit-image ~= 0.24.0",
"meshio ~= 2.3.5",
# "dolfin",
# "meshio",
# "snakemake",
Expand Down
66 changes: 66 additions & 0 deletions src/kesi/fem_utils/pyvista_resampling.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
import glob
import os
import tempfile
from io import StringIO

import nibabel
import numpy as np
Expand All @@ -11,6 +13,70 @@
from tqdm import tqdm

from kesi.fem_utils.grid_utils import load_or_create_grid
import mfem.ser as mfem

from kesi.mfem_solver.mfem_piecewise_solver import prepare_fespace



def convert_mfem_to_pyvista(mesh, solutions, names):
"""
mesh - mfem mesh object
solutions - list gridfunctions
names - list of gridfunction names to use
"""

assert len(solutions) == len(names)
with tempfile.NamedTemporaryFile(suffix='.vtk') as fp:
output = StringIO()
print("saving output mesh")
mesh.PrintVTK(output, 0)
with open(fp.name, 'a') as vtk_file:
vtk_file.write(output.getvalue())
del output

fespace = prepare_fespace(mesh)
x = mfem.GridFunction(fespace)
# setting initial values in all points, boundary elements will enforce this value
for name, sol in zip(names, solutions):
output = StringIO()

x.Assign(sol.GetDataArray())
x.SaveVTK(output, name, 0) # this crashes when using passed solutions, missing fespace???

with open(fp.name, 'a') as vtk_file:
vtk_file.write(output.getvalue())
del output

pyvista_mesh = pyvista.read(fp.name)
return pyvista_mesh


def pyvista_sample_points(pyvista_mesh, points):
point_cloud = pyvista.PolyData(points)
sampled = point_cloud.sample(pyvista_mesh, progress_bar=True, snap_to_closest_point=True)
return sampled


def pyvista_sample_grid(pyvista_mesh, grid):
"""
grid - meshgrid [x, y, z, 3]
"""
dimensions = grid.shape[0:4]
spacing = [
grid[1, 0, 0, 0] - grid[0, 0, 0, 0],
grid[0, 1, 0, 1] - grid[0, 0, 0, 1],
grid[0, 0, 1, 2] - grid[0, 0, 0, 2],

]
origin = [grid[0, 0, 0, 0], grid[0, 0, 0, 1], grid[0, 0, 0, 2]]

str_grid = pyvista.ImageData(dimensions=dimensions,
spacing=spacing,
origin=origin
)
sampled = str_grid.sample(pyvista_mesh, progress_bar=True, snap_to_closest_point=True)
return sampled


def main():
Expand Down
45 changes: 17 additions & 28 deletions src/kesi/mfem_solver/forward_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,10 @@
from scipy.interpolate import LinearNDInterpolator
from scipy.spatial import Delaunay

from kesi.fem_utils.pyvista_resampling import convert_mfem_to_pyvista, pyvista_sample_points, pyvista_sample_grid
from kesi.mfem_solver.mfem_piecewise_solver import mfem_solve_mesh, csd_distribution_coefficient, prepare_mesh


@lru_cache
def _cachedDelaunay(verts):
"""verts as tuple of tuples"""
verts = np.array(verts)
verts_triangulation = Delaunay(verts)
return verts_triangulation


def cachedDelaunay(verts):
"""
Calculates delaunay triangulation of given vertices, with caching of the result in RAM
verts - numpy array of verts (N, 3)
"""
return _cachedDelaunay(tuple(map(tuple, verts)))


class CSDForwardSolver:
def __init__(self, meshfile, conductivities, boundary_value=0, additional_refinement=False,
sampling_points=None,
Expand Down Expand Up @@ -53,11 +37,7 @@ def solve_coeff(self, coeff):
conductivities=self.conductivities)
self.solution = solution

verts = np.array(self.mesh.GetVertexArray())
sol = self.solution.GetDataArray()
verts_triangulation = cachedDelaunay(verts)

self.solution_interpolated = self.interpolator(verts_triangulation, sol)
self.pyvista_mesh_solution = convert_mfem_to_pyvista(self.mesh, [solution], ["pot"])
return solution

def solve(self, xyz, csd):
Expand All @@ -67,12 +47,21 @@ def solve(self, xyz, csd):
coeff = csd_distribution_coefficient(xyz, csd)
return self.solve_coeff(coeff)


def sample_solution_probe(self, x, y, z):
return self.solution_interpolated([x, y, z])[0]
points = np.array([[x, y, z]])
cloud = self.sample_solution(points)
return cloud.get_array("pot")

def sample_solution(self, positions):
"""positions - array (N, 3)"""
assert self.solution_interpolated is not None
pos_arr = np.array(positions)
return self.solution_interpolated(pos_arr)
"""positions - array (N, 3) or meshgrid stack [X, Y, Z, 3]"""

cloud = pyvista_sample_points(self.pyvista_mesh_solution, positions)
data = cloud.get_array("pot")
return data

def sample_grid(self, grid):
sampled = pyvista_sample_grid(self.pyvista_mesh_solution, grid)
data = sampled.get_array("pot")
sampled_grid = data.reshape(grid.shape[0:3], order="F")
return sampled_grid

7 changes: 4 additions & 3 deletions src/kesi/mfem_solver/mfem_piecewise_solver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import os
from functools import partial
from functools import partial, lru_cache
import mfem.ser as mfem
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -37,8 +37,9 @@ def refine_around_electrodes(mesh, electrode_positions):
return mesh


@lru_cache
def prepare_mesh(meshfile, refinement, electrode_positions=None):
"if electrode positions are given, perform additional refinement around electrodes positions, array (N, 3)"
"if electrode positions are given, perform additional refinement around electrodes positions, tuple of tuples of length 3"
# to create run
# gmsh -3 -format msh22 four_spheres_in_air_with_plane.geo
print("Loading mesh...")
Expand All @@ -53,7 +54,6 @@ def prepare_mesh(meshfile, refinement, electrode_positions=None):
if electrode_positions is not None:
mesh = refine_around_electrodes(mesh, electrode_positions)


return mesh


Expand Down Expand Up @@ -122,6 +122,7 @@ def mfem_solve_mesh(csd_coefficient, mesh, boundary_potential, conductivities):
b = mfem.LinearForm(fespace)

conductivities_vector = mfem.Vector(list(1.0 / (4 * np.pi * conductivities)))
# conductivities_vector = mfem.Vector(list(conductivities))

conductivities_coeff = mfem.PWConstCoefficient(conductivities_vector)

Expand Down

0 comments on commit e55569e

Please sign in to comment.