Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/bryon spectrogram update #102

Merged
merged 12 commits into from
Apr 19, 2024
2 changes: 2 additions & 0 deletions bats_ai/core/admin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .annotations import AnnotationsAdmin
from .compressed_spectrogram import CompressedSpectrogramAdmin
from .grts_cells import GRTSCellsAdmin
from .image import ImageAdmin
from .recording import RecordingAdmin
Expand All @@ -14,4 +15,5 @@
'TemporalAnnotationsAdmin',
'SpeciesAdmin',
'GRTSCellsAdmin',
'CompressedSpectrogramAdmin',
]
30 changes: 30 additions & 0 deletions bats_ai/core/admin/compressed_spectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from django.contrib import admin

from bats_ai.core.models import CompressedSpectrogram


@admin.register(CompressedSpectrogram)
class CompressedSpectrogramAdmin(admin.ModelAdmin):
list_display = [
'pk',
'recording',
'spectrogram',
'image_file',
'length',
'widths',
'starts',
'stops',
]
list_display_links = ['pk', 'recording', 'spectrogram']
list_select_related = True
autocomplete_fields = ['recording']
readonly_fields = [
'recording',
'spectrogram',
'image_file',
'created',
'modified',
'widths',
'starts',
'stops',
]
17 changes: 0 additions & 17 deletions bats_ai/core/migrations/0009_annotations_type.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# Generated by Django 4.1.13 on 2024-04-03 13:07
# Generated by Django 4.1.13 on 2024-04-11 13:06

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
('core', '0009_annotations_type'),
('core', '0008_grtscells_recording_recorded_time'),
]

operations = [
migrations.AddField(
model_name='annotations',
name='type',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='recording',
name='computed_species',
Expand Down Expand Up @@ -48,4 +53,9 @@ class Migration(migrations.Migration):
name='unusual_occurrences',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='spectrogram',
name='colormap',
field=models.CharField(max_length=20, null=True),
),
]
84 changes: 84 additions & 0 deletions bats_ai/core/migrations/0010_compressedspectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Generated by Django 4.1.13 on 2024-04-19 13:55

import django.contrib.postgres.fields
from django.db import migrations, models
import django.db.models.deletion
import django_extensions.db.fields


class Migration(migrations.Migration):
dependencies = [
('core', '0009_annotations_type_recording_computed_species_and_more'),
]

operations = [
migrations.CreateModel(
name='CompressedSpectrogram',
fields=[
(
'id',
models.BigAutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name='ID'
),
),
(
'created',
django_extensions.db.fields.CreationDateTimeField(
auto_now_add=True, verbose_name='created'
),
),
(
'modified',
django_extensions.db.fields.ModificationDateTimeField(
auto_now=True, verbose_name='modified'
),
),
('image_file', models.FileField(upload_to='')),
('length', models.IntegerField()),
(
'starts',
django.contrib.postgres.fields.ArrayField(
base_field=django.contrib.postgres.fields.ArrayField(
base_field=models.IntegerField(), size=None
),
size=None,
),
),
(
'stops',
django.contrib.postgres.fields.ArrayField(
base_field=django.contrib.postgres.fields.ArrayField(
base_field=models.IntegerField(), size=None
),
size=None,
),
),
(
'widths',
django.contrib.postgres.fields.ArrayField(
base_field=django.contrib.postgres.fields.ArrayField(
base_field=models.IntegerField(), size=None
),
size=None,
),
),
('cache_invalidated', models.BooleanField(default=True)),
(
'recording',
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to='core.recording'
),
),
(
'spectrogram',
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to='core.spectrogram'
),
),
],
options={
'get_latest_by': 'modified',
'abstract': False,
},
),
]
17 changes: 0 additions & 17 deletions bats_ai/core/migrations/0011_spectrogram_colormap.py

This file was deleted.

2 changes: 2 additions & 0 deletions bats_ai/core/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .annotations import Annotations
from .compressed_spectrogram import CompressedSpectrogram
from .grts_cells import GRTSCells
from .image import Image
from .recording import Recording, colormap
Expand All @@ -17,4 +18,5 @@
'TemporalAnnotations',
'GRTSCells',
'colormap',
'CompressedSpectrogram',
]
112 changes: 112 additions & 0 deletions bats_ai/core/models/compressed_spectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from PIL import Image
import cv2
from django.contrib.postgres.fields import ArrayField
from django.core.files.storage import default_storage
from django.db import models
from django.dispatch import receiver
from django_extensions.db.models import TimeStampedModel
import numpy as np

from .recording import Recording
from .spectrogram import Spectrogram

FREQ_MIN = 5e3
FREQ_MAX = 120e3
FREQ_PAD = 2e3


# TimeStampedModel also provides "created" and "modified" fields
class CompressedSpectrogram(TimeStampedModel, models.Model):
recording = models.ForeignKey(Recording, on_delete=models.CASCADE)
spectrogram = models.ForeignKey(Spectrogram, on_delete=models.CASCADE)
image_file = models.FileField()
length = models.IntegerField()
starts = ArrayField(ArrayField(models.IntegerField()))
stops = ArrayField(ArrayField(models.IntegerField()))
widths = ArrayField(ArrayField(models.IntegerField()))
cache_invalidated = models.BooleanField(default=True)

@property
def image_url(self):
return default_storage.url(self.image_file.name)

def predict(self):
import json
import os

import onnx
import onnxruntime as ort
import tqdm

img = Image.open(self.image_file)

relative = ('..',) * 4
asset_path = os.path.abspath(os.path.join(__file__, *relative, 'assets'))

onnx_filename = os.path.join(asset_path, 'model.mobilenet.onnx')
assert os.path.exists(onnx_filename)

session = ort.InferenceSession(
onnx_filename,
providers=[
(
'CUDAExecutionProvider',
{
'cudnn_conv_use_max_workspace': '1',
'device_id': 0,
'cudnn_conv_algo_search': 'HEURISTIC',
},
),
'CPUExecutionProvider',
],
)

img = np.array(img)

h, w, c = img.shape
ratio_y = 224 / h
ratio_x = ratio_y * 0.5
raw = cv2.resize(img, None, fx=ratio_x, fy=ratio_y, interpolation=cv2.INTER_LANCZOS4)

h, w, c = raw.shape
if w <= h:
canvas = np.zeros((h, h + 1, 3), dtype=raw.dtype)
canvas[:, :w, :] = raw
raw = canvas
h, w, c = raw.shape

inputs_ = []
for index in range(0, w - h, 100):
inputs_.append(raw[:, index : index + h, :])
inputs_.append(raw[:, -h:, :])
inputs_ = np.array(inputs_)

chunksize = 1
chunks = np.array_split(inputs_, np.arange(chunksize, len(inputs_), chunksize))
outputs = []
for chunk in tqdm.tqdm(chunks, desc='Inference'):
outputs_ = session.run(
None,
{'input': chunk},
)
outputs.append(outputs_[0])
outputs = np.vstack(outputs)
outputs = outputs.mean(axis=0)

model = onnx.load(onnx_filename)
mapping = json.loads(model.metadata_props[0].value)
labels = [mapping['forward'][str(index)] for index in range(len(mapping['forward']))]

prediction = np.argmax(outputs)
label = labels[prediction]
score = outputs[prediction]

confs = dict(zip(labels, outputs))

return label, score, confs


@receiver(models.signals.pre_delete, sender=Spectrogram)
def delete_content(sender, instance, **kwargs):
if instance.image_file:
instance.image_file.delete(save=False)
15 changes: 8 additions & 7 deletions bats_ai/core/models/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django.contrib.auth.models import User
from django.contrib.gis.db import models
from django.dispatch import receiver
from django_extensions.db.models import TimeStampedModel

from .species import Species
Expand Down Expand Up @@ -67,17 +68,17 @@ def spectrograms(self):

@property
def spectrogram(self):
from bats_ai.core.models import Spectrogram
pass

spectrograms = self.spectrograms

if len(spectrograms) == 0:
Spectrogram.generate(self, colormap=COLORMAP)

spectrograms = self.spectrograms
assert len(spectrograms) == 1

assert len(spectrograms) >= 1
spectrogram = spectrograms[0] # most recently created

return spectrogram


@receiver(models.signals.pre_delete, sender=Recording)
def delete_content(sender, instance, **kwargs):
if instance.audio_file:
instance.audio_file.delete(save=False)
Loading
Loading