diff --git a/examples/ex18.py b/examples/ex18.py index 772f8e96..1c089ceb 100644 --- a/examples/ex18.py +++ b/examples/ex18.py @@ -3,8 +3,15 @@ This is a version of Example 18 with a simple adaptive mesh refinement loop. See c++ version in the MFEM library for more detail + + Sample runs: + + python ex18.py -p 1 -r 2 -o 1 -s 3 + python ex18.py -p 1 -r 1 -o 3 -s 4 + python ex18.py -p 1 -r 0 -o 5 -s 6 + python ex18.py -p 2 -r 1 -o 1 -s 3 -mf + python ex18.py -p 2 -r 0 -o 3 -s 3 -mf ''' -from ex18_common import FE_Evolution, InitialCondition, RiemannSolver, DomainIntegrator, FaceIntegrator from mfem.common.arg_parser import ArgParser import mfem.ser as mfem from mfem.ser import intArray @@ -13,32 +20,39 @@ from numpy import sqrt, pi, cos, sin, hypot, arctan2 from scipy.special import erfc -# Equation constant parameters.(using globals to share them with ex18_common) -import ex18_common +from ex18_common import (EulerMesh, + EulerInitialCondition, + DGHyperbolicConservationLaws) def run(problem=1, ref_levels=1, order=3, ode_solver_type=4, - t_final=0.5, + t_final=2.0, dt=-0.01, cfl=0.3, visualization=True, vis_steps=50, + preassembleWeakDiv=False, meshfile=''): - ex18_common.num_equation = 4 - ex18_common.specific_heat_ratio = 1.4 - ex18_common.gas_constant = 1.0 - ex18_common.problem = problem - num_equation = ex18_common.num_equation + specific_heat_ratio = 1.4 + gas_constant = 1.0 + IntOrderOffset = 1 # 2. Read the mesh from the given mesh file. This example requires a 2D # periodic mesh, such as ../data/periodic-square.mesh. - meshfile = expanduser(join(dirname(__file__), '..', 'data', meshfile)) - mesh = mfem.Mesh(meshfile, 1, 1) + + mesh = EulerMesh(meshfile, problem) dim = mesh.Dimension() + num_equation = dim + 2 + + # Refine the mesh to increase the resolution. In this example we do + # 'ref_levels' of uniform refinement, where 'ref_levels' is a + # command-line parameter. + for lev in range(ref_levels): + mesh.UniformRefinement() # 3. Define the ODE solver used for time integration. Several explicit # Runge-Kutta methods are available. @@ -48,7 +62,7 @@ def run(problem=1, elif ode_solver_type == 2: ode_solver = mfem.RK2Solver(1.0) elif ode_solver_type == 3: - ode_solver = mfem.RK3SSolver() + ode_solver = mfem.RK3SSPSolver() elif ode_solver_type == 4: ode_solver = mfem.RK4Solver() elif ode_solver_type == 6: @@ -57,13 +71,7 @@ def run(problem=1, print("Unknown ODE solver type: " + str(ode_solver_type)) exit - # 4. Refine the mesh to increase the resolution. In this example we do - # 'ref_levels' of uniform refinement, where 'ref_levels' is a - # command-line parameter. - for lev in range(ref_levels): - mesh.UniformRefinement() - - # 5. Define the discontinuous DG finite element space of the given + # 4. Define the discontinuous DG finite element space of the given # polynomial order on the refined mesh. fec = mfem.DG_FECollection(order, dim) @@ -78,70 +86,70 @@ def run(problem=1, assert fes.GetOrdering() == mfem.Ordering.byNODES, "Ordering must be byNODES" print("Number of unknowns: " + str(vfes.GetVSize())) - # 6. Define the initial conditions, save the corresponding mesh and grid - # functions to a file. This can be opened with GLVis with the -gc option. - # The solution u has components {density, x-momentum, y-momentum, energy}. - # These are stored contiguously in the BlockVector u_block. - - offsets = [k*vfes.GetNDofs() for k in range(num_equation+1)] - offsets = mfem.intArray(offsets) - u_block = mfem.BlockVector(offsets) - mom = mfem.GridFunction(dfes, u_block, offsets[1]) - - # - # Define coefficient using VecotrPyCoefficient and PyCoefficient - # A user needs to define EvalValue method - # - u0 = InitialCondition(num_equation) - sol = mfem.GridFunction(vfes, u_block.GetData()) + # 5. Define the initial conditions, save the corresponding mesh and grid + # functions to files. These can be opened with GLVis using: + # "glvis -m euler-mesh.mesh -g euler-1-init.gf" (for x-momentum). + u0 = EulerInitialCondition(problem, + specific_heat_ratio, + gas_constant) + sol = mfem.GridFunction(vfes) sol.ProjectCoefficient(u0) - mesh.Print("vortex.mesh", 8) + # (Python note): GridFunction pointing to the subset of vector FES. + # sol is Vector with dim*fes.GetNDofs() + # Since sol.GetDataArray() returns numpy array pointing to the data, we make + # Vector from a sub-vector of the returned numpy array and pass it to GridFunction + # constructor. + + mom = mfem.GridFunction(dfes, mfem.Vector( + sol.GetDataArray()[fes.GetNDofs():])) + mesh.Print("euler-mesh.mesh", 8) + for k in range(num_equation): - uk = mfem.GridFunction(fes, u_block.GetBlock(k).GetData()) - sol_name = "vortex-" + str(k) + "-init.gf" + uk = mfem.GridFunction(fes, mfem.Vector( + sol.GetDataArray()[k*fes.GetNDofs():])) + sol_name = "euler-" + str(k) + "-init.gf" uk.Save(sol_name, 8) - # 7. Set up the nonlinear form corresponding to the DG discretization of the - # flux divergence, and assemble the corresponding mass matrix. - Aflux = mfem.MixedBilinearForm(dfes, fes) - Aflux.AddDomainIntegrator(DomainIntegrator(dim)) - Aflux.Assemble() + # 6. Set up the nonlinear form with euler flux and numerical flux + flux = mfem.EulerFlux(dim, specific_heat_ratio) + numericalFlux = mfem.RusanovFlux(flux) + formIntegrator = mfem.HyperbolicFormIntegrator( + numericalFlux, IntOrderOffset) - A = mfem.NonlinearForm(vfes) - rsolver = RiemannSolver() - ii = FaceIntegrator(rsolver, dim) - A.AddInteriorFaceIntegrator(ii) - - # 8. Define the time-dependent evolution operator describing the ODE - # right-hand side, and perform time-integration (looping over the time - # iterations, ti, with a time-step dt). - euler = FE_Evolution(vfes, A, Aflux.SpMat()) + euler = DGHyperbolicConservationLaws(vfes, formIntegrator, + preassembleWeakDivergence=preassembleWeakDiv) + # 7. Visualize momentum with its magnitude if (visualization): sout = mfem.socketstream("localhost", 19916) sout.precision(8) sout << "solution\n" << mesh << mom + sout << "window_title 'momentum, t = 0'\n" + sout << "view 0 0\n" # view from top + sout << "keys jlm\n" # turn off perspective and light, show mesh sout << "pause\n" sout.flush() print("GLVis visualization paused.") print(" Press space (in the GLVis window) to resume it.") - # Determine the minimum element size. - hmin = 0 + # 8. Time integration + hmin = np.inf if (cfl > 0): hmin = min([mesh.GetElementSize(i, 1) for i in range(mesh.GetNE())]) + # Find a safe dt, using a temporary vector. Calling Mult() computes the + # maximum char speed at all quadrature points on all faces (and all + # elements with -mf). + z = mfem.Vector(sol.Size()) + euler.Mult(sol, z) + + max_char_speed = euler.GetMaxCharSpeed() + dt = cfl * hmin / max_char_speed / (2 * order + 1) + t = 0.0 euler.SetTime(t) ode_solver.Init(euler) - if (cfl > 0): - # Find a safe dt, using a temporary vector. Calling Mult() computes the - # maximum char speed at all quadrature points on all faces. - z = mfem.Vector(A.Width()) - A.Mult(sol, z) - - dt = cfl * hmin / ex18_common.max_char_speed / (2*order+1) # Integrate in time. done = False @@ -152,23 +160,29 @@ def run(problem=1, t, dt_real = ode_solver.Step(sol, t, dt_real) if (cfl > 0): - dt = cfl * hmin / ex18_common.max_char_speed / (2*order+1) + max_char_speed = euler.GetMaxCharSpeed() + dt = cfl * hmin / max_char_speed / (2*order+1) ti = ti+1 done = (t >= t_final - 1e-8*dt) if (done or ti % vis_steps == 0): print("time step: " + str(ti) + ", time: " + "{:g}".format(t)) if (visualization): - sout << "solution\n" << mesh << mom << flush - - # 9. Save the final solution. This output can be viewed later using GLVis: - # "glvis -m vortex.mesh -g vortex-1-final.gf". + sout << "window_title 'momentum, t = " << "{:g}".format( + t) << "'\n" + sout << "solution\n" << mesh << mom + sout.flush() + + # 8. Save the final solution. This output can be viewed later using GLVis: + # "glvis -m euler.mesh -g euler-1-final.gf". + mesh.Print("euler-mesh-final.mesh", 8) for k in range(num_equation): - uk = mfem.GridFunction(fes, u_block.GetBlock(k).GetData()) - sol_name = "vortex-" + str(k) + "-final.gf" + uk = mfem.GridFunction(fes, mfem.Vector( + sol.GetDataArray()[k*fes.GetNDofs():])) + sol_name = "euler-" + str(k) + "-final.gf" uk.Save(sol_name, 8) print(" done") - # 10. Compute the L2 solution error summed for all components. + # 9. Compute the L2 solution error summed for all components. if (t_final == 2.0): error = sol.ComputeLpError(2., u0) print("Solution error: " + "{:g}".format(error)) @@ -178,7 +192,7 @@ def run(problem=1, parser = ArgParser(description='Ex18') parser.add_argument('-m', '--mesh', - default='periodic-square.mesh', + default='', action='store', type=str, help='Mesh file to use.') parser.add_argument('-p', '--problem', @@ -203,15 +217,39 @@ def run(problem=1, parser.add_argument('-c', '--cfl_number', action='store', default=0.3, type=float, help="CFL number for timestep calculation.") - parser.add_argument('-vis', '--visualization', - action='store_true', - help='Enable GLVis visualization') + parser.add_argument('-novis', '--no_visualization', + action='store_true', default=False, + help='Disable GLVis visualization') + parser.add_argument("-ea", "--element-assembly-divergence", + action='store_true', default=False, + help="Weak divergence assembly level\n" + + " ea - Element assembly with interpolated") + parser.add_argument("-mf", "--matrix-free-divergence", + action='store_true', default=False, + help="Weak divergence assembly level\n" + + " mf - Nonlinear assembly in matrix-free manner") parser.add_argument('-vs', '--visualization-steps', action='store', default=50, type=float, help="Visualize every n-th timestep.") args = parser.parse_args() + visualization = not args.no_visualization + + if (not args.matrix_free_divergence and + not args.element_assembly_divergence): + args.element_assembly_divergence = True + args.matrix_free_divergence = False + preassembleWeakDiv = True + + elif args.element_assembly_divergence: + args.matrix_free_divergence = False + preassembleWeakDiv = True + + elif args.matrix_free_divergence: + args.element_assembly_divergence = False + preassembleWeakDiv = False + parser.print_options(args) run(problem=args.problem, @@ -221,6 +259,7 @@ def run(problem=1, t_final=args.t_final, dt=args.time_step, cfl=args.cfl_number, - visualization=args.visualization, + visualization=visualization, vis_steps=args.visualization_steps, + preassembleWeakDiv=preassembleWeakDiv, meshfile=args.mesh) diff --git a/examples/ex18_common.py b/examples/ex18_common.py index 9664f718..5c6fe1ed 100644 --- a/examples/ex18_common.py +++ b/examples/ex18_common.py @@ -3,12 +3,6 @@ This is a python translation of ex18.hpp - note: following variabls are set from ex18 or ex18p - problem - num_equation - max_char_speed - specific_heat_ratio; - gas_constant; ''' import numpy as np @@ -19,332 +13,200 @@ else: import mfem.par as mfem -num_equation = 0 -specific_heat_ratio = 0 -gas_constant = 0 -problem = 0 -max_char_speed = 0 - - -class FE_Evolution(mfem.TimeDependentOperator): - def __init__(self, vfes, A, A_flux): - self.dim = vfes.GetFE(0).GetDim() - self.vfes = vfes - self.A = A - self.Aflux = A_flux - self.Me_inv = mfem.DenseTensor(vfes.GetFE(0).GetDof(), - vfes.GetFE(0).GetDof(), - vfes.GetNE()) - - self.state = mfem.Vector(num_equation) - self.f = mfem.DenseMatrix(num_equation, self.dim) - self.flux = mfem.DenseTensor(vfes.GetNDofs(), self.dim, num_equation) - self.z = mfem.Vector(A.Height()) - - dof = vfes.GetFE(0).GetDof() - Me = mfem.DenseMatrix(dof) - inv = mfem.DenseMatrixInverse(Me) - mi = mfem.MassIntegrator() - for i in range(vfes.GetNE()): - mi.AssembleElementMatrix(vfes.GetFE( - i), vfes.GetElementTransformation(i), Me) - inv.Factor() - inv.GetInverseMatrix(self.Me_inv(i)) - super(FE_Evolution, self).__init__(A.Height()) - - def GetFlux(self, x, flux): - state = self.state - dof = self.flux.SizeI() - dim = self.flux.SizeJ() - - flux_data = [] - for i in range(dof): - for k in range(num_equation): - self.state[k] = x[i, k] - ComputeFlux(state, dim, self.f) - - flux_data.append(self.f.GetDataArray().transpose().copy()) - # flux[i].Print() - # print(self.f.GetDataArray()) - # for d in range(dim): - # for k in range(num_equation): - # flux[i, d, k] = self.f[k, d] - - mcs = ComputeMaxCharSpeed(state, dim) - if (mcs > globals()['max_char_speed']): - globals()['max_char_speed'] = mcs - - flux.Assign(np.stack(flux_data)) - #print("max char speed", globals()['max_char_speed']) +from os.path import expanduser, join, dirname - def Mult(self, x, y): - globals()['max_char_speed'] = 0. - num_equation = globals()['num_equation'] - # 1. Create the vector z with the face terms -. - self.A.Mult(x, self.z) - - # 2. Add the element terms. - # i. computing the flux approximately as a grid function by interpolating - # at the solution nodes. - # ii. multiplying this grid function by a (constant) mixed bilinear form for - # each of the num_equation, computing (F(u), grad(w)) for each equation. - - xmat = mfem.DenseMatrix( - x.GetData(), self.vfes.GetNDofs(), num_equation) - self.GetFlux(xmat, self.flux) - - for k in range(num_equation): - fk = mfem.Vector(self.flux[k].GetData(), - self.dim * self.vfes.GetNDofs()) - o = k * self.vfes.GetNDofs() - zk = self.z[o: o+self.vfes.GetNDofs()] - self.Aflux.AddMult(fk, zk) - - # 3. Multiply element-wise by the inverse mass matrices. - zval = mfem.Vector() - vdofs = mfem.intArray() - dof = self.vfes.GetFE(0).GetDof() - zmat = mfem.DenseMatrix() - ymat = mfem.DenseMatrix(dof, num_equation) - for i in range(self.vfes.GetNE()): - # Return the vdofs ordered byNODES - vdofs = mfem.intArray(self.vfes.GetElementVDofs(i)) - self.z.GetSubVector(vdofs, zval) - zmat.UseExternalData(zval.GetData(), dof, num_equation) - mfem.Mult(self.Me_inv[i], zmat, ymat) - y.SetSubVector(vdofs, ymat.GetData()) - - -class DomainIntegrator(mfem.PyBilinearFormIntegrator): - def __init__(self, dim): - num_equation = globals()['num_equation'] - self.flux = mfem.DenseMatrix(num_equation, dim) - self.shape = mfem.Vector() - self.dshapedr = mfem.DenseMatrix() - self.dshapedx = mfem.DenseMatrix() - super(DomainIntegrator, self).__init__() - - def AssembleElementMatrix2(self, trial_fe, test_fe, Tr, elmat): - # Assemble the form (vec(v), grad(w)) - - # Trial space = vector L2 space (mesh dim) - # Test space = scalar L2 space - - dof_trial = trial_fe.GetDof() - dof_test = test_fe.GetDof() - dim = trial_fe.GetDim() - - self.shape.SetSize(dof_trial) - self.dshapedr.SetSize(dof_test, dim) - self.dshapedx.SetSize(dof_test, dim) - - elmat.SetSize(dof_test, dof_trial * dim) - elmat.Assign(0.0) - - maxorder = max(trial_fe.GetOrder(), test_fe.GetOrder()) - intorder = 2 * maxorder - ir = mfem.IntRules.Get(trial_fe.GetGeomType(), intorder) - - for i in range(ir.GetNPoints()): - ip = ir.IntPoint(i) - - # Calculate the shape functions - trial_fe.CalcShape(ip, self.shape) - self.shape *= ip.weight - - # Compute the physical gradients of the test functions - Tr.SetIntPoint(ip) - test_fe.CalcDShape(ip, self.dshapedr) - mfem.Mult(self.dshapedr, Tr.AdjugateJacobian(), self.dshapedx) - - for d in range(dim): - for j in range(dof_test): - for k in range(dof_trial): - elmat[j, k + d * dof_trial] += self.shape[k] * \ - self.dshapedx[j, d] - - -class FaceIntegrator(mfem.PyNonlinearFormIntegrator): - def __init__(self, rsolver, dim): - self.rsolver = rsolver - self.shape1 = mfem.Vector() - self.shape2 = mfem.Vector() - self.funval1 = mfem.Vector(num_equation) - self.funval2 = mfem.Vector(num_equation) - self.nor = mfem.Vector(dim) - self.fluxN = mfem.Vector(num_equation) - self.eip1 = mfem.IntegrationPoint() - self.eip2 = mfem.IntegrationPoint() - super(FaceIntegrator, self).__init__() - - self.fluxNA = np.atleast_2d(self.fluxN.GetDataArray()) - - def AssembleFaceVector(self, el1, el2, Tr, elfun, elvect): - num_equation = globals()['num_equation'] - # Compute the term on the interior faces. - dof1 = el1.GetDof() - dof2 = el2.GetDof() - - self.shape1.SetSize(dof1) - self.shape2.SetSize(dof2) - - elvect.SetSize((dof1 + dof2) * num_equation) - elvect.Assign(0.0) - - elfun1_mat = mfem.DenseMatrix(elfun.GetData(), dof1, num_equation) - elfun2_mat = mfem.DenseMatrix( - elfun[dof1*num_equation:].GetData(), dof2, num_equation) - - elvect1_mat = mfem.DenseMatrix(elvect.GetData(), dof1, num_equation) - elvect2_mat = mfem.DenseMatrix( - elvect[dof1*num_equation:].GetData(), dof2, num_equation) - - # Integration order calculation from DGTraceIntegrator - if (Tr.Elem2No >= 0): - intorder = (min(Tr.Elem1.OrderW(), Tr.Elem2.OrderW()) + - 2*max(el1.GetOrder(), el2.GetOrder())) +class DGHyperbolicConservationLaws(mfem.TimeDependentOperator): + def __init__(self, vfes_, formIntegrator_, preassembleWeakDivergence=True): + + super(DGHyperbolicConservationLaws, self).__init__( + vfes_.GetTrueVSize()) + self.num_equations = formIntegrator_.num_equations + self.vfes = vfes_ + self.dim = vfes_.GetMesh().SpaceDimension() + self.formIntegrator = formIntegrator_ + + self.z = mfem.Vector(vfes_.GetTrueVSize()) + + self.weakdiv = None + self.max_char_speed = None + + self.ComputeInvMass() + + if mfem_mode == 'serial': + self.nonlinearForm = mfem.NonlinearForm(self.vfes) + else: + if isinstance(self.vfes, mfem.ParFiniteElementSpace): + self.nonlinearForm = mfem.ParNonlinearForm(self.vfes) + else: + self.nonlinearForm = mfem.NonlinearForm(self.vfes) + + if preassembleWeakDivergence: + self.ComputeWeakDivergence() else: - intorder = Tr.Elem1.OrderW() + 2*el1.GetOrder() + self.nonlinearForm.AddDomainIntegrator(self.formIntegrator) + + self.nonlinearForm.AddInteriorFaceIntegrator(self.formIntegrator) + self.nonlinearForm.UseExternalIntegrators() + + def GetMaxCharSpeed(self): + return self.max_char_speed - if (el1.Space() == mfem.FunctionSpace().Pk): - intorder += 1 + def ComputeInvMass(self): + inv_mass = mfem.InverseIntegrator(mfem.MassIntegrator()) - ir = mfem.IntRules.Get(Tr.GetGeometryType(), int(intorder)) + self.invmass = [None]*self.vfes.GetNE() + for i in range(self.vfes.GetNE()): + dof = self.vfes.GetFE(i).GetDof() + self.invmass[i] = mfem.DenseMatrix(dof) + inv_mass.AssembleElementMatrix(self.vfes.GetFE(i), + self.vfes.GetElementTransformation( + i), + self.invmass[i]) + + def ComputeWeakDivergence(self): + weak_div = mfem.TransposeIntegrator(mfem.GradientIntegrator()) + + weakdiv_bynodes = mfem.DenseMatrix() - mat1A = elvect1_mat.GetDataArray() - mat2A = elvect2_mat.GetDataArray() - shape1A = np.atleast_2d(self.shape1.GetDataArray()) - shape2A = np.atleast_2d(self.shape2.GetDataArray()) + self.weakdiv = [None]*self.vfes.GetNE() - for i in range(ir.GetNPoints()): - ip = ir.IntPoint(i) - Tr.Loc1.Transform(ip, self.eip1) - Tr.Loc2.Transform(ip, self.eip2) + for i in range(self.vfes.GetNE()): + dof = self.vfes.GetFE(i).GetDof() + weakdiv_bynodes.SetSize(dof, dof*self.dim) + weak_div.AssembleElementMatrix2(self.vfes.GetFE(i), + self.vfes.GetFE(i), + self.vfes.GetElementTransformation( + i), + weakdiv_bynodes) + self.weakdiv[i] = mfem.DenseMatrix() + self.weakdiv[i].SetSize(dof, dof*self.dim) + + # Reorder so that trial space is ByDim. + # This makes applying weak divergence to flux value simpler. + for j in range(dof): + for d in range(self.dim): + self.weakdiv[i].SetCol( + j*self.dim + d, weakdiv_bynodes.GetColumn(d*dof + j)) - # Calculate basis functions on both elements at the face - el1.CalcShape(self.eip1, self.shape1) - el2.CalcShape(self.eip2, self.shape2) + def Mult(self, x, y): + # 0. Reset wavespeed computation before operator application. + self.formIntegrator.ResetMaxCharSpeed() + + # 1. Apply Nonlinear form to obtain an auxiliary result + # z = - _e + # If weak-divergence is not preassembled, we also have weak-divergence + # z = - _e + (F(u_h), ∇v) + self.nonlinearForm.Mult(x, self.z) + #print("!!!!", self.weakdiv) + if self.weakdiv is not None: # if weak divergence is pre-assembled + # Apply weak divergence to F(u_h), and inverse mass to z_loc + weakdiv_loc + + current_state = mfem.Vector() # view of current state at a node + current_flux = mfem.DenseMatrix() # flux of current state + + # element flux value. Whose column is ordered by dim. + flux = mfem.DenseMatrix() + # view of current states in an element, dof x num_eq + current_xmat = mfem.DenseMatrix() + # view of element auxiliary result, dof x num_eq + current_zmat = mfem.DenseMatrix() + current_ymat = mfem.DenseMatrix() # view of element result, dof x num_eq + + fluxFunction = self.formIntegrator.GetFluxFunction() + + xval = mfem.Vector() + zval = mfem.Vector() + flux_vec = mfem.Vector() + + for i in range(self.vfes.GetNE()): + Tr = self.vfes.GetElementTransformation(i) + dof = self.vfes.GetFE(i).GetDof() + vdofs = mfem.intArray(self.vfes.GetElementVDofs(i)) + + x.GetSubVector(vdofs, xval) + current_xmat.UseExternalData( + xval.GetData(), dof, self.num_equations) + + # + # Python Note: + # C++ code access to array data with offset is done bu GetData() + offset + # In Python, the same can be done by using numpy array generated from Vector::GetDataArray(), + # + # array = vec.GetDataArray() + # new_data_pointer = mfem.Vector(array[10:]).GetData() + # + # note that the above does not work if mfem.Vector is replaced by mfem.DenseMatrix + # This is because, while MFEM stores data in colume-major, Python numpy store raw-major. + # + + flux.SetSize(self.num_equations, self.dim*dof) + flux_vec = mfem.Vector( + flux.GetData(), self.num_equations*self.dim*dof) + data = flux_vec.GetDataArray() + + for j in range(dof): # compute flux for all nodes in the element + current_xmat.GetRow(j, current_state) + + data_ptr = mfem.Vector( + data[self.num_equations*self.dim*j:]).GetData() + current_flux = mfem.DenseMatrix(data_ptr, + self.num_equations, dof) + fluxFunction.ComputeFlux(current_state, Tr, current_flux) + + # Compute weak-divergence and add it to auxiliary result, z + # Recalling that weakdiv is reordered by dim, we can apply + # weak-divergence to the transpose of flux. + self.z.GetSubVector(vdofs, zval) + current_zmat.UseExternalData( + zval.GetData(), dof, self.num_equations) + mfem.AddMult_a_ABt(1.0, self.weakdiv[i], flux, current_zmat) + + # Apply inverse mass to auxiliary result to obtain the final result + current_ymat.SetSize(dof, self.num_equations) + mfem.Mult(self.invmass[i], current_zmat, current_ymat) + y.SetSubVector(vdofs, current_ymat.GetData()) - # Interpolate elfun at the point - elfun1_mat.MultTranspose(self.shape1, self.funval1) - elfun2_mat.MultTranspose(self.shape2, self.funval2) - Tr.Face.SetIntPoint(ip) - - # Get the normal vector and the flux on the face - - mfem.CalcOrtho(Tr.Face.Jacobian(), self.nor) - - mcs = self.rsolver.Eval( - self.funval1, self.funval2, self.nor, self.fluxN) - - # Update max char speed - if mcs > globals()['max_char_speed']: - globals()['max_char_speed'] = mcs - - self.fluxN *= ip.weight - - # - mat1A -= shape1A.transpose().dot(self.fluxNA) - mat2A += shape2A.transpose().dot(self.fluxNA) - ''' - for k in range(num_equation): - for s in range(dof1): - elvect1_mat[s, k] -= self.fluxN[k] * self.shape1[s] - for s in range(dof2): - elvect2_mat[s, k] += self.fluxN[k] * self.shape2[s] - ''' - - -class RiemannSolver(object): - def __init__(self): - num_equation = globals()['num_equation'] - self.flux1 = mfem.Vector(num_equation) - self.flux2 = mfem.Vector(num_equation) - - def Eval(self, state1, state2, nor, flux): - - # NOTE: nor in general is not a unit normal - dim = nor.Size() - - assert StateIsPhysical(state1, dim), "" - assert StateIsPhysical(state2, dim), "" - - maxE1 = ComputeMaxCharSpeed(state1, dim) - maxE2 = ComputeMaxCharSpeed(state2, dim) - maxE = max(maxE1, maxE2) - - ComputeFluxDotN(state1, nor, self.flux1) - ComputeFluxDotN(state2, nor, self.flux2) - - #normag = np.sqrt(np.sum(nor.GetDataArray()**2)) - normag = nor.Norml2() - - ''' - for i in range(num_equation): - flux[i] = (0.5 * (self.flux1[i] + self.flux2[i]) - - 0.5 * maxE * (state2[i] - state1[i]) * normag) - ''' - f = (0.5 * (self.flux1.GetDataArray() + self.flux2.GetDataArray()) - - 0.5 * maxE * (state2.GetDataArray() - state1.GetDataArray()) * normag) - flux.Assign(f) - - return maxE - - -def StateIsPhysical(state, dim): - specific_heat_ratio = globals()["specific_heat_ratio"] - - den = state[0] - #den_vel = state.GetDataArray()[1:1+dim] - den_energy = state[1 + dim] - - if (den < 0): - print("Negative density: " + str(state.GetDataArray())) - return False - if (den_energy <= 0): - print("Negative energy: " + str(state.GetDataArray())) - return False - - #den_vel2 = np.sum(den_vel**2)/den - den_vel2 = (state[1:1+dim].Norml2())**2/den - pres = (specific_heat_ratio - 1.0) * (den_energy - 0.5 * den_vel2) - if (pres <= 0): - print("Negative pressure: " + str(state.GetDataArray())) - return False - return True - - -class InitialCondition(mfem.VectorPyCoefficient): - def __init__(self, dim): - mfem.VectorPyCoefficient.__init__(self, dim) - - def EvalValue(self, x): - dim = x.shape[0] - assert dim == 2, "" - problem = globals()['problem'] - if (problem == 1): - # "Fast vortex" - radius = 0.2 - Minf = 0.5 - beta = 1. / 5. - elif (problem == 2): - # "Slow vortex" - radius = 0.2 - Minf = 0.05 - beta = 1. / 50. else: - assert False, "Cannot recognize problem. Options are: 1 - fast vortex, 2 - slow vortex" + # Apply block inverse mass + zval = mfem.Vector() # / z_loc, dof*num_eq + + # view of element auxiliary result, dof x num_eq + current_zmat = mfem.DenseMatrix() + current_ymat = mfem.DenseMatrix() # view of element result, dof x num_eq + + for i in range(self.vfes.GetNE()): + dof = self.vfes.GetFE(i).GetDof() + vdofs = mfem.intArray(self.vfes.GetElementVDofs(i)) + self.z.GetSubVector(vdofs, zval) + current_zmat.UseExternalData( + zval.GetData(), dof, self.num_equations) + current_ymat.SetSize(dof, self.num_equations) + mfem.Mult(self.invmass[i], current_zmat, current_ymat) + y.SetSubVector(vdofs, current_ymat.GetData()) + + self.max_char_speed = self.formIntegrator.GetMaxCharSpeed() + + def Update(self): + self.nonlinearForm.Update() + height = self.nonlinearForm.Height() + width = height + self.z.SetSize(height) + + ComputeInvMass() + if self.weakdiv is None: + self.ComputeWeakDivergence() + + +def GetMovingVortexInit(radius, Minf, beta, gas_constant, specific_heat_ratio): + def func(x, y): xc = 0.0 yc = 0.0 # Nice units vel_inf = 1. den_inf = 1. - specific_heat_ratio = globals()["specific_heat_ratio"] - gas_constant = globals()["gas_constant"] - pres_inf = (den_inf / specific_heat_ratio) * \ (vel_inf / Minf) * (vel_inf / Minf) temp_inf = pres_inf / (den_inf * gas_constant) @@ -370,74 +232,80 @@ def EvalValue(self, x): pres = den * gas_constant * temp energy = shrinv1 * pres / den + 0.5 * vel2 - y = np.array([den, den * velX, den * velY, den * energy]) - return y - - -def ComputePressure(state, dim): - den = state[0] - #den_vel = state.GetDataArray()[1:1+dim] - den_energy = state[1 + dim] - - specific_heat_ratio = globals()["specific_heat_ratio"] - #den_vel2 = np.sum(den_vel**2)/den - den_vel2 = (state[1:1+dim].Norml2())**2/den - pres = (specific_heat_ratio - 1.0) * (den_energy - 0.5 * den_vel2) - - return pres - - -def ComputeFlux(state, dim, flux): - den = state[0] - den_vel = state.GetDataArray()[1:1+dim] - den_energy = state[1 + dim] - - assert StateIsPhysical(state, dim), "" - - pres = ComputePressure(state, dim) - - den_vel2 = np.atleast_2d(den_vel) - fluxA = flux.GetDataArray() - fluxA[0, :] = den_vel - fluxA[1:1+dim, :] = den_vel2.transpose().dot(den_vel2) / den - for d in range(dim): - fluxA[1+d, d] += pres + y[0] = den + y[1] = den * velX + y[2] = den * velY + y[3] = den * energy - H = (den_energy + pres) / den - flux.GetDataArray()[1+dim, :] = den_vel * H + return func -def ComputeFluxDotN(state, nor, fluxN): - # NOTE: nor in general is not a unit normal - dim = nor.Size() - nor = nor.GetDataArray() - fluxN = fluxN.GetDataArray() +def EulerMesh(meshfile, problem): + if meshfile == '': + if problem in (1, 2, 3): + meshfile = "periodic-square.mesh" - den = state[0] - den_vel = state.GetDataArray()[1:1+dim] - den_energy = state[1 + dim] + elif problem == 4: + meshfile = "periodic-segment.mesh" - assert StateIsPhysical(state, dim), "" - - pres = ComputePressure(state, dim) - - den_velN = den_vel.dot(nor) - - fluxN[0] = den_velN - fluxN[1:1+dim] = den_velN * den_vel / den + pres * nor - - H = (den_energy + pres) / den - fluxN[1+dim] = den_velN * H - - -def ComputeMaxCharSpeed(state, dim): - specific_heat_ratio = globals()["specific_heat_ratio"] - - den = state[0] - den_vel2 = (state[1:1+dim].Norml2())**2/den - pres = ComputePressure(state, dim) - - sound = np.sqrt(specific_heat_ratio * pres / den) - vel = np.sqrt(den_vel2 / den) - - return vel + sound + else: + assert False, "Default mesh file not given for problem = " + \ + str(problem) + + meshfile = expanduser(join(dirname(__file__), '..', 'data', meshfile)) + + return mfem.Mesh(meshfile, 1, 1) + +# Initial condition + + +def EulerInitialCondition(problem, specific_heat_ratio, gas_constant): + + if problem == 1: + # fast moving vortex + func = GetMovingVortexInit(0.2, 0.5, 1. / 5., gas_constant, + specific_heat_ratio) + return mfem.jit.vector(vdim=4, interface="c++")(func) + + elif problem == 2: + # slow moving vortex + func = GetMovingVortexInit(0.2, 0.05, 1. / 50., gas_constant, + specific_heat_ratio) + return mfem.jit.vector(vdim=4, interface="c++")(func) + + elif problem == 3: + # moving sine wave + @ mfem.jit.vector(vdim=4, interface="c++") + def func(x, y): + assert len(x) > 2, "2D is not supportd for this probl" + density = 1.0 + 0.2 * np.sin(np.pi*(x[0]+x[1])) + velocity_x = 0.7 + velocity_y = 0.3 + pressure = 1.0 + energy = (pressure / (1.4 - 1.0) + + density * 0.5 * (velocity_x * velocity_x + velocity_y * velocity_y)) + + y[0] = density + y[1] = density * velocity_x + y[2] = density * velocity_y + y[3] = energy + + return func + + elif problem == 4: + @ mfem.jit.vector(vdim=3, interface="c++") + def func(x, y): + density = 1.0 + 0.2 * np.sin(np.pi * 2 * x[0]) + velocity_x = 1.0 + pressure = 1.0 + energy = pressure / (1.4 - 1.0) + density * \ + 0.5 * (velocity_x * velocity_x) + + y[0] = density + y[1] = density * velocity_x + y[2] = energy + + return func + + else: + assert False, "Problem Undefined" diff --git a/examples/ex18p.py b/examples/ex18p.py index ac57fe23..3253af54 100644 --- a/examples/ex18p.py +++ b/examples/ex18p.py @@ -3,10 +3,17 @@ This is a version of Example 18 with a simple adaptive mesh refinement loop. See c++ version in the MFEM library for more detail + Sample runs: + + mpirun -np 4 python ex18p.py -p 1 -rs 2 -rp 1 -o 1 -s 3 + mpirun -np 4 python ex18p.py -p 1 -rs 1 -rp 1 -o 3 -s 4 + mpirun -np 4 python ex18p.py -p 1 -rs 1 -rp 1 -o 5 -s 6 + mpirun -np 4 python ex18p.py -p 2 -rs 1 -rp 1 -o 1 -s 3 -mf + mpirun -np 4 python ex18p.py -p 2 -rs 1 -rp 1 -o 3 -s 3 -mf + ''' import mfem.par as mfem -from ex18_common import FE_Evolution, InitialCondition, RiemannSolver, DomainIntegrator, FaceIntegrator from mfem.common.arg_parser import ArgParser from os.path import expanduser, join, dirname @@ -14,8 +21,9 @@ from numpy import sqrt, pi, cos, sin, hypot, arctan2 from scipy.special import erfc -# Equation constant parameters.(using globals to share them with ex18_common) -import ex18_common +from ex18_common import (EulerMesh, + EulerInitialCondition, + DGHyperbolicConservationLaws) # 1. Initialize MPI.from mpi4py import MPI @@ -23,231 +31,287 @@ num_procs = MPI.COMM_WORLD.size myid = MPI.COMM_WORLD.rank -parser = ArgParser(description='Ex18p') -parser.add_argument('-m', '--mesh', - default='periodic-square.mesh', - action='store', type=str, - help='Mesh file to use.') -parser.add_argument('-p', '--problem', - action='store', default=1, type=int, - help='Problem setup to use. See options in velocity_function().') -parser.add_argument('-rs', '--refine_serial', - action='store', default=0, type=int, - help="Number of times to refine the mesh uniformly before parallel.") -parser.add_argument('-rp', '--refine_parallel', - action='store', default=1, type=int, - help="Number of times to refine the mesh uniformly after parallel.") -parser.add_argument('-o', '--order', - action='store', default=3, type=int, - help="Finite element order (polynomial degree)") -parser.add_argument('-s', '--ode_solver', - action='store', default=4, type=int, - help="ODE solver: 1 - Forward Euler,\n\t" + - " 2 - RK2 SSP, 3 - RK3 SSP, 4 - RK4, 6 - RK6.") -parser.add_argument('-tf', '--t_final', - action='store', default=2.0, type=float, - help="Final time; start time is 0.") -parser.add_argument("-dt", "--time_step", - action='store', default=-0.01, type=float, - help="Time step.") -parser.add_argument('-c', '--cfl_number', - action='store', default=0.3, type=float, - help="CFL number for timestep calculation.") -parser.add_argument('-vis', '--visualization', - action='store_true', - help='Enable GLVis visualization') -parser.add_argument('-vs', '--visualization-steps', - action='store', default=50, type=float, - help="Visualize every n-th timestep.") - -args = parser.parse_args() -mesh = args.mesh -ser_ref_levels = args.refine_serial -par_ref_levels = args.refine_parallel -order = args.order -ode_solver_type = args.ode_solver -t_final = args.t_final -dt = args.time_step -cfl = args.cfl_number -visualization = args.visualization -vis_steps = args.visualization_steps - -if myid == 0: - parser.print_options(args) - -device = mfem.Device('cpu') -if myid == 0: - device.Print() - -ex18_common.num_equation = 4 -ex18_common.specific_heat_ratio = 1.4 -ex18_common.gas_constant = 1.0 -ex18_common.problem = args.problem -num_equation = ex18_common.num_equation - - -# 3. Read the mesh from the given mesh file. This example requires a 2D -# periodic mesh, such as ../data/periodic-square.mesh. -meshfile = expanduser(join(dirname(__file__), '..', 'data', mesh)) -mesh = mfem.Mesh(meshfile, 1, 1) -dim = mesh.Dimension() - -# 4. Define the ODE solver used for time integration. Several explicit -# Runge-Kutta methods are available. -ode_solver = None -if ode_solver_type == 1: - ode_solver = mfem.ForwardEulerSolver() -elif ode_solver_type == 2: - ode_solver = mfem.RK2Solver(1.0) -elif ode_solver_type == 3: - ode_solver = mfem.RK3SSolver() -elif ode_solver_type == 4: - ode_solver = mfem.RK4Solver() -elif ode_solver_type == 6: - ode_solver = mfem.RK6Solver() -else: - print("Unknown ODE solver type: " + str(ode_solver_type)) - exit - -# 5. Refine the mesh in serial to increase the resolution. In this example -# we do 'ser_ref_levels' of uniform refinement, where 'ser_ref_levels' is -# a command-line parameter. -for lev in range(ser_ref_levels): - mesh.UniformRefinement() - -# 6. Define a parallel mesh by a partitioning of the serial mesh. Refine -# this mesh further in parallel to increase the resolution. Once the -# parallel mesh is defined, the serial mesh can be deleted. - -pmesh = mfem.ParMesh(MPI.COMM_WORLD, mesh) -del mesh -for lev in range(par_ref_levels): - pmesh.UniformRefinement() - -# 7. Define the discontinuous DG finite element space of the given -# polynomial order on the refined mesh. -fec = mfem.DG_FECollection(order, dim) -# Finite element space for a scalar (thermodynamic quantity) -fes = mfem.ParFiniteElementSpace(pmesh, fec) -# Finite element space for a mesh-dim vector quantity (momentum) -dfes = mfem.ParFiniteElementSpace(pmesh, fec, dim, mfem.Ordering.byNODES) -# Finite element space for all variables together (total thermodynamic state) -vfes = mfem.ParFiniteElementSpace( - pmesh, fec, num_equation, mfem.Ordering.byNODES) - -assert fes.GetOrdering() == mfem.Ordering.byNODES, "Ordering must be byNODES" -glob_size = vfes.GlobalTrueVSize() -if myid == 0: - print("Number of unknowns: " + str(glob_size)) - -# 8. Define the initial conditions, save the corresponding mesh and grid -# functions to a file. This can be opened with GLVis with the -gc option. -# The solution u has components {density, x-momentum, y-momentum, energy}. -# These are stored contiguously in the BlockVector u_block. - -offsets = [k*vfes.GetNDofs() for k in range(num_equation+1)] -offsets = mfem.intArray(offsets) -u_block = mfem.BlockVector(offsets) - -# Momentum grid function on dfes for visualization. -mom = mfem.ParGridFunction(dfes, u_block, offsets[1]) - -# Initialize the state. -u0 = InitialCondition(num_equation) -sol = mfem.ParGridFunction(vfes, u_block.GetData()) -sol.ProjectCoefficient(u0) - -smyid = '{:0>6d}'.format(myid) -pmesh.Print("vortex-mesh."+smyid, 8) -for k in range(num_equation): - uk = mfem.ParGridFunction(fes, u_block.GetBlock(k).GetData()) - sol_name = "vortex-" + str(k) + "-init."+smyid - uk.Save(sol_name, 8) - -# 9. Set up the nonlinear form corresponding to the DG discretization of the -# flux divergence, and assemble the corresponding mass matrix. -Aflux = mfem.MixedBilinearForm(dfes, fes) -Aflux.AddDomainIntegrator(DomainIntegrator(dim)) -Aflux.Assemble() - -A = mfem.ParNonlinearForm(vfes) -rsolver = RiemannSolver() -ii = FaceIntegrator(rsolver, dim) -A.AddInteriorFaceIntegrator(ii) - -# 10. Define the time-dependent evolution operator describing the ODE -# right-hand side, and perform time-integration (looping over the time -# iterations, ti, with a time-step dt). -euler = FE_Evolution(vfes, A, Aflux.SpMat()) - -if (visualization): - MPI.COMM_WORLD.Barrier() - sout = mfem.socketstream("localhost", 19916) - sout.send_text("parallel " + str(num_procs) + " " + str(myid)) - sout.precision(8) - sout.send_solution(pmesh, mom) - sout.send_text("pause") - sout.flush() + +def run(problem=1, + ser_ref_levels=0, + par_ref_levels=1, + order=3, + ode_solver_type=4, + t_final=2.0, + dt=-0.01, + cfl=0.3, + visualization=True, + vis_steps=50, + preassembleWeakDiv=False, + meshfile=''): + + specific_heat_ratio = 1.4 + gas_constant = 1.0 + IntOrderOffset = 1 + + device = mfem.Device('cpu') if myid == 0: - print("GLVis visualization paused.") - print(" Press space (in the GLVis window) to resume it.") - -# Determine the minimum element size. -my_hmin = 0 -if (cfl > 0): - my_hmin = min([pmesh.GetElementSize(i, 1) for i in range(pmesh.GetNE())]) -hmin = MPI.COMM_WORLD.allreduce(my_hmin, op=MPI.MIN) - -t = 0.0 -euler.SetTime(t) -ode_solver.Init(euler) -if (cfl > 0): - # Find a safe dt, using a temporary vector. Calling Mult() computes the - # maximum char speed at all quadrature points on all faces. - z = mfem.Vector(A.Width()) - A.Mult(sol, z) - max_char_speed = MPI.COMM_WORLD.allreduce( - ex18_common.max_char_speed, op=MPI.MAX) - ex18_common.max_char_speed = max_char_speed - dt = cfl * hmin / ex18_common.max_char_speed / (2*order+1) - -# Integrate in time. -done = False -ti = 0 -while not done: - dt_real = min(dt, t_final - t) - t, dt_real = ode_solver.Step(sol, t, dt_real) + device.Print() + + # 2. Read the mesh from the given mesh file. When the user does not provide + # mesh file, use the default mesh file for the problem. + + mesh = EulerMesh(meshfile, problem) + dim = mesh.Dimension() + num_equation = dim + 2 + + # Refine the mesh to increase the resolution. In this example we do + # 'ser_ref_levels' of uniform refinement, where 'ser_ref_levels' is a + # command-line parameter. + for lev in range(ser_ref_levels): + mesh.UniformRefinement() + + # Define a parallel mesh by a partitioning of the serial mesh. Refine this + # mesh further in parallel to increase the resolution. Once the parallel + # mesh is defined, the serial mesh can be deleted. + pmesh = mfem.ParMesh(MPI.COMM_WORLD, mesh) + del mesh + + # Refine the mesh to increase the resolution. In this example we do + # 'par_ref_levels' of uniform refinement, where 'par_ref_levels' is a + # command-line parameter. + for lev in range(par_ref_levels): + pmesh.UniformRefinement() + + # 3. Define the ODE solver used for time integration. Several explicit + # Runge-Kutta methods are available. + ode_solver = None + if ode_solver_type == 1: + ode_solver = mfem.ForwardEulerSolver() + elif ode_solver_type == 2: + ode_solver = mfem.RK2Solver(1.0) + elif ode_solver_type == 3: + ode_solver = mfem.RK3SSPSolver() + elif ode_solver_type == 4: + ode_solver = mfem.RK4Solver() + elif ode_solver_type == 6: + ode_solver = mfem.RK6Solver() + else: + print("Unknown ODE solver type: " + str(ode_solver_type)) + exit + + # 4. Define the discontinuous DG finite element space of the given + # polynomial order on the refined mesh. + fec = mfem.DG_FECollection(order, dim) + # Finite element space for a scalar (thermodynamic quantity) + fes = mfem.ParFiniteElementSpace(pmesh, fec) + # Finite element space for a mesh-dim vector quantity (momentum) + dfes = mfem.ParFiniteElementSpace(pmesh, fec, dim, mfem.Ordering.byNODES) + # Finite element space for all variables together (total thermodynamic state) + vfes = mfem.ParFiniteElementSpace( + pmesh, fec, num_equation, mfem.Ordering.byNODES) + + assert fes.GetOrdering() == mfem.Ordering.byNODES, "Ordering must be byNODES" + glob_size = vfes.GlobalTrueVSize() + if myid == 0: + print("Number of unknowns: " + str(glob_size)) + + # 5. Define the initial conditions, save the corresponding mesh and grid + # functions to files. These can be opened with GLVis using: + # "glvis -np 4 -m euler-mesh -g euler-1-init" (for x-momentum). + + # Initialize the state. + u0 = EulerInitialCondition(problem, + specific_heat_ratio, + gas_constant) + sol = mfem.ParGridFunction(vfes) + sol.ProjectCoefficient(u0) + + # (Python note): GridFunction pointing to the subset of vector FES. + # sol is Vector with dim*fes.GetNDofs() + # Since sol.GetDataArray() returns numpy array pointing to the data, we make + # Vector from a sub-vector of the returned numpy array and pass it to GridFunction + # constructor. + mom = mfem.GridFunction(dfes, mfem.Vector( + sol.GetDataArray()[fes.GetNDofs():])) + + # Output the initial solution. + smyid = '{:0>6d}'.format(myid) + pmesh.Print("euler-mesh."+smyid, 8) + for k in range(num_equation): + uk = mfem.ParGridFunction(fes, mfem.Vector( + sol.GetDataArray()[k*fes.GetNDofs():])) + sol_name = "euler-" + str(k) + "-init."+smyid + uk.Save(sol_name, 8) + + # 6. Set up the nonlinear form with euler flux and numerical flux + flux = mfem.EulerFlux(dim, specific_heat_ratio) + numericalFlux = mfem.RusanovFlux(flux) + formIntegrator = mfem.HyperbolicFormIntegrator( + numericalFlux, IntOrderOffset) + + euler = DGHyperbolicConservationLaws(vfes, formIntegrator, + preassembleWeakDivergence=preassembleWeakDiv) + + # 7. Visualize momentum with its magnitude + if (visualization): + MPI.COMM_WORLD.Barrier() + sout = mfem.socketstream("localhost", 19916) + sout.precision(8) + sout << "parallel " << str(num_procs) << " " << str(myid) << "\n" + sout << "solution\n" << pmesh << mom + sout << "window_title 'momentum, t = 0'\n" + sout << "view 0 0\n" # view from top + sout << "keys jlm\n" # turn off perspective and light, show mesh + sout << "pause\n" + sout.flush() + + if myid == 0: + print("GLVis visualization paused.") + print(" Press space (in the GLVis window) to resume it.") + # 8. Time integration + my_hmin = np.inf if (cfl > 0): + my_hmin = min([pmesh.GetElementSize(i, 1) + for i in range(pmesh.GetNE())]) + + hmin = MPI.COMM_WORLD.allreduce(my_hmin, op=MPI.MIN) + + # Find a safe dt, using a temporary vector. Calling Mult() computes the + # maximum char speed at all quadrature points on all faces (and all + # elements with -mf). + z = mfem.Vector(sol.Size()) + euler.Mult(sol, z) + + my_max_char_speed = euler.GetMaxCharSpeed() max_char_speed = MPI.COMM_WORLD.allreduce( - ex18_common.max_char_speed, op=MPI.MAX) - ex18_common.max_char_speed = max_char_speed - dt = cfl * hmin / ex18_common.max_char_speed / (2*order+1) + my_max_char_speed, op=MPI.MAX) + + dt = cfl * hmin / max_char_speed / (2 * order + 1) + + t = 0.0 + euler.SetTime(t) + ode_solver.Init(euler) + + # Integrate in time. + done = False + ti = 0 + while not done: + dt_real = min(dt, t_final - t) + t, dt_real = ode_solver.Step(sol, t, dt_real) + + if (cfl > 0): + my_max_char_speed = euler.GetMaxCharSpeed() + max_char_speed = MPI.COMM_WORLD.allreduce( + my_max_char_speed, op=MPI.MAX) + + dt = cfl * hmin / max_char_speed / (2*order+1) + + ti = ti+1 + done = (t >= t_final - 1e-8*dt) + if (done or ti % vis_steps == 0): + if myid == 0: + print("time step: " + str(ti) + ", time: " + "{:g}".format(t)) + if (visualization): + sout << "window_title 'momentum, t = " << "{:g}".format( + t) << "'\n" + sout << "parallel " << str( + num_procs) << " " << str(myid) << "\n" + sout << "solution\n" << pmesh << mom + sout.flush() - ti = ti+1 - done = (t >= t_final - 1e-8*dt) - if (done or ti % vis_steps == 0): + if myid == 0: + print("done") + + # 9. Save the final solution. This output can be viewed later using GLVis: + # "glvis -np 4 -m euler-mesh-final -g euler-1-final" (for x-momentum). + pmesh.Print("euler-mesh-final."+smyid, 8) + for k in range(num_equation): + uk = mfem.ParGridFunction(fes, mfem.Vector( + sol.GetDataArray()[k*fes.GetNDofs():])) + sol_name = "euler-" + str(k) + "-final."+smyid + uk.Save(sol_name, 8) + + # 10. Compute the L2 solution error summed for all components. + if True: + error = sol.ComputeLpError(2., u0) if myid == 0: - print("time step: " + str(ti) + ", time: " + "{:g}".format(t)) - if (visualization): - sout.send_text("parallel " + str(num_procs) + " " + str(myid)) - sout.send_solution(pmesh, mom) - sout.flush() - -if myid == 0: - print("done") - -# 11. Save the final solution. This output can be viewed later using GLVis: -# "glvis -np 4 -m vortex-mesh -g vortex-1-final". - -for k in range(num_equation): - uk = mfem.ParGridFunction(fes, u_block.GetBlock(k).GetData()) - sol_name = "vortex-" + str(k) + "-final."+smyid - uk.Save(sol_name, 8) - -# 12. Compute the L2 solution error summed for all components. -# if (t_final == 2.0): -if True: - error = sol.ComputeLpError(2., u0) + print("Solution error: " + "{:g}".format(error)) + + +if __name__ == "__main__": + + parser = ArgParser(description='Ex18p') + parser.add_argument('-m', '--mesh', + default='', + action='store', type=str, + help='Mesh file to use.') + parser.add_argument('-p', '--problem', + action='store', default=1, type=int, + help='Problem setup to use. See options in velocity_function().') + parser.add_argument('-rs', '--refine_serial', + action='store', default=0, type=int, + help="Number of times to refine the mesh uniformly before parallel.") + parser.add_argument('-rp', '--refine_parallel', + action='store', default=1, type=int, + help="Number of times to refine the mesh uniformly after parallel.") + parser.add_argument('-o', '--order', + action='store', default=3, type=int, + help="Finite element order (polynomial degree)") + parser.add_argument('-s', '--ode_solver', + action='store', default=4, type=int, + help="ODE solver: 1 - Forward Euler,\n\t" + + " 2 - RK2 SSP, 3 - RK3 SSP, 4 - RK4, 6 - RK6.") + parser.add_argument('-tf', '--t_final', + action='store', default=2.0, type=float, + help="Final time; start time is 0.") + parser.add_argument("-dt", "--time_step", + action='store', default=-0.01, type=float, + help="Time step.") + parser.add_argument('-c', '--cfl_number', + action='store', default=0.3, type=float, + help="CFL number for timestep calculation.") + parser.add_argument('-novis', '--no_visualization', + action='store_true', default=False, + help='Disable GLVis visualization') + parser.add_argument("-ea", "--element-assembly-divergence", + action='store_true', default=False, + help="Weak divergence assembly level\n" + + " ea - Element assembly with interpolated") + parser.add_argument("-mf", "--matrix-free-divergence", + action='store_true', default=False, + help="Weak divergence assembly level\n" + + " mf - Nonlinear assembly in matrix-free manner") + parser.add_argument('-vs', '--visualization-steps', + action='store', default=50, type=float, + help="Visualize every n-th timestep.") + + args = parser.parse_args() + + visualization = not args.no_visualization + + if (not args.matrix_free_divergence and + not args.element_assembly_divergence): + args.element_assembly_divergence = True + args.matrix_free_divergence = False + preassembleWeakDiv = True + + elif args.element_assembly_divergence: + args.matrix_free_divergence = False + preassembleWeakDiv = True + + elif args.matrix_free_divergence: + args.element_assembly_divergence = False + preassembleWeakDiv = False + if myid == 0: - print("Solution error: " + "{:g}".format(error)) + parser.print_options(args) + + run(problem=args.problem, + ser_ref_levels=args.refine_serial, + par_ref_levels=args.refine_parallel, + order=args.order, + ode_solver_type=args.ode_solver, + t_final=args.t_final, + dt=args.time_step, + cfl=args.cfl_number, + visualization=visualization, + vis_steps=args.visualization_steps, + preassembleWeakDiv=preassembleWeakDiv, + meshfile=args.mesh) diff --git a/mfem/_par/hyperbolic.i b/mfem/_par/hyperbolic.i new file mode 100644 index 00000000..a5627157 --- /dev/null +++ b/mfem/_par/hyperbolic.i @@ -0,0 +1,36 @@ +%module(package="mfem._par") hyperbolic +%feature("autodoc", "1"); + +%{ +#include "mfem.hpp" +#include "numpy/arrayobject.h" +#include "../common/io_stream.hpp" +#include "../common/pyoperator.hpp" +#include "../common/pycoefficient.hpp" +#include "../common/pyintrules.hpp" +#include "../common/pynonlininteg.hpp" +%} + +%include "../common/existing_mfem_headers.i" +#ifdef FILE_EXISTS_FEM_HYPERBOLIC + +%init %{ +import_array(); +%} + +%include "exception.i" +%include "std_string.i" +%include "../common/exception.i" + +%import "array.i" +%import "vector.i" +%import "densemat.i" +%import "eltrans.i" +%import "nonlininteg.i" + +%include "fem/hyperbolic.hpp" + +#endif + + + diff --git a/mfem/_par/setup.py b/mfem/_par/setup.py index f79bb32b..8240df93 100644 --- a/mfem/_par/setup.py +++ b/mfem/_par/setup.py @@ -122,7 +122,9 @@ def get_extensions(): "quadinterpolator", "quadinterpolator_face", "submesh", "transfermap", "staticcond","sidredatacollection", "psubmesh", "ptransfermap", "enzyme", - "attribute_sets", "arrays_by_name"] + "attribute_sets", "arrays_by_name", + "hyperbolic"] + if add_pumi == '1': from setup_local import puminc, pumilib modules.append("pumi") diff --git a/mfem/_ser/hyperbolic.i b/mfem/_ser/hyperbolic.i new file mode 100644 index 00000000..9047fab0 --- /dev/null +++ b/mfem/_ser/hyperbolic.i @@ -0,0 +1,36 @@ +%module(package="mfem._ser") hyperbolic +%feature("autodoc", "1"); + +%{ +#include "mfem.hpp" +#include "numpy/arrayobject.h" +#include "../common/io_stream.hpp" +#include "../common/pyoperator.hpp" +#include "../common/pycoefficient.hpp" +#include "../common/pyintrules.hpp" +#include "../common/pynonlininteg.hpp" +%} + +%include "../common/existing_mfem_headers.i" +#ifdef FILE_EXISTS_FEM_HYPERBOLIC + +%init %{ +import_array(); +%} + +%include "exception.i" +%include "std_string.i" +%include "../common/exception.i" + +%import "array.i" +%import "vector.i" +%import "densemat.i" +%import "eltrans.i" +%import "nonlininteg.i" + +%include "fem/hyperbolic.hpp" + +#endif + + + diff --git a/mfem/_ser/setup.py b/mfem/_ser/setup.py index 3bf2954c..3fe1afbd 100644 --- a/mfem/_ser/setup.py +++ b/mfem/_ser/setup.py @@ -104,7 +104,8 @@ def get_extensions(): "quadinterpolator", "quadinterpolator_face", "submesh", "transfermap", "staticcond", "sidredatacollection", "enzyme", - "attribute_sets", "arrays_by_name"] + "attribute_sets", "arrays_by_name", + "hyperbolic"] if add_cuda == '1': from setup_local import cudainc diff --git a/mfem/common/bilininteg_ext.i b/mfem/common/bilininteg_ext.i index b22c7f79..99821135 100644 --- a/mfem/common/bilininteg_ext.i +++ b/mfem/common/bilininteg_ext.i @@ -255,6 +255,10 @@ namespace mfem { self._coeff = args %} +%pythonappend ElasticityComponentIntegrator::ElasticityComponentIntegrator %{ + self._coeff = parent_ +%} + %pythonappend DGTraceIntegrator::DGTraceIntegrator %{ self._coeff = args %} diff --git a/mfem/par.py b/mfem/par.py index 47b667c7..f929a496 100644 --- a/mfem/par.py +++ b/mfem/par.py @@ -93,6 +93,7 @@ from mfem._par.psubmesh import * from mfem._par.transfermap import * from mfem._par.ptransfermap import * +from mfem._par.hyperbolic import * try: from mfem._par.gslib import * diff --git a/mfem/ser.py b/mfem/ser.py index 3ff0d3ca..fcf04fcc 100644 --- a/mfem/ser.py +++ b/mfem/ser.py @@ -72,9 +72,9 @@ from mfem._ser.fe_nurbs import * from mfem._ser.doftrans import * from mfem._ser.std_vectors import * - from mfem._ser.submesh import * from mfem._ser.transfermap import * +from mfem._ser.hyperbolic import * try: from mfem._ser.gslib import *