Skip to content

Commit

Permalink
scattering: improve the code
Browse files Browse the repository at this point in the history
  • Loading branch information
zerafachris committed May 27, 2024
1 parent 6b99a19 commit c9c0f11
Showing 1 changed file with 36 additions and 30 deletions.
66 changes: 36 additions & 30 deletions scatcluster/processing/scattering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
from scipy import stats as sp_stats
import xarray as xr
from matplotlib import dates as mdates
from obspy.clients.filesystem.sds import Client
Expand All @@ -22,24 +22,22 @@ def reduce_type(self):
"""
Pooling operation performed on the last axis.
"""
pooling_options = [
('avg', np.mean),
('max', np.max),
('median', np.median),
('std', np.std),
('gmean', sp.stats.gmean),
('hmean', sp.stats.hmean),
('pmean', sp.stats.pmean),
('kurtosis', sp.stats.kurtosis),
('skew', sp.stats.skew),
('entropy', sp.stats.entropy),
('sem', sp.stats.sem),
('differential_entropy', sp.stats.differential_entropy),
('median_abs_deviation', sp.stats.median_abs_deviation),
]
for po in pooling_options:
if self.network_pooling == po[0]:
return po[1]
pooling_options = {
'avg': np.mean,
'max': np.max,
'median': np.median,
'std': np.std,
'gmean': sp_stats.gmean,
'hmean': sp_stats.hmean,
'pmean': sp_stats.hmean,
'kurtosis': sp_stats.kurtosis,
'skew': sp_stats.skew,
'entropy': sp_stats.entropy,
'sem': sp_stats.sem,
'differential_entropy': sp_stats.differential_entropy,
'median_abs_deviation': sp_stats.median_abs_deviation,
}
return pooling_options.get(self.network_pooling, None)

def load_data_times(self):
"""
Expand All @@ -50,21 +48,29 @@ def load_data_times(self):
`{self.network_name}_times.npy` and stores them in the `data_times` attribute.
"""
self.data_times = np.load(
f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_'
f'{self.network_name}_times.npy')
try:
file_path = f'{self.data_savepath}data/{self.data_network}_{self.data_station}_{self.data_location}_' \
f'{self.network_name}_times.npy'
self.data_times = np.load(file_path)
except FileNotFoundError:
print(f"File not found: {file_path}")
except Exception as e:
print(f"An error occurred while loading data times: {e}")

def build_day_list(self) -> None:
"""Build data_day_list object
"""
day_list = [
day_start for day_start in pd.date_range(
UTCDateTime(self.data_starttime).strftime('%Y%m%d'), (
UTCDateTime(self.data_endtime) - (60 * 60 * 24)).strftime('%Y%m%d')).strftime('%Y-%m-%d').tolist()
if day_start not in [UTCDateTime(day_exc).strftime('%Y-%m-%d') for day_exc in self.data_exclude_days]
]

self.data_day_list = day_list
try:
start_time = UTCDateTime(self.data_starttime)
end_time = UTCDateTime(self.data_endtime)
exclude_days = [UTCDateTime(day).strftime('%Y-%m-%d') for day in self.data_exclude_days]
day_list = [
day_start for day_start in pd.date_range(start_time.strftime('%Y%m%d'), end_time.strftime('%Y%m%d')).strftime('%Y-%m-%d').tolist()
if day_start not in exclude_days
]
self.data_day_list = day_list
except Exception as e:
print(f"An error occurred while building day list: {e}")

def build_channel_list(self) -> None:
if self.sample_stream is None:
Expand Down

0 comments on commit c9c0f11

Please sign in to comment.