Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/spectrogram updates #93

Merged
merged 5 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
assets/example.wav filter=lfs diff=lfs merge=lfs -text
assets/model.mobilenet.onnx filter=lfs diff=lfs merge=lfs -text
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,10 @@
/**/*.shp
/**/*.shx
/**/*.csv
models/datasets/
models/spectrograms/
models/ignore/
models/*.jpg
models/*.pkl
temp*.jpg
temp*.png
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ This is the simplest configuration for developers to start with.

### Run Vue Frontend

1. Run `npm install`
2. Run `npm run dev`
1. Run `cd client/`
2. Run `npm install`
3. Run `npm run dev`

### Run Application

Expand Down
3 changes: 3 additions & 0 deletions assets/model.mobilenet.onnx
Git LFS file not shown
2 changes: 2 additions & 0 deletions bats_ai/core/admin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .annotations import AnnotationsAdmin
from .compressed_spectrogram import CompressedSpectrogramAdmin
from .grts_cells import GRTSCellsAdmin
from .image import ImageAdmin
from .recording import RecordingAdmin
Expand All @@ -14,4 +15,5 @@
'TemporalAnnotationsAdmin',
'SpeciesAdmin',
'GRTSCellsAdmin',
'CompressedSpectrogramAdmin',
]
30 changes: 30 additions & 0 deletions bats_ai/core/admin/compressed_spectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from django.contrib import admin

from bats_ai.core.models import CompressedSpectrogram


@admin.register(CompressedSpectrogram)
class CompressedSpectrogramAdmin(admin.ModelAdmin):
list_display = [
'pk',
'recording',
'spectrogram',
'image_file',
'length',
'widths',
'starts',
'stops',
]
list_display_links = ['pk', 'recording', 'spectrogram']
list_select_related = True
autocomplete_fields = ['recording']
readonly_fields = [
'recording',
'spectrogram',
'image_file',
'created',
'modified',
'widths',
'starts',
'stops',
]
5 changes: 2 additions & 3 deletions bats_ai/core/admin/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def spectrogram_status(self, recording: Recording):
def compute_spectrograms(self, request: HttpRequest, queryset: QuerySet):
counter = 0
for recording in queryset:
if not recording.has_spectrogram:
recording_compute_spectrogram.delay(recording.pk)
counter += 1
recording_compute_spectrogram.delay(recording.pk)
counter += 1
self.message_user(request, f'{counter} recordings queued', messages.SUCCESS)
2 changes: 2 additions & 0 deletions bats_ai/core/admin/spectrogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class SpectrogramAdmin(admin.ModelAdmin):
list_display = [
'pk',
'recording',
'colormap',
'image_file',
'width',
'height',
Expand All @@ -20,6 +21,7 @@ class SpectrogramAdmin(admin.ModelAdmin):
autocomplete_fields = ['recording']
readonly_fields = [
'recording',
'colormap',
'image_file',
'created',
'modified',
Expand Down
17 changes: 0 additions & 17 deletions bats_ai/core/migrations/0009_annotations_type.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
# Generated by Django 4.1.13 on 2024-04-03 13:07
# Generated by Django 4.1.13 on 2024-04-11 13:06

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
('core', '0009_annotations_type'),
('core', '0008_grtscells_recording_recorded_time'),
]

operations = [
migrations.AddField(
model_name='annotations',
name='type',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='recording',
name='computed_species',
Expand Down Expand Up @@ -48,4 +53,9 @@ class Migration(migrations.Migration):
name='unusual_occurrences',
field=models.TextField(blank=True, null=True),
),
migrations.AddField(
model_name='spectrogram',
name='colormap',
field=models.CharField(max_length=20, null=True),
),
]
84 changes: 84 additions & 0 deletions bats_ai/core/migrations/0010_compressedspectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Generated by Django 4.1.13 on 2024-04-19 13:55

import django.contrib.postgres.fields
from django.db import migrations, models
import django.db.models.deletion
import django_extensions.db.fields


class Migration(migrations.Migration):
dependencies = [
('core', '0009_annotations_type_recording_computed_species_and_more'),
]

operations = [
migrations.CreateModel(
name='CompressedSpectrogram',
fields=[
(
'id',
models.BigAutoField(
auto_created=True, primary_key=True, serialize=False, verbose_name='ID'
),
),
(
'created',
django_extensions.db.fields.CreationDateTimeField(
auto_now_add=True, verbose_name='created'
),
),
(
'modified',
django_extensions.db.fields.ModificationDateTimeField(
auto_now=True, verbose_name='modified'
),
),
('image_file', models.FileField(upload_to='')),
('length', models.IntegerField()),
(
'starts',
django.contrib.postgres.fields.ArrayField(
base_field=django.contrib.postgres.fields.ArrayField(
base_field=models.IntegerField(), size=None
),
size=None,
),
),
(
'stops',
django.contrib.postgres.fields.ArrayField(
base_field=django.contrib.postgres.fields.ArrayField(
base_field=models.IntegerField(), size=None
),
size=None,
),
),
(
'widths',
django.contrib.postgres.fields.ArrayField(
base_field=django.contrib.postgres.fields.ArrayField(
base_field=models.IntegerField(), size=None
),
size=None,
),
),
('cache_invalidated', models.BooleanField(default=True)),
(
'recording',
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to='core.recording'
),
),
(
'spectrogram',
models.ForeignKey(
on_delete=django.db.models.deletion.CASCADE, to='core.spectrogram'
),
),
],
options={
'get_latest_by': 'modified',
'abstract': False,
},
),
]
5 changes: 4 additions & 1 deletion bats_ai/core/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .annotations import Annotations
from .compressed_spectrogram import CompressedSpectrogram
from .grts_cells import GRTSCells
from .image import Image
from .recording import Recording
from .recording import Recording, colormap
from .recording_annotation_status import RecordingAnnotationStatus
from .species import Species
from .spectrogram import Spectrogram
Expand All @@ -16,4 +17,6 @@
'Spectrogram',
'TemporalAnnotations',
'GRTSCells',
'colormap',
'CompressedSpectrogram',
]
112 changes: 112 additions & 0 deletions bats_ai/core/models/compressed_spectrogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from PIL import Image
import cv2
from django.contrib.postgres.fields import ArrayField
from django.core.files.storage import default_storage
from django.db import models
from django.dispatch import receiver
from django_extensions.db.models import TimeStampedModel
import numpy as np

from .recording import Recording
from .spectrogram import Spectrogram

FREQ_MIN = 5e3
FREQ_MAX = 120e3
FREQ_PAD = 2e3


# TimeStampedModel also provides "created" and "modified" fields
class CompressedSpectrogram(TimeStampedModel, models.Model):
recording = models.ForeignKey(Recording, on_delete=models.CASCADE)
spectrogram = models.ForeignKey(Spectrogram, on_delete=models.CASCADE)
image_file = models.FileField()
length = models.IntegerField()
starts = ArrayField(ArrayField(models.IntegerField()))
stops = ArrayField(ArrayField(models.IntegerField()))
widths = ArrayField(ArrayField(models.IntegerField()))
cache_invalidated = models.BooleanField(default=True)

@property
def image_url(self):
return default_storage.url(self.image_file.name)

def predict(self):
import json
import os

import onnx
import onnxruntime as ort
import tqdm

img = Image.open(self.image_file)

relative = ('..',) * 4
asset_path = os.path.abspath(os.path.join(__file__, *relative, 'assets'))

onnx_filename = os.path.join(asset_path, 'model.mobilenet.onnx')
assert os.path.exists(onnx_filename)

session = ort.InferenceSession(
onnx_filename,
providers=[
(
'CUDAExecutionProvider',
{
'cudnn_conv_use_max_workspace': '1',
'device_id': 0,
'cudnn_conv_algo_search': 'HEURISTIC',
},
),
'CPUExecutionProvider',
],
)

img = np.array(img)

h, w, c = img.shape
ratio_y = 224 / h
ratio_x = ratio_y * 0.5
raw = cv2.resize(img, None, fx=ratio_x, fy=ratio_y, interpolation=cv2.INTER_LANCZOS4)

h, w, c = raw.shape
if w <= h:
canvas = np.zeros((h, h + 1, 3), dtype=raw.dtype)
canvas[:, :w, :] = raw
raw = canvas
h, w, c = raw.shape

inputs_ = []
for index in range(0, w - h, 100):
inputs_.append(raw[:, index : index + h, :])
inputs_.append(raw[:, -h:, :])
inputs_ = np.array(inputs_)

chunksize = 1
chunks = np.array_split(inputs_, np.arange(chunksize, len(inputs_), chunksize))
outputs = []
for chunk in tqdm.tqdm(chunks, desc='Inference'):
outputs_ = session.run(
None,
{'input': chunk},
)
outputs.append(outputs_[0])
outputs = np.vstack(outputs)
outputs = outputs.mean(axis=0)

model = onnx.load(onnx_filename)
mapping = json.loads(model.metadata_props[0].value)
labels = [mapping['forward'][str(index)] for index in range(len(mapping['forward']))]

prediction = np.argmax(outputs)
label = labels[prediction]
score = outputs[prediction]

confs = dict(zip(labels, outputs))

return label, score, confs


@receiver(models.signals.pre_delete, sender=Spectrogram)
def delete_content(sender, instance, **kwargs):
if instance.image_file:
instance.image_file.delete(save=False)
Loading
Loading