Skip to content

Commit

Permalink
Merge pull request #1183 from freechipsproject/mem-read-under-write
Browse files Browse the repository at this point in the history
Implement read-first memory behavior in Verilog
  • Loading branch information
albert-magyar authored Oct 1, 2019
2 parents 4ca2b85 + 082bc99 commit 1ced6cf
Show file tree
Hide file tree
Showing 15 changed files with 197 additions and 50 deletions.
2 changes: 1 addition & 1 deletion src/main/antlr4/FIRRTL.g4
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ stmt
| 'reg' id ':' type exp ('with' ':' reset_block)? info?
| 'mem' id ':' info? INDENT memField* DEDENT
| 'cmem' id ':' type info?
| 'smem' id ':' type info?
| 'smem' id ':' type ruw? info?
| mdir 'mport' id '=' id '[' exp ']' exp info?
| 'inst' id 'of' id info?
| 'node' id '=' exp info?
Expand Down
8 changes: 8 additions & 0 deletions src/main/proto/firrtl.proto
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ message Firrtl {
Expression init = 5;
}

enum ReadUnderWrite {
UNDEFINED = 0;
OLD = 1;
NEW = 2;
}

message Memory {
// Required.
string id = 1;
Expand All @@ -121,6 +127,7 @@ message Firrtl {
repeated string reader_id = 6;
repeated string writer_id = 7;
repeated string readwriter_id = 8;
ReadUnderWrite read_under_write = 10;
}

message CMemory {
Expand All @@ -138,6 +145,7 @@ message Firrtl {
}
// Required.
bool sync_read = 3;
ReadUnderWrite read_under_write = 5;
}

message Instance {
Expand Down
21 changes: 15 additions & 6 deletions src/main/scala/firrtl/Visitor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,23 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
private def visitSuite(ctx: SuiteContext): Statement =
Block(ctx.simple_stmt().asScala.flatMap(x => Option(x.stmt).map(visitStmt)))

private def visitRuw(ctx: Option[RuwContext]): ReadUnderWrite.Value = ctx match {
case None => ReadUnderWrite.Undefined
case Some(ctx) => ctx.getText match {
case "undefined" => ReadUnderWrite.Undefined
case "old" => ReadUnderWrite.Old
case "new" => ReadUnderWrite.New
}
}

// Memories are fairly complicated to translate thus have a dedicated method
private def visitMem(ctx: StmtContext): Statement = {
val readers = mutable.ArrayBuffer.empty[String]
val writers = mutable.ArrayBuffer.empty[String]
val readwriters = mutable.ArrayBuffer.empty[String]
case class ParamValue(typ: Option[Type] = None, lit: Option[BigInt] = None, ruw: Option[String] = None, unique: Boolean = true)
case class ParamValue(typ: Option[Type] = None, lit: Option[BigInt] = None, ruw: ReadUnderWrite.Value = ReadUnderWrite.Undefined, unique: Boolean = true)
val fieldMap = mutable.HashMap[String, ParamValue]()

val memName = ctx.id(0).getText
def parseMemFields(memFields: Seq[MemFieldContext]): Unit =
memFields.foreach { field =>
val fieldName = field.children.asScala(0).getText
Expand All @@ -184,7 +192,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
case _ =>
val paramDef = fieldName match {
case "data-type" => ParamValue(typ = Some(visitType(field.`type`())))
case "read-under-write" => ParamValue(ruw = Some(field.ruw().getText)) // TODO
case "read-under-write" => ParamValue(ruw = visitRuw(Option(field.ruw)))
case _ => ParamValue(lit = Some(BigInt(field.intLit().getText)))
}
if (fieldMap.contains(fieldName))
Expand All @@ -210,10 +218,11 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
}

def lit(param: String) = fieldMap(param).lit.get
val ruw = fieldMap.get("read-under-write").map(_.ruw).getOrElse(None)
val ruw = fieldMap.get("read-under-write").map(_.ruw).getOrElse(ir.ReadUnderWrite.Undefined)

DefMemory(info,
name = ctx.id(0).getText, dataType = fieldMap("data-type").typ.get,
name = memName,
dataType = fieldMap("data-type").typ.get,
depth = lit("depth"),
writeLatency = lit("write-latency").toInt,
readLatency = lit("read-latency").toInt,
Expand Down Expand Up @@ -269,7 +278,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
CDefMemory(info, ctx.id(0).getText, tpe, size, seq = false)
case "smem" =>
val (tpe, size) = visitCMemType(ctx.`type`())
CDefMemory(info, ctx.id(0).getText, tpe, size, seq = true)
CDefMemory(info, ctx.id(0).getText, tpe, size, seq = true, readUnderWrite = visitRuw(Option(ctx.ruw)))
case "inst" => DefInstance(info, ctx.id(0).getText, ctx.id(1).getText)
case "node" => DefNode(info, ctx.id(0).getText, visitExp(ctx_exp(0)))

Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/firrtl/WIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,8 @@ case class CDefMemory(
name: String,
tpe: Type,
size: BigInt,
seq: Boolean) extends Statement with HasInfo {
seq: Boolean,
readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined) extends Statement with HasInfo {
def serialize: String = (if (seq) "smem" else "cmem") +
s" $name : ${tpe.serialize} [$size]" + info.serialize
def mapExpr(f: Expression => Expression): Statement = this
Expand Down
11 changes: 9 additions & 2 deletions src/main/scala/firrtl/ir/IR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,13 @@ case class DefInstance(info: Info, name: String, module: String) extends Stateme
def foreachString(f: String => Unit): Unit = f(name)
def foreachInfo(f: Info => Unit): Unit = f(info)
}

object ReadUnderWrite extends Enumeration {
val Undefined = Value("undefined")
val Old = Value("old")
val New = Value("new")
}

case class DefMemory(
info: Info,
name: String,
Expand All @@ -296,7 +303,7 @@ case class DefMemory(
writers: Seq[String],
readwriters: Seq[String],
// TODO: handle read-under-write
readUnderWrite: Option[String] = None) extends Statement with IsDeclaration {
readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined) extends Statement with IsDeclaration {
def serialize: String =
s"mem $name :" + info.serialize +
indent(
Expand All @@ -307,7 +314,7 @@ case class DefMemory(
(readers map ("reader => " + _)) ++
(writers map ("writer => " + _)) ++
(readwriters map ("readwriter => " + _)) ++
Seq("read-under-write => undefined")) mkString "\n")
Seq(s"read-under-write => ${readUnderWrite}")) mkString "\n")
def mapStmt(f: Statement => Statement): Statement = this
def mapExpr(f: Expression => Expression): Statement = this
def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ object RemoveCHIRRTL extends Transform {
set_enable(rws, "en") ++
set_write(rws, "wdata", "wmask")
val mem = DefMemory(sx.info, sx.name, sx.tpe, sx.size, 1, if (sx.seq) 1 else 0,
rds map (_.name), wrs map (_.name), rws map (_.name))
rds map (_.name), wrs map (_.name), rws map (_.name), sx.readUnderWrite)
Block(mem +: stmts)
case sx: CDefMPort =>
types.get(sx.mem) match {
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/firrtl/passes/memlib/MemIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ case class DefAnnotatedMemory(
readers: Seq[String],
writers: Seq[String],
readwriters: Seq[String],
readUnderWrite: Option[String],
readUnderWrite: ReadUnderWrite.Value,
maskGran: Option[BigInt],
memRef: Option[(String, String)] /* (Module, Mem) */
//pins: Seq[Pin],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,6 @@ object MemTransformUtils {
}

def defaultPortSeq(mem: DefAnnotatedMemory): Seq[Field] = MemPortUtils.defaultPortSeq(mem.toMem)
def memPortField(s: DefAnnotatedMemory, p: String, f: String): Expression =
def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField =
MemPortUtils.memPortField(s.toMem, p, f)
}
2 changes: 1 addition & 1 deletion src/main/scala/firrtl/passes/memlib/MemUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ object MemPortUtils {
(mem.readwriters map (Field(_, Flip, rwType))))
}

def memPortField(s: DefMemory, p: String, f: String): Expression = {
def memPortField(s: DefMemory, p: String, f: String): WSubField = {
val mem = WRef(s.name, memType(s), MemKind, UnknownFlow)
val t1 = field_type(mem.tpe, p)
val t2 = field_type(t1, f)
Expand Down
7 changes: 4 additions & 3 deletions src/main/scala/firrtl/passes/memlib/ToMemIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ import firrtl.ir._
* - read latency and write latency of one
* - only one readwrite port or write port
* - zero or one read port
* - undefined read-under-write behavior
*/
object ToMemIR extends Pass {
/** Only annotate memories that are candidates for memory macro replacements
* i.e. rw, w + r (read, write 1 cycle delay)
* i.e. rw, w + r (read, write 1 cycle delay) and read-under-write "undefined."
*/
import ReadUnderWrite._
def updateStmts(s: Statement): Statement = s match {
case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 &&
(m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 =>
case m @ DefMemory(_,_,_,_,1,1,r,w,rw,Undefined) if (w.length + rw.length) == 1 && r.length <= 1 =>
DefAnnotatedMemory(m)
case sx => sx map updateStmts
}
Expand Down
99 changes: 70 additions & 29 deletions src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@ import MemPortUtils._

import collection.mutable

object DelayPipe {
private case class PipeState(ref: Expression, decl: Statement = EmptyStmt, connect: Statement = EmptyStmt, idx: Int = 0)

def apply(ns: Namespace)(e: Expression, delay: Int, clock: Expression): (Expression, Seq[Statement]) = {
def addStage(prev: PipeState): PipeState = {
val idx = prev.idx + 1
val name = ns.newName(s"${e.serialize}_r${idx}".replace('.', '_'))
val regRef = WRef(name, e.tpe, RegKind)
val regDecl = DefRegister(NoInfo, name, e.tpe, clock, zero, regRef)
PipeState(regRef, regDecl, Connect(NoInfo, regRef, prev.ref), idx)
}
val pipeline = Seq.iterate(PipeState(e), delay+1)(addStage)
(pipeline.last.ref, pipeline.map(_.decl) ++ pipeline.map(_.connect))
}
}

/** This pass generates delay reigsters for memories for verilog */
object VerilogMemDelays extends Pass {
val ug = UnknownFlow
Expand Down Expand Up @@ -49,7 +65,7 @@ object VerilogMemDelays extends Pass {
readers = sx.readers ++ (sx.readwriters map (rw => rwMap(rw)._1)),
writers = sx.writers ++ (sx.readwriters map (rw => rwMap(rw)._2)),
readwriters = Nil, readLatency = 0, writeLatency = 1)
def pipe(e: Expression, // Expression to be piped
def prependPipe(e: Expression, // Expression to be piped
n: Int, // pipe depth
clk: Expression, // clock expression
cond: Expression // condition for pipes
Expand Down Expand Up @@ -96,40 +112,69 @@ object VerilogMemDelays extends Pass {
)

stmts ++= ((sx.readers flatMap {reader =>
// generate latency pipes for read ports (enable & addr)
val clk = netlist(memPortField(sx, reader, "clk"))
val (en, ss1) = pipe(memPortField(sx, reader, "en"), sx.readLatency - 1, clk, one)
val (addr, ss2) = pipe(memPortField(sx, reader, "addr"), sx.readLatency, clk, en)
ss1 ++ ss2 ++ readPortConnects(reader, clk, en, addr)
if (sx.readUnderWrite == ReadUnderWrite.Old) {
// For a read-first ("old") mem, read data gets delayed, so don't delay read address/en
val rdata = memPortField(sx, reader, "data")
val enDriver = netlist(memPortField(sx, reader, "en"))
val addrDriver = netlist(memPortField(sx, reader, "addr"))
readPortConnects(reader, clk, enDriver, addrDriver)
} else {
// For a write-first ("new") or undefined mem, delay read control inputs
val (en, ss1) = prependPipe(memPortField(sx, reader, "en"), sx.readLatency - 1, clk, one)
val (addr, ss2) = prependPipe(memPortField(sx, reader, "addr"), sx.readLatency, clk, en)
ss1 ++ ss2 ++ readPortConnects(reader, clk, en, addr)
}
}) ++ (sx.writers flatMap {writer =>
// generate latency pipes for write ports (enable, mask, addr, data)
val clk = netlist(memPortField(sx, writer, "clk"))
val (en, ss1) = pipe(memPortField(sx, writer, "en"), sx.writeLatency - 1, clk, one)
val (mask, ss2) = pipe(memPortField(sx, writer, "mask"), sx.writeLatency - 1, clk, one)
val (addr, ss3) = pipe(memPortField(sx, writer, "addr"), sx.writeLatency - 1, clk, one)
val (data, ss4) = pipe(memPortField(sx, writer, "data"), sx.writeLatency - 1, clk, one)
val (en, ss1) = prependPipe(memPortField(sx, writer, "en"), sx.writeLatency - 1, clk, one)
val (mask, ss2) = prependPipe(memPortField(sx, writer, "mask"), sx.writeLatency - 1, clk, one)
val (addr, ss3) = prependPipe(memPortField(sx, writer, "addr"), sx.writeLatency - 1, clk, one)
val (data, ss4) = prependPipe(memPortField(sx, writer, "data"), sx.writeLatency - 1, clk, one)
ss1 ++ ss2 ++ ss3 ++ ss4 ++ writePortConnects(writer, clk, en, mask, addr, data)
}) ++ (sx.readwriters flatMap {readwriter =>
val (reader, writer) = rwMap(readwriter)
val clk = netlist(memPortField(sx, readwriter, "clk"))
// generate latency pipes for readwrite ports (enable, addr, wmode, wmask, wdata)
val (en, ss1) = pipe(memPortField(sx, readwriter, "en"), sx.readLatency - 1, clk, one)
val (wmode, ss2) = pipe(memPortField(sx, readwriter, "wmode"), sx.writeLatency - 1, clk, one)
val (wmask, ss3) = pipe(memPortField(sx, readwriter, "wmask"), sx.writeLatency - 1, clk, one)
val (wdata, ss4) = pipe(memPortField(sx, readwriter, "wdata"), sx.writeLatency - 1, clk, one)
val (raddr, ss5) = pipe(memPortField(sx, readwriter, "addr"), sx.readLatency, clk, AND(en, NOT(wmode)))
val (waddr, ss6) = pipe(memPortField(sx, readwriter, "addr"), sx.writeLatency - 1, clk, one)
repl(memPortField(sx, readwriter, "rdata")) = memPortField(mem, reader, "data")
ss1 ++ ss2 ++ ss3 ++ ss4 ++ ss5 ++ ss6 ++
readPortConnects(reader, clk, en, raddr) ++
writePortConnects(writer, clk, AND(en, wmode), wmask, waddr, wdata)
val (en, ss1) = prependPipe(memPortField(sx, readwriter, "en"), sx.readLatency - 1, clk, one)
val (wmode, ss2) = prependPipe(memPortField(sx, readwriter, "wmode"), sx.writeLatency - 1, clk, one)
val (wmask, ss3) = prependPipe(memPortField(sx, readwriter, "wmask"), sx.writeLatency - 1, clk, one)
val (wdata, ss4) = prependPipe(memPortField(sx, readwriter, "wdata"), sx.writeLatency - 1, clk, one)
val (waddr, ss5) = prependPipe(memPortField(sx, readwriter, "addr"), sx.writeLatency - 1, clk, one)
val stmts = ss1 ++ ss2 ++ ss3 ++ ss4 ++ ss5 ++ writePortConnects(writer, clk, AND(en, wmode), wmask, waddr, wdata)
if (sx.readUnderWrite == ReadUnderWrite.Old) {
// For a read-first ("old") mem, read data gets delayed, so don't delay read address/en
val enDriver = netlist(memPortField(sx, readwriter, "en"))
val addrDriver = netlist(memPortField(sx, readwriter, "addr"))
val wmodeDriver = netlist(memPortField(sx, readwriter, "wmode"))
stmts ++ readPortConnects(reader, clk, AND(enDriver, NOT(wmodeDriver)), addrDriver)
} else {
// For a write-first ("new") or undefined mem, delay read control inputs
val (raddr, raddrPipeStmts) = prependPipe(memPortField(sx, readwriter, "addr"), sx.readLatency, clk, AND(en, NOT(wmode)))
repl(memPortField(sx, readwriter, "rdata")) = memPortField(mem, reader, "data")
stmts ++ raddrPipeStmts ++ readPortConnects(reader, clk, en, raddr)
}
}))
mem // The mem stays put
case sx: Connect => kind(sx.loc) match {
case MemKind => EmptyStmt
case _ => sx
}
case sx => sx

def pipeReadData(p: String): Seq[Statement] = {
val newName = rwMap.get(p).map(_._1).getOrElse(p) // Name of final read port, whether renamed (rw port) or not
val rdataNew = memPortField(mem, newName, "data")
val rdataOld = rwMap.get(p).map(rw => memPortField(sx, p, "rdata")).getOrElse(rdataNew)
val clk = netlist(rdataOld.copy(name = "clk"))
val (rdataPipe, rdataPipeStmts) = DelayPipe(namespace)(rdataNew, sx.readLatency, clk) // TODO: use enable
repl(rdataOld) = rdataPipe
rdataPipeStmts
}

// We actually pipe the read data here; this groups it with the mem declaration to keep declarations early
if (sx.readUnderWrite == ReadUnderWrite.Old) {
Block(mem +: (sx.readers ++ sx.readwriters).flatMap(pipeReadData(_)))
} else {
mem
}
case sx: Connect if kind(sx.loc) == MemKind => EmptyStmt
case sx => sx map replaceExp(repl)
}

def replaceExp(repl: Netlist)(e: Expression): Expression = e match {
Expand All @@ -140,9 +185,6 @@ object VerilogMemDelays extends Pass {
case ex => ex map replaceExp(repl)
}

def replaceStmt(repl: Netlist)(s: Statement): Statement =
s map replaceStmt(repl) map replaceExp(repl)

def appendStmts(sx: Seq[Statement])(s: Statement): Statement = Block(s +: sx)

def memDelayMod(m: DefModule): DefModule = {
Expand All @@ -152,7 +194,6 @@ object VerilogMemDelays extends Pass {
val extraStmts = mutable.ArrayBuffer.empty[Statement]
m.foreach(buildNetlist(netlist))
m.map(memDelayStmt(netlist, namespace, repl, extraStmts))
.map(replaceStmt(repl))
.map(appendStmts(extraStmts))
}

Expand Down
11 changes: 9 additions & 2 deletions src/main/scala/firrtl/proto/FromProto.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import java.io.{File, FileInputStream, InputStream}
import collection.JavaConverters._
import FirrtlProtos._
import com.google.protobuf.CodedInputStream
import Firrtl.Statement.ReadUnderWrite

object FromProto {

Expand Down Expand Up @@ -133,6 +134,12 @@ object FromProto {
ir.Conditionally(convert(info), convert(when.getPredicate), conseq, alt)
}

def convert(ruw: ReadUnderWrite): ir.ReadUnderWrite.Value = ruw match {
case ReadUnderWrite.UNDEFINED => ir.ReadUnderWrite.Undefined
case ReadUnderWrite.OLD => ir.ReadUnderWrite.Old
case ReadUnderWrite.NEW => ir.ReadUnderWrite.New
}

def convert(dt: Firrtl.Statement.CMemory.TypeAndDepth): (ir.Type, BigInt) =
(convert(dt.getDataType), convert(dt.getDepth))

Expand All @@ -145,7 +152,7 @@ object FromProto {
case TYPE_AND_DEPTH_FIELD_NUMBER =>
convert(cmem.getTypeAndDepth)
}
CDefMemory(convert(info), cmem.getId, tpe, depth, cmem.getSyncRead)
CDefMemory(convert(info), cmem.getId, tpe, depth, cmem.getSyncRead, convert(cmem.getReadUnderWrite))
}

import Firrtl.Statement.MemoryPort.Direction._
Expand Down Expand Up @@ -181,7 +188,7 @@ object FromProto {
case BIGINT_DEPTH_FIELD_NUMBER => convert(mem.getBigintDepth)
}
ir.DefMemory(convert(info), mem.getId, dtype, depth, mem.getWriteLatency, mem.getReadLatency,
rs, ws, rws, None)
rs, ws, rws, convert(mem.getReadUnderWrite))
}

def convert(attach: Firrtl.Statement.Attach, info: Firrtl.SourceInfo): ir.Attach = {
Expand Down
Loading

0 comments on commit 1ced6cf

Please sign in to comment.