-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #852 from robfitzgerald/rjf/incremental-inference
Rjf/incremental inference
- Loading branch information
Showing
33 changed files
with
2,116 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
*.swp | ||
*debug.log | ||
.DS_Store | ||
.vscode | ||
|
||
CFC_WebApp/config.json | ||
CFC_WebApp/keys.json | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
{ | ||
"model_type": "greedy", | ||
"model_storage": "document_database", | ||
"minimum_trips": 14, | ||
"model_parameters": { | ||
"greedy": { | ||
"metric": "od_similarity", | ||
"similarity_threshold_meters": 500, | ||
"apply_cutoff": false, | ||
"incremental_evaluation": false | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
64 changes: 64 additions & 0 deletions
64
emission/analysis/modelling/similarity/confirmed_trip_feature_extraction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from typing import List | ||
import emission.core.wrapper.confirmedtrip as ecwc | ||
|
||
|
||
def origin_features(trip: ecwc.Confirmedtrip) -> List[float]: | ||
"""extract the trip origin coordinates. | ||
:param trip: trip to extract features from | ||
:return: origin coordinates | ||
""" | ||
try: | ||
origin = trip['data']['start_loc']["coordinates"] | ||
return origin | ||
except KeyError as e: | ||
msg = 'Confirmedtrip expected to have path data.start_loc.coordinates' | ||
raise KeyError(msg) from e | ||
|
||
def destination_features(trip: ecwc.Confirmedtrip) -> List[float]: | ||
"""extract the trip destination coordinates. | ||
:param trip: trip to extract features from | ||
:return: destination coordinates | ||
""" | ||
try: | ||
destination = trip['data']['end_loc']["coordinates"] | ||
return destination | ||
except KeyError as e: | ||
msg = 'Confirmedtrip expected to have path data.end_loc.coordinates' | ||
raise KeyError(msg) from e | ||
|
||
|
||
def od_features(trip: ecwc.Confirmedtrip) -> List[float]: | ||
"""extract both origin and destination coordinates. | ||
:param trip: trip to extract features from | ||
:return: od coordinates | ||
""" | ||
o_lat, o_lon = origin_features(trip) | ||
d_lat, d_lon = destination_features(trip) | ||
return [o_lat, o_lon, d_lat, d_lon] | ||
|
||
def distance_feature(trip: ecwc.Confirmedtrip) -> List[float]: | ||
"""provided for forward compatibility. | ||
:param trip: trip to extract features from | ||
:return: distance feature | ||
""" | ||
try: | ||
return [trip['data']['distance']] | ||
except KeyError as e: | ||
msg = 'Confirmedtrip expected to have path data.distance' | ||
raise KeyError(msg) from e | ||
|
||
def duration_feature(trip: ecwc.Confirmedtrip) -> List[float]: | ||
"""provided for forward compatibility. | ||
:param trip: trip to extract features from | ||
:return: duration feature | ||
""" | ||
try: | ||
return [trip['data']['duration']] | ||
except KeyError as e: | ||
msg = 'Confirmedtrip expected to have path data.duration' | ||
raise KeyError(msg) from e |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from typing import List | ||
import emission.analysis.modelling.similarity.similarity_metric as eamss | ||
import emission.analysis.modelling.similarity.confirmed_trip_feature_extraction as ctfe | ||
import emission.core.wrapper.confirmedtrip as ecwc | ||
import emission.core.common as ecc | ||
|
||
|
||
class OriginDestinationSimilarity(eamss.SimilarityMetric): | ||
""" | ||
similarity metric which compares, for two trips, | ||
the distance for origin to origin, and destination to destination, | ||
in meters. | ||
""" | ||
|
||
def extract_features(self, trip: ecwc.Confirmedtrip) -> List[float]: | ||
return ctfe.od_features(trip) | ||
|
||
def similarity(self, a: List[float], b: List[float]) -> List[float]: | ||
o_dist = ecc.calDistance([a[0], a[1]], [b[0], b[1]]) | ||
d_dist = ecc.calDistance([a[2], a[3]], [b[2], b[3]]) | ||
return [o_dist, d_dist] |
41 changes: 41 additions & 0 deletions
41
emission/analysis/modelling/similarity/similarity_metric.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from abc import ABCMeta, abstractmethod | ||
from typing import List | ||
import logging | ||
|
||
import emission.core.wrapper.confirmedtrip as ecwc | ||
|
||
|
||
class SimilarityMetric(metaclass=ABCMeta): | ||
|
||
@abstractmethod | ||
def extract_features(self, trip: ecwc.Confirmedtrip) -> List[float]: | ||
"""extracts the features we want to compare for similarity | ||
:param trip: a confirmed trip | ||
:return: the features to compare | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def similarity(self, a: List[float], b: List[float]) -> List[float]: | ||
"""compares the features, producing their similarity | ||
as computed by this similarity metric | ||
:param a: features for a trip | ||
:param b: features for another trip | ||
:return: for each feature, the similarity of these features | ||
""" | ||
pass | ||
|
||
def similar(self, a: List[float], b: List[float], thresh: float) -> bool: | ||
"""compares the features, returning true if they are similar | ||
within some threshold | ||
:param a: features for a trip | ||
:param b: features for another trip | ||
:param thresh: threshold for similarity | ||
:return: true if the feature similarity is within some threshold | ||
""" | ||
similarity_values = self.similarity(a, b) | ||
is_similar = all(map(lambda sim: sim <= thresh, similarity_values)) | ||
return is_similar |
49 changes: 49 additions & 0 deletions
49
emission/analysis/modelling/similarity/similarity_metric_type.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from __future__ import annotations | ||
import enum | ||
|
||
|
||
import emission.analysis.modelling.similarity.od_similarity as eamso | ||
import emission.analysis.modelling.similarity.similarity_metric as eamss | ||
|
||
class SimilarityMetricType(enum.Enum): | ||
OD_SIMILARITY = 0 | ||
|
||
def build(self) -> eamss.SimilarityMetric: | ||
""" | ||
hey YOU! add future similarity metric types here please! | ||
:raises KeyError: if the SimilarityMetricType isn't found in the below dictionary | ||
:return: the associated similarity metric | ||
""" | ||
metrics = { | ||
SimilarityMetricType.OD_SIMILARITY: eamso.OriginDestinationSimilarity() | ||
} | ||
|
||
metric = metrics.get(self) | ||
if metric is None: | ||
names = "{" + ",".join(SimilarityMetricType.names) + "}" | ||
msg = f"unknown metric type {metric}, must be one of {names}" | ||
raise KeyError(msg) | ||
else: | ||
return metric | ||
|
||
|
||
@classmethod | ||
def names(cls): | ||
return list(map(lambda e: e.name, list(cls))) | ||
|
||
@classmethod | ||
def from_str(cls, str): | ||
"""attempts to match the provided string to a known SimilarityMetricType. | ||
not case sensitive. | ||
:param str: a string name of a SimilarityMetricType | ||
""" | ||
try: | ||
str_caps = str.upper() | ||
return cls[str_caps] | ||
except KeyError: | ||
names = "{" + ",".join(cls.names) + "}" | ||
msg = f"{str} is not a known SimilarityMetricType, must be one of {names}" | ||
raise KeyError(msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import json | ||
import re | ||
from this import d | ||
from typing import Optional | ||
import logging | ||
from numpy import isin | ||
|
||
import emission.analysis.modelling.trip_model.model_storage as eamums | ||
import emission.analysis.modelling.trip_model.model_type as eamumt | ||
|
||
config_filename = "" | ||
|
||
def load_config(): | ||
global config_filename | ||
try: | ||
config_filename = 'conf/analysis/trip_model.conf.json' | ||
config_file = open(config_filename) | ||
except: | ||
print("analysis.trip_model.conf.json not configured, falling back to sample, default configuration") | ||
config_filename = 'conf/analysis/trip_model.conf.json.sample' | ||
config_file = open('conf/analysis/trip_model.conf.json.sample') | ||
ret_val = json.load(config_file) | ||
config_file.close() | ||
return ret_val | ||
|
||
config_data = load_config() | ||
|
||
def reload_config(): | ||
global config_data | ||
config_data = load_config() | ||
|
||
def get_config(): | ||
return config_data | ||
|
||
def get_optional_config_value(key) -> Optional[str]: | ||
""" | ||
get a config value at the provided path/key | ||
:param key: a key name or a dot-delimited path to some key within the config object | ||
:return: the value at the key, or, None if not found | ||
""" | ||
cursor = config_data | ||
path = key.split(".") | ||
for k in path: | ||
cursor = cursor.get(k) | ||
if cursor is None: | ||
return None | ||
return cursor | ||
|
||
def get_config_value_or_raise(key): | ||
logging.debug(f'getting key {key} in config') | ||
value = get_optional_config_value(key) | ||
if value is None: | ||
logging.debug('config object:') | ||
logging.debug(json.dumps(config_data, indent=2)) | ||
msg = f"expected config key {key} not found in config file {config_filename}" | ||
raise KeyError(msg) | ||
else: | ||
return value | ||
|
||
def get_model_type(): | ||
model_type_str = get_config_value_or_raise('model_type') | ||
model_type = eamumt.ModelType.from_str(model_type_str) | ||
return model_type | ||
|
||
def get_model_storage(): | ||
model_storage_str = get_config_value_or_raise('model_storage') | ||
model_storage = eamums.ModelStorage.from_str(model_storage_str) | ||
return model_storage | ||
|
||
def get_minimum_trips(): | ||
minimum_trips = get_config_value_or_raise('minimum_trips') | ||
if not isinstance(minimum_trips, int): | ||
msg = f"config key 'minimum_trips' not an integer in config file {config_filename}" | ||
raise TypeError(msg) | ||
return minimum_trips | ||
|
||
|
||
|
Oops, something went wrong.