Skip to content

Commit

Permalink
[rtl] fix ffo
Browse files Browse the repository at this point in the history
  • Loading branch information
qinjun-li authored and sequencer committed Nov 9, 2024
1 parent d61b2f3 commit aca05e4
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 62 deletions.
1 change: 1 addition & 0 deletions t1/src/Bundles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,7 @@ class MaskUnitExeReq(parameter: LaneParameter) extends Bundle {
class MaskUnitExeResponse(parameter: LaneParameter) extends Bundle {
val ffoByOther: Bool = Bool()
val writeData = new MaskUnitWriteBundle(parameter)
val pipeData: UInt = UInt(parameter.datapathWidth.W)
val index: UInt = UInt(parameter.instructionIndexBits.W)
}

Expand Down
29 changes: 1 addition & 28 deletions t1/src/Lane.scala
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,6 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
)
)

val ffoRecord: FFORecord = RegInit(0.U.asTypeOf(new FFORecord))

/** VRF read request for each slot, 3 is for [[source1]] [[source2]] [[source3]]
*/
val vrfReadRequest: Vec[Vec[DecoupledIO[VRFReadRequest]]] = Wire(
Expand Down Expand Up @@ -616,7 +614,6 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
laneState.vd := record.laneRequest.vd
laneState.instructionIndex := record.laneRequest.instructionIndex
laneState.skipEnable := skipEnable
laneState.ffoByOtherLanes := ffoRecord.ffoByOtherLanes
laneState.additionalRW := record.additionalRW
laneState.skipRead := record.laneRequest.decodeResult(Decoder.other) &&
(record.laneRequest.decodeResult(Decoder.uop) === 9.U)
Expand Down Expand Up @@ -760,9 +757,6 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
sink := source
}

executionUnit.ffoByOtherLanes := ffoRecord.ffoByOtherLanes
executionUnit.selfCompleted := ffoRecord.selfCompleted

// executionUnit <> vfu
requestVec(index) := executionUnit.vfuRequest.bits
executeDecodeVec(index) := executionUnit.executeDecode
Expand Down Expand Up @@ -791,7 +785,7 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
stage3EnqWire.bits.instructionIndex := stage2.dequeue.bits.instructionIndex
stage3EnqWire.bits.loadStore := stage2.dequeue.bits.loadStore
stage3EnqWire.bits.vd := stage2.dequeue.bits.vd
stage3EnqWire.bits.ffoByOtherLanes := ffoRecord.ffoByOtherLanes
stage3EnqWire.bits.ffoByOtherLanes := false.B
stage3EnqWire.bits.groupCounter := stage2.dequeue.bits.groupCounter
stage3EnqWire.bits.mask := stage2.dequeue.bits.mask
if (isLastSlot) {
Expand All @@ -808,19 +802,6 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
stage2.dequeue.bits.sSendResponse.foreach(_ => stage3EnqWire.bits.sSendResponse := _)
executionUnit.dequeue.bits.ffoSuccess.foreach(_ => stage3EnqWire.bits.ffoSuccess := _)

if (isLastSlot) {
when(maskUnitResponse.valid) {
when(maskUnitResponse.bits.ffoByOther) {
ffoRecord.ffoByOtherLanes := true.B
}
}
when(stage3EnqWire.fire) {
executionUnit.dequeue.bits.ffoSuccess.foreach(ffoRecord.selfCompleted := _)
// This group found means the next group ended early
ffoRecord.ffoByOtherLanes := ffoRecord.ffoByOtherLanes || ffoRecord.selfCompleted
}
}

// --- stage 3 end & stage 4 start ---
// vrfWriteQueue try to write vrf
vrfWriteArbiter(index).valid := stage3.vrfWriteRequest.valid
Expand Down Expand Up @@ -1089,14 +1070,6 @@ class Lane(val parameter: LaneParameter) extends Module with SerializableModule[
}
}

// slot 0 update
when(slotEnqueueFire.head) {
// new ffo enq
when(slotControl(1).laneRequest.decodeResult(Decoder.ffo)) {
ffoRecord := 0.U.asTypeOf(ffoRecord)
}
}

val slotDequeueFire: Seq[Bool] = (slotCanShift.head && slotOccupied.head) +: slotEnqueueFire
Seq.tabulate(parameter.chainingSize) { slotIndex =>
when(slotEnqueueFire(slotIndex) ^ slotDequeueFire(slotIndex)) {
Expand Down
7 changes: 1 addition & 6 deletions t1/src/laneStage/LaneExecutionBridge.scala
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ class LaneExecutionBridge(parameter: LaneParameter, isLastSlot: Boolean, slotInd
@public
val dataResponse: ValidIO[VFUResponseToSlot] = IO(Flipped(Valid(new VFUResponseToSlot(parameter))))

@public
val ffoByOtherLanes: Bool = IO(Input(Bool()))
@public
val selfCompleted: Bool = IO(Input(Bool()))

@public
val executeDecode: DecodeBundle = IO(Output(Decoder.bundle(parameter.decoderParam)))
@public
Expand Down Expand Up @@ -318,7 +313,7 @@ class LaneExecutionBridge(parameter: LaneParameter, isLastSlot: Boolean, slotInd
vfuRequest.bits.popInit := reduceResult.getOrElse(0.U)
vfuRequest.bits.groupIndex := executionRecord.groupCounter
vfuRequest.bits.laneIndex := executionRecord.laneIndex
vfuRequest.bits.complete := ffoByOtherLanes || selfCompleted
vfuRequest.bits.complete := false.B
vfuRequest.bits.maskType := executionRecord.maskType
vfuRequest.bits.narrow := narrowInRecord
vfuRequest.bits.unitSelet.foreach(_ := executionRecord.decodeResult(Decoder.fpExecutionType))
Expand Down
1 change: 0 additions & 1 deletion t1/src/laneStage/LaneStage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class LaneState(parameter: LaneParameter) extends Bundle {
val maskType: Bool = Bool()
val maskNotMaskedElement: Bool = Bool()
val skipEnable: Bool = Bool()
val ffoByOtherLanes: Bool = Bool()

/** vs1 or imm */
val vs1: UInt = UInt(5.W)
Expand Down
1 change: 0 additions & 1 deletion t1/src/laneStage/LaneStage0.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class LaneStage0Enqueue(parameter: LaneParameter) extends Bundle {
// vm = 0
val maskType: Bool = Bool()
val maskNotMaskedElement: Bool = Bool()
val ffoByOtherLanes: Bool = Bool()

/** vs1 or imm */
val vs1: UInt = UInt(5.W)
Expand Down
3 changes: 1 addition & 2 deletions t1/src/laneStage/LaneStage3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ class LaneStage3(parameter: LaneParameter, isLastSlot: Boolean) extends Module {

val dataSelect: Option[UInt] = Option.when(isLastSlot) {
Mux(
pipeEnqueue.get.decodeResult(Decoder.nr) ||
(enqueue.bits.ffoByOtherLanes && pipeEnqueue.get.decodeResult(Decoder.ffo)),
pipeEnqueue.get.decodeResult(Decoder.nr) || pipeEnqueue.get.ffoByOtherLanes,
pipeEnqueue.get.pipeData,
pipeEnqueue.get.data
)
Expand Down
17 changes: 7 additions & 10 deletions t1/src/laneStage/MaskExchangeUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ class MaskExchangeUnit(parameter: LaneParameter) extends Module {

val maskUnitWriteQueue: QueueIO[MaskUnitExeResponse] =
Queue.io(new MaskUnitExeResponse(parameter), parameter.maskUnitVefWriteQueueSize)
val wMaskResponse: Bool = RegInit(true.B)

// todo: sSendResponse -> sendResponse
val enqIsMaskRequest: Bool = !enqueue.bits.sSendResponse
val enqSendToDeq: Bool = !enqueue.bits.decodeResult(Decoder.maskUnit)
val enqFFo: Bool = enqueue.bits.decodeResult(Decoder.ffo)
val enqFFoIndex: Bool = enqueue.bits.decodeResult(Decoder.ffo) &&
enqueue.bits.decodeResult(Decoder.targetRd)

val maskRequestAllow: Bool =
pipeToken(parameter.maskRequestQueueSize)(maskReq.valid, tokenIO.maskRequestRelease)
// todo: connect mask request & response
maskReq.valid := enqIsMaskRequest && enqueue.valid && wMaskResponse && maskRequestAllow
maskReq.valid := enqIsMaskRequest && enqueue.valid && maskRequestAllow
maskReq.bits.source1 := enqueue.bits.pipeData
maskReq.bits.source2 := Mux(
enqFFo,
enqFFoIndex,
enqueue.bits.ffoIndex,
enqueue.bits.data
)
Expand All @@ -64,7 +64,8 @@ class MaskExchangeUnit(parameter: LaneParameter) extends Module {
maskUnitResponsePipeType.mask := maskUnitWriteQueue.deq.bits.writeData.mask
maskUnitResponsePipeType.vd := maskUnitWriteQueue.deq.bits.writeData.vd
maskUnitResponsePipeType.instructionIndex := maskUnitWriteQueue.deq.bits.index
maskUnitResponsePipeType.ffoByOtherLanes := enqueue.bits.ffoByOtherLanes
maskUnitResponsePipeType.ffoByOtherLanes := maskUnitWriteQueue.deq.bits.ffoByOther
maskUnitResponsePipeType.pipeData := maskUnitWriteQueue.deq.bits.pipeData

maskUnitWriteQueue.enq.valid := maskUnitResponse.valid
maskUnitWriteQueue.enq.bits := maskUnitResponse.bits
Expand All @@ -74,12 +75,8 @@ class MaskExchangeUnit(parameter: LaneParameter) extends Module {

dequeue.valid := (enqueue.valid && enqSendToDeq) || maskUnitWriteQueue.deq.valid
dequeue.bits := Mux(enqWantToSend, enqueue.bits, maskUnitResponsePipeType)
enqueue.ready := (dequeue.ready) && (enqWantToSend || wMaskResponse) && maskRequestEnqReady
enqueue.ready := Mux(enqSendToDeq, dequeue.ready, maskRequestEnqReady)
maskUnitWriteQueue.deq.ready := dequeue.ready && !enqWantToSend
tokenIO.maskResponseRelease := maskUnitWriteQueue.deq.fire

// update wMaskResponse
when(maskReq.fire && enqFFo || maskUnitResponse.fire) {
wMaskResponse := maskUnitResponse.fire
}
}
6 changes: 5 additions & 1 deletion t1/src/mask/BitLevelMaskWrite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import org.chipsalliance.dwbb.stdlib.queue.{Queue, QueueIO}

class BitLevelWriteRequest(parameter: T1Parameter) extends Bundle {
val data: UInt = UInt(parameter.datapathWidth.W)
val pipeData: UInt = UInt(parameter.datapathWidth.W)
val bitMask: UInt = UInt(parameter.datapathWidth.W)
val mask: UInt = UInt((parameter.datapathWidth / 8).W)
val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W)
val ffoByOther: Bool = Bool()
}

class BitLevelMaskWrite(parameter: T1Parameter) extends Module {
Expand Down Expand Up @@ -75,9 +77,11 @@ class BitLevelMaskWrite(parameter: T1Parameter) extends Module {
res.valid := WaitReadQueue.deq.valid && readResultValid
WaitReadQueue.deq.ready := res.ready && readResultValid
res.bits := DontCare
res.bits.pipeData := WaitReadQueue.deq.bits.pipeData
res.bits.ffoByOther := WaitReadQueue.deq.bits.ffoByOther
res.bits.writeData.data := Mux(needWAR, WARData, WaitReadQueue.deq.bits.data)
res.bits.writeData.mask := maskEnable(!needWAR, WaitReadQueue.deq.bits.mask)
res.bits.writeData.groupCounter := WaitReadQueue.deq.bits.groupCounter
res.bits.writeData.mask := maskEnable(!needWAR, WaitReadQueue.deq.bits.mask)

// valid token
val counter = RegInit(0.U(3.W))
Expand Down
11 changes: 8 additions & 3 deletions t1/src/mask/MaskCompress.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class CompressInput(parameter: T1Parameter) extends Bundle {
val source2: UInt = UInt((parameter.laneNumber * parameter.datapathWidth).W)
val groupCounter: UInt = UInt(parameter.laneParam.groupNumberBits.W)
val ffoInput: UInt = UInt(parameter.laneNumber.W)
val validInput: UInt = UInt(parameter.laneNumber.W)
val lastCompress: Bool = Bool()
}

Expand Down Expand Up @@ -171,11 +172,14 @@ class MaskCompress(parameter: T1Parameter) extends Module {
val mvMask = Mux1H(eew1H, Seq(1.U, 3.U, 15.U))
val mvData = in.bits.readFromScalar

val ffoMask: UInt = FillInterleaved(parameter.datapathWidth / 8, in.bits.validInput)

out.data := Mux1H(
Seq(
compress -> compressResult,
viota -> viotaResult,
mv -> mvData
mv -> mvData,
ffoType -> in.bits.source2
)
)

Expand All @@ -184,7 +188,8 @@ class MaskCompress(parameter: T1Parameter) extends Module {
Seq(
compress -> compressMask,
viota -> viotaMask,
mv -> mvMask
mv -> mvMask,
ffoType -> ffoMask
)
)

Expand Down Expand Up @@ -217,5 +222,5 @@ class MaskCompress(parameter: T1Parameter) extends Module {
)
}
}
out.ffoOutput := completedLeftOr
out.ffoOutput := completedLeftOr | Fill(parameter.laneNumber, ffoValid)
}
22 changes: 12 additions & 10 deletions t1/src/mask/MaskUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class MaskUnit(parameter: T1Parameter) extends Module {
val ffo: Bool = instReg.decodeResult(Decoder.topUop) === BitPat("b0111?")
val extendType: Bool = unitType(3) && (subType(2) || subType(1))

val allGroupExecute: Bool = maskDestinationType || unitType(2) || compress
val allGroupExecute: Bool = maskDestinationType || unitType(2) || compress || ffo
val useDefaultSew: Bool = unitType(0)
// todo: decode ?
// Indicates how many times a set of data will be executed
Expand Down Expand Up @@ -259,8 +259,7 @@ class MaskUnit(parameter: T1Parameter) extends Module {
// for prioritizeLane
// for last group remainder lastGroupRemaining
val laneDatalog = log2Ceil(parameter.datapathWidth)
val lastLaneIndex = (lastGroupRemaining >> laneDatalog).asUInt +
changeUIntSize(lastGroupRemaining, laneDatalog).orR
val lastLaneIndex = (lastGroupRemaining >> laneDatalog).asUInt
val dataNeedForPL = (~scanLeftOr(UIntToOH(lastLaneIndex))).asUInt

// for !prioritizeLane
Expand Down Expand Up @@ -543,13 +542,13 @@ class MaskUnit(parameter: T1Parameter) extends Module {
val readWaitQueue: QueueIO[MaskUnitWaitReadQueue] = Queue.io(new MaskUnitWaitReadQueue(parameter), 64)

// s0 pipe request from lane
val laseExecuteGroupDeq: Bool = Wire(Bool())
val lastExecuteGroupDeq: Bool = Wire(Bool())
exeRequestQueue.zip(exeReqReg).foreach { case (req, reg) =>
req.deq.ready := !reg.valid || laseExecuteGroupDeq || viota
req.deq.ready := !reg.valid || lastExecuteGroupDeq || viota
when(req.deq.fire) {
reg.bits := req.deq.bits
}
when(req.deq.fire ^ laseExecuteGroupDeq) {
when(req.deq.fire ^ lastExecuteGroupDeq) {
reg.valid := req.deq.fire && !viota
}
}
Expand Down Expand Up @@ -658,7 +657,7 @@ class MaskUnit(parameter: T1Parameter) extends Module {
readWaitQueue.enq.bits.last := readIssueStageState.last

// last execute group in this request group dequeue
laseExecuteGroupDeq := requestStageDeq && isLastExecuteGroup
lastExecuteGroupDeq := requestStageDeq && isLastExecuteGroup

// s1 read vrf
val write1HPipe: Vec[UInt] = Wire(Vec(parameter.laneNumber, UInt(parameter.laneNumber.W)))
Expand Down Expand Up @@ -738,7 +737,8 @@ class MaskUnit(parameter: T1Parameter) extends Module {

val writeRequest: Seq[MaskUnitExeResponse] = Seq.tabulate(parameter.laneNumber) { laneIndex =>
val res: MaskUnitExeResponse = Wire(new MaskUnitExeResponse(parameter.laneParam))
res.ffoByOther := false.B
res.ffoByOther := DontCare
res.pipeData := DontCare
res.index := instReg.instructionIndex
res.writeData.groupCounter := waiteReadDataPipeReg.groupCounter
res.writeData.vd := instReg.vd
Expand Down Expand Up @@ -845,6 +845,7 @@ class MaskUnit(parameter: T1Parameter) extends Module {
compressUnit.in.bits.groupCounter := requestCounter
compressUnit.in.bits.lastCompress := lastGroup
compressUnit.in.bits.ffoInput := VecInit(exeReqReg.map(_.bits.ffo)).asUInt
compressUnit.in.bits.validInput := VecInit(exeReqReg.map(_.valid)).asUInt
compressUnit.newInstruction := instReq.valid

reduceUnit.in.valid := executeEnqValid && unitType(2)
Expand Down Expand Up @@ -925,15 +926,17 @@ class MaskUnit(parameter: T1Parameter) extends Module {
)
)

val executeWriteByteMask: UInt = Mux(compress, compressUnit.out.mask, executeByteMask)
val executeWriteByteMask: UInt = Mux(compress || ffo, compressUnit.out.mask, executeByteMask)
maskedWrite.needWAR := maskDestinationType
maskedWrite.vd := instReg.vd
maskedWrite.in.zipWithIndex.foreach { case (req, index) =>
req.valid := executeValid
req.bits.mask := cutUIntBySize(executeWriteByteMask, parameter.laneNumber)(index)
req.bits.data := cutUInt(executeResult, parameter.datapathWidth)(index)
req.bits.pipeData := exeReqReg(index).bits.source1
req.bits.bitMask := cutUInt(currentMaskGroupForDestination, parameter.datapathWidth)(index)
req.bits.groupCounter := executeDeqGroupCounter
req.bits.ffoByOther := compressUnit.out.ffoOutput(index) && ffo
if (index == 0) {
// reduce result
when(unitType(2)) {
Expand All @@ -958,7 +961,6 @@ class MaskUnit(parameter: T1Parameter) extends Module {
when(readTypeWriteVrf) {
queue.enq.bits := writeRequest(index)
}
queue.enq.bits.ffoByOther := compressUnit.out.ffoOutput(index)
queue.enq.bits.index := instReg.instructionIndex

// write token
Expand Down

0 comments on commit aca05e4

Please sign in to comment.