From e63e0231e0b5e18a168f45d62207bcb7304c8ffc Mon Sep 17 00:00:00 2001 From: chick Date: Wed, 25 Oct 2017 18:54:24 -0700 Subject: [PATCH] Update node widths with Firrtl Transform Working transform that will adjust widths of registers, ports and wires through the annotation system. Second piece of the augmented tool chain that will ultimately take advantage of firrt interpreters instrumentation output. Adjusting widths according to data gathered thereby. Part of Issue #114 --- .../resizer/ChangeWidthTransform.scala | 140 ++++++++++++++++++ .../resizer/ChangeWidthTransformSpec.scala | 58 ++++++++ 2 files changed, 198 insertions(+) create mode 100644 src/main/scala/dsptools/numbers/resizer/ChangeWidthTransform.scala create mode 100644 src/test/scala/dsptools/resizer/ChangeWidthTransformSpec.scala diff --git a/src/main/scala/dsptools/numbers/resizer/ChangeWidthTransform.scala b/src/main/scala/dsptools/numbers/resizer/ChangeWidthTransform.scala new file mode 100644 index 00000000..9c923034 --- /dev/null +++ b/src/main/scala/dsptools/numbers/resizer/ChangeWidthTransform.scala @@ -0,0 +1,140 @@ +// See LICENSE for license details. + +package dsptools.numbers.resizer + +import firrtl.{CircuitForm, CircuitState, LowForm, Transform} +import firrtl.annotations.{Annotation, Named} +import firrtl.ir._ +import firrtl.Mappers._ +import logger.{LazyLogging, LogLevel, Logger} + +object ChangeWidthAnnotation { + def apply(target: Named, value: String): Annotation = Annotation(target, classOf[ChangeWidthTransform], value) + + def unapply(a: Annotation): Option[(Named, String)] = a match { + case Annotation(named, t, value) if t == classOf[ChangeWidthTransform] => Some((named, value)) + case _ => None + } +} + +class ChangeWidthTransform extends Transform with LazyLogging { + override def inputForm: CircuitForm = LowForm + override def outputForm: CircuitForm = LowForm + + + def makeChangeRequests(annotations: Seq[Annotation]): Map[String, ChangeRequest] = { + annotations.map { annotation => + val componentName :: widthString :: _ = annotation.value.split("""=""", 2).toList + componentName -> ChangeRequest(componentName, BigInt(widthString, 10)) + }.toMap + } + + //scalastyle:off method.length cyclomatic.complexity + private def run(c: Circuit, changeRequests: Map[String, ChangeRequest]): Circuit = { + def findModule(name: String): DefModule = { + c.modules.find(module => module.name == name) match { + case Some(m: Module) => m + case Some(m: ExtModule) => m + case _ => + throw new Exception(s"Error: could not fine $name in $c") + } + } + + def changeTpe(originalType: Type, newWidth: BigInt): Type = { + originalType match { + case SIntType(IntWidth(n)) => + val newType = SIntType(IntWidth(newWidth)) + logger.info(s"Changing $originalType to $newType") + newType + case UIntType(IntWidth(n)) => + val newType = UIntType(IntWidth(newWidth)) + logger.info(s"Changing $originalType to $newType") + newType + case other => other + } + } + + def changeWidthsInModule(module: Module, pathString: String = ""): Module = { + def expand(name: String): String = { + if(pathString.isEmpty) { + name + } + else { + pathString + "." + name + } + } + + def shouldChange(name: String): Boolean = { + changeRequests.contains(name) + } + + def annotationToWidth(annotation: Annotation): Width = { + //TODO (chick) complete this + IntWidth(32) + } + + def changeWidthsInExpression(expression: Expression): Expression = { + expression + } + + def changeWidthInPorts(ports: Seq[Port]): Seq[Port] = { + ports.map { port => + changeRequests.get(expand(port.name)) match { + case Some(changeRequest) => + port.copy(tpe = changeTpe(port.tpe, changeRequest.newWidth)) + case _ => + port + } + } + } + + def changeWidthsInStatement(statement: Statement): Statement = { + val resultStatement = statement map changeWidthsInStatement map changeWidthsInExpression + resultStatement match { + case register: DefRegister => + changeRequests.get(expand(register.name)) match { + case Some(changeReqest) => + register.copy(tpe = changeTpe(register.tpe, changeReqest.newWidth)) + case _ => register + } + case wire: DefWire => + changeRequests.get(expand(wire.name)) match { + case Some(changeReqest) => + wire.copy(tpe = changeTpe(wire.tpe, changeReqest.newWidth)) + case _ => wire + } + case instance: DefInstance => findModule(instance.module) match { + case m: ExtModule => instance + case m: Module => + changeWidthsInModule(m, s"$pathString.${module.name}.") + instance + } + case otherStatement => otherStatement + } + } + + module.copy( + ports = changeWidthInPorts(module.ports), + body = changeWidthsInStatement(module.body) + ) + } + + val modulesx = c.modules.map { + case m: ExtModule => m + case m: Module => changeWidthsInModule(m) + } + Circuit(c.info, modulesx, c.main) + } + + override def execute(state: CircuitState): CircuitState = { + Logger.setLevel(LogLevel.Debug) + getMyAnnotations(state) match { + case Nil => state + case myAnnotations => + val changeRequests = makeChangeRequests(myAnnotations) + state.copy(circuit = run(state.circuit, changeRequests)) + } + } +} + +case class ChangeRequest(name: String, newWidth: BigInt) diff --git a/src/test/scala/dsptools/resizer/ChangeWidthTransformSpec.scala b/src/test/scala/dsptools/resizer/ChangeWidthTransformSpec.scala new file mode 100644 index 00000000..7638e62c --- /dev/null +++ b/src/test/scala/dsptools/resizer/ChangeWidthTransformSpec.scala @@ -0,0 +1,58 @@ +// See LICENSE for license details. + +package dsptools.resizer + +import dsptools.numbers.resizer.ChangeWidthTransform +import firrtl.annotations.{Annotation, CircuitName, ComponentName, ModuleName} +import firrtl.{AnnotationMap, CircuitState, LowForm, Parser} +import org.scalatest.{FreeSpec, Matchers} + +class ChangeWidthTransformSpec extends FreeSpec with Matchers { + """parse a firrtl file and change the widths""" in { + val input = + """ + |circuit InstrumentingAdder : @[:@2.0] + | module InstrumentingAdder : @[:@3.2] + | input clock : Clock @[:@4.4] + | input reset : UInt<1> @[:@5.4] + | input io_a1 : SInt<32> @[:@6.4] + | input io_a2 : SInt<32> @[:@6.4] + | output io_c : SInt<32> @[:@6.4] + | + | reg register1 : SInt<32>, clock with : + | reset => (UInt<1>("h0"), register1) @[InstrumentingSpec.scala 20:22:@11.4] + | node _T_6 = add(io_a1, io_a2) @[FixedPointTypeClass.scala 21:58:@12.4] + | node _T_7 = tail(_T_6, 1) @[FixedPointTypeClass.scala 21:58:@13.4] + | node _T_8 = asSInt(_T_7) @[FixedPointTypeClass.scala 21:58:@14.4] + | io_c <= register1 + | register1 <= _T_8 + | + """.stripMargin + + val annotations = AnnotationMap(Seq( + Annotation( + ComponentName("io_a1", ModuleName("InstrumentingAdder", CircuitName("InstrumentingAdder"))), + classOf[ChangeWidthTransform], + "io_a1=16" + ), + Annotation( + ComponentName("io_a1", ModuleName("register1", CircuitName("InstrumentingAdder"))), + classOf[ChangeWidthTransform], + "register1=8" + ) + )) + + val circuitState = CircuitState(Parser.parse(input), LowForm, Some(annotations)) + + val transform = new ChangeWidthTransform + + val newCircuitState = transform.execute(circuitState) + + val newFirrtlString = newCircuitState.circuit.serialize + + newFirrtlString should include ("input io_a1 : SInt<16>") + newFirrtlString should include ("register1 : SInt<8>") + + println(s"After ChangeWidthTransform\n$newFirrtlString") + } +}