diff --git a/src/main/scala/dsptools/numbers/resizer/ChangeWidthTransform.scala b/src/main/scala/dsptools/numbers/resizer/ChangeWidthTransform.scala new file mode 100644 index 00000000..9c923034 --- /dev/null +++ b/src/main/scala/dsptools/numbers/resizer/ChangeWidthTransform.scala @@ -0,0 +1,140 @@ +// See LICENSE for license details. + +package dsptools.numbers.resizer + +import firrtl.{CircuitForm, CircuitState, LowForm, Transform} +import firrtl.annotations.{Annotation, Named} +import firrtl.ir._ +import firrtl.Mappers._ +import logger.{LazyLogging, LogLevel, Logger} + +object ChangeWidthAnnotation { + def apply(target: Named, value: String): Annotation = Annotation(target, classOf[ChangeWidthTransform], value) + + def unapply(a: Annotation): Option[(Named, String)] = a match { + case Annotation(named, t, value) if t == classOf[ChangeWidthTransform] => Some((named, value)) + case _ => None + } +} + +class ChangeWidthTransform extends Transform with LazyLogging { + override def inputForm: CircuitForm = LowForm + override def outputForm: CircuitForm = LowForm + + + def makeChangeRequests(annotations: Seq[Annotation]): Map[String, ChangeRequest] = { + annotations.map { annotation => + val componentName :: widthString :: _ = annotation.value.split("""=""", 2).toList + componentName -> ChangeRequest(componentName, BigInt(widthString, 10)) + }.toMap + } + + //scalastyle:off method.length cyclomatic.complexity + private def run(c: Circuit, changeRequests: Map[String, ChangeRequest]): Circuit = { + def findModule(name: String): DefModule = { + c.modules.find(module => module.name == name) match { + case Some(m: Module) => m + case Some(m: ExtModule) => m + case _ => + throw new Exception(s"Error: could not fine $name in $c") + } + } + + def changeTpe(originalType: Type, newWidth: BigInt): Type = { + originalType match { + case SIntType(IntWidth(n)) => + val newType = SIntType(IntWidth(newWidth)) + logger.info(s"Changing $originalType to $newType") + newType + case UIntType(IntWidth(n)) => + val newType = UIntType(IntWidth(newWidth)) + logger.info(s"Changing $originalType to $newType") + newType + case other => other + } + } + + def changeWidthsInModule(module: Module, pathString: String = ""): Module = { + def expand(name: String): String = { + if(pathString.isEmpty) { + name + } + else { + pathString + "." + name + } + } + + def shouldChange(name: String): Boolean = { + changeRequests.contains(name) + } + + def annotationToWidth(annotation: Annotation): Width = { + //TODO (chick) complete this + IntWidth(32) + } + + def changeWidthsInExpression(expression: Expression): Expression = { + expression + } + + def changeWidthInPorts(ports: Seq[Port]): Seq[Port] = { + ports.map { port => + changeRequests.get(expand(port.name)) match { + case Some(changeRequest) => + port.copy(tpe = changeTpe(port.tpe, changeRequest.newWidth)) + case _ => + port + } + } + } + + def changeWidthsInStatement(statement: Statement): Statement = { + val resultStatement = statement map changeWidthsInStatement map changeWidthsInExpression + resultStatement match { + case register: DefRegister => + changeRequests.get(expand(register.name)) match { + case Some(changeReqest) => + register.copy(tpe = changeTpe(register.tpe, changeReqest.newWidth)) + case _ => register + } + case wire: DefWire => + changeRequests.get(expand(wire.name)) match { + case Some(changeReqest) => + wire.copy(tpe = changeTpe(wire.tpe, changeReqest.newWidth)) + case _ => wire + } + case instance: DefInstance => findModule(instance.module) match { + case m: ExtModule => instance + case m: Module => + changeWidthsInModule(m, s"$pathString.${module.name}.") + instance + } + case otherStatement => otherStatement + } + } + + module.copy( + ports = changeWidthInPorts(module.ports), + body = changeWidthsInStatement(module.body) + ) + } + + val modulesx = c.modules.map { + case m: ExtModule => m + case m: Module => changeWidthsInModule(m) + } + Circuit(c.info, modulesx, c.main) + } + + override def execute(state: CircuitState): CircuitState = { + Logger.setLevel(LogLevel.Debug) + getMyAnnotations(state) match { + case Nil => state + case myAnnotations => + val changeRequests = makeChangeRequests(myAnnotations) + state.copy(circuit = run(state.circuit, changeRequests)) + } + } +} + +case class ChangeRequest(name: String, newWidth: BigInt) diff --git a/src/test/scala/dsptools/resizer/ChangeWidthTransformSpec.scala b/src/test/scala/dsptools/resizer/ChangeWidthTransformSpec.scala new file mode 100644 index 00000000..7638e62c --- /dev/null +++ b/src/test/scala/dsptools/resizer/ChangeWidthTransformSpec.scala @@ -0,0 +1,58 @@ +// See LICENSE for license details. + +package dsptools.resizer + +import dsptools.numbers.resizer.ChangeWidthTransform +import firrtl.annotations.{Annotation, CircuitName, ComponentName, ModuleName} +import firrtl.{AnnotationMap, CircuitState, LowForm, Parser} +import org.scalatest.{FreeSpec, Matchers} + +class ChangeWidthTransformSpec extends FreeSpec with Matchers { + """parse a firrtl file and change the widths""" in { + val input = + """ + |circuit InstrumentingAdder : @[:@2.0] + | module InstrumentingAdder : @[:@3.2] + | input clock : Clock @[:@4.4] + | input reset : UInt<1> @[:@5.4] + | input io_a1 : SInt<32> @[:@6.4] + | input io_a2 : SInt<32> @[:@6.4] + | output io_c : SInt<32> @[:@6.4] + | + | reg register1 : SInt<32>, clock with : + | reset => (UInt<1>("h0"), register1) @[InstrumentingSpec.scala 20:22:@11.4] + | node _T_6 = add(io_a1, io_a2) @[FixedPointTypeClass.scala 21:58:@12.4] + | node _T_7 = tail(_T_6, 1) @[FixedPointTypeClass.scala 21:58:@13.4] + | node _T_8 = asSInt(_T_7) @[FixedPointTypeClass.scala 21:58:@14.4] + | io_c <= register1 + | register1 <= _T_8 + | + """.stripMargin + + val annotations = AnnotationMap(Seq( + Annotation( + ComponentName("io_a1", ModuleName("InstrumentingAdder", CircuitName("InstrumentingAdder"))), + classOf[ChangeWidthTransform], + "io_a1=16" + ), + Annotation( + ComponentName("io_a1", ModuleName("register1", CircuitName("InstrumentingAdder"))), + classOf[ChangeWidthTransform], + "register1=8" + ) + )) + + val circuitState = CircuitState(Parser.parse(input), LowForm, Some(annotations)) + + val transform = new ChangeWidthTransform + + val newCircuitState = transform.execute(circuitState) + + val newFirrtlString = newCircuitState.circuit.serialize + + newFirrtlString should include ("input io_a1 : SInt<16>") + newFirrtlString should include ("register1 : SInt<8>") + + println(s"After ChangeWidthTransform\n$newFirrtlString") + } +}