From 96870511463d644868deede6f47b0f2f65788eab Mon Sep 17 00:00:00 2001 From: Bryon Lewis Date: Tue, 17 Dec 2024 13:43:28 -0500 Subject: [PATCH] initial prediction on compressed spectrogram generation --- bats_ai/core/admin/__init__.py | 2 ++ bats_ai/core/admin/recording_annotations.py | 28 +++++++++++++++ ...012_recordingannotation_additional_data.py | 19 ++++++++++ bats_ai/core/models/recording_annotation.py | 3 ++ bats_ai/core/tasks.py | 35 +++++++++++++++++-- bats_ai/core/views/recording.py | 23 ++++++++++++ 6 files changed, 107 insertions(+), 3 deletions(-) create mode 100644 bats_ai/core/admin/recording_annotations.py create mode 100644 bats_ai/core/migrations/0012_recordingannotation_additional_data.py diff --git a/bats_ai/core/admin/__init__.py b/bats_ai/core/admin/__init__.py index 8f782b8..52f0de7 100644 --- a/bats_ai/core/admin/__init__.py +++ b/bats_ai/core/admin/__init__.py @@ -3,6 +3,7 @@ from .grts_cells import GRTSCellsAdmin from .image import ImageAdmin from .recording import RecordingAdmin +from .recording_annotations import RecordingAnnotationAdmin from .species import SpeciesAdmin from .spectrogram import SpectrogramAdmin from .temporal_annotations import TemporalAnnotationsAdmin @@ -16,4 +17,5 @@ 'SpeciesAdmin', 'GRTSCellsAdmin', 'CompressedSpectrogramAdmin', + 'RecordingAnnotationAdmin', ] diff --git a/bats_ai/core/admin/recording_annotations.py b/bats_ai/core/admin/recording_annotations.py new file mode 100644 index 0000000..d83c31c --- /dev/null +++ b/bats_ai/core/admin/recording_annotations.py @@ -0,0 +1,28 @@ +from django.contrib import admin + +from bats_ai.core.models import RecordingAnnotation + + +@admin.register(RecordingAnnotation) +class RecordingAnnotationAdmin(admin.ModelAdmin): + list_display = [ + 'pk', + 'recording', + 'owner', + 'species_codes', # Add the custom field here + 'confidence', + 'additional_data', + 'comments', + 'model', + ] + list_select_related = True + filter_horizontal = ('species',) # or filter_vertical + autocomplete_fields = ['owner'] + + # Custom method to display the species codes as a comma-separated string + @admin.display(description='Species Codes') + def species_codes(self, obj): + # Assuming species have a `species_code` field + return ', '.join([species.species_code for species in obj.species.all()]) + + # Optionally, you can also add a verbose name for this field diff --git a/bats_ai/core/migrations/0012_recordingannotation_additional_data.py b/bats_ai/core/migrations/0012_recordingannotation_additional_data.py new file mode 100644 index 0000000..e2620e4 --- /dev/null +++ b/bats_ai/core/migrations/0012_recordingannotation_additional_data.py @@ -0,0 +1,19 @@ +# Generated by Django 4.1.13 on 2024-12-17 17:57 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ('core', '0011_alter_annotations_options_annotations_confidence_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='recordingannotation', + name='additional_data', + field=models.JSONField( + blank=True, help_text='Additional information about the models/data', null=True + ), + ), + ] diff --git a/bats_ai/core/models/recording_annotation.py b/bats_ai/core/models/recording_annotation.py index 3429560..ab5d9b5 100644 --- a/bats_ai/core/models/recording_annotation.py +++ b/bats_ai/core/models/recording_annotation.py @@ -21,3 +21,6 @@ class RecordingAnnotation(TimeStampedModel, models.Model): ], help_text='A confidence value between 0 and 1.0, default is 1.0.', ) + additional_data = models.JSONField( + blank=True, null=True, help_text='Additional information about the models/data' + ) diff --git a/bats_ai/core/tasks.py b/bats_ai/core/tasks.py index 18604e2..65fd580 100644 --- a/bats_ai/core/tasks.py +++ b/bats_ai/core/tasks.py @@ -8,7 +8,15 @@ import numpy as np import scipy -from bats_ai.core.models import Annotations, CompressedSpectrogram, Recording, Spectrogram, colormap +from bats_ai.core.models import ( + Annotations, + CompressedSpectrogram, + Recording, + RecordingAnnotation, + Species, + Spectrogram, + colormap, +) def generate_compressed(recording: Recording, spectrogram: Spectrogram): @@ -178,7 +186,8 @@ def recording_compute_spectrogram(recording_id: int): if cmap is None: spectrogram_id = spectrogram_id_temp if spectrogram_id is not None: - generate_compress_spectrogram.delay(recording_id, spectrogram_id) + compressed_spectro = generate_compress_spectrogram(recording_id, spectrogram_id) + predict(compressed_spectro.pk) @shared_task @@ -197,7 +206,7 @@ def generate_compress_spectrogram(recording_id: int, spectrogram_id: int): existing.cache_invalidated = False existing.save() else: - CompressedSpectrogram.objects.create( + existing = CompressedSpectrogram.objects.create( recording=recording, spectrogram=spectrogram, image_file=image_file, @@ -207,10 +216,30 @@ def generate_compress_spectrogram(recording_id: int, spectrogram_id: int): stops=stops, cache_invalidated=False, ) + return existing @shared_task def predict(compressed_spectrogram_id: int): compressed_spectrogram = CompressedSpectrogram.objects.get(pk=compressed_spectrogram_id) label, score, confs = compressed_spectrogram.predict() + confidences = [{'label': key, 'value': float(value)} for key, value in confs.items()] + sorted_confidences = sorted(confidences, key=lambda x: x['value'], reverse=True) + output = { + 'label': label, + 'score': float(score), + 'confidences': sorted_confidences, + } + species = Species.objects.filter(species_code=label) + + recording_annotation = RecordingAnnotation.objects.create( + recording=compressed_spectrogram.recording, + owner=compressed_spectrogram.recording.owner, + comments='Compressed Spectrogram Generation Prediction', + model='model.mobilenet.onnx', + confidence=output['score'], + additional_data=output, + ) + recording_annotation.species.set(species) + recording_annotation.save() return label, score, confs diff --git a/bats_ai/core/views/recording.py b/bats_ai/core/views/recording.py index c935880..3f0adb7 100644 --- a/bats_ai/core/views/recording.py +++ b/bats_ai/core/views/recording.py @@ -837,3 +837,26 @@ def delete_temporal_annotation(request, recording_id: int, id: int): return {'error': 'Recording not found'} except Annotations.DoesNotExist: return {'error': 'Annotation not found'} + + +# TODO - this may be modified to use different models in the +@router.post('/{id}/spectrogram/compressed/predict') +def precit_spectrogram_compressed(request: HttpRequest, id: int): + try: + recording = Recording.objects.get(pk=id) + compressed_spectrogram = CompressedSpectrogram.objects.filter(recording=id).first() + except compressed_spectrogram.DoesNotExist: + return {'error': 'Compressed Spectrogram'} + except recording.DoesNotExist: + return {'error': 'Recording does not exist'} + + label, score, confs = compressed_spectrogram.predict() + confidences = [] + confidences = [{'label': key, 'value': float(value)} for key, value in confs.items()] + sorted_confidences = sorted(confidences, key=lambda x: x['value'], reverse=True) + output = { + 'label': label, + 'score': float(score), + 'confidences': sorted_confidences, + } + return output