Skip to content

Commit

Permalink
[test] Fix FP units rdy signals
Browse files Browse the repository at this point in the history
  • Loading branch information
tancheng committed Jan 3, 2025
1 parent 3ce0af4 commit 418e753
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 250 deletions.
132 changes: 0 additions & 132 deletions fu/flexible/translate/FlexibleFuRTL_test.py

This file was deleted.

Empty file removed fu/flexible/translate/__init__.py
Empty file.
14 changes: 8 additions & 6 deletions fu/float/FpAddRTL.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class FpAddRTL(Fu):
def construct(s, DataType, PredicateType, CtrlType,
num_inports, num_outports, data_mem_size, exp_nbits = 4,
sig_nbits = 11):

super(FpAddRTL, s).construct(DataType, PredicateType, CtrlType,
num_inports, num_outports,
data_mem_size)
Expand All @@ -36,24 +37,25 @@ def construct(s, DataType, PredicateType, CtrlType,
num_entries = 2
FuInType = mk_bits(clog2(num_inports + 1))
CountType = mk_bits(clog2(num_entries + 1))

# TODO: parameterize rounding mode
s.rounding_mode = 0b000
s.FLOATING_ONE = concat(
b1(0), mk_bits(exp_nbits)(2**(exp_nbits-1)-1),
mk_bits(sig_nbits)() )

# Components
s.fadd = AddFN( exp_nbits+1, sig_nbits )
s.fadd = AddFN(exp_nbits + 1, sig_nbits)
s.fadd.roundingMode //= s.rounding_mode
s.fadd.subOp //= lambda: s.recv_opt.msg.ctrl == OPT_FSUB

# Wires
s.in0 = Wire( FuInType )
s.in1 = Wire( FuInType )
s.in0 = Wire(FuInType)
s.in1 = Wire(FuInType)

idx_nbits = clog2( num_inports )
s.in0_idx = Wire( idx_nbits )
s.in1_idx = Wire( idx_nbits )
idx_nbits = clog2(num_inports)
s.in0_idx = Wire(idx_nbits)
s.in1_idx = Wire(idx_nbits)

s.in0_idx //= s.in0[0:idx_nbits]
s.in1_idx //= s.in1[0:idx_nbits]
Expand Down
22 changes: 10 additions & 12 deletions fu/float/FpMulRTL.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def construct(s, DataType, PredicateType, CtrlType,
s.rounding_mode = 0b000

# Components
s.fmul = MulFN(exp_nbits+1, sig_nbits)
s.fmul = MulFN(exp_nbits + 1, sig_nbits)
s.fmul.roundingMode //= s.rounding_mode

# Wires
Expand All @@ -61,29 +61,27 @@ def construct(s, DataType, PredicateType, CtrlType,
@update
def comb_logic():

s.recv_all_val @= 0
# For pick input register
s.in0 @= 0
s.in1 @= 0
for i in range( num_inports ):
for i in range(num_inports):
s.recv_in[i].rdy @= b1(0)

for i in range( num_outports ):
s.send_out[i].en @= s.recv_opt.en
for i in range(num_outports):
s.send_out[i].val @= 0
s.send_out[i].msg @= DataType()

s.recv_const.rdy @= 0
s.recv_predicate.rdy @= b1(0)
s.recv_opt.rdy @= 0

if s.recv_opt.en:
if s.recv_opt.val & s.send_out[0].rdy:
if s.recv_opt.msg.fu_in[0] != 0:
s.in0 @= zext(s.recv_opt.msg.fu_in[0] - 1, FuInType)
if s.recv_opt.msg.fu_in[1] != 0:
s.in1 @= zext(s.recv_opt.msg.fu_in[1] - 1, FuInType)

s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \
s.recv_in[s.in1_idx].msg.predicate

if s.recv_opt.val:
if s.recv_opt.msg.ctrl == OPT_FMUL:
s.fmul.a @= s.recv_in[s.in0_idx].msg.payload
Expand All @@ -105,16 +103,16 @@ def comb_logic():
s.send_out[0].msg.predicate @= s.recv_in[s.in0_idx].msg.predicate & \
(~s.recv_opt.msg.predicate | \
s.recv_predicate.msg.predicate)
s.recv_all_val @= s.recv_in[s.in0_idx].val & s.recv_in[s.in1_idx].val & \
s.recv_all_val @= s.recv_in[s.in0_idx].val & \
((s.recv_opt.msg.predicate == b1(0)) | s.recv_predicate.val)
s.send_out[0].val @= s.recv_all_val
s.recv_in[s.in0_idx].rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_in[s.in1_idx].rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_const.rdy @= s.recv_all_val & s.send_out[0].rdy
s.recv_opt.rdy @= s.recv_all_val & s.send_out[0].rdy

else:
for j in range( num_outports ):
s.send_out[j].val @= b1( 0 )
for j in range(num_outports):
s.send_out[j].val @= b1(0)
s.recv_opt.rdy @= 0
s.recv_in[s.in0_idx].rdy @= 0
s.recv_in[s.in1_idx].rdy @= 0
Expand Down
96 changes: 47 additions & 49 deletions fu/float/test/FpAddRTL_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Date : Aug 8, 2023
"""


from pymtl3 import *
from pymtl3.stdlib.test_utils import (run_sim,
config_model_with_cmdline_opts)
Expand All @@ -21,50 +20,49 @@
from ....lib.opt_type import *
from ....mem.const.ConstQueueRTL import ConstQueueRTL


round_near_even = 0b000

def test_elaborate( cmdline_opts ):
DataType = mk_data( 16, 1 )
PredicateType = mk_predicate( 1, 1 )
def test_elaborate(cmdline_opts):
DataType = mk_data(16, 1)
PredType = mk_predicate(1, 1)
ConfigType = mk_ctrl()
data_mem_size = 8
num_inports = 2
num_outports = 1
FuInType = mk_bits( clog2( num_inports + 1 ) )
pickRegister = [ FuInType( x+1 ) for x in range( num_inports ) ]
src_in0 = [ DataType(1, 1), DataType(7, 1), DataType(4, 1) ]
src_in1 = [ DataType(2, 1), DataType(3, 1), DataType(1, 1) ]
src_predicate = [ PredicateType(1, 0), PredicateType(1, 0), PredicateType(1, 1) ]
src_const = [ DataType(5, 1), DataType(0, 0), DataType(7, 1) ]
FuInType = mk_bits(clog2(num_inports + 1))
pick_register = [ FuInType(x + 1) for x in range(num_inports) ]
src_in0 = [ DataType(1, 1), DataType(7, 1), DataType(4, 1) ]
src_in1 = [ DataType(3, 1), ]
src_predicate = [ PredType(1, 0), PredType(1, 0), PredType(1, 1) ]
src_const = [ DataType(5, 1), DataType(7, 1) ]
sink_out = [ DataType(6, 0), DataType(4, 0), DataType(11, 1) ]
src_opt = [ ConfigType( OPT_ADD_CONST, b1( 1 ), pickRegister ),
ConfigType( OPT_SUB, b1( 1 ), pickRegister ),
ConfigType( OPT_ADD_CONST, b1( 1 ), pickRegister ) ]
dut = FpAddRTL( DataType, PredicateType, ConfigType, num_inports,
num_outports, data_mem_size, exp_nbits = 4, sig_nbits = 11 )
dut = config_model_with_cmdline_opts( dut, cmdline_opts, duts=[] )
src_opt = [ ConfigType( OPT_ADD_CONST, b1(1), pick_register ),
ConfigType( OPT_SUB, b1(1), pick_register ),
ConfigType( OPT_ADD_CONST, b1(1), pick_register ) ]
dut = FpAddRTL(DataType, PredType, ConfigType, num_inports,
num_outports, data_mem_size, exp_nbits = 4, sig_nbits = 11)
dut = config_model_with_cmdline_opts(dut, cmdline_opts, duts = [])

#-------------------------------------------------------------------------
# Test harness
#-------------------------------------------------------------------------

class TestHarness( Component ):
class TestHarness(Component):

def construct( s, FunctionUnit, DataType, PredicateType, ConfigType,
def construct(s, FunctionUnit, DataType, PredType, ConfigType,
num_inports, num_outports, data_mem_size,
src0_msgs, src1_msgs, src_predicate, src_const,
ctrl_msgs, sink_msgs ):
ctrl_msgs, sink_msgs):

s.src_in0 = TestSrcRTL( DataType, src0_msgs )
s.src_in1 = TestSrcRTL( DataType, src1_msgs )
s.src_predicate = TestSrcRTL( PredicateType, src_predicate )
s.src_opt = TestSrcRTL( ConfigType, ctrl_msgs )
s.sink_out = TestSinkRTL( DataType, sink_msgs )
s.src_in0 = TestSrcRTL ( DataType, src0_msgs )
s.src_in1 = TestSrcRTL ( DataType, src1_msgs )
s.src_predicate = TestSrcRTL ( PredType, src_predicate )
s.src_opt = TestSrcRTL ( ConfigType, ctrl_msgs )
s.sink_out = TestSinkRTL( DataType, sink_msgs )

s.const_queue = ConstQueueRTL( DataType, src_const )
s.dut = FunctionUnit( DataType, PredicateType, ConfigType,
num_inports, num_outports, data_mem_size )
s.const_queue = ConstQueueRTL(DataType, src_const)
s.dut = FunctionUnit(DataType, PredType, ConfigType,
num_inports, num_outports, data_mem_size)

connect( s.src_in0.send, s.dut.recv_in[0] )
connect( s.src_in1.send, s.dut.recv_in[1] )
Expand All @@ -73,43 +71,43 @@ def construct( s, FunctionUnit, DataType, PredicateType, ConfigType,
connect( s.src_opt.send, s.dut.recv_opt )
connect( s.dut.send_out[0], s.sink_out.recv )

def done( s ):
def done(s):
return s.src_in0.done() and s.src_in1.done() and \
s.src_opt.done() and s.sink_out.done()

def line_trace( s ):
def line_trace(s):
return s.dut.line_trace()

def mk_float_to_bits_fn( DataType, exp_nbits = 4, sig_nbits = 11 ):
def mk_float_to_bits_fn(DataType, exp_nbits = 4, sig_nbits = 11):
return lambda f_value, predicate: (
DataType( floatToFN( f_value,
precision = 1 + exp_nbits + sig_nbits ),
predicate ) )
DataType(floatToFN(f_value,
precision = 1 + exp_nbits + sig_nbits),
predicate))

def test_add_basic():
FU = FpAddRTL
exp_nbits = 4
sig_nbits = 11
DataType = mk_data( 1 + exp_nbits + sig_nbits, 1 )
f2b = mk_float_to_bits_fn( DataType, exp_nbits, sig_nbits )
PredicateType = mk_predicate( 1, 1 )
DataType = mk_data(1 + exp_nbits + sig_nbits, 1)
f2b = mk_float_to_bits_fn(DataType, exp_nbits, sig_nbits)
PredType = mk_predicate(1, 1)
ConfigType = mk_ctrl()
data_mem_size = 8
num_inports = 2
num_outports = 1
FuInType = mk_bits( clog2( num_inports + 1 ) )
pickRegister = [ FuInType( x+1 ) for x in range( num_inports ) ]
src_in0 = [ f2b(1.1, 1), f2b(7.7, 1), f2b(4.4, 1) ]
src_in1 = [ f2b(2.2, 1), f2b(3.3, 1), f2b(1.1, 1) ]
src_predicate = [ PredicateType(1, 0), PredicateType(1, 0), PredicateType(1, 1) ]
src_const = [ f2b(5.5, 1), f2b(0, 0), f2b(7.7, 1) ]
sink_out = [ f2b(6.602, 0), f2b(4.4, 0), f2b(12.1, 1) ] # 6.6 -> 6.602
src_opt = [ ConfigType( OPT_FADD_CONST, b1( 1 ), pickRegister ),
ConfigType( OPT_FSUB, b1( 1 ), pickRegister ),
ConfigType( OPT_FADD_CONST, b1( 1 ), pickRegister ) ]
th = TestHarness( FU, DataType, PredicateType, ConfigType,
FuInType = mk_bits(clog2(num_inports + 1))
pick_register = [ FuInType(x + 1) for x in range(num_inports) ]
src_predicate = [ PredType(1, 0), PredType(1, 0), PredType(1, 1) ]
src_in0 = [ f2b(1.1, 1), f2b(7.7, 1), f2b(4.4, 1) ]
src_in1 = [ f2b(3.3, 1), ]
src_const = [ f2b(5.5, 1), f2b(7.7, 1) ]
sink_out = [ f2b(6.602, 0), f2b(4.4, 0), f2b(12.1, 1) ] # 6.6 -> 6.602
src_opt = [ ConfigType( OPT_FADD_CONST, b1(1), pick_register ),
ConfigType( OPT_FSUB, b1(1), pick_register ),
ConfigType( OPT_FADD_CONST, b1(1), pick_register ) ]
th = TestHarness( FU, DataType, PredType, ConfigType,
num_inports, num_outports, data_mem_size,
src_in0, src_in1, src_predicate, src_const, src_opt,
sink_out )
run_sim( th )
run_sim(th)

Loading

0 comments on commit 418e753

Please sign in to comment.