Skip to content
This repository has been archived by the owner on Aug 20, 2024. It is now read-only.

Add a data structure for memory conf reading and writing #1041

Merged
merged 2 commits into from
Mar 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions src/main/scala/firrtl/passes/memlib/MemConf.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// See LICENSE for license details.

package firrtl.passes
package memlib

import scala.util.matching._

sealed abstract class MemPort(val name: String) { override def toString = name }

case object ReadPort extends MemPort("read")
case object WritePort extends MemPort("write")
case object MaskedWritePort extends MemPort("mwrite")
case object ReadWritePort extends MemPort("rw")
case object MaskedReadWritePort extends MemPort("mrw")

object MemPort {

val all = Set(ReadPort, WritePort, MaskedWritePort, ReadWritePort, MaskedReadWritePort)

def apply(s: String): Option[MemPort] = MemPort.all.find(_.name == s)

def fromString(s: String): Map[MemPort, Int] = {
s.split(",").toSeq.map(MemPort.apply).map(_ match {
case Some(x) => x
case _ => throw new Exception(s"Error parsing MemPort string : ${s}")
}).groupBy(identity).mapValues(_.size)
}
}

case class MemConf(
name: String,
depth: Int,
width: Int,
ports: Map[MemPort, Int],
maskGranularity: Option[Int]
) {

private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") } mkString (",")
private def maskGranStr = maskGranularity.map((p) => s"mask_gran $p").getOrElse("")

// Assert that all of the entries in the port map are greater than zero to make it easier to compare two of these case classes
// (otherwise an entry of XYZPort -> 0 would not be equivalent to another with no XYZPort despite being semantically the same)
ports.foreach { case (k, v) => require(v > 0, "Cannot have negative or zero entry in the port map") }

override def toString = s"name ${name} depth ${depth} width ${width} ports ${portsStr} ${maskGranStr} \n"
}

object MemConf {

val regex = raw"\s*name\s+(\w+)\s+depth\s+(\d+)\s+width\s+(\d+)\s+ports\s+([^\s]+)\s+(?:mask_gran\s+(\d+))?\s*".r

def fromString(s: String): Seq[MemConf] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scaladoc indicating that this takes in a string containing multiple lines delimited by \n?

s.split("\n").toSeq.map(_ match {
case MemConf.regex(name, depth, width, ports, maskGran) => MemConf(name, depth.toInt, width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt))
case _ => throw new Exception(s"Error parsing MemConf string : ${s}")
})
}

def apply(name: String, depth: Int, width: Int, readPorts: Int, writePorts: Int, readWritePorts: Int, maskGranularity: Option[Int]): MemConf = {
val ports: Map[MemPort, Int] = (if (maskGranularity.isEmpty) {
(if (writePorts == 0) Map.empty[MemPort, Int] else Map(WritePort -> writePorts)) ++
(if (readWritePorts == 0) Map.empty[MemPort, Int] else Map(ReadWritePort -> readWritePorts))
} else {
(if (writePorts == 0) Map.empty[MemPort, Int] else Map(MaskedWritePort -> writePorts)) ++
(if (readWritePorts == 0) Map.empty[MemPort, Int] else Map(MaskedReadWritePort -> readWritePorts))
}) ++ (if (readPorts == 0) Map.empty[MemPort, Int] else Map(ReadPort -> readPorts))
return new MemConf(name, depth, width, ports, maskGranularity)
}
}
14 changes: 5 additions & 9 deletions src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,11 @@ class ConfWriter(filename: String) {
val outputBuffer = new CharArrayWriter
def append(m: DefAnnotatedMemory) = {
// legacy
val maskGran = m.maskGran
val readers = List.fill(m.readers.length)("read")
val writers = List.fill(m.writers.length)(if (maskGran.isEmpty) "write" else "mwrite")
val readwriters = List.fill(m.readwriters.length)(if (maskGran.isEmpty) "rw" else "mrw")
val ports = (writers ++ readers ++ readwriters) mkString ","
val maskGranConf = maskGran match { case None => "" case Some(p) => s"mask_gran $p" }
val width = bitWidth(m.dataType)
val conf = s"name ${m.name} depth ${m.depth} width $width ports $ports $maskGranConf \n"
outputBuffer.append(conf)
// assert that we don't overflow going from BigInt to Int conversion
require(bitWidth(m.dataType) <= Int.MaxValue)
m.maskGran.foreach { case x => require(x <= Int.MaxValue) }
val conf = MemConf(m.name, m.depth, bitWidth(m.dataType).toInt, m.readers.length, m.writers.length, m.readwriters.length, m.maskGran.map(_.toInt))
outputBuffer.append(conf.toString)
}
def serialize() = {
val outputFile = new PrintWriter(filename)
Expand Down
61 changes: 57 additions & 4 deletions src/test/scala/firrtlTests/ReplSeqMemTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ class ReplSeqMemSpec extends SimpleTransformSpec {
}
)

def checkMemConf(filename: String, mems: Set[MemConf]) {
// Read the mem conf
val file = scala.io.Source.fromFile(filename)
val text = try file.mkString finally file.close()
// Verify that this does not throw an exception
val fromConf = MemConf.fromString(text)
// Verify the mems in the conf are the same as the expected ones
require(Set(fromConf: _*) == mems, "Parsed conf set:\n {\n " + fromConf.mkString(" ") + " }\n must be the same as reference conf set: \n {\n " + mems.toSeq.mkString(" ") + " }\n")
}

"ReplSeqMem" should "generate blackbox wrappers for mems of bundle type" in {
val input = """
circuit Top :
Expand Down Expand Up @@ -63,11 +73,17 @@ circuit Top :
read mport R1 = entries_info2[head_ptr], clock
io2.commit_entry.bits.info <- R1
""".stripMargin
val mems = Set(
MemConf("entries_info_ext", 24, 30, Map(WritePort -> 1, ReadPort -> 1), None),
MemConf("entries_info2_ext", 24, 30, Map(MaskedWritePort -> 1, ReadPort -> 1), Some(10))
)
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
parse(res.getEmittedCircuit.value)
// Check the emitted conf
checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}

Expand All @@ -85,11 +101,14 @@ circuit Top :
when p_valid :
write mport T_155 = mem[p_address], clock
""".stripMargin
val mems = Set(MemConf("mem_ext", 32, 64, Map(MaskedWritePort -> 1), Some(64)))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
parse(res.getEmittedCircuit.value)
// Check the emitted conf
checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}

Expand All @@ -110,11 +129,14 @@ circuit CustomMemory :
_T_18 <= io.dI
skip
""".stripMargin
val mems = Set(MemConf("mem_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
parse(res.getEmittedCircuit.value)
// Check the emitted conf
checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}

Expand All @@ -135,11 +157,14 @@ circuit CustomMemory :
_T_18 <= io.dI
skip
""".stripMargin
val mems = Set(MemConf("mem_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
parse(res.getEmittedCircuit.value)
// Check the emitted conf
checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}

Expand Down Expand Up @@ -188,6 +213,7 @@ circuit Top :
tests foreach { case(hurdle, origin) => checkConnectOrigin(hurdle, origin) }

}

"ReplSeqMem" should "not de-duplicate memories with the nodedupe annotation " in {
val input = """
circuit CustomMemory :
Expand All @@ -209,6 +235,10 @@ circuit CustomMemory :
_T_20 <= io.dI
skip
"""
val mems = Set(
MemConf("mem_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None),
MemConf("mem_1_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)
)
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(
ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc),
Expand All @@ -221,6 +251,8 @@ circuit CustomMemory :
case _ => false
}
numExtMods should be (2)
// Check the emitted conf
checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}

Expand Down Expand Up @@ -249,6 +281,10 @@ circuit CustomMemory :
_T_22 <= io.dI
skip
"""
val mems = Set(
MemConf("mem_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None),
MemConf("mem_1_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)
)
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(
ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc),
Expand All @@ -261,6 +297,8 @@ circuit CustomMemory :
case _ => false
}
numExtMods should be (2)
// Check the emitted conf
checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}

Expand Down Expand Up @@ -300,6 +338,10 @@ circuit CustomMemory :
w1 <= io.dI
w2 <= io.dI
"""
val mems = Set(
MemConf("mem_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None),
MemConf("mem_0_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)
)
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(
ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc),
Expand All @@ -316,6 +358,8 @@ circuit CustomMemory :
// If the NoDedupMemAnnotation were handled incorrectly as it was prior to this test, there
// would be 3 ExtModules
numExtMods should be (2)
// Check the emitted conf
checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}

Expand All @@ -340,6 +384,7 @@ circuit CustomMemory :
_T_20 <= io.dI
skip
"""
val mems = Set(MemConf("mem_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
Expand All @@ -353,7 +398,7 @@ circuit CustomMemory :
(new java.io.File(confLoc)).delete()
}

"ReplSeqMem" should "should not have a mask if there is none" in {
"ReplSeqMem" should "not have a mask if there is none" in {
val input = """
circuit CustomMemory :
module CustomMemory :
Expand All @@ -368,14 +413,17 @@ circuit CustomMemory :
write mport w = mem[io.waddr], clock
w <= io.wdata
"""
val mems = Set(MemConf("mem_ext", 1024, 16, Map(WritePort -> 1, ReadPort -> 1), None))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
res.getEmittedCircuit.value shouldNot include ("mask")
// Check the emitted conf
checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}

"ReplSeqMem" should "should not conjoin enable signal with mask condition" in {
"ReplSeqMem" should "not conjoin enable signal with mask condition" in {
val input = """
circuit CustomMemory :
module CustomMemory :
Expand All @@ -393,16 +441,19 @@ circuit CustomMemory :
when io.mask[1] :
w[1] <= io.wdata[1]
"""
val mems = Set(MemConf("mem_ext", 1024, 16, Map(MaskedWritePort -> 1, ReadPort -> 1), Some(8)))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask
res should containLine ("mem.W0_mask_0 <= validif(io_en, io_mask_0)")
res should containLine ("mem.W0_mask_1 <= validif(io_en, io_mask_1)")
// Check the emitted conf
checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}

"ReplSeqMem" should "should not conjoin enable signal with wmask condition (RW Port)" in {
"ReplSeqMem" should "not conjoin enable signal with wmask condition (RW Port)" in {
val input = """
circuit CustomMemory :
module CustomMemory :
Expand All @@ -424,16 +475,18 @@ circuit CustomMemory :
io.out <= r

"""
val mems = Set(MemConf("mem_ext", 1024, 16, Map(MaskedReadWritePort -> 1), Some(8)))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc),
InferReadWriteAnnotation)
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask
res should containLine ("mem.RW0_wmask_0 <= validif(io_en, io_mask_0)")
res should containLine ("mem.RW0_wmask_1 <= validif(io_en, io_mask_1)")
// Check the emitted conf
checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}
}

// TODO: make more checks
// conf