diff --git a/bats_ai/core/admin/__init__.py b/bats_ai/core/admin/__init__.py index 48b05ad..8f782b8 100644 --- a/bats_ai/core/admin/__init__.py +++ b/bats_ai/core/admin/__init__.py @@ -1,4 +1,5 @@ from .annotations import AnnotationsAdmin +from .compressed_spectrogram import CompressedSpectrogramAdmin from .grts_cells import GRTSCellsAdmin from .image import ImageAdmin from .recording import RecordingAdmin @@ -14,4 +15,5 @@ 'TemporalAnnotationsAdmin', 'SpeciesAdmin', 'GRTSCellsAdmin', + 'CompressedSpectrogramAdmin', ] diff --git a/bats_ai/core/admin/compressed_spectrogram.py b/bats_ai/core/admin/compressed_spectrogram.py new file mode 100644 index 0000000..a8ce90c --- /dev/null +++ b/bats_ai/core/admin/compressed_spectrogram.py @@ -0,0 +1,30 @@ +from django.contrib import admin + +from bats_ai.core.models import CompressedSpectrogram + + +@admin.register(CompressedSpectrogram) +class CompressedSpectrogramAdmin(admin.ModelAdmin): + list_display = [ + 'pk', + 'recording', + 'spectrogram', + 'image_file', + 'length', + 'widths', + 'starts', + 'stops', + ] + list_display_links = ['pk', 'recording', 'spectrogram'] + list_select_related = True + autocomplete_fields = ['recording'] + readonly_fields = [ + 'recording', + 'spectrogram', + 'image_file', + 'created', + 'modified', + 'widths', + 'starts', + 'stops', + ] diff --git a/bats_ai/core/migrations/0009_annotations_type.py b/bats_ai/core/migrations/0009_annotations_type.py deleted file mode 100644 index 6091cd7..0000000 --- a/bats_ai/core/migrations/0009_annotations_type.py +++ /dev/null @@ -1,17 +0,0 @@ -# Generated by Django 4.1.13 on 2024-03-22 17:23 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ('core', '0008_grtscells_recording_recorded_time'), - ] - - operations = [ - migrations.AddField( - model_name='annotations', - name='type', - field=models.TextField(blank=True, null=True), - ), - ] diff --git a/bats_ai/core/migrations/0010_recording_computed_species_recording_detector_and_more.py b/bats_ai/core/migrations/0009_annotations_type_recording_computed_species_and_more.py similarity index 76% rename from bats_ai/core/migrations/0010_recording_computed_species_recording_detector_and_more.py rename to bats_ai/core/migrations/0009_annotations_type_recording_computed_species_and_more.py index 1819063..ce429a8 100644 --- a/bats_ai/core/migrations/0010_recording_computed_species_recording_detector_and_more.py +++ b/bats_ai/core/migrations/0009_annotations_type_recording_computed_species_and_more.py @@ -1,14 +1,19 @@ -# Generated by Django 4.1.13 on 2024-04-03 13:07 +# Generated by Django 4.1.13 on 2024-04-11 13:06 from django.db import migrations, models class Migration(migrations.Migration): dependencies = [ - ('core', '0009_annotations_type'), + ('core', '0008_grtscells_recording_recorded_time'), ] operations = [ + migrations.AddField( + model_name='annotations', + name='type', + field=models.TextField(blank=True, null=True), + ), migrations.AddField( model_name='recording', name='computed_species', @@ -48,4 +53,9 @@ class Migration(migrations.Migration): name='unusual_occurrences', field=models.TextField(blank=True, null=True), ), + migrations.AddField( + model_name='spectrogram', + name='colormap', + field=models.CharField(max_length=20, null=True), + ), ] diff --git a/bats_ai/core/migrations/0010_compressedspectrogram.py b/bats_ai/core/migrations/0010_compressedspectrogram.py new file mode 100644 index 0000000..28fe4fb --- /dev/null +++ b/bats_ai/core/migrations/0010_compressedspectrogram.py @@ -0,0 +1,84 @@ +# Generated by Django 4.1.13 on 2024-04-19 13:55 + +import django.contrib.postgres.fields +from django.db import migrations, models +import django.db.models.deletion +import django_extensions.db.fields + + +class Migration(migrations.Migration): + dependencies = [ + ('core', '0009_annotations_type_recording_computed_species_and_more'), + ] + + operations = [ + migrations.CreateModel( + name='CompressedSpectrogram', + fields=[ + ( + 'id', + models.BigAutoField( + auto_created=True, primary_key=True, serialize=False, verbose_name='ID' + ), + ), + ( + 'created', + django_extensions.db.fields.CreationDateTimeField( + auto_now_add=True, verbose_name='created' + ), + ), + ( + 'modified', + django_extensions.db.fields.ModificationDateTimeField( + auto_now=True, verbose_name='modified' + ), + ), + ('image_file', models.FileField(upload_to='')), + ('length', models.IntegerField()), + ( + 'starts', + django.contrib.postgres.fields.ArrayField( + base_field=django.contrib.postgres.fields.ArrayField( + base_field=models.IntegerField(), size=None + ), + size=None, + ), + ), + ( + 'stops', + django.contrib.postgres.fields.ArrayField( + base_field=django.contrib.postgres.fields.ArrayField( + base_field=models.IntegerField(), size=None + ), + size=None, + ), + ), + ( + 'widths', + django.contrib.postgres.fields.ArrayField( + base_field=django.contrib.postgres.fields.ArrayField( + base_field=models.IntegerField(), size=None + ), + size=None, + ), + ), + ('cache_invalidated', models.BooleanField(default=True)), + ( + 'recording', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to='core.recording' + ), + ), + ( + 'spectrogram', + models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, to='core.spectrogram' + ), + ), + ], + options={ + 'get_latest_by': 'modified', + 'abstract': False, + }, + ), + ] diff --git a/bats_ai/core/migrations/0011_spectrogram_colormap.py b/bats_ai/core/migrations/0011_spectrogram_colormap.py deleted file mode 100644 index bc22fd8..0000000 --- a/bats_ai/core/migrations/0011_spectrogram_colormap.py +++ /dev/null @@ -1,17 +0,0 @@ -# Generated by Django 4.1.13 on 2024-04-09 13:49 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ('core', '0010_recording_computed_species_recording_detector_and_more'), - ] - - operations = [ - migrations.AddField( - model_name='spectrogram', - name='colormap', - field=models.CharField(max_length=20, null=True), - ), - ] diff --git a/bats_ai/core/models/__init__.py b/bats_ai/core/models/__init__.py index 5979660..36e2f0f 100644 --- a/bats_ai/core/models/__init__.py +++ b/bats_ai/core/models/__init__.py @@ -1,4 +1,5 @@ from .annotations import Annotations +from .compressed_spectrogram import CompressedSpectrogram from .grts_cells import GRTSCells from .image import Image from .recording import Recording, colormap @@ -17,4 +18,5 @@ 'TemporalAnnotations', 'GRTSCells', 'colormap', + 'CompressedSpectrogram', ] diff --git a/bats_ai/core/models/compressed_spectrogram.py b/bats_ai/core/models/compressed_spectrogram.py new file mode 100644 index 0000000..5dd57c1 --- /dev/null +++ b/bats_ai/core/models/compressed_spectrogram.py @@ -0,0 +1,112 @@ +from PIL import Image +import cv2 +from django.contrib.postgres.fields import ArrayField +from django.core.files.storage import default_storage +from django.db import models +from django.dispatch import receiver +from django_extensions.db.models import TimeStampedModel +import numpy as np + +from .recording import Recording +from .spectrogram import Spectrogram + +FREQ_MIN = 5e3 +FREQ_MAX = 120e3 +FREQ_PAD = 2e3 + + +# TimeStampedModel also provides "created" and "modified" fields +class CompressedSpectrogram(TimeStampedModel, models.Model): + recording = models.ForeignKey(Recording, on_delete=models.CASCADE) + spectrogram = models.ForeignKey(Spectrogram, on_delete=models.CASCADE) + image_file = models.FileField() + length = models.IntegerField() + starts = ArrayField(ArrayField(models.IntegerField())) + stops = ArrayField(ArrayField(models.IntegerField())) + widths = ArrayField(ArrayField(models.IntegerField())) + cache_invalidated = models.BooleanField(default=True) + + @property + def image_url(self): + return default_storage.url(self.image_file.name) + + def predict(self): + import json + import os + + import onnx + import onnxruntime as ort + import tqdm + + img = Image.open(self.image_file) + + relative = ('..',) * 4 + asset_path = os.path.abspath(os.path.join(__file__, *relative, 'assets')) + + onnx_filename = os.path.join(asset_path, 'model.mobilenet.onnx') + assert os.path.exists(onnx_filename) + + session = ort.InferenceSession( + onnx_filename, + providers=[ + ( + 'CUDAExecutionProvider', + { + 'cudnn_conv_use_max_workspace': '1', + 'device_id': 0, + 'cudnn_conv_algo_search': 'HEURISTIC', + }, + ), + 'CPUExecutionProvider', + ], + ) + + img = np.array(img) + + h, w, c = img.shape + ratio_y = 224 / h + ratio_x = ratio_y * 0.5 + raw = cv2.resize(img, None, fx=ratio_x, fy=ratio_y, interpolation=cv2.INTER_LANCZOS4) + + h, w, c = raw.shape + if w <= h: + canvas = np.zeros((h, h + 1, 3), dtype=raw.dtype) + canvas[:, :w, :] = raw + raw = canvas + h, w, c = raw.shape + + inputs_ = [] + for index in range(0, w - h, 100): + inputs_.append(raw[:, index : index + h, :]) + inputs_.append(raw[:, -h:, :]) + inputs_ = np.array(inputs_) + + chunksize = 1 + chunks = np.array_split(inputs_, np.arange(chunksize, len(inputs_), chunksize)) + outputs = [] + for chunk in tqdm.tqdm(chunks, desc='Inference'): + outputs_ = session.run( + None, + {'input': chunk}, + ) + outputs.append(outputs_[0]) + outputs = np.vstack(outputs) + outputs = outputs.mean(axis=0) + + model = onnx.load(onnx_filename) + mapping = json.loads(model.metadata_props[0].value) + labels = [mapping['forward'][str(index)] for index in range(len(mapping['forward']))] + + prediction = np.argmax(outputs) + label = labels[prediction] + score = outputs[prediction] + + confs = dict(zip(labels, outputs)) + + return label, score, confs + + +@receiver(models.signals.pre_delete, sender=Spectrogram) +def delete_content(sender, instance, **kwargs): + if instance.image_file: + instance.image_file.delete(save=False) diff --git a/bats_ai/core/models/recording.py b/bats_ai/core/models/recording.py index a4574bc..4eb3b42 100644 --- a/bats_ai/core/models/recording.py +++ b/bats_ai/core/models/recording.py @@ -2,6 +2,7 @@ from django.contrib.auth.models import User from django.contrib.gis.db import models +from django.dispatch import receiver from django_extensions.db.models import TimeStampedModel from .species import Species @@ -67,17 +68,17 @@ def spectrograms(self): @property def spectrogram(self): - from bats_ai.core.models import Spectrogram + pass spectrograms = self.spectrograms - if len(spectrograms) == 0: - Spectrogram.generate(self, colormap=COLORMAP) - - spectrograms = self.spectrograms - assert len(spectrograms) == 1 - assert len(spectrograms) >= 1 spectrogram = spectrograms[0] # most recently created return spectrogram + + +@receiver(models.signals.pre_delete, sender=Recording) +def delete_content(sender, instance, **kwargs): + if instance.audio_file: + instance.audio_file.delete(save=False) diff --git a/bats_ai/core/models/spectrogram.py b/bats_ai/core/models/spectrogram.py index fc127fc..eaedefb 100644 --- a/bats_ai/core/models/spectrogram.py +++ b/bats_ai/core/models/spectrogram.py @@ -6,16 +6,17 @@ from PIL import Image import cv2 from django.core.files import File +from django.core.files.storage import default_storage from django.db import models from django.db.models.fields.files import FieldFile +from django.dispatch import receiver from django_extensions.db.models import TimeStampedModel import librosa import matplotlib.pyplot as plt import numpy as np -import scipy import tqdm -from bats_ai.core.models import Annotations, Recording +from .recording import Recording logger = logging.getLogger(__name__) @@ -114,7 +115,7 @@ def generate(cls, recording, colormap=None, dpi=520): vmin = window.min() vmax = window.max() - chunksize = int(5e3) + chunksize = int(2e3) arange = np.arange(chunksize, window.shape[1], chunksize) chunks = np.array_split(window, arange, axis=1) @@ -218,7 +219,7 @@ def generate(cls, recording, colormap=None, dpi=520): img.save(buf, format='JPEG', quality=80) buf.seek(0) - name = 'spectrogram.jpg' + name = f'{recording.pk}_{colormap}_spectrogram.jpg' image_file = File(buf, name=name) spectrogram = cls( @@ -232,148 +233,7 @@ def generate(cls, recording, colormap=None, dpi=520): colormap=colormap, ) spectrogram.save() - - @property - def compressed(self): - img = self.image_np - - annotations = Annotations.objects.filter(recording=self.recording) - - threshold = 0.5 - while True: - canvas = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) - canvas = canvas.astype(np.float32) - - is_light = np.median(canvas) > 128.0 - if is_light: - canvas = 255.0 - canvas - - amplitude = canvas.max(axis=0) - amplitude -= amplitude.min() - amplitude /= amplitude.max() - amplitude[amplitude < threshold] = 0.0 - amplitude[amplitude > 0] = 1.0 - amplitude = amplitude.reshape(1, -1) - - canvas -= canvas.min() - canvas /= canvas.max() - canvas *= 255.0 - canvas *= amplitude - canvas = np.around(canvas).astype(np.uint8) - - width = canvas.shape[1] - for annotation in annotations: - start = annotation.start_time / self.duration - stop = annotation.end_time / self.duration - - start = int(np.around(start * width)) - stop = int(np.around(stop * width)) - canvas[:, start : stop + 1] = 255.0 - - mask = canvas.max(axis=0) - mask = scipy.signal.medfilt(mask, 3) - mask[0] = 0 - mask[-1] = 0 - - starts = [] - stops = [] - for index in range(1, len(mask) - 1): - value_pre = mask[index - 1] - value = mask[index] - value_post = mask[index + 1] - if value != 0: - if value_pre == 0: - starts.append(index) - if value_post == 0: - stops.append(index) - assert len(starts) == len(stops) - - starts = [val - 40 for val in starts] # 10 ms buffer - stops = [val + 40 for val in stops] # 10 ms buffer - ranges = list(zip(starts, stops)) - - while True: - found = False - merged = [] - index = 0 - while index < len(ranges) - 1: - start1, stop1 = ranges[index] - start2, stop2 = ranges[index + 1] - - start1 = min(max(start1, 0), len(mask)) - start2 = min(max(start2, 0), len(mask)) - stop1 = min(max(stop1, 0), len(mask)) - stop2 = min(max(stop2, 0), len(mask)) - - if stop1 >= start2: - found = True - merged.append((start1, stop2)) - index += 2 - else: - merged.append((start1, stop1)) - index += 1 - if index == len(ranges) - 1: - merged.append((start2, stop2)) - ranges = merged - if not found: - for index in range(1, len(ranges)): - start1, stop1 = ranges[index - 1] - start2, stop2 = ranges[index] - assert start1 < stop1 - assert start2 < stop2 - assert start1 < start2 - assert stop1 < stop2 - assert stop1 < start2 - break - - segments = [] - starts_ = [] - stops_ = [] - domain = img.shape[1] - widths = [] - total_width = 0 - for start, stop in ranges: - segment = img[:, start:stop] - segments.append(segment) - - starts_.append(int(round(self.duration * (start / domain)))) - stops_.append(int(round(self.duration * (stop / domain)))) - widths.append(stop - start) - total_width += stop - start - - # buffer = np.zeros((len(img), 20, 3), dtype=img.dtype) - # segments.append(buffer) - # segments = segments[:-1] - - if len(segments) > 0: - break - - threshold -= 0.05 - if threshold < 0: - segments = None - break - - if segments is None: - canvas = img.copy() - else: - canvas = np.hstack(segments) - - canvas = Image.fromarray(canvas, 'RGB') - - # canvas.save('temp.compressed.jpg') - - buf = io.BytesIO() - canvas.save(buf, format='JPEG', quality=80) - buf.seek(0) - img_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') - - metadata = { - 'starts': starts_, - 'stops': stops_, - 'widths': widths, - 'length': total_width, - } - return canvas, img_base64, metadata + return spectrogram.pk @property def image_np(self): @@ -395,77 +255,12 @@ def base64(self): return img_base64 - def predict(self): - import json - import os - - import onnx - import onnxruntime as ort - import tqdm - - img, _, _ = self.compressed - - relative = ('..',) * 4 - asset_path = os.path.abspath(os.path.join(__file__, *relative, 'assets')) - - onnx_filename = os.path.join(asset_path, 'model.mobilenet.onnx') - assert os.path.exists(onnx_filename) - - session = ort.InferenceSession( - onnx_filename, - providers=[ - ( - 'CUDAExecutionProvider', - { - 'cudnn_conv_use_max_workspace': '1', - 'device_id': 0, - 'cudnn_conv_algo_search': 'HEURISTIC', - }, - ), - 'CPUExecutionProvider', - ], - ) + @property + def image_url(self): + return default_storage.url(self.image_file.name) + - img = np.array(img) - - h, w, c = img.shape - ratio_y = 224 / h - ratio_x = ratio_y * 0.5 - raw = cv2.resize(img, None, fx=ratio_x, fy=ratio_y, interpolation=cv2.INTER_LANCZOS4) - - h, w, c = raw.shape - if w <= h: - canvas = np.zeros((h, h + 1, 3), dtype=raw.dtype) - canvas[:, :w, :] = raw - raw = canvas - h, w, c = raw.shape - - inputs_ = [] - for index in range(0, w - h, 100): - inputs_.append(raw[:, index : index + h, :]) - inputs_.append(raw[:, -h:, :]) - inputs_ = np.array(inputs_) - - chunksize = 1 - chunks = np.array_split(inputs_, np.arange(chunksize, len(inputs_), chunksize)) - outputs = [] - for chunk in tqdm.tqdm(chunks, desc='Inference'): - outputs_ = session.run( - None, - {'input': chunk}, - ) - outputs.append(outputs_[0]) - outputs = np.vstack(outputs) - outputs = outputs.mean(axis=0) - - model = onnx.load(onnx_filename) - mapping = json.loads(model.metadata_props[0].value) - labels = [mapping['forward'][str(index)] for index in range(len(mapping['forward']))] - - prediction = np.argmax(outputs) - label = labels[prediction] - score = outputs[prediction] - - confs = dict(zip(labels, outputs)) - - return label, score, confs +@receiver(models.signals.pre_delete, sender=Spectrogram) +def delete_content(sender, instance, **kwargs): + if instance.image_file: + instance.image_file.delete(save=False) diff --git a/bats_ai/core/rest/__init__.py b/bats_ai/core/rest/__init__.py index ecd2d00..73bcd90 100644 --- a/bats_ai/core/rest/__init__.py +++ b/bats_ai/core/rest/__init__.py @@ -1,10 +1,8 @@ from rest_framework import routers -from .image import ImageViewSet from .spectrogram import SpectrogramViewSet -__all__ = ['ImageViewSet', 'SpectrogramViewSet'] +__all__ = ['SpectrogramViewSet'] rest = routers.SimpleRouter() -rest.register(r'images', ImageViewSet) rest.register(r'spectrograms', SpectrogramViewSet) diff --git a/bats_ai/core/rest/image.py b/bats_ai/core/rest/image.py deleted file mode 100644 index d749110..0000000 --- a/bats_ai/core/rest/image.py +++ /dev/null @@ -1,42 +0,0 @@ -from django.http import HttpResponseRedirect -from django_filters import rest_framework as filters -from rest_framework import serializers, status -from rest_framework.decorators import action -from rest_framework.pagination import PageNumberPagination -from rest_framework.permissions import IsAuthenticatedOrReadOnly -from rest_framework.response import Response -from rest_framework.viewsets import ReadOnlyModelViewSet - -from bats_ai.core.models import Image -from bats_ai.core.tasks import image_compute_checksum - - -class ImageSerializer(serializers.ModelSerializer): - class Meta: - model = Image - fields = ['id', 'name', 'checksum', 'created', 'owner'] - read_only_fields = ['checksum', 'created'] - - -class ImageViewSet(ReadOnlyModelViewSet): - queryset = Image.objects.all() - - permission_classes = [IsAuthenticatedOrReadOnly] - serializer_class = ImageSerializer - - filter_backends = [filters.DjangoFilterBackend] - filterset_fields = ['name', 'checksum'] - - pagination_class = PageNumberPagination - - @action(detail=True, methods=['get']) - def download(self, request, pk=None): - image = self.get_object() - return HttpResponseRedirect(image.blob.url) - - @action(detail=True, methods=['post']) - def compute(self, request, pk=None): - # Ensure that the image exists, so a non-existent pk isn't dispatched - image = self.get_object() - image_compute_checksum.delay(image.pk) - return Response('', status=status.HTTP_202_ACCEPTED) diff --git a/bats_ai/core/tasks.py b/bats_ai/core/tasks.py index 376adfe..18604e2 100644 --- a/bats_ai/core/tasks.py +++ b/bats_ai/core/tasks.py @@ -1,6 +1,159 @@ +import io +import tempfile + +from PIL import Image from celery import shared_task +import cv2 +from django.core.files import File +import numpy as np +import scipy + +from bats_ai.core.models import Annotations, CompressedSpectrogram, Recording, Spectrogram, colormap + + +def generate_compressed(recording: Recording, spectrogram: Spectrogram): + img = spectrogram.image_np + annotations = Annotations.objects.filter(recording=recording) + + threshold = 0.5 + while True: + canvas = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + canvas = canvas.astype(np.float32) + + is_light = np.median(canvas) > 128.0 + if is_light: + canvas = 255.0 - canvas + + amplitude = canvas.max(axis=0) + amplitude -= amplitude.min() + amplitude /= amplitude.max() + amplitude[amplitude < threshold] = 0.0 + amplitude[amplitude > 0] = 1.0 + amplitude = amplitude.reshape(1, -1) + + canvas -= canvas.min() + canvas /= canvas.max() + canvas *= 255.0 + canvas *= amplitude + canvas = np.around(canvas).astype(np.uint8) + + width = canvas.shape[1] + for annotation in annotations: + start = annotation.start_time / spectrogram.duration + stop = annotation.end_time / spectrogram.duration + + start = int(np.around(start * width)) + stop = int(np.around(stop * width)) + canvas[:, start : stop + 1] = 255.0 + + mask = canvas.max(axis=0) + mask = scipy.signal.medfilt(mask, 3) + mask[0] = 0 + mask[-1] = 0 + + starts = [] + stops = [] + for index in range(1, len(mask) - 1): + value_pre = mask[index - 1] + value = mask[index] + value_post = mask[index + 1] + if value != 0: + if value_pre == 0: + starts.append(index) + if value_post == 0: + stops.append(index) + assert len(starts) == len(stops) + + starts = [val - 40 for val in starts] # 10 ms buffer + stops = [val + 40 for val in stops] # 10 ms buffer + ranges = list(zip(starts, stops)) + + while True: + found = False + merged = [] + index = 0 + while index < len(ranges) - 1: + start1, stop1 = ranges[index] + start2, stop2 = ranges[index + 1] + + start1 = min(max(start1, 0), len(mask)) + start2 = min(max(start2, 0), len(mask)) + stop1 = min(max(stop1, 0), len(mask)) + stop2 = min(max(stop2, 0), len(mask)) + + if stop1 >= start2: + found = True + merged.append((start1, stop2)) + index += 2 + else: + merged.append((start1, stop1)) + index += 1 + if index == len(ranges) - 1: + merged.append((start2, stop2)) + ranges = merged + if not found: + for index in range(1, len(ranges)): + start1, stop1 = ranges[index - 1] + start2, stop2 = ranges[index] + assert start1 < stop1 + assert start2 < stop2 + assert start1 < start2 + assert stop1 < stop2 + assert stop1 < start2 + break + + segments = [] + starts_ = [] + stops_ = [] + domain = img.shape[1] + widths = [] + total_width = 0 + for start, stop in ranges: + segment = img[:, start:stop] + segments.append(segment) -from bats_ai.core.models import Image, Recording, colormap + starts_.append(int(round(spectrogram.duration * (start / domain)))) + stops_.append(int(round(spectrogram.duration * (stop / domain)))) + widths.append(stop - start) + total_width += stop - start + + # buffer = np.zeros((len(img), 20, 3), dtype=img.dtype) + # segments.append(buffer) + # segments = segments[:-1] + + if len(segments) > 0: + break + + threshold -= 0.05 + if threshold < 0: + segments = None + break + + if segments is None: + canvas = img.copy() + else: + canvas = np.hstack(segments) + + canvas = Image.fromarray(canvas, 'RGB') + buf = io.BytesIO() + canvas.save(buf, format='JPEG', quality=80) + buf.seek(0) + + # Use temporary files + with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file: + temp_file_name = temp_file.name + canvas.save(temp_file_name) + + # Read the temporary file + with open(temp_file_name, 'rb') as f: + temp_file_content = f.read() + + # Wrap the content in BytesIO + buf = io.BytesIO(temp_file_content) + name = f'{spectrogram.pk}_spectrogram_compressed.jpg' + image_file = File(buf, name=name) + + return total_width, image_file, widths, starts_, stops_ @shared_task @@ -16,12 +169,48 @@ def recording_compute_spectrogram(recording_id: int): cmaps = [ None, # Default (dark) spectrogram - 'inference', # Machine learning inference spectrogram 'light', # Light spectrogram ] - + spectrogram_id = None for cmap in cmaps: with colormap(cmap): - if not recording.has_spectrogram: - recording.spectrogram # compute by simply referenceing the attribute - assert recording.has_spectrogram + spectrogram_id_temp = Spectrogram.generate(recording, cmap) + if cmap is None: + spectrogram_id = spectrogram_id_temp + if spectrogram_id is not None: + generate_compress_spectrogram.delay(recording_id, spectrogram_id) + + +@shared_task +def generate_compress_spectrogram(recording_id: int, spectrogram_id: int): + recording = Recording.objects.get(pk=recording_id) + spectrogram = Spectrogram.objects.get(pk=spectrogram_id) + length, image_file, widths, starts, stops = generate_compressed(recording, spectrogram) + found = CompressedSpectrogram.objects.filter(recording=recording, spectrogram=spectrogram) + if found.exists(): + existing = found.first() + existing.length = length + existing.image_file = image_file + existing.widths = widths + existing.starts = starts + existing.stops = stops + existing.cache_invalidated = False + existing.save() + else: + CompressedSpectrogram.objects.create( + recording=recording, + spectrogram=spectrogram, + image_file=image_file, + length=length, + widths=widths, + starts=starts, + stops=stops, + cache_invalidated=False, + ) + + +@shared_task +def predict(compressed_spectrogram_id: int): + compressed_spectrogram = CompressedSpectrogram.objects.get(pk=compressed_spectrogram_id) + label, score, confs = compressed_spectrogram.predict() + return label, score, confs diff --git a/bats_ai/core/views/recording.py b/bats_ai/core/views/recording.py index 60e97be..44321ce 100644 --- a/bats_ai/core/views/recording.py +++ b/bats_ai/core/views/recording.py @@ -11,7 +11,14 @@ from ninja.files import UploadedFile from ninja.pagination import RouterPaginated -from bats_ai.core.models import Annotations, Recording, Species, TemporalAnnotations, colormap +from bats_ai.core.models import ( + Annotations, + CompressedSpectrogram, + Recording, + Species, + TemporalAnnotations, + colormap, +) from bats_ai.core.tasks import recording_compute_spectrogram from bats_ai.core.views.species import SpeciesSchema from bats_ai.core.views.temporal_annotations import ( @@ -268,11 +275,11 @@ def get_spectrogram(request: HttpRequest, id: int): except Recording.DoesNotExist: return {'error': 'Recording not found'} - with colormap(): + with colormap(None): spectrogram = recording.spectrogram spectro_data = { - 'base64_spectrogram': spectrogram.base64, + 'url': spectrogram.image_url, 'spectroInfo': { 'width': spectrogram.width, 'height': spectrogram.height, @@ -328,30 +335,29 @@ def get_spectrogram(request: HttpRequest, id: int): def get_spectrogram_compressed(request: HttpRequest, id: int): try: recording = Recording.objects.get(pk=id) - except Recording.DoesNotExist: - return {'error': 'Recording not found'} - - with colormap('inference'): - label, score, confs = recording.spectrogram.predict() - print(label, score, confs) + compressed_spectrogram = CompressedSpectrogram.objects.filter(recording=id).first() + except compressed_spectrogram.DoesNotExist: + return {'error': 'Compressed Spectrogram'} + except recording.DoesNotExist: + return {'error': 'Recording does not exist'} with colormap(): - spectrogram = recording.spectrogram - _, compressed_base64, metadata = spectrogram.compressed + label, score, confs = compressed_spectrogram.predict() + print(label, score, confs) spectro_data = { - 'base64_spectrogram': compressed_base64, + 'url': compressed_spectrogram.image_url, 'spectroInfo': { - 'width': spectrogram.width, + 'width': compressed_spectrogram.spectrogram.width, 'start_time': 0, - 'end_time': spectrogram.duration, - 'height': spectrogram.height, - 'low_freq': spectrogram.frequency_min, - 'high_freq': spectrogram.frequency_max, - 'start_times': metadata['starts'], - 'end_times': metadata['stops'], - 'widths': metadata['widths'], - 'compressedWidth': metadata['length'], + 'end_time': compressed_spectrogram.spectrogram.duration, + 'height': compressed_spectrogram.spectrogram.height, + 'low_freq': compressed_spectrogram.spectrogram.frequency_min, + 'high_freq': compressed_spectrogram.spectrogram.frequency_max, + 'start_times': compressed_spectrogram.starts, + 'end_times': compressed_spectrogram.stops, + 'widths': compressed_spectrogram.widths, + 'compressedWidth': compressed_spectrogram.length, }, } diff --git a/client/src/api/api.ts b/client/src/api/api.ts index 8754ee8..fea6f67 100644 --- a/client/src/api/api.ts +++ b/client/src/api/api.ts @@ -108,8 +108,7 @@ export interface UserInfo { id: number; } export interface Spectrogram { - 'base64_spectrogram': string; - url?: string; + url: string; filename?: string; annotations?: SpectrogramAnnotation[]; temporal?: SpectrogramTemporalAnnotation[]; diff --git a/client/src/components/ThumbnailViewer.vue b/client/src/components/ThumbnailViewer.vue index 3ba31f5..d766862 100644 --- a/client/src/components/ThumbnailViewer.vue +++ b/client/src/components/ThumbnailViewer.vue @@ -160,7 +160,7 @@ export default defineComponent({ clientHeight.value = containerRef.value.clientHeight; } if (containerRef.value && ! geoJS.getGeoViewer().value) { - geoJS.initializeViewer(containerRef.value, naturalWidth, naturalHeight, true); + geoJS.initializeViewer(containerRef.value, naturalWidth, naturalHeight, true); } const coords = geoJS.getGeoViewer().value.camera().worldToDisplay({x: 0, y:0}); const end = geoJS.getGeoViewer().value.camera().worldToDisplay({x: 0, y:naturalHeight}); @@ -235,8 +235,8 @@ export default defineComponent({ position: absolute; top: 50%; left: 50%; - -ms-transform: translate(-50%, -50%); - transform: translate(-50%, -50%); + // -ms-transform: translate(-50%, -50%); + // transform: translate(-50%, -50%); } .geojs-map.annotation-input { cursor: inherit; diff --git a/client/src/components/geoJS/geoJSUtils.ts b/client/src/components/geoJS/geoJSUtils.ts index c799cba..0269de5 100644 --- a/client/src/components/geoJS/geoJSUtils.ts +++ b/client/src/components/geoJS/geoJSUtils.ts @@ -17,6 +17,8 @@ const useGeoJS = () => { right: 1, }; + let originalDimensions = { width: 0, height: 0 }; + const getGeoViewer = () => { return geoViewer; }; @@ -29,6 +31,7 @@ const useGeoJS = () => { ) => { thumbnail.value = thumbanilVal; container.value = sourceContainer; + originalDimensions = {width, height }; const params = geo.util.pixelCoordinateParams(container.value, width, height); if (!container.value) { return; @@ -112,7 +115,7 @@ const useGeoJS = () => { .draw(); } if (resetCam) { - resetMapDimensions(width, height, 0.3,resetCam); + resetMapDimensions(width, height, 0.3, resetCam); } else { const params = geo.util.pixelCoordinateParams(container.value, width, height, width, height); const margin = 0.3; @@ -129,16 +132,21 @@ const useGeoJS = () => { }; const resetZoom = () => { - const { width: mapWidth } = geoViewer.value.camera().viewport; + const { width: mapWidth, } = geoViewer.value.camera().viewport; const bounds = !thumbnail.value ? { - left: -125, // Making sure the legend is on the screen - top: 0, - right: mapWidth, + left: 0, // Making sure the legend is on the screen + top: -(originalBounds.bottom - originalDimensions.height) / 2.0, + right: mapWidth*2, bottom: originalBounds.bottom, } - : originalBounds; + : { + left: 0, + top: 0, + right: originalDimensions.width, + bottom: originalDimensions.height, + }; const zoomAndCenter = geoViewer.value.zoomAndCenterFromBounds(bounds, 0); geoViewer.value.zoom(zoomAndCenter.zoom); geoViewer.value.center(zoomAndCenter.center); @@ -154,13 +162,13 @@ const useGeoJS = () => { }); const params = geo.util.pixelCoordinateParams(container.value, width, height, width, height); const { right, bottom } = params.map.maxBounds; - originalBounds = params.map.maxBounds; geoViewer.value.maxBounds({ left: 0 - right * margin, top: 0 - bottom * margin, right: right * (1 + margin), bottom: bottom * (1 + margin), }); + originalBounds = geoViewer.value.maxBounds(); geoViewer.value.zoomRange({ // do not set a min limit so that bounds clamping determines min min: -Infinity, diff --git a/client/src/views/Spectrogram.vue b/client/src/views/Spectrogram.vue index 2147323..81819a0 100644 --- a/client/src/views/Spectrogram.vue +++ b/client/src/views/Spectrogram.vue @@ -84,17 +84,25 @@ export default defineComponent({ }; const loadData = async () => { + loadedImage.value = false; const response = compressed.value ? await getSpectrogramCompressed(props.id) : await getSpectrogram(props.id); - image.value.src = `data:image/png;base64,${response.data["base64_spectrogram"]}`; + if (response.data['url']) { + image.value.src = response.data['url']; + } else { + // TODO Error Out if there is no URL + console.error('No URL found for the spectrogram'); + } + image.value.onload = () => { + loadedImage.value = true; + }; spectroInfo.value = response.data["spectroInfo"]; annotations.value = response.data["annotations"]?.sort((a, b) => a.start_time - b.start_time) || []; temporalAnnotations.value = response.data["temporal"]?.sort((a, b) => a.start_time - b.start_time) || []; if (response.data.currentUser) { currentUser.value = response.data.currentUser; } - loadedImage.value = true; const speciesResponse = await getSpecies(); speciesList.value = speciesResponse.data; if (response.data.otherUsers && spectroInfo.value) { diff --git a/client/yarn.lock b/client/yarn.lock index cf77368..2095fa3 100644 --- a/client/yarn.lock +++ b/client/yarn.lock @@ -2394,6 +2394,10 @@ es-define-property@^1.0.0: dependencies: get-intrinsic "^1.2.4" +esbuild-darwin-arm64@0.14.54: + version "0.14.54" + resolved "https://registry.npmjs.org/esbuild-darwin-arm64/-/esbuild-darwin-arm64-0.14.54.tgz" + integrity sha512-OPafJHD2oUPyvJMrsCvDGkRrVCar5aVyHfWGQzY1dWnzErjrDuSETxwA2HSsyg2jORLY8yBfzc1MIpUkXlctmw== es-errors@^1.3.0: version "1.3.0" resolved "https://registry.yarnpkg.com/es-errors/-/es-errors-1.3.0.tgz#05f75a25dab98e4fb1dcd5e1472c0546d5057c8f" diff --git a/dev/django.Dockerfile b/dev/django.Dockerfile index 5992260..7861630 100644 --- a/dev/django.Dockerfile +++ b/dev/django.Dockerfile @@ -28,6 +28,8 @@ COPY ./setup.py /opt/django-project/setup.py # Use a directory name which will never be an import name, as isort considers this as first-party. WORKDIR /opt/django-project +# hadolint ignore=DL3013 +RUN pip install --no-cache-dir --upgrade pip RUN set -ex \ && pip install --no-cache-dir -e .[dev] diff --git a/setup.py b/setup.py index 839f99c..02f4b0f 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,6 @@ 'django-allauth', 'django-configurations[database,email]', 'django-extensions', - 'django-filter', 'django-large-image', 'django-oauth-toolkit', 'djangorestframework', @@ -55,17 +54,22 @@ 'django-s3-file-field[boto3]<1', 'gunicorn', 'flower', - 'large-image[rasterio,pil]>=1.22', + # Spectrogram Generation 'librosa', 'matplotlib', 'mercantile', 'numpy', + # 'onnxruntime-gpu', 'onnx', 'onnxruntime', - # 'onnxruntime-gpu', 'opencv-python-headless', - 'rio-cogeo', 'tqdm', + # large image + 'django-large-image>=0.10.0', + 'large-image[rasterio,pil]>=1.22', + 'rio-cogeo', + # guano metadata + 'guano', ], extras_require={ 'dev': [