-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3a11c50
commit 4c98154
Showing
13 changed files
with
414 additions
and
215 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from django.contrib import admin | ||
|
||
from bats_ai.core.models import CompressedSpectrogram | ||
|
||
@admin.register(CompressedSpectrogram) | ||
class CompressedSpectrogramAdmin(admin.ModelAdmin): | ||
list_display = [ | ||
'pk', | ||
'recording', | ||
'spectrogram', | ||
'image_file', | ||
'length', | ||
'widths', | ||
'starts', | ||
'stops', | ||
] | ||
list_display_links = ['pk', 'recording', 'spectrogram'] | ||
list_select_related = True | ||
autocomplete_fields = ['recording'] | ||
readonly_fields = [ | ||
'recording', | ||
'spectrogram', | ||
'image_file', | ||
'created', | ||
'modified', | ||
'widths', | ||
'starts', | ||
'stops', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Generated by Django 4.1.13 on 2024-04-12 18:56 | ||
|
||
import django.contrib.postgres.fields | ||
from django.db import migrations, models | ||
import django.db.models.deletion | ||
import django_extensions.db.fields | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
dependencies = [ | ||
('core', '0009_annotations_type_recording_computed_species_and_more'), | ||
] | ||
|
||
operations = [ | ||
migrations.CreateModel( | ||
name='CompressedSpectrogram', | ||
fields=[ | ||
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), | ||
('created', django_extensions.db.fields.CreationDateTimeField(auto_now_add=True, verbose_name='created')), | ||
('modified', django_extensions.db.fields.ModificationDateTimeField(auto_now=True, verbose_name='modified')), | ||
('image_file', models.FileField(upload_to='')), | ||
('length', models.IntegerField()), | ||
('starts', django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), size=None), size=None)), | ||
('stops', django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), size=None), size=None)), | ||
('widths', django.contrib.postgres.fields.ArrayField(base_field=django.contrib.postgres.fields.ArrayField(base_field=models.IntegerField(), size=None), size=None)), | ||
('cache_invalidated', models.BooleanField(default=True)), | ||
('recording', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='core.recording')), | ||
('spectrogram', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='core.spectrogram')), | ||
], | ||
options={ | ||
'get_latest_by': 'modified', | ||
'abstract': False, | ||
}, | ||
), | ||
] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import base64 | ||
import io | ||
import math | ||
|
||
from PIL import Image | ||
import cv2 | ||
from django.core.files import File | ||
from django.db import models | ||
from django.db.models.fields.files import FieldFile | ||
from django.contrib.postgres.fields import ArrayField | ||
from django_extensions.db.models import TimeStampedModel | ||
import librosa | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import scipy | ||
import tqdm | ||
|
||
from bats_ai.core.models import Annotations, Recording, Spectrogram | ||
from django.dispatch import receiver | ||
|
||
FREQ_MIN = 5e3 | ||
FREQ_MAX = 120e3 | ||
FREQ_PAD = 2e3 | ||
|
||
|
||
# TimeStampedModel also provides "created" and "modified" fields | ||
class CompressedSpectrogram(TimeStampedModel, models.Model): | ||
recording = models.ForeignKey(Recording, on_delete=models.CASCADE) | ||
spectrogram = models.ForeignKey(Spectrogram, on_delete=models.CASCADE) | ||
image_file = models.FileField() | ||
length = models.IntegerField() # pixels | ||
starts = ArrayField(ArrayField(models.IntegerField())) # pixels | ||
stops = ArrayField(ArrayField(models.IntegerField())) # milliseconds | ||
widths = ArrayField(ArrayField(models.IntegerField())) # hz | ||
cache_invalidated = models.BooleanField(default=True) | ||
|
||
|
||
def predict(self): | ||
import json | ||
import os | ||
|
||
import onnx | ||
import onnxruntime as ort | ||
import tqdm | ||
|
||
img = Image.open(self.image_file) | ||
|
||
relative = ('..',) * 4 | ||
asset_path = os.path.abspath(os.path.join(__file__, *relative, 'assets')) | ||
|
||
onnx_filename = os.path.join(asset_path, 'model.mobilenet.onnx') | ||
assert os.path.exists(onnx_filename) | ||
|
||
session = ort.InferenceSession( | ||
onnx_filename, | ||
providers=[ | ||
( | ||
'CUDAExecutionProvider', | ||
{ | ||
'cudnn_conv_use_max_workspace': '1', | ||
'device_id': 0, | ||
'cudnn_conv_algo_search': 'HEURISTIC', | ||
}, | ||
), | ||
'CPUExecutionProvider', | ||
], | ||
) | ||
|
||
img = np.array(img) | ||
|
||
h, w, c = img.shape | ||
ratio_y = 224 / h | ||
ratio_x = ratio_y * 0.5 | ||
raw = cv2.resize(img, None, fx=ratio_x, fy=ratio_y, interpolation=cv2.INTER_LANCZOS4) | ||
|
||
h, w, c = raw.shape | ||
if w <= h: | ||
canvas = np.zeros((h, h + 1, 3), dtype=raw.dtype) | ||
canvas[:, :w, :] = raw | ||
raw = canvas | ||
h, w, c = raw.shape | ||
|
||
inputs_ = [] | ||
for index in range(0, w - h, 100): | ||
inputs_.append(raw[:, index : index + h, :]) | ||
inputs_.append(raw[:, -h:, :]) | ||
inputs_ = np.array(inputs_) | ||
|
||
chunksize = 1 | ||
chunks = np.array_split(inputs_, np.arange(chunksize, len(inputs_), chunksize)) | ||
outputs = [] | ||
for chunk in tqdm.tqdm(chunks, desc='Inference'): | ||
outputs_ = session.run( | ||
None, | ||
{'input': chunk}, | ||
) | ||
outputs.append(outputs_[0]) | ||
outputs = np.vstack(outputs) | ||
outputs = outputs.mean(axis=0) | ||
|
||
model = onnx.load(onnx_filename) | ||
mapping = json.loads(model.metadata_props[0].value) | ||
labels = [mapping['forward'][str(index)] for index in range(len(mapping['forward']))] | ||
|
||
prediction = np.argmax(outputs) | ||
label = labels[prediction] | ||
score = outputs[prediction] | ||
|
||
confs = dict(zip(labels, outputs)) | ||
|
||
return label, score, confs | ||
|
||
@receiver(models.signals.pre_delete, sender=Spectrogram) | ||
def delete_content(sender, instance, **kwargs): | ||
if instance.image_file: | ||
instance.image_file.delete(save=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.