Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to the Cluster API used by the GPU Project #3714

Merged
merged 11 commits into from
Feb 9, 2025
4 changes: 4 additions & 0 deletions src/main/scala/subsystem/BusTopology.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ case object COH extends TLBusWrapperLocation("coh")
case class CSBUS(clusterId: Int) extends TLBusWrapperLocation(s"csbus$clusterId")
case class CMBUS(clusterId: Int) extends TLBusWrapperLocation(s"cmbus$clusterId")
case class CCBUS(clusterId: Int) extends TLBusWrapperLocation(s"ccbus$clusterId")
case class CLBUS(clusterId: Int) extends TLBusWrapperLocation(s"clbus$clusterId")
case class CCOH (clusterId: Int) extends TLBusWrapperLocation(s"ccoh$clusterId")

/** Parameterizes the subsystem in terms of optional clock-crossings
Expand Down Expand Up @@ -120,11 +121,14 @@ case class ClusterBusTopologyParams(
) extends TLBusWrapperTopology(
instantiations = List(
(CSBUS(clusterId), csbus),
(CLBUS(clusterId), csbus), // TODO don't copy from csbus params
(CCBUS(clusterId), ccbus)) ++ (if (coherence.nBanks == 0) Nil else List(
(CMBUS(clusterId), csbus),
(CCOH (clusterId), CoherenceManagerWrapperParams(csbus.blockBytes, csbus.beatBytes, coherence.nBanks, CCOH(clusterId).name)(coherence.coherenceManager)))),
connections = if (coherence.nBanks == 0) Nil else List(
(CSBUS(clusterId), CCOH (clusterId), TLBusWrapperConnection(driveClockFromMaster = Some(true), nodeBinding = BIND_STAR)()),
// NOTE(hansung): not sure this is necessary
(CLBUS(clusterId), CCOH (clusterId), TLBusWrapperConnection(driveClockFromMaster = Some(true), nodeBinding = BIND_STAR)()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between the CLBUS and the CSBUS?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want the CLBUS to be totally private, you probably don't want it to the CCOH (ClusterCoherence) device, since that exposes cluster-external memory

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks for the heads up

(CCOH (clusterId), CMBUS(clusterId), TLBusWrapperConnection.crossTo(
xType = NoCrossing,
driveClockFromMaster = Some(true),
Expand Down
63 changes: 39 additions & 24 deletions src/main/scala/subsystem/Cluster.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,20 @@ import scala.collection.immutable.SortedMap

case class ClustersLocated(loc: HierarchicalLocation) extends Field[Seq[CanAttachCluster]](Nil)

trait BaseClusterParams extends HierarchicalElementParams {
val clusterId: Int
}

abstract class InstantiableClusterParams[ClusterType <: Cluster]
extends HierarchicalElementParams
with BaseClusterParams {
def instantiate(crossing: HierarchicalElementCrossingParamsLike, lookup: LookupByClusterIdImpl)(implicit p: Parameters): ClusterType
}

case class ClusterParams(
val clusterId: Int,
val clockSinkParams: ClockSinkParameters = ClockSinkParameters()
) extends HierarchicalElementParams {
) extends InstantiableClusterParams[Cluster] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are many of these changes intended to let you extend Cluster?

val baseName = "cluster"
val uniqueName = s"${baseName}_$clusterId"
def instantiate(crossing: HierarchicalElementCrossingParamsLike, lookup: LookupByClusterIdImpl)(implicit p: Parameters): Cluster = {
Expand All @@ -31,7 +41,7 @@ case class ClusterParams(
}

class Cluster(
val thisClusterParams: ClusterParams,
val thisClusterParams: BaseClusterParams,
crossing: ClockCrossingType,
lookup: LookupByClusterIdImpl)(implicit p: Parameters) extends BaseHierarchicalElement(crossing)(p)
with Attachable
Expand All @@ -46,10 +56,12 @@ class Cluster(
lazy val allClockGroupsNode = ClockGroupIdentityNode()

val csbus = tlBusWrapperLocationMap(CSBUS(clusterId)) // like the sbus in the base subsystem
val clbus = tlBusWrapperLocationMap(CLBUS(clusterId)) // like the sbus in the base subsystem
val ccbus = tlBusWrapperLocationMap(CCBUS(clusterId)) // like the cbus in the base subsystem
val cmbus = tlBusWrapperLocationMap.lift(CMBUS(clusterId)).getOrElse(csbus)

csbus.clockGroupNode := allClockGroupsNode
clbus.clockGroupNode := allClockGroupsNode
ccbus.clockGroupNode := allClockGroupsNode

val slaveNode = ccbus.inwardNode
Expand All @@ -66,7 +78,7 @@ class Cluster(
def toPlicDomain = this
lazy val msipNodes = totalTileIdList.map { i => (i, IntIdentityNode()) }.to(SortedMap)
lazy val meipNodes = totalTileIdList.map { i => (i, IntIdentityNode()) }.to(SortedMap)
lazy val seipNodes = totalTileIdList.map { i => (i, IntIdentityNode()) }.to(SortedMap)
lazy val seipNodes = totalTiles.filter(_._2.tileParams.core.useSupervisor).keys.map { i => (i, IntIdentityNode()) }.to(SortedMap)
lazy val tileToPlicNodes = totalTileIdList.map { i => (i, IntIdentityNode()) }.to(SortedMap)
lazy val debugNodes = totalTileIdList.map { i => (i, IntSyncIdentityNode()) }.to(SortedMap)
lazy val nmiNodes = totalTiles.filter { case (i,t) => t.tileParams.core.useNMI }
Expand All @@ -79,7 +91,7 @@ class Cluster(
// TODO fix: shouldn't need to connect dummy notifications
tileHaltXbarNode := NullIntSource()
tileWFIXbarNode := NullIntSource()
tileCeaseXbarNode := NullIntSource()
// tileCeaseXbarNode := NullIntSource()

override lazy val module = new ClusterModuleImp(this)
}
Expand All @@ -88,12 +100,12 @@ class ClusterModuleImp(outer: Cluster) extends BaseHierarchicalElementModuleImp[

case class InCluster(id: Int) extends HierarchicalLocation(s"Cluster$id")

class ClusterPRCIDomain(
abstract class ClusterPRCIDomain[ClusterType <: Cluster](
clockSinkParams: ClockSinkParameters,
crossingParams: HierarchicalElementCrossingParamsLike,
clusterParams: ClusterParams,
clusterParams: InstantiableClusterParams[ClusterType],
lookup: LookupByClusterIdImpl)
(implicit p: Parameters) extends HierarchicalElementPRCIDomain[Cluster](clockSinkParams, crossingParams)
(implicit p: Parameters) extends HierarchicalElementPRCIDomain[ClusterType](clockSinkParams, crossingParams)
{
val element = element_reset_domain {
LazyModule(clusterParams.instantiate(crossingParams, lookup))
Expand All @@ -104,19 +116,19 @@ class ClusterPRCIDomain(


trait CanAttachCluster {
type ClusterType <: Cluster
type ClusterContextType <: DefaultHierarchicalElementContextType

def clusterParams: ClusterParams
def clusterParams: InstantiableClusterParams[ClusterType]
def crossingParams: HierarchicalElementCrossingParamsLike

def instantiate(allClusterParams: Seq[ClusterParams], instantiatedClusters: SortedMap[Int, ClusterPRCIDomain])(implicit p: Parameters): ClusterPRCIDomain = {
def instantiate(allClusterParams: Seq[BaseClusterParams], instantiatedClusters: SortedMap[Int, ClusterPRCIDomain[_]])(implicit p: Parameters): ClusterPRCIDomain[ClusterType] = {
val clockSinkParams = clusterParams.clockSinkParams.copy(name = Some(clusterParams.uniqueName))
val cluster_prci_domain = LazyModule(new ClusterPRCIDomain(
clockSinkParams, crossingParams, clusterParams, PriorityMuxClusterIdFromSeq(allClusterParams)))
val cluster_prci_domain = LazyModule(new ClusterPRCIDomain[ClusterType](
clockSinkParams, crossingParams, clusterParams, PriorityMuxClusterIdFromSeq(allClusterParams)) {})
cluster_prci_domain
}

def connect(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connect(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
connectMasterPorts(domain, context)
connectSlavePorts(domain, context)
connectInterrupts(domain, context)
Expand All @@ -126,21 +138,21 @@ trait CanAttachCluster {
connectTrace(domain, context)
}

def connectMasterPorts(domain: ClusterPRCIDomain, context: Attachable): Unit = {
def connectMasterPorts(domain: ClusterPRCIDomain[ClusterType], context: Attachable): Unit = {
implicit val p = context.p
val dataBus = context.locateTLBusWrapper(crossingParams.master.where)
dataBus.coupleFrom(clusterParams.baseName) { bus =>
bus :=* crossingParams.master.injectNode(context) :=* domain.crossMasterPort(crossingParams.crossingType)
}
}
def connectSlavePorts(domain: ClusterPRCIDomain, context: Attachable): Unit = {
def connectSlavePorts(domain: ClusterPRCIDomain[ClusterType], context: Attachable): Unit = {
implicit val p = context.p
val controlBus = context.locateTLBusWrapper(crossingParams.slave.where)
controlBus.coupleTo(clusterParams.baseName) { bus =>
domain.crossSlavePort(crossingParams.crossingType) :*= crossingParams.slave.injectNode(context) :*= TLWidthWidget(controlBus.beatBytes) :*= bus
}
}
def connectInterrupts(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connectInterrupts(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
implicit val p = context.p

domain.element.debugNodes.foreach { case (hartid, node) =>
Expand Down Expand Up @@ -170,23 +182,23 @@ trait CanAttachCluster {
}
}

def connectPRC(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connectPRC(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
implicit val p = context.p
domain.element.allClockGroupsNode :*= context.allClockGroupsNode
domain {
domain.element_reset_domain.clockNode := crossingParams.resetCrossingType.injectClockNode := domain.clockNode
}
}

def connectOutputNotifications(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connectOutputNotifications(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
implicit val p = context.p
context.tileHaltXbarNode :=* domain.crossIntOut(NoCrossing, domain.element.tileHaltXbarNode)
context.tileWFIXbarNode :=* domain.crossIntOut(NoCrossing, domain.element.tileWFIXbarNode)
context.tileCeaseXbarNode :=* domain.crossIntOut(NoCrossing, domain.element.tileCeaseXbarNode)

}

def connectInputConstants(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connectInputConstants(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
implicit val p = context.p
val tlBusToGetPrefixFrom = context.locateTLBusWrapper(crossingParams.mmioBaseAddressPrefixWhere)
domain.element.tileHartIdNodes.foreach { case (hartid, node) =>
Expand All @@ -197,7 +209,7 @@ trait CanAttachCluster {
}
}

def connectTrace(domain: ClusterPRCIDomain, context: ClusterContextType): Unit = {
def connectTrace(domain: ClusterPRCIDomain[ClusterType], context: ClusterContextType): Unit = {
implicit val p = context.p
domain.element.traceNodes.foreach { case (hartid, node) =>
val traceNexusNode = BundleBridgeBlockDuringReset[TraceBundle](
Expand All @@ -212,23 +224,26 @@ trait CanAttachCluster {
}
}

case class ClusterAttachParams(
case class ClusterAttachParams (
clusterParams: ClusterParams,
crossingParams: HierarchicalElementCrossingParamsLike
) extends CanAttachCluster
) extends CanAttachCluster {
type ClusterType = Cluster
}

case class CloneClusterAttachParams(
sourceClusterId: Int,
cloneParams: CanAttachCluster
) extends CanAttachCluster {
type ClusterType = cloneParams.ClusterType
def clusterParams = cloneParams.clusterParams
def crossingParams = cloneParams.crossingParams

override def instantiate(allClusterParams: Seq[ClusterParams], instantiatedClusters: SortedMap[Int, ClusterPRCIDomain])(implicit p: Parameters): ClusterPRCIDomain = {
override def instantiate(allClusterParams: Seq[BaseClusterParams], instantiatedClusters: SortedMap[Int, ClusterPRCIDomain[_]])(implicit p: Parameters): ClusterPRCIDomain[ClusterType] = {
require(instantiatedClusters.contains(sourceClusterId))
val clockSinkParams = clusterParams.clockSinkParams.copy(name = Some(clusterParams.uniqueName))
val cluster_prci_domain = CloneLazyModule(
new ClusterPRCIDomain(clockSinkParams, crossingParams, clusterParams, PriorityMuxClusterIdFromSeq(allClusterParams)),
new ClusterPRCIDomain[ClusterType](clockSinkParams, crossingParams, clusterParams, PriorityMuxClusterIdFromSeq(allClusterParams)) {},
instantiatedClusters(sourceClusterId)
)
cluster_prci_domain
Expand Down
8 changes: 4 additions & 4 deletions src/main/scala/subsystem/HasHierarchicalElements.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,16 @@ trait InstantiatesHierarchicalElements { this: LazyModule with Attachable =>
}

val clusterAttachParams: Seq[CanAttachCluster] = p(ClustersLocated(location)).sortBy(_.clusterParams.clusterId)
val clusterParams: Seq[ClusterParams] = clusterAttachParams.map(_.clusterParams)
val clusterParams: Seq[BaseClusterParams] = clusterAttachParams.map(_.clusterParams)
val clusterCrossingTypes: Seq[ClockCrossingType] = clusterAttachParams.map(_.crossingParams.crossingType)
val cluster_prci_domains: SortedMap[Int, ClusterPRCIDomain] = clusterAttachParams.foldLeft(SortedMap[Int, ClusterPRCIDomain]()) {
val cluster_prci_domains: SortedMap[Int, ClusterPRCIDomain[_]] = clusterAttachParams.foldLeft(SortedMap[Int, ClusterPRCIDomain[_]]()) {
case (instantiated, params) => instantiated + (params.clusterParams.clusterId -> params.instantiate(clusterParams, instantiated)(p))
}

val element_prci_domains: Seq[HierarchicalElementPRCIDomain[_]] = tile_prci_domains.values.toSeq ++ cluster_prci_domains.values.toSeq

val leafTiles: SortedMap[Int, BaseTile] = SortedMap(tile_prci_domains.mapValues(_.element.asInstanceOf[BaseTile]).toSeq.sortBy(_._1):_*)
val totalTiles: SortedMap[Int, BaseTile] = (leafTiles ++ cluster_prci_domains.values.map(_.element.totalTiles).flatten)
val totalTiles: SortedMap[Int, BaseTile] = (leafTiles ++ cluster_prci_domains.values.map(_.element.asInstanceOf[Cluster].totalTiles).flatten)

// Helper functions for accessing certain parameters that are popular to refer to in subsystem code
def nLeafTiles: Int = leafTiles.size
Expand All @@ -123,7 +123,7 @@ trait HasHierarchicalElements extends DefaultHierarchicalElementContextType
params.connect(tile_prci_domains(params.tileParams.tileId).asInstanceOf[TilePRCIDomain[params.TileType]], this.asInstanceOf[params.TileContextType])
}
clusterAttachParams.foreach { params =>
params.connect(cluster_prci_domains(params.clusterParams.clusterId).asInstanceOf[ClusterPRCIDomain], this.asInstanceOf[params.ClusterContextType])
params.connect(cluster_prci_domains(params.clusterParams.clusterId).asInstanceOf[ClusterPRCIDomain[params.ClusterType]], this.asInstanceOf[params.ClusterContextType])
}
}

Expand Down
10 changes: 5 additions & 5 deletions src/main/scala/subsystem/LookupByClusterId.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@ import chisel3._
import chisel3.util._

abstract class LookupByClusterIdImpl {
def apply[T <: Data](f: ClusterParams => Option[T], clusterId: UInt): T
def apply[T <: Data](f: BaseClusterParams => Option[T], clusterId: UInt): T
}

case class ClustersWontDeduplicate(t: ClusterParams) extends LookupByClusterIdImpl {
def apply[T <: Data](f: ClusterParams => Option[T], clusterId: UInt): T = f(t).get
case class ClustersWontDeduplicate(t: BaseClusterParams) extends LookupByClusterIdImpl {
def apply[T <: Data](f: BaseClusterParams => Option[T], clusterId: UInt): T = f(t).get
}

case class PriorityMuxClusterIdFromSeq(seq: Seq[ClusterParams]) extends LookupByClusterIdImpl {
def apply[T <: Data](f: ClusterParams => Option[T], clusterId: UInt): T =
case class PriorityMuxClusterIdFromSeq(seq: Seq[BaseClusterParams]) extends LookupByClusterIdImpl {
def apply[T <: Data](f: BaseClusterParams => Option[T], clusterId: UInt): T =
PriorityMux(seq.collect { case t if f(t).isDefined => (t.clusterId.U === clusterId) -> f(t).get })
}
Loading