diff --git a/requirements.txt b/requirements.txt index 1399d57..10e41df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,10 @@ defusedxml==0.7.1 +langchain_core==0.3.6 +langchain_openai==0.2.1 +numpy==2.1.1 +pandas==2.2.3 +python-dotenv==1.0.1 questionary==2.0.1 +scipy==1.14.1 tcxreader==0.4.10 +tqdm==4.66.5 diff --git a/src/main.py b/src/main.py index 9451615..a97655f 100644 --- a/src/main.py +++ b/src/main.py @@ -1,14 +1,25 @@ import re import os +import time import logging import webbrowser -import time -from defusedxml.minidom import parseString + +from typing import Tuple import questionary +import numpy as np +import pandas as pd +from tqdm import tqdm +from dotenv import load_dotenv +from langchain_openai import ChatOpenAI +from langchain_core.prompts.prompt import PromptTemplate +from defusedxml.minidom import parseString +from scipy.spatial.distance import squareform, pdist from tcxreader.tcxreader import TCXReader + +load_dotenv() logger = logging.getLogger() if not logger.handlers: @@ -39,15 +50,22 @@ def main(): else: file_path = ask_file_path(file_location) - if sport in ["Swim", "Other"]: - logger.info("Formatting the TCX file to be imported to TrainingPeaks") - format_to_swim(file_path) - elif sport in ["Bike", "Run"]: - logger.info("Validating the TCX file") - validate_tcx_file(file_path) - else: - logger.error("Invalid sport selected") - raise ValueError("Invalid sport selected") + if file_path: + if sport in ["Swim", "Other"]: + logger.info( + "Formatting the TCX file to be imported to TrainingPeaks" + ) + format_to_swim(file_path) + elif sport in ["Bike", "Run"]: + logger.info("Validating the TCX file") + _, tcx_data = validate_tcx_file(file_path) + if ask_llm_analysis(): + plan = ask_training_plan() + logger.info("Performing LLM analysis") + perform_llm_analysis(tcx_data, sport, plan) + else: + logger.error("Invalid sport selected") + raise ValueError("Invalid sport selected") indent_xml_file(file_path) logger.info("Process completed successfully!") @@ -76,7 +94,8 @@ def ask_activity_id() -> str: def download_tcx_file(activity_id: str, sport: str) -> None: if sport in ["Swim", "Other"]: - url = f"https://www.strava.com/activities/{activity_id}/export_original" + url = f"https://www.strava.com/activities/{ + activity_id}/export_original" else: url = f"https://www.strava.com/activities/{activity_id}/export_tcx" try: @@ -104,14 +123,23 @@ def get_latest_download() -> str: return latest_file -def ask_file_path(file_location) -> str: - question = "Enter the path to the TCX file:" if file_location == "Provide path" else "Check if the TCX file was downloaded and then enter the path to the file:" +def ask_file_path(file_location: str) -> str: + if file_location == "Provide path": + question = "Enter the path to the TCX file:" + else: + question = "Check if the TCX was downloaded and validate the file:" + return questionary.path( question, - validate=os.path.isfile + validate=validation, + only_directories=False ).ask() +def validation(path: str) -> bool: + return os.path.isfile(path) + + def format_to_swim(file_path: str) -> None: xml_str = read_xml_file(file_path) xml_str = modify_xml_header(xml_str) @@ -138,7 +166,7 @@ def write_xml_file(file_path: str, xml_str: str) -> None: xml_file.write(xml_str) -def validate_tcx_file(file_path: str) -> bool: +def validate_tcx_file(file_path: str) -> Tuple[bool, TCXReader]: xml_str = read_xml_file(file_path) if not xml_str: logger.error("The TCX file is empty.") @@ -151,12 +179,112 @@ def validate_tcx_file(file_path: str) -> bool: "The TCX file is valid. You covered a significant distance in this activity, with %d meters.", data.distance ) - return True + return True, data except Exception as err: logger.error("Invalid TCX file.") raise ValueError(f"Error reading the TCX file: {err}") from err +def ask_llm_analysis() -> str: + return questionary.confirm( + "Do you want to perform AI analysis?", + default=False + ).ask() + + +def ask_training_plan() -> str: + return questionary.text( + "Was there anything planned for this training?" + ).ask() + + +def perform_llm_analysis(data: TCXReader, sport: str, plan: str) -> str: + dataframe = preprocess_trackpoints_data(data) + + prompt = """SYSTEM: You are an AI Assistant that helps athletes to improve their performance. + Based on the following csv data that is related to a {sport} training session, carry out an analysis highlighting positive points, where the athlete did well and where he did poorly and what he can do to improve in the next {sport}. + + {data} + + """ + prompt += "plan: {plan}" if plan else "" + prompt = PromptTemplate.from_template(prompt) + prompt = prompt.format( + sport=sport, + data=dataframe.to_csv(index=False), + plan=plan + ) + + openai_llm = ChatOpenAI( + openai_api_key=os.getenv("OPENAI_API_KEY"), + model_name="gpt-4o", + max_tokens=1500, + temperature=0.6, + max_retries=5 + ) + response = openai_llm.invoke(prompt) + logger.info("AI analysis completed successfully.") + logger.info("\nAI response:\n %s \n", response.content) + return response.content + + +def preprocess_trackpoints_data(data): + dataframe = pd.DataFrame(data.trackpoints_to_dict()) + dataframe.rename( + columns={ + "distance": "Distance_Km", + "time": "Time", + "Speed": "Speed_Kmh" + }, inplace=True + ) + dataframe["Time"] = dataframe["Time"].apply(lambda x: x.value / 10**9) + dataframe["Distance_Km"] = round(dataframe["Distance_Km"] / 1000, 2) + dataframe["Speed_Kmh"] = dataframe["Speed_Kmh"] * 3.6 + dataframe["Pace"] = round( + dataframe["Speed_Kmh"].apply(lambda x: 60 / x if x > 0 else 0), + 2 + ) + if dataframe["cadence"].isnull().sum() >= len(dataframe) / 2: + dataframe.drop(columns=["cadence"], inplace=True) + + dataframe = dataframe.drop_duplicates() + dataframe = dataframe.reset_index(drop=True) + dataframe = dataframe.dropna() + + if dataframe.shape[0] > 4000: + dataframe = run_euclidean_dist_deletion(dataframe, 0.55) + elif dataframe.shape[0] > 1000: + dataframe = run_euclidean_dist_deletion(dataframe, 0.35) + else: + dataframe = run_euclidean_dist_deletion(dataframe, 0.10) + + dataframe["Time"] = pd.to_datetime( + dataframe["Time"], + unit='s' + ).dt.strftime('%H:%M:%S') + + return dataframe + + +def run_euclidean_dist_deletion(dataframe: pd.DataFrame, percentage: float) -> pd.DataFrame: + dists = pdist(dataframe, metric='euclidean') + dists = squareform(dists) + np.fill_diagonal(dists, np.inf) + + total_rows = int(percentage * len(dataframe)) + with tqdm(total=total_rows, desc="Removing similar points") as pbar: + for _ in range(total_rows): + min_idx = np.argmin(dists) + row, col = np.unravel_index(min_idx, dists.shape) + dists[row, :] = np.inf + dists[:, col] = np.inf + dataframe = dataframe.drop(row) + pbar.update(1) + + dataframe = dataframe.reset_index(drop=True) + return dataframe + + def indent_xml_file(file_path: str) -> None: try: with open(file_path, "r", encoding='utf-8') as xml_file: diff --git a/tests/test_main.py b/tests/test_main.py index adb94ce..b102fc2 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,7 +1,10 @@ import os -#import sys +# import sys import unittest + from unittest.mock import patch +from pandas import DataFrame +from tcxreader.tcxreader import TCXReader # sys.path.append(os.path.abspath('')) @@ -18,13 +21,21 @@ ask_file_location, ask_activity_id, ask_file_path, - get_latest_download + get_latest_download, + validation, + ask_training_plan, + ask_llm_analysis, + perform_llm_analysis, + preprocess_trackpoints_data, + run_euclidean_dist_deletion ) class TestMain(unittest.TestCase): def setUp(self) -> None: - pass + tcx_reader = TCXReader() + self.running_example_data = tcx_reader.read("assets/run.tcx") + self.biking_example_data = tcx_reader.read("assets/bike.tcx") @patch('src.main.webbrowser.open') def test_download_tcx_file(self, mock_open): @@ -96,6 +107,7 @@ def test_validate_tcx_file(self): file_path = "assets/bike.tcx" result = validate_tcx_file(file_path) self.assertTrue(result) + self.assertEqual(len(result), 2) def test_validate_tcx_file_error(self): file_path = "assets/swim.tcx" @@ -185,6 +197,9 @@ def test_main_invalid_sport(self, mock_indent, mock_validate, mock_format, mock_ mock_validate.assert_not_called() mock_indent.assert_not_called() + @patch('src.main.ask_training_plan') + @patch('src.main.perform_llm_analysis') + @patch('src.main.ask_llm_analysis') @patch('src.main.ask_sport') @patch('src.main.ask_file_location') @patch('src.main.ask_activity_id') @@ -194,10 +209,15 @@ def test_main_invalid_sport(self, mock_indent, mock_validate, mock_format, mock_ @patch('src.main.validate_tcx_file') @patch('src.main.indent_xml_file') def test_main_bike_sport(self, mock_indent, mock_validate, mock_format, mock_ask_path, mock_download, - mock_ask_id, mock_ask_location, mock_ask_sport): + mock_ask_id, mock_ask_location, mock_ask_sport, mock_llm_analysis, mock_perform_llm, + mock_training_plan): mock_ask_sport.return_value = "Bike" mock_ask_location.return_value = "Local" mock_ask_path.return_value = "assets/bike.tcx" + mock_llm_analysis.return_value = True + mock_validate.return_value = True, "TCX Data" + mock_perform_llm.return_value = "Training Plan" + mock_training_plan.return_value = "" main() @@ -207,6 +227,8 @@ def test_main_bike_sport(self, mock_indent, mock_validate, mock_format, mock_ask mock_ask_path.assert_called_once() mock_download.assert_not_called() mock_format.assert_not_called() + mock_llm_analysis.assert_called_once() + mock_perform_llm.assert_called_once() mock_validate.assert_called_once_with("assets/bike.tcx") mock_indent.assert_called_once_with("assets/bike.tcx") @@ -245,7 +267,8 @@ def test_ask_file_path(self): result = ask_file_path("Provide path") mock_path.assert_called_once_with( "Enter the path to the TCX file:", - validate=os.path.isfile + validate=validation, + only_directories=False ) self.assertEqual(result, "assets/test.tcx") @@ -254,8 +277,9 @@ def test_ask_file_path(self): mock_path.return_value.ask.return_value = "assets/downloaded.tcx" result = ask_file_path("Download") mock_path.assert_called_once_with( - "Check if the TCX file was downloaded and then enter the path to the file:", - validate=os.path.isfile + "Check if the TCX was downloaded and validate the file:", + validate=validation, + only_directories=False ) self.assertEqual(result, "assets/downloaded.tcx") @@ -275,6 +299,60 @@ def test_get_latest_downloads_with_ask(self, mock_ask_path): self.assertEqual(result, "assets/bike.tcx") + def test_validation(self): + file_path = "assets/bike.tcx" + result = validation(file_path) + + self.assertTrue(result) + + def test_ask_training_plan(self): + with patch('src.main.questionary.text') as mock_text: + mock_text.return_value.ask.return_value = "" + result = ask_training_plan() + mock_text.assert_called_once_with( + "Was there anything planned for this training?" + ) + self.assertEqual(result, "") + + def test_ask_llm_analysis(self): + with patch('src.main.questionary.confirm') as mock_confirm: + mock_confirm.return_value.ask.return_value = True + result = ask_llm_analysis() + mock_confirm.assert_called_once_with( + "Do you want to perform AI analysis?", + default=False + ) + self.assertTrue(result) + + @patch('src.main.ChatOpenAI') + def test_perform_llm_analysis(self, mock_chat): + mock_invoke = mock_chat.return_value.invoke.return_value + mock_invoke.content = "Training Plan" + tcx_data = self.running_example_data + sport = "Run" + plan = "Training Plan" + + result = perform_llm_analysis(tcx_data, sport, plan) + self.assertEqual(result, "Training Plan") + + def test_preprocess_running_trackpoints_data(self): + tcx_data = self.running_example_data + result = preprocess_trackpoints_data(tcx_data) + self.assertEqual(len(result), 1646) + + def test_preprocess_biking_trackpoints_data(self): + tcx_data = self.biking_example_data + result = preprocess_trackpoints_data(tcx_data) + self.assertEqual(len(result), 2028) + + def test_run_euclidean_distance(self): + dataframe = DataFrame({ + 'latitude': [1, 2, 3, 3.5, 4, 5, 6, 6.5, 7, 8, 9], + 'longitude': [1, 2, 3, 3.5, 4, 5, 6, 6.5, 7, 8, 9] + }) + result = run_euclidean_dist_deletion(dataframe, 0.1) + self.assertEqual(len(result), 10) + if __name__ == '__main__': unittest.main()