Skip to content

Commit

Permalink
Improve direction_id handling via trip, if available
Browse files Browse the repository at this point in the history
  • Loading branch information
vingerha committed Feb 12, 2025
1 parent df03c60 commit 334a918
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 84 deletions.
34 changes: 10 additions & 24 deletions custom_components/gtfs2/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,20 @@
from homeassistant.helpers import selector

from .const import (
ATTR_API_KEY_LOCATIONS,
DOMAIN,
DEFAULT_PATH,
DOMAIN,
DEFAULT_API_KEY_LOCATION,
DEFAULT_REFRESH_INTERVAL,
DEFAULT_LOCAL_STOP_REFRESH_INTERVAL,
DEFAULT_LOCAL_STOP_TIMERANGE,
DEFAULT_LOCAL_STOP_RADIUS,
DEFAULT_OFFSET,
DEFAULT_ACCEPT_HEADER_PB,
DEFAULT_API_KEY_NAME,
DEFAULT_MAX_LOCAL_STOPS,
DEFAULT_STOP_LIST,
CONF_API_KEY_LOCATION,
CONF_API_KEY,
CONF_API_KEY_NAME,
CONF_ACCEPT_HEADER_PB,
DEFAULT_ACCEPT_HEADER_PB,
DEFAULT_API_KEY_NAME,
CONF_VEHICLE_POSITION_URL,
CONF_TRIP_UPDATE_URL,
CONF_ALERTS_URL,
Expand All @@ -51,8 +48,9 @@
CONF_OFFSET,
CONF_REAL_TIME,
CONF_SOURCE_TIMEZONE_CORRECTION,
CONF_MAX_LOCAL_STOPS,
CONF_STOP_LIST
ATTR_API_KEY_LOCATIONS,
DEFAULT_MAX_LOCAL_STOPS,
CONF_MAX_LOCAL_STOPS
)

from .gtfs_helper import (
Expand All @@ -64,8 +62,7 @@
remove_datasource,
check_datasource_index,
get_agency_list,
get_local_stop_list,
get_stop_range_list
get_local_stop_list
)

_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -470,11 +467,6 @@ async def async_step_init(
) -> FlowResult:
"""Manage the options."""
errors: dict[str, str] = {}
data = self.config_entry.data
options = self.config_entry.options
self._pygtfs = get_gtfs(
self.hass, DEFAULT_PATH, data, False
)
if user_input is not None:
if self.config_entry.data.get(CONF_DEVICE_TRACKER_ID, None):
_data = user_input
Expand All @@ -486,7 +478,6 @@ async def async_step_init(
stop_limit = await _check_stop_list(self, _data)
if stop_limit :
return self.async_abort(reason=stop_limit)

if user_input.get(CONF_REAL_TIME,None):
self._user_inputs.update(user_input)
_LOGGER.debug(f"UserInputs Options Init with realtime: {self._user_inputs}")
Expand All @@ -497,24 +488,19 @@ async def async_step_init(
return self.async_create_entry(title="", data=self._user_inputs)

if self.config_entry.data.get(CONF_DEVICE_TRACKER_ID, None):
stop_list1 = [ selector.SelectOptionDict(value=r, label=r.split(':')[1]) for r in get_stop_range_list(self.hass, self._pygtfs, data, options) ]
_LOGGER.debug("Stop_list 1: %s", stop_list1)
stop_list = get_stop_range_list(self.hass, self._pygtfs, data, options)
_LOGGER.debug("Stop_list: %s", stop_list)
opt1_schema = {
vol.Optional(CONF_LOCAL_STOP_REFRESH_INTERVAL, default=self.config_entry.options.get(CONF_LOCAL_STOP_REFRESH_INTERVAL, DEFAULT_LOCAL_STOP_REFRESH_INTERVAL)): int,
vol.Optional(CONF_RADIUS, default=self.config_entry.options.get(CONF_RADIUS, DEFAULT_LOCAL_STOP_RADIUS)): vol.All(vol.Coerce(int), vol.Range(min=50, max=5000)),
vol.Optional(CONF_TIMERANGE, default=self.config_entry.options.get(CONF_TIMERANGE, DEFAULT_LOCAL_STOP_TIMERANGE)): vol.All(vol.Coerce(int), vol.Range(min=15, max=120)),
vol.Optional(CONF_OFFSET, default=self.config_entry.options.get(CONF_OFFSET, DEFAULT_OFFSET)): int,
vol.Required(CONF_MAX_LOCAL_STOPS, default=self.config_entry.options.get(CONF_MAX_LOCAL_STOPS, DEFAULT_MAX_LOCAL_STOPS)): int,
vol.Optional(CONF_REAL_TIME, default=self.config_entry.options.get(CONF_REAL_TIME)): selector.BooleanSelector(),
vol.Optional(CONF_STOP_LIST, default = self.config_entry.options.get(CONF_STOP_LIST,DEFAULT_STOP_LIST )): selector.SelectSelector(selector.SelectSelectorConfig(options=stop_list1, translation_key="stop_range_list",custom_value=True, multiple=True, mode=selector.SelectSelectorMode.LIST,)),
}
vol.Optional(CONF_REAL_TIME, default=self.config_entry.options.get(CONF_REAL_TIME)): selector.BooleanSelector()
}
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(opt1_schema),
errors = errors
)
)

else:
opt1_schema = {
Expand Down
2 changes: 0 additions & 2 deletions custom_components/gtfs2/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
DEFAULT_LOCAL_STOP_TIMERANGE_HISTORY = 15
DEFAULT_LOCAL_STOP_RADIUS = 200
DEFAULT_MAX_LOCAL_STOPS = 15
DEFAULT_STOP_LIST = ['All']

DEFAULT_NAME = "GTFS Sensor2"
DEFAULT_PATH = "gtfs2"
Expand Down Expand Up @@ -290,7 +289,6 @@
CONF_REAL_TIME = "real_time"
CONF_SOURCE_TIMEZONE_CORRECTION = "source_timezone_correction"
CONF_MAX_LOCAL_STOPS = "max_local_stops"
CONF_STOP_LIST = "stop_list"

# gtfs_rt specific
CONF_API_KEY = "api_key"
Expand Down
2 changes: 1 addition & 1 deletion custom_components/gtfs2/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ async def _async_update_data(self) -> dict[str, str]:
self._stop_sequence = self._data["next_departure"]["origin_stop_sequence"]
self._destination_id = data["destination"].split(": ")[0]
self._trip_id = self._data.get('next_departure', {}).get('trip_id', None)
self._direction = data["direction"]
self._direction = self._data.get('next_departure', {}).get('trip_direction_id', data["direction"])
self._relative = False
try:
self._get_rt_alerts = await self.hass.async_add_executor_job(get_rt_alerts, self)
Expand Down
65 changes: 8 additions & 57 deletions custom_components/gtfs2/gtfs_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
CONF_API_KEY_LOCATION,
CONF_API_KEY_NAME,
CONF_ACCEPT_HEADER_PB,
CONF_RADIUS,
DEFAULT_LOCAL_STOP_TIMERANGE,
DEFAULT_LOCAL_STOP_TIMERANGE_HISTORY,
DEFAULT_LOCAL_STOP_RADIUS,
Expand Down Expand Up @@ -97,8 +96,8 @@ def get_next_departure(self):
tomorrow_calendar_date_where = f"AND (calendar_date_today.date = date('now') or calendar_date_today.date = date('now','+1 day') )"
tomorrow_select2 = f"CASE WHEN date('now') < calendar_date_today.date THEN '1' else '0' END as tomorrow,"
sql_query = f"""
SELECT trip.trip_id, trip.route_id,trip.trip_headsign,
route.route_long_name,route.route_short_name,
SELECT trip.trip_id, trip.route_id,trip.trip_headsign, trip.direction_id,
route.route_long_name,route.route_short_name,
start_station.stop_id as origin_stop_id,
start_station.stop_name as origin_stop_name,
start_station.stop_timezone as origin_stop_timezone,
Expand Down Expand Up @@ -156,7 +155,7 @@ def get_next_departure(self):
AND calendar.end_date >= date('now')
AND trip.service_id not in (select service_id from calendar_dates where date = date('now') and exception_type = 2)
UNION ALL
SELECT trip.trip_id, trip.route_id,trip.trip_headsign,
SELECT trip.trip_id, trip.route_id,trip.trip_headsign, trip.direction_id,
route.route_long_name,route.route_short_name,
start_station.stop_id as origin_stop_id,
start_station.stop_name as origin_stop_name,
Expand Down Expand Up @@ -427,6 +426,7 @@ def get_next_departure(self):
data_returned = {
"trip_id": item["trip_id"],
"route_id": item["route_id"],
"trip_direction_id": item["direction_id"],
"day": item["day"],
"first": item["first"],
"last": item["last"],
Expand Down Expand Up @@ -641,37 +641,6 @@ def get_stop_list(schedule, route_id, direction):
stops.append(val)
_LOGGER.debug(f"stops: {stops}")
return stops

def get_stop_range_list(hass, schedule, data, options):
_LOGGER.debug("Getting local stop list with data: %s", data)
device_tracker = hass.states.get(data['device_tracker_id'])
latitude = device_tracker.attributes.get("latitude", None)
longitude = device_tracker.attributes.get("longitude", None)
radius = options.get(CONF_RADIUS, DEFAULT_LOCAL_STOP_RADIUS) / 111111

sql_stops = f"""
SELECT stop.stop_id, stop.stop_name
FROM stops stop
where abs(stop.stop_lat - :latitude) < :radius and abs(stop.stop_lon - :longitude) < :radius
""" # noqa: S608
result = schedule.engine.connect().execute(
text(sql_stops),
{
"latitude": latitude,
"longitude": longitude,
"radius": radius
},
)
stops_range_list = []
stops = []
for row_cursor in result:
row = row_cursor._asdict()
stops_range_list.append(list(row_cursor))
for x in stops_range_list:
val = x[0] + ": " + x[1]
stops.append(val)
_LOGGER.debug(f"stops in range: {stops}")
return stops

def get_agency_list(schedule, data):
_LOGGER.debug("Getting agencies with data: %s", data)
Expand Down Expand Up @@ -906,12 +875,7 @@ def get_local_stop_list(hass, schedule, data):


def get_local_stops_next_departures(self):
_LOGGER.debug("Get local stop departure with _data: %s", self._data)
_LOGGER.debug("Get local stop departure with options: %s", self.config_entry.options)
_LOGGER.debug("Stopslist 1:: %s", self.config_entry.options.get('stop_list', None))



_LOGGER.debug("Get local stop departure with data: %s", self._data)
if check_extracting(self.hass, self._data['gtfs_dir'],self._data['file']):
_LOGGER.warning("Cannot get next depurtures on this datasource as still unpacking: %s", self._data["file"])
return {}
Expand Down Expand Up @@ -942,18 +906,8 @@ def get_local_stops_next_departures(self):
tomorrow_name = tomorrow.strftime("%A").lower()
tomorrow_select = f"calendar.{tomorrow_name} AS tomorrow,"
tomorrow_calendar_date_where = f"AND (calendar_date_today.date = date(:now_offset) or calendar_date_today.date = date(:now_offset,'+1 day'))"
tomorrow_select2 = f"CASE WHEN date(:now_offset) < calendar_date_today.date THEN '1' else '0' END as tomorrow,"
stop_list_where = 'AND 1=1'
my_stop_list = self.config_entry.options.get('stop_list',[])
_LOGGER.debug("Local stops list: %s", my_stop_list)
if my_stop_list and my_stop_list != 'All':
# get the stop_id, in a mysql acceptable list with ()
my_stops = ', '.join(["'" + str(i.split(':')[0]) + "'" for i in my_stop_list])
my_stops = '(' + my_stops + ')'
stop_list_where = f"AND stop.stop_id in {my_stops}"

_LOGGER.debug("Query values, Latitude %s - Longitude %s - Timerange %s - Timerange_history %s - Radius %s - Now: %s", latitude, longitude, time_range, time_range_history, radius, now)
_LOGGER.debug("Query where statements, tomorrow_calendar_date_where: %s - stop_list_where: %s",tomorrow_calendar_date_where,stop_list_where)
tomorrow_select2 = f"CASE WHEN date(:now_offset) < calendar_date_today.date THEN '1' else '0' END as tomorrow,"
_LOGGER.debug("Query params: Latitude %s - Longitude %s - Timerange %s - Timerange_history %s - Radius %s - Now: %s", latitude, longitude, time_range, time_range_history, radius, now)
sql_query = f"""
SELECT * FROM (
SELECT stop.stop_id, stop.stop_name,stop.stop_lat as latitude, stop.stop_lon as longitude, stop.stop_timezone as stop_timezone, agency.agency_timezone as agency_timezone, trip.trip_id, trip.trip_headsign, trip.direction_id, time(st.departure_time) as departure_time,st.stop_sequence as stop_sequence,
Expand Down Expand Up @@ -985,8 +939,7 @@ def get_local_stops_next_departures(self):
and ((datetime(date(:now_offset) || ' ' || time(st.departure_time) ) between datetime(:now_offset,:timerange_history) and datetime(:now_offset,:timerange))
or (datetime(date(:now_offset,'+1 day') || ' ' || time(st.departure_time) ) between datetime(:now_offset,:timerange_history) and datetime(:now_offset,:timerange)))
AND calendar.start_date <= date(:now_offset)
AND calendar.end_date >= date(:now_offset)
{stop_list_where}
AND calendar.end_date >= date(:now_offset)
)
UNION ALL
SELECT * FROM (
Expand Down Expand Up @@ -1019,9 +972,7 @@ def get_local_stops_next_departures(self):
and ((datetime(date(:now_offset) || ' ' || time(st.departure_time) ) between datetime(:now_offset,:timerange_history) and datetime(:now_offset,:timerange))
or (datetime(date(:now_offset,'+1 day') || ' ' || time(st.departure_time) ) between datetime(:now_offset,:timerange_history) and datetime(:now_offset,:timerange)))
{tomorrow_calendar_date_where}
{stop_list_where}
)
order by stop_id, tomorrow, departure_time
""" # noqa: S608
result = schedule.engine.connect().execute(
Expand Down

0 comments on commit 334a918

Please sign in to comment.