Skip to content

Commit

Permalink
compressed spectrogram model
Browse files Browse the repository at this point in the history
  • Loading branch information
BryonLewis committed Apr 12, 2024
1 parent 3a11c50 commit 4c98154
Show file tree
Hide file tree
Showing 13 changed files with 414 additions and 215 deletions.
2 changes: 2 additions & 0 deletions bats_ai/core/admin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .species import SpeciesAdmin
from .spectrogram import SpectrogramAdmin
from .temporal_annotations import TemporalAnnotationsAdmin
from .compressed_spectrogram import CompressedSpectrogramAdmin

__all__ = [
'AnnotationsAdmin',
Expand All @@ -14,4 +15,5 @@
'TemporalAnnotationsAdmin',
'SpeciesAdmin',
'GRTSCellsAdmin',
'CompressedSpectrogramAdmin',
]
29 changes: 29 additions & 0 deletions bats_ai/core/admin/compressed_spectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from django.contrib import admin

from bats_ai.core.models import CompressedSpectrogram

@admin.register(CompressedSpectrogram)
class CompressedSpectrogramAdmin(admin.ModelAdmin):
list_display = [
'pk',
'recording',
'spectrogram',
'image_file',
'length',
'widths',
'starts',
'stops',
]
list_display_links = ['pk', 'recording', 'spectrogram']
list_select_related = True
autocomplete_fields = ['recording']
readonly_fields = [
'recording',
'spectrogram',
'image_file',
'created',
'modified',
'widths',
'starts',
'stops',
]
14 changes: 14 additions & 0 deletions bats_ai/core/admin/spectrogram.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from django.contrib import admin

from bats_ai.core.models import Spectrogram
from django.db.models import QuerySet
from django.http import HttpRequest


@admin.register(Spectrogram)
Expand Down Expand Up @@ -31,3 +33,15 @@ class SpectrogramAdmin(admin.ModelAdmin):
'frequency_min',
'frequency_max',
]

actions = ['computed_compressed_spectrogram']


@admin.action(description='Compute Compressed Spectrograms')
def computed_compressed_spectrogram(self, request: HttpRequest, queryset: QuerySet):
counter = 0
for recording in queryset:
if not recording.has_spectrogram:
recording_compute_spectrogram.delay(recording.pk)
counter += 1
self.message_user(request, f'{counter} recordings queued', messages.SUCCESS)
17 changes: 0 additions & 17 deletions bats_ai/core/migrations/0009_annotations_type.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
# Generated by Django 4.1.13 on 2024-04-03 13:07
# Generated by Django 4.1.13 on 2024-04-11 13:06

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('core', '0009_annotations_type'),
('core', '0008_grtscells_recording_recorded_time'),
]

operations = [
migrations.AddField(
model_name='annotations',
name='type',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='recording',
name='computed_species',
field=models.ManyToManyField(
related_name='recording_computed_species', to='core.species'
),
field=models.ManyToManyField(related_name='recording_computed_species', to='core.species'),
),
migrations.AddField(
model_name='recording',
Expand All @@ -24,9 +28,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name='recording',
name='official_species',
field=models.ManyToManyField(
related_name='recording_official_species', to='core.species'
),
field=models.ManyToManyField(related_name='recording_official_species', to='core.species'),
),
migrations.AddField(
model_name='recording',
Expand All @@ -48,4 +50,9 @@ class Migration(migrations.Migration):
name='unusual_occurrences',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='spectrogram',
name='colormap',
field=models.CharField(max_length=20, null=True),
),
]
17 changes: 0 additions & 17 deletions bats_ai/core/migrations/0009_spectrogram_colormap.py

This file was deleted.

36 changes: 36 additions & 0 deletions bats_ai/core/migrations/0010_compressedspectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Generated by Django 4.1.13 on 2024-04-12 18:56

import django.contrib.postgres.fields
from django.db import migrations, models
import django.db.models.deletion
import django_extensions.db.fields


class Migration(migrations.Migration):

dependencies = [
('core', '0009_annotations_type_recording_computed_species_and_more'),
]

operations = [
migrations.CreateModel(
name='CompressedSpectrogram',
fields=[
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('created', django_extensions.db.fields.CreationDateTimeField(auto_now_add=True, verbose_name='created')),
('modified', django_extensions.db.fields.ModificationDateTimeField(auto_now=True, verbose_name='modified')),
('image_file', models.FileField(upload_to='')),
('length', models.IntegerField()),
('starts', django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), size=None), size=None)),
('stops', django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), size=None), size=None)),
('widths', django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), size=None), size=None)),
('cache_invalidated', models.BooleanField(default=True)),
('recording', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='core.recording')),
('spectrogram', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='core.spectrogram')),
],
options={
'get_latest_by': 'modified',
'abstract': False,
},
),
]
18 changes: 0 additions & 18 deletions bats_ai/core/migrations/0011_spectrogram_colormap.py

This file was deleted.

2 changes: 2 additions & 0 deletions bats_ai/core/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .species import Species
from .spectrogram import Spectrogram
from .temporal_annotations import TemporalAnnotations
from .compressed_spectrogram import CompressedSpectrogram

__all__ = [
'Annotations',
Expand All @@ -17,4 +18,5 @@
'TemporalAnnotations',
'GRTSCells',
'colormap',
'CompressedSpectrogram',
]
116 changes: 116 additions & 0 deletions bats_ai/core/models/compressed_spectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import base64
import io
import math

from PIL import Image
import cv2
from django.core.files import File
from django.db import models
from django.db.models.fields.files import FieldFile
from django.contrib.postgres.fields import ArrayField
from django_extensions.db.models import TimeStampedModel
import librosa
import matplotlib.pyplot as plt
import numpy as np
import scipy
import tqdm

from bats_ai.core.models import Annotations, Recording, Spectrogram
from django.dispatch import receiver

FREQ_MIN = 5e3
FREQ_MAX = 120e3
FREQ_PAD = 2e3


# TimeStampedModel also provides "created" and "modified" fields
class CompressedSpectrogram(TimeStampedModel, models.Model):
recording = models.ForeignKey(Recording, on_delete=models.CASCADE)
spectrogram = models.ForeignKey(Spectrogram, on_delete=models.CASCADE)
image_file = models.FileField()
length = models.IntegerField() # pixels
starts = ArrayField(ArrayField(models.IntegerField())) # pixels
stops = ArrayField(ArrayField(models.IntegerField())) # milliseconds
widths = ArrayField(ArrayField(models.IntegerField())) # hz
cache_invalidated = models.BooleanField(default=True)


def predict(self):
import json
import os

import onnx
import onnxruntime as ort
import tqdm

img = Image.open(self.image_file)

relative = ('..',) * 4
asset_path = os.path.abspath(os.path.join(__file__, *relative, 'assets'))

onnx_filename = os.path.join(asset_path, 'model.mobilenet.onnx')
assert os.path.exists(onnx_filename)

session = ort.InferenceSession(
onnx_filename,
providers=[
(
'CUDAExecutionProvider',
{
'cudnn_conv_use_max_workspace': '1',
'device_id': 0,
'cudnn_conv_algo_search': 'HEURISTIC',
},
),
'CPUExecutionProvider',
],
)

img = np.array(img)

h, w, c = img.shape
ratio_y = 224 / h
ratio_x = ratio_y * 0.5
raw = cv2.resize(img, None, fx=ratio_x, fy=ratio_y, interpolation=cv2.INTER_LANCZOS4)

h, w, c = raw.shape
if w <= h:
canvas = np.zeros((h, h + 1, 3), dtype=raw.dtype)
canvas[:, :w, :] = raw
raw = canvas
h, w, c = raw.shape

inputs_ = []
for index in range(0, w - h, 100):
inputs_.append(raw[:, index : index + h, :])
inputs_.append(raw[:, -h:, :])
inputs_ = np.array(inputs_)

chunksize = 1
chunks = np.array_split(inputs_, np.arange(chunksize, len(inputs_), chunksize))
outputs = []
for chunk in tqdm.tqdm(chunks, desc='Inference'):
outputs_ = session.run(
None,
{'input': chunk},
)
outputs.append(outputs_[0])
outputs = np.vstack(outputs)
outputs = outputs.mean(axis=0)

model = onnx.load(onnx_filename)
mapping = json.loads(model.metadata_props[0].value)
labels = [mapping['forward'][str(index)] for index in range(len(mapping['forward']))]

prediction = np.argmax(outputs)
label = labels[prediction]
score = outputs[prediction]

confs = dict(zip(labels, outputs))

return label, score, confs

@receiver(models.signals.pre_delete, sender=Spectrogram)
def delete_content(sender, instance, **kwargs):
if instance.image_file:
instance.image_file.delete(save=False)
12 changes: 6 additions & 6 deletions bats_ai/core/models/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.contrib.auth.models import User
from django.contrib.gis.db import models
from django_extensions.db.models import TimeStampedModel
from django.dispatch import receiver

from .species import Species

Expand Down Expand Up @@ -86,13 +87,12 @@ def spectrogram(self):

spectrograms = self.spectrograms

if len(spectrograms) == 0:
Spectrogram.generate(self, colormap=COLORMAP)

spectrograms = self.spectrograms
assert len(spectrograms) == 1

assert len(spectrograms) >= 1
spectrogram = spectrograms[0] # most recently created

return spectrogram

@receiver(models.signals.pre_delete, sender=Recording)
def delete_content(sender, instance, **kwargs):
if instance.audio_file:
instance.audio_file.delete(save=False)
Loading

0 comments on commit 4c98154

Please sign in to comment.