From 3455781035a0c8a553197c2b1a0351162fc1582f Mon Sep 17 00:00:00 2001 From: "Maksimova, Viktoria" Date: Wed, 26 Feb 2025 07:59:11 -0800 Subject: [PATCH] Revert "Align translation of `OpCooperativeMatrixLengthKHR` to match the spec" This reverts commit 45da76224aa9878841dcb64942999469687b19f1. And resolves #17079 --- llvm-spirv/lib/SPIRV/SPIRVReader.cpp | 3 ++- llvm-spirv/lib/SPIRV/SPIRVWriter.cpp | 4 ---- llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp | 11 ----------- llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h | 3 --- .../cooperative_matrix_checked.ll | 3 ++- .../cooperative_matrix_prefetch.ll | 3 ++- .../SPV_KHR_cooperative_matrix/cooperative_matrix.ll | 3 ++- 7 files changed, 8 insertions(+), 22 deletions(-) diff --git a/llvm-spirv/lib/SPIRV/SPIRVReader.cpp b/llvm-spirv/lib/SPIRV/SPIRVReader.cpp index ac95590d20800..3212e071ae75f 100644 --- a/llvm-spirv/lib/SPIRV/SPIRVReader.cpp +++ b/llvm-spirv/lib/SPIRV/SPIRVReader.cpp @@ -3627,7 +3627,8 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName, Func->addFnAttr(Attribute::Convergent); } CallInst *Call; - if (OC == OpCooperativeMatrixLengthKHR) { + if (OC == OpCooperativeMatrixLengthKHR && + Ops[0]->getOpCode() == OpTypeCooperativeMatrixKHR) { // OpCooperativeMatrixLengthKHR needs special handling as its operand is // a Type instead of a Value. llvm::Type *MatTy = transType(reinterpret_cast(Ops[0])); diff --git a/llvm-spirv/lib/SPIRV/SPIRVWriter.cpp b/llvm-spirv/lib/SPIRV/SPIRVWriter.cpp index 5a5e9992ec7e9..91e8e27b174cf 100644 --- a/llvm-spirv/lib/SPIRV/SPIRVWriter.cpp +++ b/llvm-spirv/lib/SPIRV/SPIRVWriter.cpp @@ -6787,10 +6787,6 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI, transValue(CI->getArgOperand(2), BB), BB); return BM->addStoreInst(transValue(CI->getArgOperand(0), BB), V, {}, BB); } - case OpCooperativeMatrixLengthKHR: { - return BM->addCooperativeMatrixLengthKHRInst( - transScavengedType(CI), transType(CI->getArgOperand(0)->getType()), BB); - } case OpGroupNonUniformShuffleDown: { Function *F = CI->getCalledFunction(); if (F->arg_size() && F->getArg(0)->hasStructRetAttr()) { diff --git a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp index 753d5c03633eb..f1b210ccd7fe0 100644 --- a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp +++ b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp @@ -279,9 +279,6 @@ class SPIRVModuleImpl : public SPIRVModule { SPIRVTypeTaskSequenceINTEL *addTaskSequenceINTELType() override; SPIRVInstruction *addTaskSequenceGetINTELInst(SPIRVType *, SPIRVValue *, SPIRVBasicBlock *) override; - SPIRVInstruction * - addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *, - SPIRVBasicBlock *) override; SPIRVType *addOpaqueGenericType(Op) override; SPIRVTypeDeviceEvent *addDeviceEventType() override; SPIRVTypeQueue *addQueueType() override; @@ -1097,14 +1094,6 @@ SPIRVInstruction *SPIRVModuleImpl::addTaskSequenceGetINTELInst( BB); } -SPIRVInstruction *SPIRVModuleImpl::addCooperativeMatrixLengthKHRInst( - SPIRVType *RetTy, SPIRVType *MatTy, SPIRVBasicBlock *BB) { - return addInstruction( - SPIRVInstTemplateBase::create(OpCooperativeMatrixLengthKHR, RetTy, - getId(), getVec(MatTy->getId()), BB, this), - BB); -} - SPIRVType *SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) { return addType(new SPIRVTypeOpaqueGeneric(TheOpCode, this, getId())); } diff --git a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h index 8b2b0462e223e..41932cceab2c8 100644 --- a/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h +++ b/llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h @@ -272,9 +272,6 @@ class SPIRVModule { virtual SPIRVTypeTaskSequenceINTEL *addTaskSequenceINTELType() = 0; virtual SPIRVInstruction * addTaskSequenceGetINTELInst(SPIRVType *, SPIRVValue *, SPIRVBasicBlock *) = 0; - virtual SPIRVInstruction * - addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *, - SPIRVBasicBlock *) = 0; virtual SPIRVTypeVoid *addVoidType() = 0; virtual SPIRVType *addOpaqueGenericType(Op) = 0; virtual SPIRVTypeDeviceEvent *addDeviceEventType() = 0; diff --git a/llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll b/llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll index 1cebf53c10620..1c43d2f30f713 100644 --- a/llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll +++ b/llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll @@ -32,7 +32,8 @@ ; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]] ; CHECK-SPIRV: CooperativeMatrixConstructCheckedINTEL [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy2]] [[#Load1:]] -; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]] +; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR. +; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]] ; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy3]] ; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixStoreCheckedINTEL diff --git a/llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_prefetch.ll b/llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_prefetch.ll index 480832d666eed..53f6a51a71656 100644 --- a/llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_prefetch.ll +++ b/llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_prefetch.ll @@ -32,7 +32,8 @@ ; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]] ; CHECK-SPIRV: CompositeConstruct [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy2]] [[#Load1:]] -; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]] +; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR. +; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]] ; CHECK-SPIRV: CooperativeMatrixPrefetchINTEL ; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]] ; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]] diff --git a/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/cooperative_matrix.ll b/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/cooperative_matrix.ll index 4e99ab4ccb392..71d7139ee7962 100644 --- a/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/cooperative_matrix.ll +++ b/llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/cooperative_matrix.ll @@ -30,7 +30,8 @@ ; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]] ; CHECK-SPIRV: CompositeConstruct [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy2]] [[#Load1:]] -; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]] +; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR. +; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]] ; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]] ; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]] ; CHECK-SPIRV: CooperativeMatrixStoreKHR