From 8b29db7b6d4c330f5790a06d12bc183cccecc184 Mon Sep 17 00:00:00 2001 From: Benjamin Cretois Date: Mon, 25 Sep 2023 10:01:04 +0200 Subject: [PATCH] [ADD] condition for writing .csv file --- README.md | 14 +++++++++++++- src/predict.py | 22 +++++++++++++++------- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index d758934..b2eb926 100644 --- a/README.md +++ b/README.md @@ -27,12 +27,24 @@ cd snowmobile_analyzer docker build -t snowmobile -f Dockerfile . ``` -Run the program: +Run the program using the `analyze.sh` script which is a wrapper around the Docker command ```bash ./analyze.sh ./example/example_audio.mp3 ``` +Note that if you want to have more control over the arguments you can use Docker: + +```bash +docker run \ + --rm \ + --gpus all \ + -v ./logs:/app/logs \ # Important to write the log files + -v "$FOLDER_TO_EXPOSE":/data \ + snowmobile \ + --input /data/"$FILENAME" +``` + Note that you can change `./example/example_audio.mp3` to the path of your own file. ### Use without Docker diff --git a/src/predict.py b/src/predict.py index 47f999d..171a2ed 100644 --- a/src/predict.py +++ b/src/predict.py @@ -41,7 +41,7 @@ def compute_hr(array): return hr -def predict(testLoader, model, device, threshold=0.95): +def predict(testLoader, model, device, threshold=0.99): proba_list = [] hr_list = [] @@ -112,13 +112,20 @@ def write_results(prob_audioclip_array, hr_array, outname, min_hr, min_conf): # Update the start time of the detection idx_begin = idx_end - with open(outname, "w") as file: + # Write only if there are some detections respecting our conditions + if len(rows_for_csv) > 0: + with open(outname, "w") as file: - writer = csv.writer(file) - header = ["start_detection", "end_detection", "label", "confidence", "hr"] + writer = csv.writer(file) + header = ["start_detection", "end_detection", "label", "confidence", "hr"] - writer.writerow(header) - writer.writerows(rows_for_csv) + writer.writerow(header) + writer.writerows(rows_for_csv) + else: + file_analyzed = os.path.basename(outname) + message = f"No detection has been made for {file_analyzed}" + print(message) + logging.info(message) def analyzeFile( @@ -148,6 +155,7 @@ def analyzeFile( ) pred_audioclip_array, pred_hr_array = predict(predLoader, model, device) + write_results(pred_audioclip_array, pred_hr_array, outpath, min_hr, min_conf) # Give the tim it took to analyze file @@ -189,7 +197,7 @@ def analyzeFile( parser.add_argument( "--min_conf", help="Minimum value for model confidence to take detection in", - default=0.95, + default=0.99, required=False, type=int, )