-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* wip * requirements fix * update spectro * compressed spectrogram model * linting * change spectrogram to utilzie URLs instead of base64 * update viewer defaults * setup update, pip upgrade, merged testing --------- Co-authored-by: Jason Parham <[email protected]>
- Loading branch information
1 parent
0f8f479
commit 7972422
Showing
21 changed files
with
530 additions
and
352 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from django.contrib import admin | ||
|
||
from bats_ai.core.models import CompressedSpectrogram | ||
|
||
|
||
@admin.register(CompressedSpectrogram) | ||
class CompressedSpectrogramAdmin(admin.ModelAdmin): | ||
list_display = [ | ||
'pk', | ||
'recording', | ||
'spectrogram', | ||
'image_file', | ||
'length', | ||
'widths', | ||
'starts', | ||
'stops', | ||
] | ||
list_display_links = ['pk', 'recording', 'spectrogram'] | ||
list_select_related = True | ||
autocomplete_fields = ['recording'] | ||
readonly_fields = [ | ||
'recording', | ||
'spectrogram', | ||
'image_file', | ||
'created', | ||
'modified', | ||
'widths', | ||
'starts', | ||
'stops', | ||
] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Generated by Django 4.1.13 on 2024-04-19 13:55 | ||
|
||
import django.contrib.postgres.fields | ||
from django.db import migrations, models | ||
import django.db.models.deletion | ||
import django_extensions.db.fields | ||
|
||
|
||
class Migration(migrations.Migration): | ||
dependencies = [ | ||
('core', '0009_annotations_type_recording_computed_species_and_more'), | ||
] | ||
|
||
operations = [ | ||
migrations.CreateModel( | ||
name='CompressedSpectrogram', | ||
fields=[ | ||
( | ||
'id', | ||
models.BigAutoField( | ||
auto_created=True, primary_key=True, serialize=False, verbose_name='ID' | ||
), | ||
), | ||
( | ||
'created', | ||
django_extensions.db.fields.CreationDateTimeField( | ||
auto_now_add=True, verbose_name='created' | ||
), | ||
), | ||
( | ||
'modified', | ||
django_extensions.db.fields.ModificationDateTimeField( | ||
auto_now=True, verbose_name='modified' | ||
), | ||
), | ||
('image_file', models.FileField(upload_to='')), | ||
('length', models.IntegerField()), | ||
( | ||
'starts', | ||
django.contrib.postgres.fields.ArrayField( | ||
base_field=django.contrib.postgres.fields.ArrayField( | ||
base_field=models.IntegerField(), size=None | ||
), | ||
size=None, | ||
), | ||
), | ||
( | ||
'stops', | ||
django.contrib.postgres.fields.ArrayField( | ||
base_field=django.contrib.postgres.fields.ArrayField( | ||
base_field=models.IntegerField(), size=None | ||
), | ||
size=None, | ||
), | ||
), | ||
( | ||
'widths', | ||
django.contrib.postgres.fields.ArrayField( | ||
base_field=django.contrib.postgres.fields.ArrayField( | ||
base_field=models.IntegerField(), size=None | ||
), | ||
size=None, | ||
), | ||
), | ||
('cache_invalidated', models.BooleanField(default=True)), | ||
( | ||
'recording', | ||
models.ForeignKey( | ||
on_delete=django.db.models.deletion.CASCADE, to='core.recording' | ||
), | ||
), | ||
( | ||
'spectrogram', | ||
models.ForeignKey( | ||
on_delete=django.db.models.deletion.CASCADE, to='core.spectrogram' | ||
), | ||
), | ||
], | ||
options={ | ||
'get_latest_by': 'modified', | ||
'abstract': False, | ||
}, | ||
), | ||
] |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from PIL import Image | ||
import cv2 | ||
from django.contrib.postgres.fields import ArrayField | ||
from django.core.files.storage import default_storage | ||
from django.db import models | ||
from django.dispatch import receiver | ||
from django_extensions.db.models import TimeStampedModel | ||
import numpy as np | ||
|
||
from .recording import Recording | ||
from .spectrogram import Spectrogram | ||
|
||
FREQ_MIN = 5e3 | ||
FREQ_MAX = 120e3 | ||
FREQ_PAD = 2e3 | ||
|
||
|
||
# TimeStampedModel also provides "created" and "modified" fields | ||
class CompressedSpectrogram(TimeStampedModel, models.Model): | ||
recording = models.ForeignKey(Recording, on_delete=models.CASCADE) | ||
spectrogram = models.ForeignKey(Spectrogram, on_delete=models.CASCADE) | ||
image_file = models.FileField() | ||
length = models.IntegerField() | ||
starts = ArrayField(ArrayField(models.IntegerField())) | ||
stops = ArrayField(ArrayField(models.IntegerField())) | ||
widths = ArrayField(ArrayField(models.IntegerField())) | ||
cache_invalidated = models.BooleanField(default=True) | ||
|
||
@property | ||
def image_url(self): | ||
return default_storage.url(self.image_file.name) | ||
|
||
def predict(self): | ||
import json | ||
import os | ||
|
||
import onnx | ||
import onnxruntime as ort | ||
import tqdm | ||
|
||
img = Image.open(self.image_file) | ||
|
||
relative = ('..',) * 4 | ||
asset_path = os.path.abspath(os.path.join(__file__, *relative, 'assets')) | ||
|
||
onnx_filename = os.path.join(asset_path, 'model.mobilenet.onnx') | ||
assert os.path.exists(onnx_filename) | ||
|
||
session = ort.InferenceSession( | ||
onnx_filename, | ||
providers=[ | ||
( | ||
'CUDAExecutionProvider', | ||
{ | ||
'cudnn_conv_use_max_workspace': '1', | ||
'device_id': 0, | ||
'cudnn_conv_algo_search': 'HEURISTIC', | ||
}, | ||
), | ||
'CPUExecutionProvider', | ||
], | ||
) | ||
|
||
img = np.array(img) | ||
|
||
h, w, c = img.shape | ||
ratio_y = 224 / h | ||
ratio_x = ratio_y * 0.5 | ||
raw = cv2.resize(img, None, fx=ratio_x, fy=ratio_y, interpolation=cv2.INTER_LANCZOS4) | ||
|
||
h, w, c = raw.shape | ||
if w <= h: | ||
canvas = np.zeros((h, h + 1, 3), dtype=raw.dtype) | ||
canvas[:, :w, :] = raw | ||
raw = canvas | ||
h, w, c = raw.shape | ||
|
||
inputs_ = [] | ||
for index in range(0, w - h, 100): | ||
inputs_.append(raw[:, index : index + h, :]) | ||
inputs_.append(raw[:, -h:, :]) | ||
inputs_ = np.array(inputs_) | ||
|
||
chunksize = 1 | ||
chunks = np.array_split(inputs_, np.arange(chunksize, len(inputs_), chunksize)) | ||
outputs = [] | ||
for chunk in tqdm.tqdm(chunks, desc='Inference'): | ||
outputs_ = session.run( | ||
None, | ||
{'input': chunk}, | ||
) | ||
outputs.append(outputs_[0]) | ||
outputs = np.vstack(outputs) | ||
outputs = outputs.mean(axis=0) | ||
|
||
model = onnx.load(onnx_filename) | ||
mapping = json.loads(model.metadata_props[0].value) | ||
labels = [mapping['forward'][str(index)] for index in range(len(mapping['forward']))] | ||
|
||
prediction = np.argmax(outputs) | ||
label = labels[prediction] | ||
score = outputs[prediction] | ||
|
||
confs = dict(zip(labels, outputs)) | ||
|
||
return label, score, confs | ||
|
||
|
||
@receiver(models.signals.pre_delete, sender=Spectrogram) | ||
def delete_content(sender, instance, **kwargs): | ||
if instance.image_file: | ||
instance.image_file.delete(save=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.