Skip to content

Commit

Permalink
WIP: swap axes for now
Browse files Browse the repository at this point in the history
  • Loading branch information
ken-lauer committed Sep 11, 2024
1 parent 9a0ef40 commit 8c59736
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 37 deletions.
63 changes: 33 additions & 30 deletions docs/examples/wavefront.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,14 @@
"outputs": [],
"source": [
"W = Wavefront.gaussian_pulse(\n",
" dims=(801, 101, 101),\n",
" dims=(101, 101, 801),\n",
" wavelength=1.35e-8,\n",
" grid_spacing=(0.0625, 6e-6, 6e-6),\n",
" pad=(40, 100, 100),\n",
" grid_spacing=(6e-6, 6e-6, 0.0625),\n",
" pad=(100, 100, 40),\n",
" nphotons=1e12,\n",
" zR=2.0,\n",
" sigma_t=5,\n",
")\n",
"\n",
"\n",
"## TODO: above is actually zxy, so move Z to end (+ grid spacing etc)\n",
"## TODO: Genesis is actually t, x, y - so does xyz make sense?"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "40fe9987-33dd-4509-a570-66ef6e9182b4",
"metadata": {},
"outputs": [],
"source": [
"W.metadata.mesh.axis_labels"
")"
]
},
{
Expand Down Expand Up @@ -105,6 +91,23 @@
"print(W.rmesh.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dce5b957-e23d-4f28-8fd8-cfa9a2101cf4",
"metadata": {},
"outputs": [],
"source": [
"def get_longitudinal_slice(arr):\n",
" nx, ny, _ = arr.shape\n",
" return arr[int(nx / 2), int(ny / 2), :]\n",
"\n",
"\n",
"def get_transverse_slice(arr):\n",
" _, _, nz = arr.shape\n",
" return arr[:, :, int(nz / 2)]"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -113,7 +116,7 @@
"outputs": [],
"source": [
"plt.figure(figsize=(4, 2))\n",
"plt.plot(W.rspace_domain[0], np.abs(W.rmesh[:, int(nx / 2), int(ny / 2)]), color=\"b\")\n",
"plt.plot(W.rspace_domain[2], np.abs(get_longitudinal_slice(W.rmesh)), color=\"b\")\n",
"plt.xlabel(\"Time [fs]\");"
]
},
Expand All @@ -125,7 +128,7 @@
"outputs": [],
"source": [
"plt.figure(figsize=(4, 4))\n",
"plt.imshow(np.abs(W.rmesh[int(nt / 2), :, :]) ** 2)\n",
"plt.imshow(np.abs(get_transverse_slice(W.rmesh) ** 2))\n",
"plt.colorbar();"
]
},
Expand All @@ -136,9 +139,8 @@
"metadata": {},
"outputs": [],
"source": [
"nw, nkx, nky = W.kmesh.shape\n",
"plt.figure(figsize=(4, 2))\n",
"kspace_slice = np.abs(W.kmesh[:, int(nkx / 2), int(nky / 2)])\n",
"kspace_slice = np.abs(get_longitudinal_slice(W.kmesh))\n",
"plt.plot(kspace_slice, \".-\", color=\"b\")\n",
"peak_x = np.argmax(kspace_slice)\n",
"plt.xlim(peak_x - 20, peak_x + 20);"
Expand All @@ -162,7 +164,7 @@
"outputs": [],
"source": [
"plt.figure(figsize=(4, 4))\n",
"kspace_image = np.abs(W.kmesh[int(nw / 2), :, :]) ** 2\n",
"kspace_image = np.abs(get_transverse_slice(W.kmesh)) ** 2\n",
"\n",
"center = np.argmax(kspace_image, axis=0)[0]\n",
"plt.imshow(kspace_image)\n",
Expand All @@ -188,8 +190,8 @@
"drifted_w = W.drift(3.0)\n",
"fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 4))\n",
"\n",
"rspace_abs_orig = np.abs(W.rmesh[int(nt / 2), :, :]) ** 2\n",
"rspace_abs_prop = np.abs(drifted_w.rmesh[int(nt / 2), :, :]) ** 2\n",
"rspace_abs_orig = np.abs(get_transverse_slice(W.rmesh)) ** 2\n",
"rspace_abs_prop = np.abs(get_transverse_slice(drifted_w.rmesh)) ** 2\n",
"vmin = np.min((np.min(rspace_abs_orig), np.min(rspace_abs_prop)))\n",
"vmax = np.min((np.max(rspace_abs_orig), np.max(rspace_abs_prop)))\n",
"im1 = ax1.imshow(rspace_abs_orig, vmin=vmin, vmax=vmax)\n",
Expand Down Expand Up @@ -236,10 +238,10 @@
"source": [
"zR = 2.0\n",
"X = Wavefront.gaussian_pulse(\n",
" dims=(801, 101, 101),\n",
" dims=(101, 101, 801),\n",
" wavelength=1.35e-8,\n",
" grid_spacing=(0.0625, 6e-6, 6e-6),\n",
" pad=(40, 100, 100),\n",
" grid_spacing=(6e-6, 6e-6, 0.0625),\n",
" pad=(100, 100, 40),\n",
" nphotons=1e12,\n",
" zR=zR,\n",
" sigma_t=5,\n",
Expand All @@ -249,7 +251,8 @@
"zgrid = 5\n",
"dz = 0.25\n",
"\n",
"wfz = np.zeros((mz * zgrid, X.pad.grid[1], X.pad.grid[2]))\n",
"nx, ny = get_transverse_slice(X.rmesh).shape\n",
"wfz = np.zeros((mz * zgrid, nx, ny))\n",
"fwhmz_fit = np.zeros(mz * zgrid)\n",
"\n",
"w0 = np.sqrt(zR * X.wavelength / np.pi)\n",
Expand All @@ -269,7 +272,7 @@
" if zi > 0:\n",
" print(\"Propagating to: \", zi * dz)\n",
" X.drift(dz, inplace=True)\n",
" wf = np.abs(X.rmesh[int(nt / 2), :, :]) ** 2\n",
" wf = np.abs(get_transverse_slice(X.rmesh)) ** 2\n",
"\n",
" popt_gaussian, ydata_fit, FWHM, roots = gaussian_fit(\n",
" domain_x,\n",
Expand Down
20 changes: 13 additions & 7 deletions pmd_beamphysics/wavefront.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def conversion_coeffs(wavelength: float, dim: int) -> Tuple[float, ...]:
"""
k0 = calculate_k0(wavelength)
hbar = scipy.constants.hbar / scipy.constants.e * 1.0e15 # fs-eV
return tuple([2.0 * np.pi * hbar] + [2.0 * np.pi / k0] * (dim - 1))
return tuple([2.0 * np.pi / k0] * (dim - 1) + [2.0 * np.pi * hbar])


def cartesian_domain(
Expand Down Expand Up @@ -421,11 +421,12 @@ def drift_propagator_paraxial(
wavelength: float,
):
"""Fresnel propagator in paraxial approximation to distance z [m]."""
return kmesh * drift_kernel_paraxial(
kernel = drift_kernel_paraxial(
transverse_kspace_grid=transverse_kspace_grid,
z=z,
wavelength=wavelength,
)
return kmesh * kernel[:, :, np.newaxis]


def thin_lens_kernel_xy(
Expand Down Expand Up @@ -454,7 +455,7 @@ def thin_lens_kernel_xy(
np.ndarray
"""
k0 = calculate_k0(wavelength)
xx, yy = nd_space_mesh(ranges[1:], grid[1:])
xx, yy = nd_space_mesh(ranges[:2], grid[:2])
return np.exp(-1j * k0 / 2.0 * (xx**2 / f_lens_x + yy**2 / f_lens_y))


Expand Down Expand Up @@ -513,7 +514,9 @@ def create_gaussian_pulse_3d_with_q(
eta = 2.0 * k0 * zR * sigma_t / np.sqrt(np.pi)

pulse = np.sqrt(eta) * np.sqrt(nphotons) * ux * uy * ut
return pulse.astype(dtype)

# TODO a bit of swapping without changing the gaussian pulse
return np.moveaxis(pulse.astype(dtype), 0, -1)


def transverse_divergence_padding_factor(
Expand Down Expand Up @@ -860,21 +863,24 @@ def gaussian_pulse(
-------
Wavefront
"""
# TODO a bit of swapping without changing the gaussian pulse
nx, ny, nz = dims
gx, gy, gz = grid_spacing
pulse = create_gaussian_pulse_3d_with_q(
wavelength=wavelength,
nphotons=nphotons,
zR=zR,
sigma_t=sigma_t,
grid_spacing=grid_spacing,
grid=dims,
grid_spacing=(gz, gx, gy),
grid=(nz, nx, ny),
dtype=dtype,
)
return cls(
rmesh=pulse,
wavelength=wavelength,
grid_spacing=grid_spacing,
pad=pad,
axis_labels="zxy",
axis_labels="xyz",
longitudinal_axis="z",
)

Expand Down

0 comments on commit 8c59736

Please sign in to comment.