Skip to content

Commit

Permalink
PEP8
Browse files Browse the repository at this point in the history
  • Loading branch information
ddobie committed Nov 20, 2023
1 parent 37840aa commit c0c4817
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 51 deletions.
21 changes: 13 additions & 8 deletions vast_pipeline/image/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class Image(object):
path (str): The system path to the image.
"""

def __init__(self, path: str) -> None:
"""
Initiliase an image object.
Expand Down Expand Up @@ -81,7 +82,7 @@ class FitsImage(Image):

entire_image = True

def __init__(self, path: str, hdu_index: int=0) -> None:
def __init__(self, path: str, hdu_index: int = 0) -> None:
"""
Initialise a FitsImage object.
Expand Down Expand Up @@ -117,7 +118,7 @@ def __get_header(self, hdu_index: int) -> fits.Header:
Returns:
The FITS header as an astropy.io.fits.Header object.
"""

try:
with open_fits(self.path) as hdulist:
hdu = hdulist[hdu_index]
Expand Down Expand Up @@ -225,7 +226,8 @@ def __get_radius_pixels(
The radius of the image in pixels.
"""
if self.entire_image:
# a large circle that *should* include the whole image (and then some)
# a large circle that *should* include the whole image
# (and then some)
diameter = np.hypot(header[fits_naxis1], header[fits_naxis2])
else:
# We simply place the largest circle we can in the centre.
Expand All @@ -246,10 +248,11 @@ def __get_frequency(self, header: fits.Header) -> None:
self.freq_eff = None
self.freq_bw = None
try:
if ('ctype3' in header) and (header['ctype3'] in ('FREQ', 'VOPT')):
freq_keys = ('FREQ', 'VOPT')
if ('ctype3' in header) and (header['ctype3'] in freq_keys):
self.freq_eff = header['crval3']
self.freq_bw = header['cdelt3'] if 'cdelt3' in header else 0.0
elif ('ctype4' in header) and (header['ctype4'] in ('FREQ', 'VOPT')):
elif ('ctype4' in header) and (header['ctype4'] in freq_keys):
self.freq_eff = header['crval4']
self.freq_bw = header['cdelt4'] if 'cdelt4' in header else 0.0
else:
Expand All @@ -273,6 +276,7 @@ class SelavyImage(FitsImage):
associated with the image.
config (Dict): The image configuration settings.
"""

def __init__(
self,
path: str,
Expand Down Expand Up @@ -315,7 +319,8 @@ def read_selavy(self, dj_image: models.Image) -> pd.DataFrame:
Dataframe containing the cleaned and processed Selavy components.
"""
# TODO: improve with loading only the cols we need and set datatype
if self.selavy_path.endswith(".xml") or self.selavy_path.endswith(".vot"):
if self.selavy_path.endswith(
".xml") or self.selavy_path.endswith(".vot"):
df = Table.read(
self.selavy_path, format="votable", use_names_over_ids=True
).to_pandas()
Expand Down Expand Up @@ -462,12 +467,12 @@ def read_selavy(self, dj_image: models.Image) -> pd.DataFrame:
.agg('sum')
)

df['flux_int_isl_ratio'] = (
df['flux_int_isl_ratio'] = (
df['flux_int'].values
/ island_flux_totals.loc[df['island_id']]['flux_int'].values
)

df['flux_peak_isl_ratio'] = (
df['flux_peak_isl_ratio'] = (
df['flux_peak'].values
/ island_flux_totals.loc[df['island_id']]['flux_peak'].values
)
Expand Down
30 changes: 17 additions & 13 deletions vast_pipeline/image/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def calc_error_radius(ra, ra_err, dec, dec_err) -> float:
np.deg2rad(i),
dec_1,
np.deg2rad(j)
)) for i,j in zip(ra_offsets, dec_offsets)
)) for i, j in zip(ra_offsets, dec_offsets)
]

seps = np.column_stack(seps)
Expand Down Expand Up @@ -192,7 +192,7 @@ def calc_condon_flux_errors(
(1. + (theta_B / major)**2)**alpha_maj2 *
(1. + (theta_b / minor)**2)**alpha_min2 *
snr**2)
rho_sq3 = ((major * minor / (4.* theta_B * theta_b)) *
rho_sq3 = ((major * minor / (4. * theta_B * theta_b)) *
(1. + (theta_B / major)**2)**alpha_maj3 *
(1. + (theta_b / minor)**2)**alpha_min3 *
snr**2)
Expand All @@ -212,9 +212,9 @@ def calc_condon_flux_errors(

# ra and dec errors
errorra = np.sqrt((error_par_major * np.sin(theta))**2 +
(error_par_minor * np.cos(theta))**2)
(error_par_minor * np.cos(theta))**2)
errordec = np.sqrt((error_par_major * np.cos(theta))**2 +
(error_par_minor * np.sin(theta))**2)
(error_par_minor * np.sin(theta))**2)

errormajor = np.sqrt(2) * major / rho1
errorminor = np.sqrt(2) * minor / rho2
Expand All @@ -240,35 +240,39 @@ def calc_condon_flux_errors(
help1 = (errormajor / major)**2
help2 = (errorminor / minor)**2
help3 = theta_B * theta_b / (major * minor)
errorflux = np.abs(flux_int) * np.sqrt(errorpeaksq / flux_peak**2 + help3 * (help1 + help2))
help4 = np.sqrt(errorpeaksq / flux_peak**2 + help3 * (help1 + help2))
errorflux = np.abs(flux_int) * help4

# need to return flux_peak if used.
return errorpeak, errorflux, errormajor, errorminor, errortheta, errorra, errordec

except Exception as e:
logger.debug("Error in the calculation of Condon errors for a source", exc_info=True)
logger.debug(
"Error in the calculation of Condon errors for a source",
exc_info=True)
return 0., 0., 0., 0., 0., 0., 0.

def open_fits(fits_path: Union[str, Path], memmap: Optional[bool]=True):

def open_fits(fits_path: Union[str, Path], memmap: Optional[bool] = True):
"""
This function opens both compressed and uncompressed fits files.
Args:
fits_path: Path to the fits file
memmap: Open the fits file with mmap.
Returns:
HDUList loaded from the fits file
Raises:
ValueError: File extension must be .fits or .fits.fz
"""

if type(fits_path) == Path:
if isinstance(fits_path, Path):
fits_path = str(fits_path)

hdul = fits.open(fits_path, memmap=memmap)

if fits_path.endswith('.fits'):
return hdul
elif fits_path.endswith('.fits.fz'):
Expand Down
45 changes: 25 additions & 20 deletions vast_pipeline/pipeline/forced_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def get_data_from_parquet(
Args:
file_and_image_id:
a tuple containing the path of the measurements parquet file and the image ID.
a tuple containing the path of the measurements parquet file and
the image ID.
p_run_path:
Pipeline run path to get forced parquet in case of add mode.
add_mode:
Expand Down Expand Up @@ -108,26 +109,27 @@ def get_data_from_parquet(
def _forcedphot_preload(image: str,
background: str,
noise: str,
memmap: Optional[bool]=False
memmap: Optional[bool] = False
):
"""
Load the relevant image, background and noisemap files.
Args:
image: a string with the path of the image file
background: a string with the path of the background map
noise: a string with the path of the noise map
Returns:
A tuple containing the HDU lists
"""

image_hdul = open_fits(image, memmap=memmap)
background_hdul = open_fits(background, memmap=memmap)
noise_hdul = open_fits(noise, memmap=memmap)

return image_hdul, background_hdul, noise_hdul



def extract_from_image(
df: pd.DataFrame,
image: str,
Expand Down Expand Up @@ -172,12 +174,13 @@ def extract_from_image(
unit=(u.deg, u.deg)
)
# load the image, background and noisemaps into memory
# a dedicated function may seem unneccesary, but will be useful if we split the load to a separate thread.
# a dedicated function may seem unneccesary, but will be useful if we
# split the load to a separate thread.
forcedphot_input = _forcedphot_preload(image,
background,
noise,
memmap=False
)
)
FP = ForcedPhot(*forcedphot_input)

flux, flux_err, chisq, DOF, cluster_id = FP.measure(
Expand All @@ -197,7 +200,7 @@ def finalise_forced_dfs(
df: pd.DataFrame, prefix: str, max_id: int, beam_bmaj: float,
beam_bmin: float, beam_bpa: float, id: int, datetime: datetime.datetime,
image: str
) -> pd.DataFrame:
) -> pd.DataFrame:
"""
Compute populate leftover columns for the dataframe with forced
photometry data given the input parameters
Expand Down Expand Up @@ -254,7 +257,7 @@ def parallel_extraction(
df: pd.DataFrame, df_images: pd.DataFrame, df_sources: pd.DataFrame,
min_sigma: float, edge_buffer: float, cluster_threshold: float,
allow_nan: bool, add_mode: bool, p_run_path: str
) -> pd.DataFrame:
) -> pd.DataFrame:
"""
Parallelize forced extraction with Dask
Expand Down Expand Up @@ -291,7 +294,7 @@ def parallel_extraction(
"""
# explode the lists in 'img_diff' column (this will make a copy of the df)
out = (
df.rename(columns={'img_diff':'image', 'source':'source_tmp_id'})
df.rename(columns={'img_diff': 'image', 'source': 'source_tmp_id'})
# merge the rms_min column from df_images
.merge(
df_images[['rms_min']],
Expand All @@ -316,8 +319,8 @@ def parallel_extraction(
out['max_snr'] = out['flux_peak'].values / out['image_rms_min'].values
out = out[out['max_snr'] > min_sigma].reset_index(drop=True)
logger.debug("Min forced sigma dropped %i sources",
predrop_shape - out.shape[0]
)
predrop_shape - out.shape[0]
)

# drop some columns that are no longer needed and the df should look like
# out
Expand All @@ -340,7 +343,8 @@ def parallel_extraction(
# create a list of dictionaries with image file paths and dataframes
# with data related to each images
def image_data_func(image_name: str) -> Dict[str, Any]:
nonlocal out # `out` refers to the `out` declared in nearest enclosing scope
# `out` refers to the `out` declared in nearest enclosing scope
nonlocal out
return {
'image_id': df_images.at[image_name, 'id'],
'image': df_images.at[image_name, 'path'],
Expand Down Expand Up @@ -415,7 +419,7 @@ def image_data_func(image_name: str) -> Dict[str, Any]:
pd.concat(intermediate_df, axis=0, sort=False)
.rename(
columns={
'wavg_ra':'ra', 'wavg_dec':'dec', 'image_name': 'image'
'wavg_ra': 'ra', 'wavg_dec': 'dec', 'image_name': 'image'
}
)
)
Expand All @@ -424,7 +428,7 @@ def image_data_func(image_name: str) -> Dict[str, Any]:


def write_group_to_parquet(
df: pd.DataFrame, fname: str, add_mode: bool) -> None:
df: pd.DataFrame, fname: str, add_mode: bool) -> None:
'''
Write a dataframe correpondent to a single group/image
to a parquet file.
Expand All @@ -451,7 +455,7 @@ def write_group_to_parquet(


def parallel_write_parquet(
df: pd.DataFrame, run_path: str, add_mode: bool = False) -> None:
df: pd.DataFrame, run_path: str, add_mode: bool = False) -> None:
'''
Parallelize writing parquet files for forced measurements.
Expand All @@ -467,9 +471,10 @@ def parallel_write_parquet(
None
'''
images = df['image'].unique().tolist()
get_fname = lambda n: os.path.join(

def get_fname(n): return os.path.join(
run_path,
'forced_measurements_' + n.replace('.','_') + '.parquet'
'forced_measurements_' + n.replace('.', '_') + '.parquet'
)
dfs = list(map(lambda x: (df[df['image'] == x], get_fname(x)), images))
n_cpu = cpu_count() - 1
Expand Down
21 changes: 11 additions & 10 deletions vast_pipeline/pipeline/new_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_image_rms_measurements(

npix = round(
(nbeam / 2. * bmaj.to('arcsec') /
pixelscale).value
pixelscale).value
)

npix = int(round(npix * edge_buffer))
Expand Down Expand Up @@ -156,7 +156,7 @@ def get_image_rms_measurements(
nan_valid = []

# Get slices of each source and check NaN is not included.
for i,j in zip(array_coords[0], array_coords[1]):
for i, j in zip(array_coords[0], array_coords[1]):
sl = tuple((
slice(i - acceptable_no_nan_dist, i + acceptable_no_nan_dist),
slice(j - acceptable_no_nan_dist, j + acceptable_no_nan_dist)
Expand Down Expand Up @@ -245,10 +245,10 @@ def new_sources(
min_sigma: float, edge_buffer: float, p_run: Run
) -> pd.DataFrame:
"""
Processes the new sources detected to check that they are valid new sources.
This involves checking to see that the source *should* be seen at all in
the images where it is not detected. For valid new sources the snr
value the source would have in non-detected images is also calculated.
Processes the new sources detected to check that they are valid new
sources. This involves checking to see that the source *should* be seen at
all in the images where it is not detected. For valid new sources the
snr value the source would have in non-detected images is also calculated.
Args:
sources_df:
Expand Down Expand Up @@ -353,7 +353,7 @@ def new_sources(
left_on='detection',
right_on='name',
how='left'
).rename(columns={'datetime':'detection_time'})
).rename(columns={'datetime': 'detection_time'})

new_sources_df = new_sources_df.merge(
images_df[[
Expand All @@ -364,7 +364,7 @@ def new_sources(
right_on='name',
how='left'
).rename(columns={
'datetime':'img_diff_time',
'datetime': 'img_diff_time',
'rms_min': 'img_diff_rms_min',
'rms_median': 'img_diff_rms_median',
'noise_path': 'img_diff_rms_path'
Expand Down Expand Up @@ -439,10 +439,11 @@ def new_sources(
new_sources_df
.drop_duplicates('source')
.set_index('source')
.rename(columns={'true_sigma':'new_high_sigma'})
.rename(columns={'true_sigma': 'new_high_sigma'})
)

# moving forward only the new_high_sigma columns is needed, drop all others.
# moving forward only the new_high_sigma columns is needed, drop all
# others.
new_sources_df = new_sources_df[['new_high_sigma']]

logger.info(
Expand Down

0 comments on commit c0c4817

Please sign in to comment.