From a8fccc3c8699c195cede6e96f883b708a022c183 Mon Sep 17 00:00:00 2001 From: SaltyChiang Date: Wed, 11 Dec 2024 03:33:37 +0800 Subject: [PATCH] Rename `cb2` to `evenodd`. Bug fix for GPT interface. --- pyquda_core/pyquda/__init__.py | 51 ++++++++++++++---------- pyquda_core/pyquda/_version.py | 2 +- pyquda_core/pyquda/field.py | 35 +++++++++------- pyquda_utils/core.py | 5 ++- pyquda_utils/deprecated.py | 8 ++-- pyquda_utils/gpt.py | 15 ++++--- pyquda_utils/io/__init__.py | 32 +++++++-------- pyquda_utils/phase.py | 5 +-- pyquda_utils/quasi_axial_gauge_fixing.py | 5 +-- pyquda_utils/source.py | 2 +- tests/test.dslash.py | 2 +- tests/test.laplace.py | 2 +- 12 files changed, 91 insertions(+), 73 deletions(-) diff --git a/pyquda_core/pyquda/__init__.py b/pyquda_core/pyquda/__init__.py index 5bee466..be65fd7 100644 --- a/pyquda_core/pyquda/__init__.py +++ b/pyquda_core/pyquda/__init__.py @@ -6,7 +6,6 @@ from mpi4py import MPI from ._version import __version__ # noqa: F401 -from . import pyquda as quda from .field import LatticeInfo @@ -141,7 +140,7 @@ def _getDefaultGrid(mpi_size: int, latt_size: List[int]): return min(min_grid) -def _initEnviron(**kwargs): +def _setEnviron(**kwargs): def _setEnviron(env, key, value): if value is not None: if env in environ: @@ -154,7 +153,7 @@ def _setEnviron(env, key, value): _setEnviron(f"QUDA_{key.upper()}", key, kwargs[key]) -def _initEnvironWarn(**kwargs): +def _setEnvironWarn(**kwargs): def _setEnviron(env, key, value): if value is not None: if env in environ: @@ -172,9 +171,8 @@ def _setEnviron(env, key, value): def initGPU(backend: Literal["numpy", "cupy", "torch"] = None, gpuid: int = -1): global _CUDA_BACKEND, _HIP, _GPUID, _COMPUTE_CAPABILITY - if isGridInitialized(): - _MPI_LOGGER.critical("initGPU should be called before init", RuntimeError) + _MPI_LOGGER.critical("initGPU should be called before initGrid", RuntimeError) if _GPUID < 0: from platform import node as gethostname @@ -239,17 +237,31 @@ def initGPU(backend: Literal["numpy", "cupy", "torch"] = None, gpuid: int = -1): _MPI_LOGGER.warning("GPU is already initialized", RuntimeWarning) -def initQUDA(grid_size: List[int], gpuid: int): +def initGrid(grid_size: List[int]): + global _GRID_SIZE, _GRID_COORD + if _GRID_SIZE is None: + Gx, Gy, Gz, Gt = grid_size + if _MPI_SIZE != Gx * Gy * Gz * Gt: + _MPI_LOGGER.critical(f"The MPI size {_MPI_SIZE} does not match the grid size {grid_size}", ValueError) + _GRID_SIZE = [Gx, Gy, Gz, Gt] + _GRID_COORD = getCoordFromRank(_MPI_RANK, _GRID_SIZE) + _MPI_LOGGER.info(f"Using the grid size {_GRID_SIZE}") + else: + _MPI_LOGGER.warning("Grid is already initialized", RuntimeWarning) + + +def initQUDA(grid_size: List[int], gpuid: int, use_quda_allocator: bool = False): import atexit + from . import pyquda as quda, malloc_pyquda - # if _CUDA_BACKEND == "cupy": - # import cupy - # from . import malloc_pyquda + if use_quda_allocator: + if _CUDA_BACKEND == "cupy": + import cupy - # allocator = cupy.cuda.PythonFunctionAllocator( - # malloc_pyquda.pyquda_device_malloc, malloc_pyquda.pyquda_device_free - # ) - # cupy.cuda.set_allocator(allocator.malloc) + allocator = cupy.cuda.PythonFunctionAllocator( + malloc_pyquda.pyquda_device_malloc, malloc_pyquda.pyquda_device_free + ) + cupy.cuda.set_allocator(allocator.malloc) quda.initCommsGridQuda(4, grid_size) quda.initQuda(gpuid) @@ -293,7 +305,7 @@ def init( """ Initialize MPI along with the QUDA library. """ - global _GRID_SIZE, _GRID_COORD, _DEFAULT_LATTICE + global _DEFAULT_LATTICE if _GRID_SIZE is None: initGPU(backend) @@ -301,20 +313,15 @@ def init( use_default_latt = latt_size is not None and t_boundary is not None and anisotropy is not None if use_default_grid: grid_size = _getDefaultGrid(_MPI_SIZE, latt_size) - Gx, Gy, Gz, Gt = grid_size if grid_size is not None else [1, 1, 1, 1] - if _MPI_SIZE != Gx * Gy * Gz * Gt: - _MPI_LOGGER.critical(f"The MPI size {_MPI_SIZE} does not match the grid size {grid_size}", ValueError) - _GRID_SIZE = [Gx, Gy, Gz, Gt] - _GRID_COORD = getCoordFromRank(_MPI_RANK, _GRID_SIZE) - _MPI_LOGGER.info(f"Using the grid size {_GRID_SIZE}") + initGrid(grid_size if grid_size is not None else [1, 1, 1, 1]) if use_default_grid and not use_default_latt: _MPI_LOGGER.info(f"Using the lattice size {latt_size} only for getting the default grid size {_GRID_SIZE}") if use_default_latt: _DEFAULT_LATTICE = LatticeInfo(latt_size, t_boundary, anisotropy) _MPI_LOGGER.info(f"Using the default lattice LatticeInfo({latt_size}, {t_boundary}, {anisotropy})") - _initEnvironWarn(resource_path=resource_path if resource_path != "" else None) - _initEnviron( + _setEnvironWarn(resource_path=resource_path if resource_path != "" else None) + _setEnviron( rank_verbosity=",".join(rank_verbosity) if rank_verbosity != [0] else None, enable_mps="1" if enable_mps else None, enable_gdr="1" if enable_gdr else None, diff --git a/pyquda_core/pyquda/_version.py b/pyquda_core/pyquda/_version.py index c0984d5..43e38a8 100644 --- a/pyquda_core/pyquda/_version.py +++ b/pyquda_core/pyquda/_version.py @@ -1 +1 @@ -__version__ = "0.9.10" +__version__ = "0.9.11" diff --git a/pyquda_core/pyquda/field.py b/pyquda_core/pyquda/field.py index 44a3549..cf9ead4 100644 --- a/pyquda_core/pyquda/field.py +++ b/pyquda_core/pyquda/field.py @@ -130,40 +130,47 @@ def lexico(data: numpy.ndarray, axes: List[int], dtype=None): Npre = int(numpy.prod(shape[: axes[0]])) Nsuf = int(numpy.prod(shape[axes[-1] + 1 :])) dtype = data.dtype if dtype is None else dtype - data_cb2 = data.reshape(Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf) + data_evenodd = data.reshape(Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf) data_lexico = numpy.zeros((Npre, Lt, Lz, Ly, Lx, Nsuf), dtype) for t in range(Lt): for z in range(Lz): for y in range(Ly): eo = (t + z + y) % 2 if eo == 0: - data_lexico[:, t, z, y, 0::2] = data_cb2[:, 0, t, z, y, :] - data_lexico[:, t, z, y, 1::2] = data_cb2[:, 1, t, z, y, :] + data_lexico[:, t, z, y, 0::2] = data_evenodd[:, 0, t, z, y, :] + data_lexico[:, t, z, y, 1::2] = data_evenodd[:, 1, t, z, y, :] else: - data_lexico[:, t, z, y, 1::2] = data_cb2[:, 0, t, z, y, :] - data_lexico[:, t, z, y, 0::2] = data_cb2[:, 1, t, z, y, :] + data_lexico[:, t, z, y, 1::2] = data_evenodd[:, 0, t, z, y, :] + data_lexico[:, t, z, y, 0::2] = data_evenodd[:, 1, t, z, y, :] return data_lexico.reshape(*shape[: axes[0]], Lt, Lz, Ly, Lx, *shape[axes[-1] + 1 :]) -def cb2(data: numpy.ndarray, axes: List[int], dtype=None): +def evenodd(data: numpy.ndarray, axes: List[int], dtype=None): shape = data.shape Lt, Lz, Ly, Lx = [shape[axis] for axis in axes] Npre = int(numpy.prod(shape[: axes[0]])) Nsuf = int(numpy.prod(shape[axes[-1] + 1 :])) dtype = data.dtype if dtype is None else dtype data_lexico = data.reshape(Npre, Lt, Lz, Ly, Lx, Nsuf) - data_cb2 = numpy.zeros((Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf), dtype) + data_evenodd = numpy.zeros((Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf), dtype) for t in range(Lt): for z in range(Lz): for y in range(Ly): eo = (t + z + y) % 2 if eo == 0: - data_cb2[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 0::2] - data_cb2[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 1::2] + data_evenodd[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 0::2] + data_evenodd[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 1::2] else: - data_cb2[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 1::2] - data_cb2[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 0::2] - return data_cb2.reshape(*shape[: axes[0]], 2, Lt, Lz, Ly, Lx // 2, *shape[axes[-1] + 1 :]) + data_evenodd[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 1::2] + data_evenodd[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 0::2] + return data_evenodd.reshape(*shape[: axes[0]], 2, Lt, Lz, Ly, Lx // 2, *shape[axes[-1] + 1 :]) + + +def cb2(data: numpy.ndarray, axes: List[int], dtype=None): + from . import getLogger + + getLogger().warning("cb2 is deprecated, use evenodd instead", DeprecationWarning) + return evenodd(data, axes, dtype) def checksum(latt_info: Union[LatticeInfo, GeneralInfo], data: numpy.ndarray) -> Tuple[int, int]: @@ -675,9 +682,9 @@ def load( if Nc is not None: latt_info.Nc = Nc if not issubclass(cls, MultiField): - retval = cls(latt_info, cb2(value, [0, 1, 2, 3])) + retval = cls(latt_info, evenodd(value, [0, 1, 2, 3])) else: - retval = cls(latt_info, len(label), numpy.asarray([cb2(data, [0, 1, 2, 3]) for data in value])) + retval = cls(latt_info, len(label), numpy.asarray([evenodd(data, [0, 1, 2, 3]) for data in value])) secs = perf_counter() - s getLogger().debug(f"Loaded {filename} in {secs:.3f} secs, {gbytes / secs:.3f} GB/s") return retval diff --git a/pyquda_utils/core.py b/pyquda_utils/core.py index de2dd4a..0ba13d4 100644 --- a/pyquda_utils/core.py +++ b/pyquda_utils/core.py @@ -4,6 +4,7 @@ from pyquda import ( initGPU, + initGrid, initQUDA, init, getCoordFromRank, @@ -15,6 +16,7 @@ getGridCoord, setDefaultLattice, getDefaultLattice, + getCUDABackend, getLogger, setLoggerLevel, dirac as fermion, @@ -41,7 +43,8 @@ LatticePropagator, LatticeStaggeredPropagator, lexico, - cb2, + evenodd, + evenodd as cb2, ) from pyquda.dirac.abstract import Multigrid, FermionDirac, StaggeredFermionDirac diff --git a/pyquda_utils/deprecated.py b/pyquda_utils/deprecated.py index d345804..53c6539 100644 --- a/pyquda_utils/deprecated.py +++ b/pyquda_utils/deprecated.py @@ -1,6 +1,6 @@ from typing import List -from pyquda import getLogger, getGridSize, quda, enum_quda +from pyquda import getLogger, getGridSize, pyquda as quda, enum_quda from pyquda.field import LatticeFermion, LatticeGauge, LatticeInfo, LatticePropagator, Nc, Ns from pyquda.dirac.abstract import FermionDirac @@ -101,11 +101,11 @@ def getDslash( latt_info = LatticeInfo([Lx, Ly, Lz, Lt], t_boundary, xi) if clover_csw != 0.0: - from .dirac.clover_wilson import CloverWilsonDirac + from pyquda.dirac.clover_wilson import CloverWilsonDirac return CloverWilsonDirac(latt_info, mass, tol, maxiter, clover_csw, clover_xi, geo_block_size) else: - from .dirac.wilson import WilsonDirac + from pyquda.dirac.wilson import WilsonDirac return WilsonDirac(latt_info, mass, tol, maxiter, geo_block_size) @@ -131,6 +131,6 @@ def getStaggeredDslash( t_boundary = 1 latt_info = LatticeInfo([Lx, Ly, Lz, Lt], t_boundary, 1.0) - from .dirac.hisq import HISQDirac + from pyquda.dirac.hisq import HISQDirac return HISQDirac(latt_info, mass, tol, maxiter, naik_epsilon, None) diff --git a/pyquda_utils/gpt.py b/pyquda_utils/gpt.py index a5d4f3c..951e765 100644 --- a/pyquda_utils/gpt.py +++ b/pyquda_utils/gpt.py @@ -1,15 +1,17 @@ from typing import List import numpy -from pyquda import getSublatticeSize, getGridSize -from pyquda.field import cb2, LatticeGauge, LatticeInfo, LatticePropagator +from .core import evenodd, getGridSize, LatticeGauge, LatticeInfo, LatticePropagator import gpt as g def LatticeInfoGPT(grid: g.grid, gen_simd_width: int): assert getGridSize() == grid.mpi - sublatt_size = getSublatticeSize(grid.fdimensions, grid.mpi) + GLx, GLy, GLz, GLt = grid.fdimensions + Gx, Gy, Gz, Gt = grid.mpi + Lx, Ly, Lz, Lt = GLx // Gx, GLy // Gy, GLz // Gz, GLt // Gt + sublatt_size = [Lx, Ly, Lz, Lt] Nd = len(sublatt_size) precision = grid.precision.nbytes n_simd = gen_simd_width // (2 * precision) @@ -32,7 +34,7 @@ def LatticeGaugeGPT(lattice: List[g.lattice], gen_simd_width: int, gauge: Lattic value = [] for index in range(latt_info.Nd): value.append( - cb2( + evenodd( numpy.asarray(lattice[index].mview()[0]) .view(f"xab", v, cupy.exp(1j * (gi * Li) / GLi * w)) for i in range(1, Li): rotate[i] = contract("xba,xbc,xc->xac", gauge_prod[i - 1].conj(), v, cupy.exp(1j * (i + gi * Li) / GLi * w)) - rotate = LatticeLink(gauge.latt_info, cb2(rotate.reshape(*axes_shape).transpose(*axes).get(), [0, 1, 2, 3])) + rotate = LatticeLink(gauge.latt_info, evenodd(rotate.reshape(*axes_shape).transpose(*axes).get(), [0, 1, 2, 3])) rotate.toDevice() rotate_ = LatticeFermion(gauge.latt_info) rotate.pack(rotate_) diff --git a/pyquda_utils/source.py b/pyquda_utils/source.py index f966cd7..430ceb6 100644 --- a/pyquda_utils/source.py +++ b/pyquda_utils/source.py @@ -1,6 +1,6 @@ from typing import List, Literal, Union -from pyquda import quda, getGridSize, getLogger +from pyquda import pyquda as quda, getGridSize, getLogger from pyquda.enum_quda import QudaDslashType, QudaParity from pyquda.field import ( Ns, diff --git a/tests/test.dslash.py b/tests/test.dslash.py index d5b314e..f595fdb 100644 --- a/tests/test.dslash.py +++ b/tests/test.dslash.py @@ -3,7 +3,7 @@ from check_pyquda import weak_field -from pyquda import init, quda +from pyquda import init, pyquda as quda from pyquda.field import Ns, Nc from pyquda.enum_quda import QudaParity from pyquda_utils import core diff --git a/tests/test.laplace.py b/tests/test.laplace.py index c368e55..1ac13ea 100644 --- a/tests/test.laplace.py +++ b/tests/test.laplace.py @@ -7,7 +7,7 @@ from check_pyquda import weak_field -from pyquda import enum_quda, quda +from pyquda import enum_quda, pyquda as quda from pyquda.field import LatticeGauge, LatticeInfo, LatticeStaggeredFermion, MultiLatticeStaggeredFermion, Nc from pyquda.dirac import setGlobalPrecision from pyquda_utils import core, io