Skip to content

Commit

Permalink
blackify code
Browse files Browse the repository at this point in the history
  • Loading branch information
jvshields committed Jul 25, 2024
1 parent d4e97e0 commit c5ac1f7
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 33 deletions.
4 changes: 3 additions & 1 deletion stardis/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def parse_config_to_model(config_fname, add_config_keys=None, add_config_vals=No
logging.info("Reading model")
if config.model.type == "marcs":
raw_marcs_model = read_marcs_model(
Path(config.model.fname), gzipped=config.model.gzipped, spherical=config.model.spherical
Path(config.model.fname),
gzipped=config.model.gzipped,
spherical=config.model.spherical,
)
stellar_model = raw_marcs_model.to_stellar_model(
adata, final_atomic_number=config.model.final_atomic_number
Expand Down
14 changes: 9 additions & 5 deletions stardis/io/model/marcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def to_geometry(self):
-self.data.depth.values[::-1] * u.cm
) # Flip data to move from innermost stellar point to surface
if self.spherical:
r += self.metadata['radius']
r += self.metadata["radius"]

return Radial1DGeometry(r)

def to_composition(self, atom_data, final_atomic_number):
Expand Down Expand Up @@ -250,7 +250,11 @@ def read_marcs_metadata(fpath, gzipped=True, spherical=False):
"radius",
"radius_units",
),
(r"\s+(\d+\.\d+(?:E[+-]?\d+)?) Luminosity \[(.+)\]", "luminosity", "luminosity_units"),
(
r"\s+(\d+\.\d+(?:E[+-]?\d+)?) Luminosity \[(.+)\]",
"luminosity",
"luminosity_units",
),
(
r" (\d+.\d+) (\d+.\d+) (\d+.\d+) (\d+.\d+) are the convection parameters: alpha, nu, y and beta",
"conv_alpha",
Expand Down Expand Up @@ -290,8 +294,8 @@ def read_marcs_metadata(fpath, gzipped=True, spherical=False):
contents = file.readlines(BYTES_THROUGH_METADATA)

lines = list(contents)
#Check each line against the regex patterns and add the matched values to the metadata dictionary

# Check each line against the regex patterns and add the matched values to the metadata dictionary
for i in range(len(metadata_re_str)):
line = lines[i]
metadata_re_match = metadata_re[i].match(line)
Expand Down
76 changes: 49 additions & 27 deletions stardis/radiation_field/radiation_field_solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,15 @@ def single_theta_trace_parallel(
# Need to calculate a mean opacity for the traversal between points. Linearly interporlating. Van Noort paper suggests interpolating
# alphas in log space. We could have a choice for interpolation scheme here.
mean_alphas = np.exp((np.log(alphas[1:]) + np.log(alphas[:-1])) * 0.5)

taus = np.zeros_like(mean_alphas, dtype=np.float64)
for gap_index in numba.prange(taus.shape[0]):
for nu_index in range(taus.shape[1]):
taus[gap_index, nu_index] = mean_alphas[gap_index, nu_index] * ray_dist_to_next_depth_point[gap_index]


taus[gap_index, nu_index] = (
mean_alphas[gap_index, nu_index]
* ray_dist_to_next_depth_point[gap_index]
)

no_of_depth_gaps = len(ray_dist_to_next_depth_point)

source = source_function(tracing_nus, temps)
Expand Down Expand Up @@ -195,8 +197,8 @@ def single_theta_trace(
tracing_nus,
thetas,
source_function,
spherical = False,
reference_radius = 2.5e11,
spherical=False,
reference_radius=2.5e11,
):
"""
Performs ray tracing at an angle following van Noort 2001 eq 14.
Expand Down Expand Up @@ -227,8 +229,10 @@ def single_theta_trace(
mean_alphas = np.exp((np.log(alphas[1:]) + np.log(alphas[:-1])) * 0.5)
if spherical:
pass

taus = (mean_alphas[:,:, np.newaxis] * ray_dist_to_next_depth_point[:, np.newaxis, :])

taus = (
mean_alphas[:, :, np.newaxis] * ray_dist_to_next_depth_point[:, np.newaxis, :]
)
no_of_depth_gaps = len(ray_dist_to_next_depth_point)

source = source_function(tracing_nus, temps)[:, :, np.newaxis]
Expand Down Expand Up @@ -278,7 +282,14 @@ def single_theta_trace(
return I_nu_theta


def raytrace(stellar_model, stellar_radiation_field, no_of_thetas=20, n_threads=1, spherical=False, reference_radius=2.5e11):
def raytrace(
stellar_model,
stellar_radiation_field,
no_of_thetas=20,
n_threads=1,
spherical=False,
reference_radius=2.5e11,
):
"""
Raytraces over many angles and integrates to get flux using the midpoint
rule.
Expand All @@ -299,23 +310,26 @@ def raytrace(stellar_model, stellar_radiation_field, no_of_thetas=20, n_threads=
"""

if spherical:
#Calculate photosphere correction - apply it later to F_nu
# Calculate photosphere correction - apply it later to F_nu
pass
else:
pass
dtheta = (np.pi / 2) / no_of_thetas #Korg uses Gauss-Legendre quadrature here
dtheta = (np.pi / 2) / no_of_thetas # Korg uses Gauss-Legendre quadrature here
start_theta = dtheta / 2
end_theta = (np.pi / 2) - (dtheta / 2)
thetas = np.linspace(start_theta, end_theta, no_of_thetas)
weights = 2 * np.pi * dtheta * np.sin(thetas) * np.cos(thetas)

if True:
ray_distances, ray_deepest_point_mask = calculate_spherical_ray(thetas, stellar_model.geometry.r)
ray_distances, ray_deepest_point_mask = calculate_spherical_ray(
thetas, stellar_model.geometry.r
)
# print(ray_distances.shape)
else:
ray_distances = stellar_model.geometry.dist_to_next_depth_point.reshape(-1,1) / np.cos(thetas)
else:
ray_distances = stellar_model.geometry.dist_to_next_depth_point.reshape(
-1, 1
) / np.cos(thetas)
# print(ray_distances.shape)


###TODO: Thetas should probably be held by the model? Then can be passed in from there.
if n_threads == 1: # Single threaded
Expand All @@ -334,7 +348,9 @@ def raytrace(stellar_model, stellar_radiation_field, no_of_thetas=20, n_threads=

else: # Parallel threaded
for theta_index, theta in enumerate(thetas):
stellar_radiation_field.F_nu += weights[theta_index] * single_theta_trace_parallel(
stellar_radiation_field.F_nu += weights[
theta_index
] * single_theta_trace_parallel(
ray_distances[:, theta_index],
stellar_model.temperatures.value.reshape(-1, 1),
stellar_radiation_field.opacities.total_alphas,
Expand All @@ -345,28 +361,34 @@ def raytrace(stellar_model, stellar_radiation_field, no_of_thetas=20, n_threads=

return stellar_radiation_field.F_nu


def calculate_spherical_ray(thetas, depth_points_radii):
###NOTE: This will need to be revisited to handle some rays more carefully if they don't go through the star
ray_distance_to_next_depth_point = np.zeros((len(depth_points_radii) - 1, len(thetas)))
###NOTE: This will need to be revisited to handle some rays more carefully if they don't go through the star
ray_distance_to_next_depth_point = np.zeros(
(len(depth_points_radii) - 1, len(thetas))
)
ray_deepest_layer_mask = np.zeros((len(depth_points_radii), len(thetas)))

for theta_index, theta in enumerate(thetas):
b = depth_points_radii[-1] * np.sin(theta)
ray_depth_selection_mask = b < depth_points_radii #mask for the depth points that the ray will pass through.
ray_depth_selection_mask = (
b < depth_points_radii
) # mask for the depth points that the ray will pass through.
ray_z_coordinate_grid = np.sqrt(depth_points_radii**2 - b**2)

ray_distance_to_next_depth_point[:, theta_index] = np.diff(ray_z_coordinate_grid)

ray_distance_to_next_depth_point[:, theta_index] = np.diff(
ray_z_coordinate_grid
)
ray_deepest_layer_mask[:, theta_index] = ray_depth_selection_mask
if ray_distance_to_next_depth_point.any() == 0:
print(f"NaN in ray_distance_to_next_depth_point, theta is {theta}")
print(ray_distance_to_next_depth_point)



# b = depth_points_radii[-1] * np.sin(thetas) #impact parameter
# ray_depth_selection_mask = b < depth_points_radii #mask for the depth points that the ray will pass through.
# ray_depth_selection_mask = b < depth_points_radii #mask for the depth points that the ray will pass through.
# #The layers the ray doesn't pass through will not contribute to the outgoing flux
# ray_distances = np.zeros_like(b)
# print(ray_distances.shape)
# ray_distances = np.sqrt(depth_points_radii[ray_depth_selection_mask]**2 - b**2)
# ray_distance_to_next_depth_point = np.diff(ray_distances)
return(ray_distance_to_next_depth_point, ray_deepest_layer_mask)
return (ray_distance_to_next_depth_point, ray_deepest_layer_mask)

0 comments on commit c5ac1f7

Please sign in to comment.