Skip to content

Commit

Permalink
Dev/bryon spectrogram update (#102)
Browse files Browse the repository at this point in the history
* wip

* requirements fix

* update spectro

* compressed spectrogram model

* linting

* change spectrogram to utilzie URLs instead of base64

* update viewer defaults

* setup update, pip upgrade, merged testing

---------

Co-authored-by: Jason Parham <[email protected]>
  • Loading branch information
BryonLewis and bluemellophone authored Apr 19, 2024
1 parent 0f8f479 commit 7972422
Show file tree
Hide file tree
Showing 21 changed files with 530 additions and 352 deletions.
2 changes: 2 additions & 0 deletions bats_ai/core/admin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .annotations import AnnotationsAdmin
from .compressed_spectrogram import CompressedSpectrogramAdmin
from .grts_cells import GRTSCellsAdmin
from .image import ImageAdmin
from .recording import RecordingAdmin
Expand All @@ -14,4 +15,5 @@
'TemporalAnnotationsAdmin',
'SpeciesAdmin',
'GRTSCellsAdmin',
'CompressedSpectrogramAdmin',
]
30 changes: 30 additions & 0 deletions bats_ai/core/admin/compressed_spectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from django.contrib import admin

from bats_ai.core.models import CompressedSpectrogram


@admin.register(CompressedSpectrogram)
class CompressedSpectrogramAdmin(admin.ModelAdmin):
list_display = [
'pk',
'recording',
'spectrogram',
'image_file',
'length',
'widths',
'starts',
'stops',
]
list_display_links = ['pk', 'recording', 'spectrogram']
list_select_related = True
autocomplete_fields = ['recording']
readonly_fields = [
'recording',
'spectrogram',
'image_file',
'created',
'modified',
'widths',
'starts',
'stops',
]
17 changes: 0 additions & 17 deletions bats_ai/core/migrations/0009_annotations_type.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# Generated by Django 4.1.13 on 2024-04-03 13:07
# Generated by Django 4.1.13 on 2024-04-11 13:06

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
('core', '0009_annotations_type'),
('core', '0008_grtscells_recording_recorded_time'),
]

operations = [
migrations.AddField(
model_name='annotations',
name='type',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='recording',
name='computed_species',
Expand Down Expand Up @@ -48,4 +53,9 @@ class Migration(migrations.Migration):
name='unusual_occurrences',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='spectrogram',
name='colormap',
field=models.CharField(max_length=20, null=True),
),
]
84 changes: 84 additions & 0 deletions bats_ai/core/migrations/0010_compressedspectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Generated by Django 4.1.13 on 2024-04-19 13:55

import django.contrib.postgres.fields
from django.db import migrations, models
import django.db.models.deletion
import django_extensions.db.fields


class Migration(migrations.Migration):
dependencies = [
('core', '0009_annotations_type_recording_computed_species_and_more'),
]

operations = [
migrations.CreateModel(
name='CompressedSpectrogram',
fields=[
(
'id',
models.BigAutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name='ID'
),
),
(
'created',
django_extensions.db.fields.CreationDateTimeField(
auto_now_add=True, verbose_name='created'
),
),
(
'modified',
django_extensions.db.fields.ModificationDateTimeField(
auto_now=True, verbose_name='modified'
),
),
('image_file', models.FileField(upload_to='')),
('length', models.IntegerField()),
(
'starts',
django.contrib.postgres.fields.ArrayField(
base_field=django.contrib.postgres.fields.ArrayField(
base_field=models.IntegerField(), size=None
),
size=None,
),
),
(
'stops',
django.contrib.postgres.fields.ArrayField(
base_field=django.contrib.postgres.fields.ArrayField(
base_field=models.IntegerField(), size=None
),
size=None,
),
),
(
'widths',
django.contrib.postgres.fields.ArrayField(
base_field=django.contrib.postgres.fields.ArrayField(
base_field=models.IntegerField(), size=None
),
size=None,
),
),
('cache_invalidated', models.BooleanField(default=True)),
(
'recording',
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to='core.recording'
),
),
(
'spectrogram',
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to='core.spectrogram'
),
),
],
options={
'get_latest_by': 'modified',
'abstract': False,
},
),
]
17 changes: 0 additions & 17 deletions bats_ai/core/migrations/0011_spectrogram_colormap.py

This file was deleted.

2 changes: 2 additions & 0 deletions bats_ai/core/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .annotations import Annotations
from .compressed_spectrogram import CompressedSpectrogram
from .grts_cells import GRTSCells
from .image import Image
from .recording import Recording, colormap
Expand All @@ -17,4 +18,5 @@
'TemporalAnnotations',
'GRTSCells',
'colormap',
'CompressedSpectrogram',
]
112 changes: 112 additions & 0 deletions bats_ai/core/models/compressed_spectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from PIL import Image
import cv2
from django.contrib.postgres.fields import ArrayField
from django.core.files.storage import default_storage
from django.db import models
from django.dispatch import receiver
from django_extensions.db.models import TimeStampedModel
import numpy as np

from .recording import Recording
from .spectrogram import Spectrogram

FREQ_MIN = 5e3
FREQ_MAX = 120e3
FREQ_PAD = 2e3


# TimeStampedModel also provides "created" and "modified" fields
class CompressedSpectrogram(TimeStampedModel, models.Model):
recording = models.ForeignKey(Recording, on_delete=models.CASCADE)
spectrogram = models.ForeignKey(Spectrogram, on_delete=models.CASCADE)
image_file = models.FileField()
length = models.IntegerField()
starts = ArrayField(ArrayField(models.IntegerField()))
stops = ArrayField(ArrayField(models.IntegerField()))
widths = ArrayField(ArrayField(models.IntegerField()))
cache_invalidated = models.BooleanField(default=True)

@property
def image_url(self):
return default_storage.url(self.image_file.name)

def predict(self):
import json
import os

import onnx
import onnxruntime as ort
import tqdm

img = Image.open(self.image_file)

relative = ('..',) * 4
asset_path = os.path.abspath(os.path.join(__file__, *relative, 'assets'))

onnx_filename = os.path.join(asset_path, 'model.mobilenet.onnx')
assert os.path.exists(onnx_filename)

session = ort.InferenceSession(
onnx_filename,
providers=[
(
'CUDAExecutionProvider',
{
'cudnn_conv_use_max_workspace': '1',
'device_id': 0,
'cudnn_conv_algo_search': 'HEURISTIC',
},
),
'CPUExecutionProvider',
],
)

img = np.array(img)

h, w, c = img.shape
ratio_y = 224 / h
ratio_x = ratio_y * 0.5
raw = cv2.resize(img, None, fx=ratio_x, fy=ratio_y, interpolation=cv2.INTER_LANCZOS4)

h, w, c = raw.shape
if w <= h:
canvas = np.zeros((h, h + 1, 3), dtype=raw.dtype)
canvas[:, :w, :] = raw
raw = canvas
h, w, c = raw.shape

inputs_ = []
for index in range(0, w - h, 100):
inputs_.append(raw[:, index : index + h, :])
inputs_.append(raw[:, -h:, :])
inputs_ = np.array(inputs_)

chunksize = 1
chunks = np.array_split(inputs_, np.arange(chunksize, len(inputs_), chunksize))
outputs = []
for chunk in tqdm.tqdm(chunks, desc='Inference'):
outputs_ = session.run(
None,
{'input': chunk},
)
outputs.append(outputs_[0])
outputs = np.vstack(outputs)
outputs = outputs.mean(axis=0)

model = onnx.load(onnx_filename)
mapping = json.loads(model.metadata_props[0].value)
labels = [mapping['forward'][str(index)] for index in range(len(mapping['forward']))]

prediction = np.argmax(outputs)
label = labels[prediction]
score = outputs[prediction]

confs = dict(zip(labels, outputs))

return label, score, confs


@receiver(models.signals.pre_delete, sender=Spectrogram)
def delete_content(sender, instance, **kwargs):
if instance.image_file:
instance.image_file.delete(save=False)
15 changes: 8 additions & 7 deletions bats_ai/core/models/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django.contrib.auth.models import User
from django.contrib.gis.db import models
from django.dispatch import receiver
from django_extensions.db.models import TimeStampedModel

from .species import Species
Expand Down Expand Up @@ -67,17 +68,17 @@ def spectrograms(self):

@property
def spectrogram(self):
from bats_ai.core.models import Spectrogram
pass

spectrograms = self.spectrograms

if len(spectrograms) == 0:
Spectrogram.generate(self, colormap=COLORMAP)

spectrograms = self.spectrograms
assert len(spectrograms) == 1

assert len(spectrograms) >= 1
spectrogram = spectrograms[0] # most recently created

return spectrogram


@receiver(models.signals.pre_delete, sender=Recording)
def delete_content(sender, instance, **kwargs):
if instance.audio_file:
instance.audio_file.delete(save=False)
Loading

0 comments on commit 7972422

Please sign in to comment.