Skip to content

Commit

Permalink
Apply refurb style suggestions (#266)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukeshingles authored Nov 20, 2024
1 parent 56f2fd1 commit 80000f6
Show file tree
Hide file tree
Showing 15 changed files with 114 additions and 101 deletions.
39 changes: 21 additions & 18 deletions artistools/atomic/_atomic_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,24 +315,27 @@ def get_levels(
derived_transitions_columns: Sequence[str] | None = None,
use_rust_reader: bool | None = None,
) -> pd.DataFrame:
pldf = get_levels_polars(
modelpath,
ionlist=ionlist,
get_transitions=get_transitions,
get_photoionisations=get_photoionisations,
quiet=quiet,
derived_transitions_columns=derived_transitions_columns,
use_rust_reader=use_rust_reader,
)
pldf = pldf.with_columns(
levels=pl.col("levels").map_elements(
lambda x: x.to_pandas(use_pyarrow_extension_array=True), return_dtype=pl.Object
),
transitions=pl.col("transitions").map_elements(
lambda x: x.collect().to_pandas(use_pyarrow_extension_array=True), return_dtype=pl.Object
),
"""Return a pandas DataFrame of energy levels."""
return (
get_levels_polars(
modelpath,
ionlist=ionlist,
get_transitions=get_transitions,
get_photoionisations=get_photoionisations,
quiet=quiet,
derived_transitions_columns=derived_transitions_columns,
use_rust_reader=use_rust_reader,
)
.with_columns(
levels=pl.col("levels").map_elements(
lambda x: x.to_pandas(use_pyarrow_extension_array=True), return_dtype=pl.Object
),
transitions=pl.col("transitions").map_elements(
lambda x: x.collect().to_pandas(use_pyarrow_extension_array=True), return_dtype=pl.Object
),
)
.to_pandas(use_pyarrow_extension_array=True)
)
return pldf.to_pandas(use_pyarrow_extension_array=True)


def parse_recombratefile(frecomb: io.TextIOBase) -> t.Generator[tuple[int, int, pl.DataFrame], None, None]:
Expand All @@ -342,7 +345,7 @@ def parse_recombratefile(frecomb: io.TextIOBase) -> t.Generator[tuple[int, int,
arr_log10t = []
arr_rrc_low_n = []
arr_rrc_total = []
for _ in range(int(t_count)):
for _ in range(t_count):
log10t, rrc_low_n, rrc_total = (float(x) for x in frecomb.readline().split())

arr_log10t.append(log10t)
Expand Down
4 changes: 1 addition & 3 deletions artistools/estimators/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,7 @@ def scan_estimators(
# for older files with no deposition data, take heating part of deposition and heating fraction
pldflazy = pldflazy.with_columns(total_dep=pl.col("heating_dep") / pl.col("heating_heating_dep/total_dep"))

pldflazy = pldflazy.with_columns(nntot=pl.sum_horizontal(cs.starts_with("nnelement_")))

return pldflazy.fill_null(0)
return pldflazy.with_columns(nntot=pl.sum_horizontal(cs.starts_with("nnelement_"))).fill_null(0)


def read_estimators(
Expand Down
2 changes: 1 addition & 1 deletion artistools/estimators/plot3destimators_classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def plot_Te_vs_time_lineofsight_3d_model(modelpath, modeldata, estimators, reado
for mgi in readonly_mgi:
associated_modeldata_row_for_mgi = modeldata.loc[modeldata["inputcellid"] == assoc_cells[mgi][0]]

Te = [estimators[timestep, mgi]["Te"] for timestep, _ in enumerate(times)]
Te = [estimators[timestep, mgi]["Te"] for timestep in range(len(times))]
plt.scatter(times, Te, label=f"vel={associated_modeldata_row_for_mgi['vel_y_mid'].to_numpy()[0] / CLIGHT}")

plt.xlabel("time [days]")
Expand Down
2 changes: 1 addition & 1 deletion artistools/gsinetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def plot_qdot_abund_modelcells(
arr_time_artis_s_alltimesteps = np.array([t * 8.640000e04 for t in arr_time_artis_days_alltimesteps])
# no completed timesteps yet, so display full set of timesteps that artis will compute
if not arr_time_artis_days:
arr_time_artis_days = list(arr_time_artis_days_alltimesteps)
arr_time_artis_days = arr_time_artis_days_alltimesteps.copy()

arr_time_gsi_s = np.array([modelmeta["t_model_init_days"] * 86400, *arr_time_artis_s_alltimesteps])

Expand Down
2 changes: 1 addition & 1 deletion artistools/inputmodel/make1dslicefrom3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def append_cell_to_output(
cell, outcellid: int, t_model: str | float, listout: list[str], xlist: list[float], ylists: list[list[float]]
) -> None:
dist = math.sqrt(float(cell["pos_x_min"]) ** 2 + float(cell["pos_y_min"]) ** 2 + float(cell["pos_z_min"]) ** 2)
velocity = float(dist) / float(t_model) / 86400.0 / 1.0e5
velocity = dist / float(t_model) / 86400.0 / 1.0e5

listout.append(
f"{outcellid:6d} {velocity:8.2f} {math.log10(max(float(cell['rho']), 1e-100)):8.5f} "
Expand Down
2 changes: 1 addition & 1 deletion artistools/inputmodel/recombinationenergy.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def get_particles_recomb_nuc_energy(traj_root, dfbinding):
nuclear_released_en_list.append(nuc_en_released)
# print(particleid, ye, elecbinding_en)
except FileNotFoundError:
pass
# print(f' WARNING particle {particleid} not found! ')
pass

dfrecomb = pd.DataFrame({
"ye": ye_list,
Expand Down
128 changes: 71 additions & 57 deletions artistools/lightcurve/plotlightcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,10 @@ def plot_deposition_thermalisation(
axis.plot(
depdata["tmid_days"],
gammadep_lsun * 3.826e33,
**{
**plotkwargs,
"label": plotkwargs["label"] + r" $\dot{E}_{dep,\gamma}$",
"linestyle": "dashed",
"color": color_gamma,
},
**(
plotkwargs
| {"label": plotkwargs["label"] + r" $\dot{E}_{dep,\gamma}$", "linestyle": "dashed", "color": color_gamma}
),
)

color_beta = axis._get_lines.get_next_color() # type: ignore[attr-defined] # noqa: SLF001
Expand All @@ -105,24 +103,28 @@ def plot_deposition_thermalisation(
axis.plot(
depdata["tmid_days"],
depdata["eps_elec_Lsun"] * 3.826e33,
**{
**plotkwargs,
"label": plotkwargs["label"] + r" $\dot{E}_{rad,\beta^-}$",
"linestyle": "dotted",
"color": color_beta,
},
**(
plotkwargs
| {
"label": plotkwargs["label"] + r" $\dot{E}_{rad,\beta^-}$",
"linestyle": "dotted",
"color": color_beta,
}
),
)

if "elecdep_Lsun" in depdata:
axis.plot(
depdata["tmid_days"],
depdata["elecdep_Lsun"] * 3.826e33,
**{
**plotkwargs,
"label": plotkwargs["label"] + r" $\dot{E}_{dep,\beta^-}$",
"linestyle": "dashed",
"color": color_beta,
},
**(
plotkwargs
| {
"label": plotkwargs["label"] + r" $\dot{E}_{dep,\beta^-}$",
"linestyle": "dashed",
"color": color_beta,
}
),
)

# c23modelpath = Path(
Expand Down Expand Up @@ -153,35 +155,41 @@ def plot_deposition_thermalisation(
axis.plot(
depdata["tmid_days"],
depdata["eps_alpha_ana_Lsun"] * 3.826e33,
**{
**plotkwargs,
"label": plotkwargs["label"] + r" $\dot{E}_{rad,\alpha}$ analytical",
"linestyle": "solid",
"color": color_alpha,
},
**(
plotkwargs
| {
"label": plotkwargs["label"] + r" $\dot{E}_{rad,\alpha}$ analytical",
"linestyle": "solid",
"color": color_alpha,
}
),
)

if "eps_alpha_Lsun" in depdata:
axis.plot(
depdata["tmid_days"],
depdata["eps_alpha_Lsun"] * 3.826e33,
**{
**plotkwargs,
"label": plotkwargs["label"] + r" $\dot{E}_{rad,\alpha}$",
"linestyle": "dashed",
"color": color_alpha,
},
**(
plotkwargs
| {
"label": plotkwargs["label"] + r" $\dot{E}_{rad,\alpha}$",
"linestyle": "dashed",
"color": color_alpha,
}
),
)

axis.plot(
depdata["tmid_days"],
depdata["alphadep_Lsun"] * 3.826e33,
**{
**plotkwargs,
"label": plotkwargs["label"] + r" $\dot{E}_{dep,\alpha}$",
"linestyle": "dotted",
"color": color_alpha,
},
**(
plotkwargs
| {
"label": plotkwargs["label"] + r" $\dot{E}_{dep,\alpha}$",
"linestyle": "dotted",
"color": color_alpha,
}
),
)

if args.plotthermalisation:
Expand All @@ -190,37 +198,43 @@ def plot_deposition_thermalisation(
axistherm.plot(
depdata["tmid_days"],
f_gamma,
**{
**plotkwargs,
"label": modelname + r" $\left(\dot{E}_{dep,\gamma} \middle/ \dot{E}_{rad,\gamma}\right)$",
"linestyle": "solid",
"color": color_gamma,
},
**(
plotkwargs
| {
"label": modelname + r" $\left(\dot{E}_{dep,\gamma} \middle/ \dot{E}_{rad,\gamma}\right)$",
"linestyle": "solid",
"color": color_gamma,
}
),
)

f_beta = depdata["elecdep_Lsun"] / depdata["eps_elec_Lsun"]
axistherm.plot(
depdata["tmid_days"],
f_beta,
**{
**plotkwargs,
"label": modelname + r" $\left(\dot{E}_{dep,\beta^-} \middle/ \dot{E}_{rad,\beta^-}\right)$",
"linestyle": "solid",
"color": color_beta,
},
**(
plotkwargs
| {
"label": modelname + r" $\left(\dot{E}_{dep,\beta^-} \middle/ \dot{E}_{rad,\beta^-}\right)$",
"linestyle": "solid",
"color": color_beta,
}
),
)

f_alpha = depdata["alphadep_Lsun"] / depdata["eps_alpha_Lsun"]

axistherm.plot(
depdata["tmid_days"],
f_alpha,
**{
**plotkwargs,
"label": modelname + r" $\left(\dot{E}_{dep,\alpha} \middle/ \dot{E}_{rad,\alpha}\right)$",
"linestyle": "solid",
"color": color_alpha,
},
**(
plotkwargs
| {
"label": modelname + r" $\left(\dot{E}_{dep,\alpha} \middle/ \dot{E}_{rad,\alpha}\right)$",
"linestyle": "solid",
"color": color_alpha,
}
),
)

ejecta_ke_erg: float = dfmodel.select("kinetic_en_erg").sum().collect().item()
Expand All @@ -241,7 +255,7 @@ def plot_deposition_thermalisation(
axistherm.plot(
depdata["tmid_days"],
barnes_f_gamma,
**{**plotkwargs, "label": r"Barnes+2016 $f_\gamma$", "linestyle": "dashed", "color": color_gamma},
**(plotkwargs | {"label": r"Barnes+2016 $f_\gamma$", "linestyle": "dashed", "color": color_gamma}),
)

e0_beta_mev = 0.5
Expand All @@ -255,7 +269,7 @@ def plot_deposition_thermalisation(
axistherm.plot(
depdata["tmid_days"],
barnes_f_beta,
**{**plotkwargs, "label": r"Barnes+2016 $f_\beta$", "linestyle": "dashed", "color": color_beta},
**(plotkwargs | {"label": r"Barnes+2016 $f_\beta$", "linestyle": "dashed", "color": color_beta}),
)

e0_alpha_mev = 6.0
Expand All @@ -269,7 +283,7 @@ def plot_deposition_thermalisation(
axistherm.plot(
depdata["tmid_days"],
barnes_f_alpha,
**{**plotkwargs, "label": r"Barnes+2016 $f_\alpha$", "linestyle": "dashed", "color": color_alpha},
**(plotkwargs | {"label": r"Barnes+2016 $f_\alpha$", "linestyle": "dashed", "color": color_alpha}),
)


Expand Down
2 changes: 1 addition & 1 deletion artistools/lightcurve/viewingangleanalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def make_peak_colour_viewing_angle_plot(args: argparse.Namespace) -> None:

bands = [args.filter[0], args.filter[1]]

datafilename = bands[0] + "band_" + str(modelname) + "_viewing_angle_data.txt"
datafilename = bands[0] + "band_" + modelname + "_viewing_angle_data.txt"
viewing_angle_plot_data = pd.read_csv(datafilename, delimiter=" ")
data = {f"{bands[0]}max": viewing_angle_plot_data["peak_mag_polyfit"].to_numpy()}
data[f"time_{bands[0]}max"] = viewing_angle_plot_data["risetime_polyfit"].to_numpy()
Expand Down
2 changes: 1 addition & 1 deletion artistools/linefluxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def make_emitting_regions_plot(args: argparse.Namespace) -> None:
refdatakeys[refdataindex] = refdatakeys_thisseries
refdatatimes[refdataindex] = np.array([float(t) for t in refdatakeys_thisseries])
refdatapoints[refdataindex] = [floers_te_nne[t] for t in refdatakeys_thisseries]
print(f"{refdatafilename} data available for times: {list(refdatakeys_thisseries)}")
print(f"{refdatafilename} data available for times: {refdatakeys_thisseries}")

times_days = (np.array(args.timebins_tstart) + np.array(args.timebins_tend)) / 2.0

Expand Down
8 changes: 4 additions & 4 deletions artistools/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def showtimesteptimes(modelpath: Path | None = None, numberofcolumns: int = 5) -
strline += "\t"
newindex = rownum + colnum * indexendofcolumnone
if newindex + 1 < len(times):
strline += f"{newindex:4d}: {float(times[newindex + 1]):.3f}d"
strline += f"{newindex:4d}: {times[newindex + 1]:.3f}d"
print(strline)


Expand Down Expand Up @@ -528,10 +528,10 @@ def get_time_range(
timestepmax = timesteplast
if time_days_lower is None:
assert timestepmin is not None
time_days_lower = float(tstarts[timestepmin]) if clamp_to_timesteps else timemin
time_days_lower = tstarts[timestepmin] if clamp_to_timesteps else timemin
if time_days_upper is None:
assert timestepmax is not None
time_days_upper = float(tends[timestepmax]) if clamp_to_timesteps else timemax
time_days_upper = tends[timestepmax] if clamp_to_timesteps else timemax
assert timestepmin is not None
assert timestepmax is not None

Expand Down Expand Up @@ -1348,7 +1348,7 @@ def get_dfrankassignments(modelpath: Path | str) -> pd.DataFrame | None:
if filerankassignments.is_file():
dfrankassignments = pd.read_csv(filerankassignments, sep=r"\s+")
return dfrankassignments.rename(
columns={dfrankassignments.columns[0]: str(dfrankassignments.columns[0]).lstrip("#")}
columns={dfrankassignments.columns[0]: dfrankassignments.columns[0].lstrip("#")}
)
return None

Expand Down
6 changes: 3 additions & 3 deletions artistools/nltepops/plotnltepops.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def plot_populations_with_time_or_velocity(
# == ionlevel]['n_LTE'].values[0])

for ionlevel in ionlevels:
plottimesteps = np.array([int(ts) for ts, level, mgi in populations if level == ionlevel])
plottimesteps = np.array([ts for ts, level, mgi in populations if level == ionlevel])
timedays = [at.get_timestep_time(modelpath, ts) for ts in plottimesteps]
plotpopulations = np.array([
float(populations[ts, level, mgi]) for ts, level, mgi in populations if level == ionlevel
Expand Down Expand Up @@ -610,8 +610,8 @@ def make_plot(modelpath, atomic_number, ion_stages_displayed, mgilist, timestep,
nne = estimators[timestep, modelgridindex]["nne"]
W = estimators[timestep, modelgridindex]["W"]

subplot_title = str(modelname)
if len(modelname) > 10:
subplot_title = modelname
if len(subplot_title) > 10:
subplot_title += "\n"
velocity = at.inputmodel.get_modeldata_tuple(modelpath)[0]["vel_r_max_kmps"][modelgridindex]
subplot_title += f" {velocity:.0f} km/s at"
Expand Down
2 changes: 1 addition & 1 deletion artistools/packets/packets.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ def add_derived_columns_lazy(
dfpackets = dfpackets.with_columns(
em_modelgridindex=(
pl.col("emission_velocity")
.cut(breaks=list(velbins), labels=[str(x) for x in range(-1, len(velbins))])
.cut(breaks=velbins, labels=[str(x) for x in range(-1, len(velbins))])
.cast(str)
.cast(pl.Int32)
)
Expand Down
Loading

0 comments on commit 80000f6

Please sign in to comment.