Skip to content

Commit

Permalink
Add a Pip-based CI testing for inference code, with Colab dependencies.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627135153
  • Loading branch information
sdenton4 authored and copybara-github committed May 11, 2024
1 parent c829814 commit 0df75a9
Show file tree
Hide file tree
Showing 5 changed files with 1,426 additions and 1,353 deletions.
40 changes: 40 additions & 0 deletions .github/install_colab_deps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Installs Colab dependencies for CI testing."""

from typing import Sequence

from absl import app
import requests


REQS_FILE = 'https://raw.githubusercontent.com/googlecolab/backend-info/main/pip-freeze.txt'
COLAB_REQS_FILE = '/tmp/colab_reqs.txt'


def main(unused_argv: Sequence[str]) -> None:
got = requests.get(REQS_FILE)
requirements_str = str(got.content, 'utf8')
# Skip the file:// lines, which we do not have access to.
lines = [
ln + '\n' for ln in requirements_str.split('\n') if 'file://' not in ln
]
with open(COLAB_REQS_FILE, 'w') as f:
f.writelines(lines)


if __name__ == '__main__':
app.run(main)
31 changes: 21 additions & 10 deletions chirp/inference/tests/bootstrap_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Tests for project state handling."""

import os
import shutil
import tempfile

from chirp import audio_utils
Expand All @@ -33,20 +34,29 @@

class BootstrapTest(absltest.TestCase):

def setUp(self):
super().setUp()
# `self.create_tempdir()` raises an UnparsedFlagAccessError, which is why
# we use `tempdir` directly.
self.tempdir = tempfile.mkdtemp()

def tearDown(self):
super().tearDown()
shutil.rmtree(self.tempdir)

def make_wav_files(self, classes, filenames):
# Create a pile of files.
rng = np.random.default_rng(seed=42)
tmpdir = self.create_tempdir()
for subdir in classes:
subdir_path = os.path.join(tmpdir.full_path, subdir)
subdir_path = os.path.join(self.tempdir, subdir)
os.mkdir(subdir_path)
for filename in filenames:
with open(
os.path.join(subdir_path, f'{filename}_{subdir}.wav'), 'wb'
) as f:
noise = rng.normal(scale=0.2, size=16000)
wavfile.write(f, 16000, noise)
audio_glob = os.path.join(tmpdir.full_path, '*/*.wav')
audio_glob = os.path.join(self.tempdir, '*/*.wav')
return audio_glob

def write_placeholder_embeddings(self, audio_glob, source_infos, embed_dir):
Expand Down Expand Up @@ -129,15 +139,16 @@ def test_bootstrap_from_embeddings(self):
source_infos = embed_lib.create_source_infos([audio_glob], shard_len_s=5.0)
self.assertLen(source_infos, len(classes) * len(filenames))

embed_dir = self.create_tempdir()
labeled_dir = self.create_tempdir()
self.write_placeholder_embeddings(
audio_glob, source_infos, embed_dir.full_path
)
embed_dir = os.path.join(self.tempdir, 'embeddings')
labeled_dir = os.path.join(self.tempdir, 'labeled')
epath.Path(embed_dir).mkdir(parents=True, exist_ok=True)
epath.Path(labeled_dir).mkdir(parents=True, exist_ok=True)

self.write_placeholder_embeddings(audio_glob, source_infos, embed_dir)

bootstrap_config = bootstrap.BootstrapConfig.load_from_embedding_path(
embeddings_path=embed_dir.full_path,
annotated_path=labeled_dir.full_path,
embeddings_path=embed_dir,
annotated_path=labeled_dir,
)
print('config hash : ', bootstrap_config.embedding_config_hash())

Expand Down
7 changes: 6 additions & 1 deletion chirp/inference/tests/embed_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,12 @@ def test_logits_output_head(self):
logits_model = _make_output_head_model(
'/tmp/logits_model', embedding_dim=128
)
base_outputs = base_model.embed(np.zeros(5 * 22050))
base_outputs = base_model.embed(np.zeros(5 * 22050, dtype=np.float32))
print('Keras model: ', logits_model.logits_model)
print(
'logits model output: ',
logits_model(np.zeros([1, 128], dtype=np.float32)),
)
updated_outputs = logits_model.add_logits(base_outputs, keep_original=True)
self.assertSequenceEqual(
updated_outputs.logits['other_label'].shape,
Expand Down
Loading

0 comments on commit 0df75a9

Please sign in to comment.