From 8c59736bc03999c7c5eb71dd6efb51cac6239660 Mon Sep 17 00:00:00 2001 From: Ken Lauer <152229072+ken-lauer@users.noreply.github.com> Date: Wed, 11 Sep 2024 14:55:31 -0700 Subject: [PATCH] WIP: swap axes for now --- docs/examples/wavefront.ipynb | 63 ++++++++++++++++++----------------- pmd_beamphysics/wavefront.py | 20 +++++++---- 2 files changed, 46 insertions(+), 37 deletions(-) diff --git a/docs/examples/wavefront.ipynb b/docs/examples/wavefront.ipynb index 4db23a9..1bf9896 100644 --- a/docs/examples/wavefront.ipynb +++ b/docs/examples/wavefront.ipynb @@ -41,28 +41,14 @@ "outputs": [], "source": [ "W = Wavefront.gaussian_pulse(\n", - " dims=(801, 101, 101),\n", + " dims=(101, 101, 801),\n", " wavelength=1.35e-8,\n", - " grid_spacing=(0.0625, 6e-6, 6e-6),\n", - " pad=(40, 100, 100),\n", + " grid_spacing=(6e-6, 6e-6, 0.0625),\n", + " pad=(100, 100, 40),\n", " nphotons=1e12,\n", " zR=2.0,\n", " sigma_t=5,\n", - ")\n", - "\n", - "\n", - "## TODO: above is actually zxy, so move Z to end (+ grid spacing etc)\n", - "## TODO: Genesis is actually t, x, y - so does xyz make sense?" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "40fe9987-33dd-4509-a570-66ef6e9182b4", - "metadata": {}, - "outputs": [], - "source": [ - "W.metadata.mesh.axis_labels" + ")" ] }, { @@ -105,6 +91,23 @@ "print(W.rmesh.shape)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "dce5b957-e23d-4f28-8fd8-cfa9a2101cf4", + "metadata": {}, + "outputs": [], + "source": [ + "def get_longitudinal_slice(arr):\n", + " nx, ny, _ = arr.shape\n", + " return arr[int(nx / 2), int(ny / 2), :]\n", + "\n", + "\n", + "def get_transverse_slice(arr):\n", + " _, _, nz = arr.shape\n", + " return arr[:, :, int(nz / 2)]" + ] + }, { "cell_type": "code", "execution_count": null, @@ -113,7 +116,7 @@ "outputs": [], "source": [ "plt.figure(figsize=(4, 2))\n", - "plt.plot(W.rspace_domain[0], np.abs(W.rmesh[:, int(nx / 2), int(ny / 2)]), color=\"b\")\n", + "plt.plot(W.rspace_domain[2], np.abs(get_longitudinal_slice(W.rmesh)), color=\"b\")\n", "plt.xlabel(\"Time [fs]\");" ] }, @@ -125,7 +128,7 @@ "outputs": [], "source": [ "plt.figure(figsize=(4, 4))\n", - "plt.imshow(np.abs(W.rmesh[int(nt / 2), :, :]) ** 2)\n", + "plt.imshow(np.abs(get_transverse_slice(W.rmesh) ** 2))\n", "plt.colorbar();" ] }, @@ -136,9 +139,8 @@ "metadata": {}, "outputs": [], "source": [ - "nw, nkx, nky = W.kmesh.shape\n", "plt.figure(figsize=(4, 2))\n", - "kspace_slice = np.abs(W.kmesh[:, int(nkx / 2), int(nky / 2)])\n", + "kspace_slice = np.abs(get_longitudinal_slice(W.kmesh))\n", "plt.plot(kspace_slice, \".-\", color=\"b\")\n", "peak_x = np.argmax(kspace_slice)\n", "plt.xlim(peak_x - 20, peak_x + 20);" @@ -162,7 +164,7 @@ "outputs": [], "source": [ "plt.figure(figsize=(4, 4))\n", - "kspace_image = np.abs(W.kmesh[int(nw / 2), :, :]) ** 2\n", + "kspace_image = np.abs(get_transverse_slice(W.kmesh)) ** 2\n", "\n", "center = np.argmax(kspace_image, axis=0)[0]\n", "plt.imshow(kspace_image)\n", @@ -188,8 +190,8 @@ "drifted_w = W.drift(3.0)\n", "fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 4))\n", "\n", - "rspace_abs_orig = np.abs(W.rmesh[int(nt / 2), :, :]) ** 2\n", - "rspace_abs_prop = np.abs(drifted_w.rmesh[int(nt / 2), :, :]) ** 2\n", + "rspace_abs_orig = np.abs(get_transverse_slice(W.rmesh)) ** 2\n", + "rspace_abs_prop = np.abs(get_transverse_slice(drifted_w.rmesh)) ** 2\n", "vmin = np.min((np.min(rspace_abs_orig), np.min(rspace_abs_prop)))\n", "vmax = np.min((np.max(rspace_abs_orig), np.max(rspace_abs_prop)))\n", "im1 = ax1.imshow(rspace_abs_orig, vmin=vmin, vmax=vmax)\n", @@ -236,10 +238,10 @@ "source": [ "zR = 2.0\n", "X = Wavefront.gaussian_pulse(\n", - " dims=(801, 101, 101),\n", + " dims=(101, 101, 801),\n", " wavelength=1.35e-8,\n", - " grid_spacing=(0.0625, 6e-6, 6e-6),\n", - " pad=(40, 100, 100),\n", + " grid_spacing=(6e-6, 6e-6, 0.0625),\n", + " pad=(100, 100, 40),\n", " nphotons=1e12,\n", " zR=zR,\n", " sigma_t=5,\n", @@ -249,7 +251,8 @@ "zgrid = 5\n", "dz = 0.25\n", "\n", - "wfz = np.zeros((mz * zgrid, X.pad.grid[1], X.pad.grid[2]))\n", + "nx, ny = get_transverse_slice(X.rmesh).shape\n", + "wfz = np.zeros((mz * zgrid, nx, ny))\n", "fwhmz_fit = np.zeros(mz * zgrid)\n", "\n", "w0 = np.sqrt(zR * X.wavelength / np.pi)\n", @@ -269,7 +272,7 @@ " if zi > 0:\n", " print(\"Propagating to: \", zi * dz)\n", " X.drift(dz, inplace=True)\n", - " wf = np.abs(X.rmesh[int(nt / 2), :, :]) ** 2\n", + " wf = np.abs(get_transverse_slice(X.rmesh)) ** 2\n", "\n", " popt_gaussian, ydata_fit, FWHM, roots = gaussian_fit(\n", " domain_x,\n", diff --git a/pmd_beamphysics/wavefront.py b/pmd_beamphysics/wavefront.py index b43750b..fea2179 100644 --- a/pmd_beamphysics/wavefront.py +++ b/pmd_beamphysics/wavefront.py @@ -369,7 +369,7 @@ def conversion_coeffs(wavelength: float, dim: int) -> Tuple[float, ...]: """ k0 = calculate_k0(wavelength) hbar = scipy.constants.hbar / scipy.constants.e * 1.0e15 # fs-eV - return tuple([2.0 * np.pi * hbar] + [2.0 * np.pi / k0] * (dim - 1)) + return tuple([2.0 * np.pi / k0] * (dim - 1) + [2.0 * np.pi * hbar]) def cartesian_domain( @@ -421,11 +421,12 @@ def drift_propagator_paraxial( wavelength: float, ): """Fresnel propagator in paraxial approximation to distance z [m].""" - return kmesh * drift_kernel_paraxial( + kernel = drift_kernel_paraxial( transverse_kspace_grid=transverse_kspace_grid, z=z, wavelength=wavelength, ) + return kmesh * kernel[:, :, np.newaxis] def thin_lens_kernel_xy( @@ -454,7 +455,7 @@ def thin_lens_kernel_xy( np.ndarray """ k0 = calculate_k0(wavelength) - xx, yy = nd_space_mesh(ranges[1:], grid[1:]) + xx, yy = nd_space_mesh(ranges[:2], grid[:2]) return np.exp(-1j * k0 / 2.0 * (xx**2 / f_lens_x + yy**2 / f_lens_y)) @@ -513,7 +514,9 @@ def create_gaussian_pulse_3d_with_q( eta = 2.0 * k0 * zR * sigma_t / np.sqrt(np.pi) pulse = np.sqrt(eta) * np.sqrt(nphotons) * ux * uy * ut - return pulse.astype(dtype) + + # TODO a bit of swapping without changing the gaussian pulse + return np.moveaxis(pulse.astype(dtype), 0, -1) def transverse_divergence_padding_factor( @@ -860,13 +863,16 @@ def gaussian_pulse( ------- Wavefront """ + # TODO a bit of swapping without changing the gaussian pulse + nx, ny, nz = dims + gx, gy, gz = grid_spacing pulse = create_gaussian_pulse_3d_with_q( wavelength=wavelength, nphotons=nphotons, zR=zR, sigma_t=sigma_t, - grid_spacing=grid_spacing, - grid=dims, + grid_spacing=(gz, gx, gy), + grid=(nz, nx, ny), dtype=dtype, ) return cls( @@ -874,7 +880,7 @@ def gaussian_pulse( wavelength=wavelength, grid_spacing=grid_spacing, pad=pad, - axis_labels="zxy", + axis_labels="xyz", longitudinal_axis="z", )