Skip to content

Commit

Permalink
Got tests to run, but failing
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaschoi03 committed Mar 12, 2024
1 parent 65795d4 commit 0713f66
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 62 deletions.
Binary file not shown.
44 changes: 39 additions & 5 deletions src/classifier/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,57 @@
from pymtl3.stdlib import stream
from pymtl3.passes.backends.verilog import *
from os import path

from src.serdes.deserializer import Deserializer
from src.serdes.serializer import Serializer

# Pymtl3 harness for the `Classifier` module.
class Classifier(VerilogPlaceholder, Component):
# Constructor

def construct(s, BIT_WIDTH=32, DECIMAL_PT = 16, N_SAMPLES = 8, CUTOFF_FREQ = 65536000, CUTOFF_MAG = 1310720, SAMPLING_FREQUENCY = 44000):
# Interface

s.recv_msg = [InPort(BIT_WIDTH) for _ in range(N_SAMPLES)]
s.recv_val = InPort()
s.recv_rdy = OutPort()

s.recv = stream.ifcs.RecvIfcRTL( mk_bits(BIT_WIDTH*N_SAMPLES) )
s.send = stream.ifcs.SendIfcRTL(mk_bits(1))
s.send_msg = OutPort()
s.send_val = OutPort()
s.send_rdy = InPort()

# Name of the top level module to be imported
s.set_metadata(VerilogPlaceholderPass.top_module, "HarnessClassifier")
s.set_metadata(VerilogPlaceholderPass.top_module, "Classifier")
# Source file path
# The ../ is necessary here because pytest is run from the build directory
s.set_metadata(
VerilogPlaceholderPass.src_file,
path.join(path.dirname(__file__), "harness/classifier.v"),
path.join(path.dirname(__file__), "classifier.v"),
)

class ClassifierWrapper(Component):
def construct(s, BIT_WIDTH, DECIMAL_PT, N_SAMPLES, CUTOFF_FREQ, CUTOFF_MAG, SAMPLING_FREQUENCY):
s.recv = stream.ifcs.RecvIfcRTL(mk_bits(BIT_WIDTH))
s.send = stream.ifcs.SendIfcRTL(mk_bits(1))

# Hook up a deserializer
s.deserializer = Deserializer(BIT_WIDTH, N_SAMPLES)
s.recv.msg //= s.deserializer.recv_msg
s.recv.val //= s.deserializer.recv_val
s.deserializer.recv_rdy //= s.recv.rdy

# Hook up the FFT
s.dut = Classifier(BIT_WIDTH, DECIMAL_PT, N_SAMPLES, CUTOFF_FREQ, CUTOFF_MAG, SAMPLING_FREQUENCY)

s.dut.send_msg //= s.send.msg
s.dut.send_val //= s.send.val
s.send.rdy //= s.dut.send_rdy

# Hook up the deserializer to the FFT
for i in range(N_SAMPLES):
s.deserializer.send_msg[i] //= s.dut.recv_msg[i]

s.deserializer.send_val //= s.dut.recv_val
s.dut.recv_rdy //= s.deserializer.send_rdy

def line_trace(s):
return f"{s.deserializer.line_trace()} > {s.dut.line_trace()} > {s.serializer.line_trace()}"
45 changes: 0 additions & 45 deletions src/classifier/harness/classifier.v

This file was deleted.

54 changes: 42 additions & 12 deletions src/classifier/tests/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from pymtl3.stdlib.test_utils import mk_test_case_table, run_sim
from tools.utils import mk_packed
from src.classifier.classifier import Classifier
from src.classifier.classifier import ClassifierWrapper
import numpy as np
from fixedpt import Fixed
import wave

# ----------------------------------------------------------------------
# Helper Functions
Expand All @@ -25,16 +27,31 @@ def make_arr_fixed (n, d, a):
def pack_msg (n, arr):
return mk_packed(n)(arr)

def read_wav_file(file_path):
with wave.open(file_path, 'rb') as wav_file:
audio_data = wav_file.readframes(-1)
audio_array = np.frombuffer(audio_data, dtype=np.int16)
sample_rate = wav_file.getframerate()
num_channels = wav_file.getnchannels()

return audio_array, sample_rate, num_channels

# Cast a Fixed object to a Bits object
def fixed_bits(f: Fixed) -> Bits:
value = f.get()

return mk_bits(len(f))(value)

# -------------------------------------------------------------------------
# TestHarness
# -------------------------------------------------------------------------
class TestHarness(Component):
def construct(s, classifier, BIT_WIDTH=32, DECIMAL_PT = 16, N_SAMPLES = 8, CUTOFF_FREQ = 65536000, CUTOFF_MAG = 1310720, SAMPLING_FREQUENCY = 44000):
# Instantiate models

s.src = stream.SourceRTL(mk_bits(BIT_WIDTH*N_SAMPLES))
s.src = stream.SourceRTL(mk_bits(BIT_WIDTH))
s.sink = stream.SinkRTL(mk_bits(1))
s.classifier = classifier
s.classifier = ClassifierWrapper(BIT_WIDTH, DECIMAL_PT, N_SAMPLES, CUTOFF_FREQ, CUTOFF_MAG, SAMPLING_FREQUENCY)

# Connect

Expand All @@ -44,20 +61,33 @@ def construct(s, classifier, BIT_WIDTH=32, DECIMAL_PT = 16, N_SAMPLES = 8, CUTOF
def done(s):
return s.src.done() and s.sink.done()

def simple_test():
def false_test():
audio_array = np.array([0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5,0,0.5,1,0.5])
audio_array = audio_array - 0.5
sample_rate = 50000
frq_arr = np.fft.fft(audio_array)
real_part = frq_arr.real
return [real_part, 0]

def true_test():
file_path = '/home/tic3/c2s2_ip/src/classifier/audio_files/ABS_MCBY_-YWX_MixPre-672.WAV'
audio_array, sample_rate, num_channels = read_wav_file(file_path)
audio_array_0 = audio_array[:64]
audio_array_1 = audio_array[64:128]
audio_array_2 = audio_array[128:192]
audio_array_3 = audio_array[192:256]
frq_arr = np.fft.fft(audio_array_0)
real_part = frq_arr.real
return [real_part, 1]


test_case_table = mk_test_case_table(
[
(
"msgs src_delay sink_delay BIT_WIDTH DECIMAL_PT N_SAMPLES CUTOFF_FREQ CUTOFF_MAG SAMPLING_FREQUENCY slow"
),
["simple_test", simple_test, 4, 4, 32, 16, 64, 65536000, 1310720, 50000, False],
#["false_test", false_test, 4, 4, 32, 16, 64, 65536000, 1310720, 5000, False],
["true_test", true_test, 4, 4, 32, 16, 64, 65536000, 1310720, 96000, False],
]
)

Expand All @@ -67,24 +97,24 @@ def test(test_params, cmdline_opts):
Classifier(test_params.BIT_WIDTH, test_params.DECIMAL_PT, test_params.N_SAMPLES, test_params.CUTOFF_FREQ, test_params.CUTOFF_MAG, test_params.SAMPLING_FREQUENCY),
test_params.BIT_WIDTH, test_params.DECIMAL_PT, test_params.N_SAMPLES, test_params.CUTOFF_FREQ, test_params.CUTOFF_MAG, test_params.SAMPLING_FREQUENCY
)

msgs = test_params.msgs()
print(msgs)
msgs = [make_arr_fixed(test_params.BIT_WIDTH, test_params.DECIMAL_PT, x) if i%2 != 0 else x for i, x in enumerate(msgs, start=1)]
print(msgs)
msgs = [mk_packed(test_params.BIT_WIDTH)(*x) if i%2 != 0 else x for i, x in enumerate(msgs, start=1)]
print(msgs)
inputs = [[Fixed(x, True, test_params.BIT_WIDTH, test_params.DECIMAL_PT) for x in sample] for sample in msgs[::2]]
outputs = [x for x in msgs[1::2]]

inputs = [fixed_bits(x) for sample in inputs for x in sample]
outputs = [x for x in outputs]

th.set_param(
"top.src.construct",
msgs=msgs[::2],
msgs=inputs,
initial_delay=test_params.src_delay,
interval_delay=test_params.src_delay,
)

th.set_param(
"top.sink.construct",
msgs=msgs[1::2],
msgs=outputs,
initial_delay=test_params.sink_delay,
interval_delay=test_params.sink_delay,
)
Expand Down

0 comments on commit 0713f66

Please sign in to comment.