diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h index a65e4667ab76c..a808979f385a6 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -842,12 +842,20 @@ class TargetTransformInfo { LLVM_ABI AddressingModeKind getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const; + /// Some targets only support masked load/store with a constant mask. + enum MaskKind { + VariableOrConstantMask, + ConstantMask, + }; + /// Return true if the target supports masked store. - LLVM_ABI bool isLegalMaskedStore(Type *DataType, Align Alignment, - unsigned AddressSpace) const; + LLVM_ABI bool + isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace, + MaskKind MaskKind = VariableOrConstantMask) const; /// Return true if the target supports masked load. - LLVM_ABI bool isLegalMaskedLoad(Type *DataType, Align Alignment, - unsigned AddressSpace) const; + LLVM_ABI bool + isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace, + MaskKind MaskKind = VariableOrConstantMask) const; /// Return true if the target supports nontemporal store. LLVM_ABI bool isLegalNTStore(Type *DataType, Align Alignment) const; diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h index d8e35748f53e5..af295fc28022b 100644 --- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -309,12 +309,14 @@ class TargetTransformInfoImplBase { } virtual bool isLegalMaskedStore(Type *DataType, Align Alignment, - unsigned AddressSpace) const { + unsigned AddressSpace, + TTI::MaskKind MaskKind) const { return false; } virtual bool isLegalMaskedLoad(Type *DataType, Align Alignment, - unsigned AddressSpace) const { + unsigned AddressSpace, + TTI::MaskKind MaskKind) const { return false; } diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp index 45369f0ffe137..f9d330dfbd0ed 100644 --- a/llvm/lib/Analysis/TargetTransformInfo.cpp +++ b/llvm/lib/Analysis/TargetTransformInfo.cpp @@ -468,13 +468,17 @@ TargetTransformInfo::getPreferredAddressingMode(const Loop *L, } bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, Align Alignment, - unsigned AddressSpace) const { - return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace); + unsigned AddressSpace, + TTI::MaskKind MaskKind) const { + return TTIImpl->isLegalMaskedStore(DataType, Alignment, AddressSpace, + MaskKind); } bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, Align Alignment, - unsigned AddressSpace) const { - return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace); + unsigned AddressSpace, + TTI::MaskKind MaskKind) const { + return TTIImpl->isLegalMaskedLoad(DataType, Alignment, AddressSpace, + MaskKind); } bool TargetTransformInfo::isLegalNTStore(Type *DataType, diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp index 24a18e181ba80..4274e951446b8 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -2465,6 +2465,7 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD, SDValue PassThru = MLD->getPassThru(); Align Alignment = MLD->getBaseAlign(); ISD::LoadExtType ExtType = MLD->getExtensionType(); + MachineMemOperand::Flags MMOFlags = MLD->getMemOperand()->getFlags(); // Split Mask operand SDValue MaskLo, MaskHi; @@ -2490,9 +2491,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD, std::tie(PassThruLo, PassThruHi) = DAG.SplitVector(PassThru, dl); MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( - MLD->getPointerInfo(), MachineMemOperand::MOLoad, - LocationSize::beforeOrAfterPointer(), Alignment, MLD->getAAInfo(), - MLD->getRanges()); + MLD->getPointerInfo(), MMOFlags, LocationSize::beforeOrAfterPointer(), + Alignment, MLD->getAAInfo(), MLD->getRanges()); Lo = DAG.getMaskedLoad(LoVT, dl, Ch, Ptr, Offset, MaskLo, PassThruLo, LoMemVT, MMO, MLD->getAddressingMode(), ExtType, @@ -2515,8 +2515,8 @@ void DAGTypeLegalizer::SplitVecRes_MLOAD(MaskedLoadSDNode *MLD, LoMemVT.getStoreSize().getFixedValue()); MMO = DAG.getMachineFunction().getMachineMemOperand( - MPI, MachineMemOperand::MOLoad, LocationSize::beforeOrAfterPointer(), - Alignment, MLD->getAAInfo(), MLD->getRanges()); + MPI, MMOFlags, LocationSize::beforeOrAfterPointer(), Alignment, + MLD->getAAInfo(), MLD->getRanges()); Hi = DAG.getMaskedLoad(HiVT, dl, Ch, Ptr, Offset, MaskHi, PassThruHi, HiMemVT, MMO, MLD->getAddressingMode(), ExtType, diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 985a54ca83256..88b35582a9f7d 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -5063,6 +5063,8 @@ void SelectionDAGBuilder::visitMaskedLoad(const CallInst &I, bool IsExpanding) { auto MMOFlags = MachineMemOperand::MOLoad; if (I.hasMetadata(LLVMContext::MD_nontemporal)) MMOFlags |= MachineMemOperand::MONonTemporal; + if (I.hasMetadata(LLVMContext::MD_invariant_load)) + MMOFlags |= MachineMemOperand::MOInvariant; MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand( MachinePointerInfo(PtrOperand), MMOFlags, diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h index 6cc4987428567..52fc28a98449b 100644 --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -323,12 +323,14 @@ class AArch64TTIImpl final : public BasicTTIImplBase { } bool isLegalMaskedLoad(Type *DataType, Align Alignment, - unsigned /*AddressSpace*/) const override { + unsigned /*AddressSpace*/, + TTI::MaskKind /*MaskKind*/) const override { return isLegalMaskedLoadStore(DataType, Alignment); } bool isLegalMaskedStore(Type *DataType, Align Alignment, - unsigned /*AddressSpace*/) const override { + unsigned /*AddressSpace*/, + TTI::MaskKind /*MaskKind*/) const override { return isLegalMaskedLoadStore(DataType, Alignment); } diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp index d12b802fe234f..fdb0ec40cb41f 100644 --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -1125,7 +1125,8 @@ bool ARMTTIImpl::isProfitableLSRChainElement(Instruction *I) const { } bool ARMTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment, - unsigned /*AddressSpace*/) const { + unsigned /*AddressSpace*/, + TTI::MaskKind /*MaskKind*/) const { if (!EnableMaskedLoadStores || !ST->hasMVEIntegerOps()) return false; diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h index 919a6fc9fd0b0..30f2151b41239 100644 --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h @@ -186,12 +186,16 @@ class ARMTTIImpl final : public BasicTTIImplBase { bool isProfitableLSRChainElement(Instruction *I) const override; - bool isLegalMaskedLoad(Type *DataTy, Align Alignment, - unsigned AddressSpace) const override; - - bool isLegalMaskedStore(Type *DataTy, Align Alignment, - unsigned AddressSpace) const override { - return isLegalMaskedLoad(DataTy, Alignment, AddressSpace); + bool + isLegalMaskedLoad(Type *DataTy, Align Alignment, unsigned AddressSpace, + TTI::MaskKind MaskKind = + TTI::MaskKind::VariableOrConstantMask) const override; + + bool + isLegalMaskedStore(Type *DataTy, Align Alignment, unsigned AddressSpace, + TTI::MaskKind MaskKind = + TTI::MaskKind::VariableOrConstantMask) const override { + return isLegalMaskedLoad(DataTy, Alignment, AddressSpace, MaskKind); } bool forceScalarizeMaskedGather(VectorType *VTy, diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp index 8f3f0cc8abb01..3f84cbb6555ed 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -343,14 +343,16 @@ InstructionCost HexagonTTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, } bool HexagonTTIImpl::isLegalMaskedStore(Type *DataType, Align /*Alignment*/, - unsigned /*AddressSpace*/) const { + unsigned /*AddressSpace*/, + TTI::MaskKind /*MaskKind*/) const { // This function is called from scalarize-masked-mem-intrin, which runs // in pre-isel. Use ST directly instead of calling isHVXVectorType. return HexagonMaskedVMem && ST.isTypeForHVX(DataType); } bool HexagonTTIImpl::isLegalMaskedLoad(Type *DataType, Align /*Alignment*/, - unsigned /*AddressSpace*/) const { + unsigned /*AddressSpace*/, + TTI::MaskKind /*MaskKind*/) const { // This function is called from scalarize-masked-mem-intrin, which runs // in pre-isel. Use ST directly instead of calling isHVXVectorType. return HexagonMaskedVMem && ST.isTypeForHVX(DataType); diff --git a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h index e95b5a10b76a7..67388984bb3e3 100644 --- a/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ b/llvm/lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -165,9 +165,10 @@ class HexagonTTIImpl final : public BasicTTIImplBase { } bool isLegalMaskedStore(Type *DataType, Align Alignment, - unsigned AddressSpace) const override; - bool isLegalMaskedLoad(Type *DataType, Align Alignment, - unsigned AddressSpace) const override; + unsigned AddressSpace, + TTI::MaskKind MaskKind) const override; + bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace, + TTI::MaskKind MaskKind) const override; bool isLegalMaskedGather(Type *Ty, Align Alignment) const override; bool isLegalMaskedScatter(Type *Ty, Align Alignment) const override; bool forceScalarizeMaskedGather(VectorType *VTy, diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp index 77913f27838e2..5ff5fa36ac467 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp @@ -395,6 +395,25 @@ void NVPTXInstPrinter::printMemOperand(const MCInst *MI, int OpNum, } } +void NVPTXInstPrinter::printUsedBytesMaskPragma(const MCInst *MI, int OpNum, + raw_ostream &O) { + auto &Op = MI->getOperand(OpNum); + assert(Op.isImm() && "Invalid operand"); + uint32_t Imm = (uint32_t)Op.getImm(); + if (Imm != UINT32_MAX) { + O << ".pragma \"used_bytes_mask " << format_hex(Imm, 1) << "\";\n\t"; + } +} + +void NVPTXInstPrinter::printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, + raw_ostream &O) { + const MCOperand &Op = MI->getOperand(OpNum); + if (Op.isReg() && Op.getReg() == MCRegister::NoRegister) + O << "_"; + else + printOperand(MI, OpNum, O); +} + void NVPTXInstPrinter::printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O) { int64_t Imm = MI->getOperand(OpNum).getImm(); diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h index 92155b01464e8..3d172441adfcc 100644 --- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h +++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h @@ -46,6 +46,8 @@ class NVPTXInstPrinter : public MCInstPrinter { StringRef Modifier = {}); void printMemOperand(const MCInst *MI, int OpNum, raw_ostream &O, StringRef Modifier = {}); + void printUsedBytesMaskPragma(const MCInst *MI, int OpNum, raw_ostream &O); + void printRegisterOrSinkSymbol(const MCInst *MI, int OpNum, raw_ostream &O); void printHexu32imm(const MCInst *MI, int OpNum, raw_ostream &O); void printProtoIdent(const MCInst *MI, int OpNum, raw_ostream &O); void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O); diff --git a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp index a3496090def3c..c8b53571c1e59 100644 --- a/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXForwardParams.cpp @@ -96,7 +96,7 @@ static bool eliminateMove(MachineInstr &Mov, const MachineRegisterInfo &MRI, const MachineOperand *ParamSymbol = Mov.uses().begin(); assert(ParamSymbol->isSymbol()); - constexpr unsigned LDInstBasePtrOpIdx = 5; + constexpr unsigned LDInstBasePtrOpIdx = 6; constexpr unsigned LDInstAddrSpaceOpIdx = 2; for (auto *LI : LoadInsts) { (LI->uses().begin() + LDInstBasePtrOpIdx) diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp index 996d653940118..0e1125ab8d8b3 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp @@ -105,6 +105,7 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) { switch (N->getOpcode()) { case ISD::LOAD: case ISD::ATOMIC_LOAD: + case NVPTXISD::MLoad: if (tryLoad(N)) return; break; @@ -1132,6 +1133,19 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { ? NVPTX::PTXLdStInstCode::Signed : NVPTX::PTXLdStInstCode::Untyped; + uint32_t UsedBytesMask; + switch (N->getOpcode()) { + case ISD::LOAD: + case ISD::ATOMIC_LOAD: + UsedBytesMask = UINT32_MAX; + break; + case NVPTXISD::MLoad: + UsedBytesMask = N->getConstantOperandVal(3); + break; + default: + llvm_unreachable("Unexpected opcode"); + } + assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 && FromTypeWidth <= 128 && "Invalid width for load"); @@ -1142,6 +1156,7 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) { getI32Imm(CodeAddrSpace, DL), getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), + getI32Imm(UsedBytesMask, DL), Base, Offset, Chain}; @@ -1196,14 +1211,14 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { // type is integer // Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float // Read at least 8 bits (predicates are stored as 8-bit values) - // The last operand holds the original LoadSDNode::getExtensionType() value - const unsigned ExtensionType = - N->getConstantOperandVal(N->getNumOperands() - 1); + // Get the original LoadSDNode::getExtensionType() value + const unsigned ExtensionType = N->getConstantOperandVal(4); const unsigned FromType = (ExtensionType == ISD::SEXTLOAD) ? NVPTX::PTXLdStInstCode::Signed : NVPTX::PTXLdStInstCode::Untyped; const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD); + const uint32_t UsedBytesMask = N->getConstantOperandVal(3); assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD)); @@ -1213,6 +1228,7 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) { getI32Imm(CodeAddrSpace, DL), getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), + getI32Imm(UsedBytesMask, DL), Base, Offset, Chain}; @@ -1250,10 +1266,13 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) { SDLoc DL(LD); unsigned ExtensionType; + uint32_t UsedBytesMask; if (const auto *Load = dyn_cast(LD)) { ExtensionType = Load->getExtensionType(); + UsedBytesMask = UINT32_MAX; } else { - ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1); + ExtensionType = LD->getConstantOperandVal(4); + UsedBytesMask = LD->getConstantOperandVal(3); } const unsigned FromType = (ExtensionType == ISD::SEXTLOAD) ? NVPTX::PTXLdStInstCode::Signed @@ -1265,8 +1284,12 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) { ExtensionType != ISD::NON_EXTLOAD)); const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG); - SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base, - Offset, LD->getChain()}; + SDValue Ops[] = {getI32Imm(FromType, DL), + getI32Imm(FromTypeWidth, DL), + getI32Imm(UsedBytesMask, DL), + Base, + Offset, + LD->getChain()}; const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy; std::optional Opcode; @@ -1277,6 +1300,10 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) { Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_i16, NVPTX::LD_GLOBAL_NC_i32, NVPTX::LD_GLOBAL_NC_i64); break; + case NVPTXISD::MLoad: + Opcode = pickOpcodeForVT(TargetVT, std::nullopt, NVPTX::LD_GLOBAL_NC_i32, + NVPTX::LD_GLOBAL_NC_i64); + break; case NVPTXISD::LoadV2: Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_GLOBAL_NC_v2i16, diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp index a77eb0240e677..e9026cdf3d699 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp @@ -771,7 +771,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM, setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom); for (MVT VT : MVT::fixedlen_vector_valuetypes()) if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256) - setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom); + setOperationAction({ISD::STORE, ISD::LOAD, ISD::MSTORE, ISD::MLOAD}, VT, + Custom); // Custom legalization for LDU intrinsics. // TODO: The logic to lower these is not very robust and we should rewrite it. @@ -3092,6 +3093,86 @@ static SDValue lowerSELECT(SDValue Op, SelectionDAG &DAG) { return Or; } +static SDValue lowerMSTORE(SDValue Op, SelectionDAG &DAG) { + SDNode *N = Op.getNode(); + + SDValue Chain = N->getOperand(0); + SDValue Val = N->getOperand(1); + SDValue BasePtr = N->getOperand(2); + SDValue Offset = N->getOperand(3); + SDValue Mask = N->getOperand(4); + + SDLoc DL(N); + EVT ValVT = Val.getValueType(); + MemSDNode *MemSD = cast(N); + assert(ValVT.isVector() && "Masked vector store must have vector type"); + assert(MemSD->getAlign() >= DAG.getEVTAlign(ValVT) && + "Unexpected alignment for masked store"); + + unsigned Opcode = 0; + switch (ValVT.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unexpected masked vector store type"); + case MVT::v4i64: + case MVT::v4f64: { + Opcode = NVPTXISD::StoreV4; + break; + } + case MVT::v8i32: + case MVT::v8f32: { + Opcode = NVPTXISD::StoreV8; + break; + } + } + + SmallVector Ops; + + // Construct the new SDNode. First operand is the chain. + Ops.push_back(Chain); + + // The next N operands are the values to store. Encode the mask into the + // values using the sentinel register 0 to represent a masked-off element. + assert(Mask.getValueType().isVector() && + Mask.getValueType().getVectorElementType() == MVT::i1 && + "Mask must be a vector of i1"); + assert(Mask.getOpcode() == ISD::BUILD_VECTOR && + "Mask expected to be a BUILD_VECTOR"); + assert(Mask.getValueType().getVectorNumElements() == + ValVT.getVectorNumElements() && + "Mask size must be the same as the vector size"); + for (auto [I, Op] : enumerate(Mask->ops())) { + // Mask elements must be constants. + if (Op.getNode()->getAsZExtVal() == 0) { + // Append a sentinel register 0 to the Ops vector to represent a masked + // off element, this will be handled in tablegen + Ops.push_back(DAG.getRegister(MCRegister::NoRegister, + ValVT.getVectorElementType())); + } else { + // Extract the element from the vector to store + SDValue ExtVal = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ValVT.getVectorElementType(), + Val, DAG.getIntPtrConstant(I, DL)); + Ops.push_back(ExtVal); + } + } + + // Next, the pointer operand. + Ops.push_back(BasePtr); + + // Finally, the offset operand. We expect this to always be undef, and it will + // be ignored in lowering, but to mirror the handling of the other vector + // store instructions we include it in the new SDNode. + assert(Offset.getOpcode() == ISD::UNDEF && + "Offset operand expected to be undef"); + Ops.push_back(Offset); + + SDValue NewSt = + DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops, + MemSD->getMemoryVT(), MemSD->getMemOperand()); + + return NewSt; +} + SDValue NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { switch (Op.getOpcode()) { @@ -3128,8 +3209,16 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { return LowerVECREDUCE(Op, DAG); case ISD::STORE: return LowerSTORE(Op, DAG); + case ISD::MSTORE: { + assert(STI.has256BitVectorLoadStore( + cast(Op.getNode())->getAddressSpace()) && + "Masked store vector not supported on subtarget."); + return lowerMSTORE(Op, DAG); + } case ISD::LOAD: return LowerLOAD(Op, DAG); + case ISD::MLOAD: + return LowerMLOAD(Op, DAG); case ISD::SHL_PARTS: return LowerShiftLeftParts(Op, DAG); case ISD::SRA_PARTS: @@ -3321,10 +3410,56 @@ SDValue NVPTXTargetLowering::LowerVASTART(SDValue Op, SelectionDAG &DAG) const { MachinePointerInfo(SV)); } +static std::pair +convertMLOADToLoadWithUsedBytesMask(MemSDNode *N, SelectionDAG &DAG) { + SDValue Chain = N->getOperand(0); + SDValue BasePtr = N->getOperand(1); + SDValue Mask = N->getOperand(3); + SDValue Passthru = N->getOperand(4); + + SDLoc DL(N); + EVT ResVT = N->getValueType(0); + assert(ResVT.isVector() && "Masked vector load must have vector type"); + // While we only expect poison passthru vectors as an input to the backend, + // when the legalization framework splits a poison vector in half, it creates + // two undef vectors, so we can technically expect those too. + assert((Passthru.getOpcode() == ISD::POISON || + Passthru.getOpcode() == ISD::UNDEF) && + "Passthru operand expected to be poison or undef"); + + // Extract the mask and convert it to a uint32_t representing the used bytes + // of the entire vector load + uint32_t UsedBytesMask = 0; + uint32_t ElementSizeInBits = ResVT.getVectorElementType().getSizeInBits(); + assert(ElementSizeInBits % 8 == 0 && "Unexpected element size"); + uint32_t ElementSizeInBytes = ElementSizeInBits / 8; + uint32_t ElementMask = (1u << ElementSizeInBytes) - 1u; + + for (SDValue Op : reverse(Mask->ops())) { + // We technically only want to do this shift for every + // iteration *but* the first, but in the first iteration UsedBytesMask is 0, + // so this shift is a no-op. + UsedBytesMask <<= ElementSizeInBytes; + + // Mask elements must be constants. + if (Op->getAsZExtVal() != 0) + UsedBytesMask |= ElementMask; + } + + assert(UsedBytesMask != 0 && UsedBytesMask != UINT32_MAX && + "Unexpected masked load with elements masked all on or all off"); + + // Create a new load sd node to be handled normally by ReplaceLoadVector. + MemSDNode *NewLD = cast( + DAG.getLoad(ResVT, DL, Chain, BasePtr, N->getMemOperand()).getNode()); + + return {NewLD, UsedBytesMask}; +} + /// replaceLoadVector - Convert vector loads into multi-output scalar loads. static std::optional> replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) { - LoadSDNode *LD = cast(N); + MemSDNode *LD = cast(N); const EVT ResVT = LD->getValueType(0); const EVT MemVT = LD->getMemoryVT(); @@ -3351,6 +3486,11 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) { return std::nullopt; } + // If we have a masked load, convert it to a normal load now + std::optional UsedBytesMask = std::nullopt; + if (LD->getOpcode() == ISD::MLOAD) + std::tie(LD, UsedBytesMask) = convertMLOADToLoadWithUsedBytesMask(LD, DAG); + // Since LoadV2 is a target node, we cannot rely on DAG type legalization. // Therefore, we must ensure the type is legal. For i1 and i8, we set the // loaded type to i16 and propagate the "real" type as the memory type. @@ -3379,9 +3519,13 @@ replaceLoadVector(SDNode *N, SelectionDAG &DAG, const NVPTXSubtarget &STI) { // Copy regular operands SmallVector OtherOps(LD->ops()); + OtherOps.push_back( + DAG.getConstant(UsedBytesMask.value_or(UINT32_MAX), DL, MVT::i32)); + // The select routine does not have access to the LoadSDNode instance, so // pass along the extension information - OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL)); + OtherOps.push_back( + DAG.getIntPtrConstant(cast(LD)->getExtensionType(), DL)); SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT, LD->getMemOperand()); @@ -3469,6 +3613,42 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const { llvm_unreachable("Unexpected custom lowering for load"); } +SDValue NVPTXTargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const { + // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle + // masked loads of these types and have to handle them here. + // v2f32 also needs to be handled here if the subtarget has f32x2 + // instructions, making it legal. + // + // Note: misaligned masked loads should never reach this point + // because the override of isLegalMaskedLoad in NVPTXTargetTransformInfo.cpp + // will validate alignment. Therefore, we do not need to special case handle + // them here. + EVT VT = Op.getValueType(); + if (NVPTX::isPackedVectorTy(VT)) { + auto Result = + convertMLOADToLoadWithUsedBytesMask(cast(Op.getNode()), DAG); + MemSDNode *LD = std::get<0>(Result); + uint32_t UsedBytesMask = std::get<1>(Result); + + SDLoc DL(LD); + + // Copy regular operands + SmallVector OtherOps(LD->ops()); + + OtherOps.push_back(DAG.getConstant(UsedBytesMask, DL, MVT::i32)); + + // We currently are not lowering extending loads, but pass the extension + // type anyway as later handling expects it. + OtherOps.push_back( + DAG.getIntPtrConstant(cast(LD)->getExtensionType(), DL)); + SDValue NewLD = + DAG.getMemIntrinsicNode(NVPTXISD::MLoad, DL, LD->getVTList(), OtherOps, + LD->getMemoryVT(), LD->getMemOperand()); + return NewLD; + } + return SDValue(); +} + static SDValue lowerSTOREVector(SDValue Op, SelectionDAG &DAG, const NVPTXSubtarget &STI) { MemSDNode *N = cast(Op.getNode()); @@ -5377,6 +5557,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { // ISD::LOAD -> NVPTXISD::Load (unless it's under-aligned). We have to do it // here. Opcode = NVPTXISD::LoadV2; + // append a "full" used bytes mask operand right before the extension type + // operand, signifying that all bytes are used. + Operands.push_back(DCI.DAG.getConstant(UINT32_MAX, DL, MVT::i32)); Operands.push_back(DCI.DAG.getIntPtrConstant( cast(LD)->getExtensionType(), DL)); break; @@ -5385,9 +5568,9 @@ combineUnpackingMovIntoLoad(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) { Opcode = NVPTXISD::LoadV4; break; case NVPTXISD::LoadV4: - // V8 is only supported for f32. Don't forget, we're not changing the load - // size here. This is already a 256-bit load. - if (ElementVT != MVT::v2f32) + // V8 is only supported for f32/i32. Don't forget, we're not changing the + // load size here. This is already a 256-bit load. + if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32) return SDValue(); OldNumOutputs = 4; Opcode = NVPTXISD::LoadV8; @@ -5462,9 +5645,9 @@ static SDValue combinePackingMovIntoStore(SDNode *N, Opcode = NVPTXISD::StoreV4; break; case NVPTXISD::StoreV4: - // V8 is only supported for f32. Don't forget, we're not changing the store - // size here. This is already a 256-bit store. - if (ElementVT != MVT::v2f32) + // V8 is only supported for f32/i32. Don't forget, we're not changing the + // store size here. This is already a 256-bit store. + if (ElementVT != MVT::v2f32 && ElementVT != MVT::v2i32) return SDValue(); Opcode = NVPTXISD::StoreV8; break; @@ -6615,6 +6798,7 @@ void NVPTXTargetLowering::ReplaceNodeResults( ReplaceBITCAST(N, DAG, Results); return; case ISD::LOAD: + case ISD::MLOAD: replaceLoadVector(N, DAG, Results, STI); return; case ISD::INTRINSIC_W_CHAIN: diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h index d71a86fd463f6..dd8e49de7aa6a 100644 --- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.h +++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.h @@ -235,6 +235,7 @@ class NVPTXTargetLowering : public TargetLowering { SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const; SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerMLOAD(SDValue Op, SelectionDAG &DAG) const; SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const; SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 8b129e7e5eeae..77fdf6911a420 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1588,6 +1588,14 @@ def ADDR : Operand { let MIOperandInfo = (ops ADDR_base, i32imm); } +def UsedBytesMask : Operand { + let PrintMethod = "printUsedBytesMaskPragma"; +} + +def RegOrSink : Operand { + let PrintMethod = "printRegisterOrSinkSymbol"; +} + def AtomicCode : Operand { let PrintMethod = "printAtomicCode"; } @@ -1832,8 +1840,10 @@ def Callseq_End : class LD : NVPTXInst< (outs regclass:$dst), - (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, AtomicCode:$Sign, - i32imm:$fromWidth, ADDR:$addr), + (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, + AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes, + ADDR:$addr), + "${usedBytes}" "ld${sem:sem}${scope:scope}${addsp:addsp}.${Sign:sign}$fromWidth " "\t$dst, [$addr];">; @@ -1865,21 +1875,27 @@ multiclass LD_VEC { def _v2 : NVPTXInst< (outs regclass:$dst1, regclass:$dst2), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, - AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$addr), + AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes, + ADDR:$addr), + "${usedBytes}" "ld${sem:sem}${scope:scope}${addsp:addsp}.v2.${Sign:sign}$fromWidth " "\t{{$dst1, $dst2}}, [$addr];">; def _v4 : NVPTXInst< (outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4), (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, - AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$addr), + AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes, + ADDR:$addr), + "${usedBytes}" "ld${sem:sem}${scope:scope}${addsp:addsp}.v4.${Sign:sign}$fromWidth " "\t{{$dst1, $dst2, $dst3, $dst4}}, [$addr];">; if support_v8 then def _v8 : NVPTXInst< (outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4, regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8), - (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, AtomicCode:$Sign, - i32imm:$fromWidth, ADDR:$addr), + (ins AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, + AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes, + ADDR:$addr), + "${usedBytes}" "ld${sem:sem}${scope:scope}${addsp:addsp}.v8.${Sign:sign}$fromWidth " "\t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, " "[$addr];">; @@ -1900,7 +1916,7 @@ multiclass ST_VEC { "\t[$addr], {{$src1, $src2}};">; def _v4 : NVPTXInst< (outs), - (ins O:$src1, O:$src2, O:$src3, O:$src4, + (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4, AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth, ADDR:$addr), "st${sem:sem}${scope:scope}${addsp:addsp}.v4.b$fromWidth " @@ -1908,8 +1924,8 @@ multiclass ST_VEC { if support_v8 then def _v8 : NVPTXInst< (outs), - (ins O:$src1, O:$src2, O:$src3, O:$src4, - O:$src5, O:$src6, O:$src7, O:$src8, + (ins RegOrSink:$src1, RegOrSink:$src2, RegOrSink:$src3, RegOrSink:$src4, + RegOrSink:$src5, RegOrSink:$src6, RegOrSink:$src7, RegOrSink:$src8, AtomicCode:$sem, AtomicCode:$scope, AtomicCode:$addsp, i32imm:$fromWidth, ADDR:$addr), "st${sem:sem}${scope:scope}${addsp:addsp}.v8.b$fromWidth " diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 8501d4d7bb86f..d18c7e20df038 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -2552,7 +2552,10 @@ def LDU_GLOBAL_v4i32 : VLDU_G_ELE_V4; // during the lifetime of the kernel. class LDG_G - : NVPTXInst<(outs regclass:$result), (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src), + : NVPTXInst<(outs regclass:$result), + (ins AtomicCode:$Sign, i32imm:$fromWidth, + UsedBytesMask:$usedBytes, ADDR:$src), + "${usedBytes}" "ld.global.nc.${Sign:sign}$fromWidth \t$result, [$src];">; def LD_GLOBAL_NC_i16 : LDG_G; @@ -2564,19 +2567,25 @@ def LD_GLOBAL_NC_i64 : LDG_G; // Elementized vector ldg class VLDG_G_ELE_V2 : NVPTXInst<(outs regclass:$dst1, regclass:$dst2), - (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src), + (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes, + ADDR:$src), + "${usedBytes}" "ld.global.nc.v2.${Sign:sign}$fromWidth \t{{$dst1, $dst2}}, [$src];">; class VLDG_G_ELE_V4 : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4), - (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src), + (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes, + ADDR:$src), + "${usedBytes}" "ld.global.nc.v4.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4}}, [$src];">; class VLDG_G_ELE_V8 : NVPTXInst<(outs regclass:$dst1, regclass:$dst2, regclass:$dst3, regclass:$dst4, regclass:$dst5, regclass:$dst6, regclass:$dst7, regclass:$dst8), - (ins AtomicCode:$Sign, i32imm:$fromWidth, ADDR:$src), + (ins AtomicCode:$Sign, i32imm:$fromWidth, UsedBytesMask:$usedBytes, + ADDR:$src), + "${usedBytes}" "ld.global.nc.v8.${Sign:sign}$fromWidth \t{{$dst1, $dst2, $dst3, $dst4, $dst5, $dst6, $dst7, $dst8}}, [$src];">; // FIXME: 8-bit LDG should be fixed once LDG/LDU nodes are made into proper loads. diff --git a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp index 320c0fb6950a7..4bbf49f93f43b 100644 --- a/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp @@ -1808,8 +1808,8 @@ bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op, // For CUDA, we preserve the param loads coming from function arguments return false; - assert(TexHandleDef.getOperand(6).isSymbol() && "Load is not a symbol!"); - StringRef Sym = TexHandleDef.getOperand(6).getSymbolName(); + assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!"); + StringRef Sym = TexHandleDef.getOperand(7).getSymbolName(); InstrsToRemove.insert(&TexHandleDef); Op.ChangeToES(Sym.data()); MFI->getImageHandleSymbolIndex(Sym); diff --git a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp index e8ea1ad6c404d..710d063e75725 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.cpp @@ -30,6 +30,7 @@ const char *NVPTXSelectionDAGInfo::getTargetNodeName(unsigned Opcode) const { MAKE_CASE(NVPTXISD::LoadV2) MAKE_CASE(NVPTXISD::LoadV4) MAKE_CASE(NVPTXISD::LoadV8) + MAKE_CASE(NVPTXISD::MLoad) MAKE_CASE(NVPTXISD::LDUV2) MAKE_CASE(NVPTXISD::LDUV4) MAKE_CASE(NVPTXISD::StoreV2) diff --git a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h index 07c130baeaa4f..9dd0a1eaa5856 100644 --- a/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h +++ b/llvm/lib/Target/NVPTX/NVPTXSelectionDAGInfo.h @@ -36,6 +36,7 @@ enum NodeType : unsigned { LoadV2, LoadV4, LoadV8, + MLoad, LDUV2, // LDU.v2 LDUV4, // LDU.v4 StoreV2, diff --git a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp index a4aff44ac04f6..f1774a7c5572e 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTagInvariantLoads.cpp @@ -27,13 +27,14 @@ using namespace llvm; -static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) { +static bool isInvariantLoad(const Instruction *I, const Value *Ptr, + const bool IsKernelFn) { // Don't bother with non-global loads - if (LI->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL) + if (Ptr->getType()->getPointerAddressSpace() != NVPTXAS::ADDRESS_SPACE_GLOBAL) return false; // If the load is already marked as invariant, we don't need to do anything - if (LI->getMetadata(LLVMContext::MD_invariant_load)) + if (I->getMetadata(LLVMContext::MD_invariant_load)) return false; // We use getUnderlyingObjects() here instead of getUnderlyingObject() @@ -41,7 +42,7 @@ static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) { // not. We need to look through phi nodes to handle pointer induction // variables. SmallVector Objs; - getUnderlyingObjects(LI->getPointerOperand(), Objs); + getUnderlyingObjects(Ptr, Objs); return all_of(Objs, [&](const Value *V) { if (const auto *A = dyn_cast(V)) @@ -53,9 +54,9 @@ static bool isInvariantLoad(const LoadInst *LI, const bool IsKernelFn) { }); } -static void markLoadsAsInvariant(LoadInst *LI) { - LI->setMetadata(LLVMContext::MD_invariant_load, - MDNode::get(LI->getContext(), {})); +static void markLoadsAsInvariant(Instruction *I) { + I->setMetadata(LLVMContext::MD_invariant_load, + MDNode::get(I->getContext(), {})); } static bool tagInvariantLoads(Function &F) { @@ -63,12 +64,17 @@ static bool tagInvariantLoads(Function &F) { bool Changed = false; for (auto &I : instructions(F)) { - if (auto *LI = dyn_cast(&I)) { - if (isInvariantLoad(LI, IsKernelFn)) { + if (auto *LI = dyn_cast(&I)) + if (isInvariantLoad(LI, LI->getPointerOperand(), IsKernelFn)) { markLoadsAsInvariant(LI); Changed = true; } - } + if (auto *II = dyn_cast(&I)) + if (II->getIntrinsicID() == Intrinsic::masked_load && + isInvariantLoad(II, II->getOperand(0), IsKernelFn)) { + markLoadsAsInvariant(II); + Changed = true; + } } return Changed; } diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp index 64593e6439184..5d5553c573b0f 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp @@ -592,6 +592,45 @@ Value *NVPTXTTIImpl::rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, return nullptr; } +bool NVPTXTTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment, + unsigned AddrSpace, + TTI::MaskKind MaskKind) const { + if (MaskKind != TTI::MaskKind::ConstantMask) + return false; + + // We currently only support this feature for 256-bit vectors, so the + // alignment must be at least 32 + if (Alignment < 32) + return false; + + if (!ST->has256BitVectorLoadStore(AddrSpace)) + return false; + + auto *VTy = dyn_cast(DataTy); + if (!VTy) + return false; + + auto *ElemTy = VTy->getScalarType(); + return (ElemTy->getScalarSizeInBits() == 32 && VTy->getNumElements() == 8) || + (ElemTy->getScalarSizeInBits() == 64 && VTy->getNumElements() == 4); +} + +bool NVPTXTTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment, + unsigned /*AddrSpace*/, + TTI::MaskKind MaskKind) const { + if (MaskKind != TTI::MaskKind::ConstantMask) + return false; + + if (Alignment < DL.getTypeStoreSize(DataTy)) + return false; + + // We do not support sub-byte element type masked loads. + auto *VTy = dyn_cast(DataTy); + if (!VTy) + return false; + return VTy->getElementType()->getScalarSizeInBits() >= 8; +} + unsigned NVPTXTTIImpl::getLoadStoreVecRegBitWidth(unsigned AddrSpace) const { // 256 bit loads/stores are currently only supported for global address space if (ST->has256BitVectorLoadStore(AddrSpace)) diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h index 78eb751cf3c2e..d7f4e1da4073b 100644 --- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h +++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h @@ -181,6 +181,12 @@ class NVPTXTTIImpl final : public BasicTTIImplBase { bool collectFlatAddressOperands(SmallVectorImpl &OpIndexes, Intrinsic::ID IID) const override; + bool isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddrSpace, + TTI::MaskKind MaskKind) const override; + + bool isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddrSpace, + TTI::MaskKind MaskKind) const override; + unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const override; Value *rewriteIntrinsicWithAddressSpace(IntrinsicInst *II, Value *OldV, diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h index 39c1173e2986c..484c4791390ac 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h @@ -285,11 +285,13 @@ class RISCVTTIImpl final : public BasicTTIImplBase { } bool isLegalMaskedLoad(Type *DataType, Align Alignment, - unsigned /*AddressSpace*/) const override { + unsigned /*AddressSpace*/, + TTI::MaskKind /*MaskKind*/) const override { return isLegalMaskedLoadStore(DataType, Alignment); } bool isLegalMaskedStore(Type *DataType, Align Alignment, - unsigned /*AddressSpace*/) const override { + unsigned /*AddressSpace*/, + TTI::MaskKind /*MaskKind*/) const override { return isLegalMaskedLoadStore(DataType, Alignment); } diff --git a/llvm/lib/Target/VE/VETargetTransformInfo.h b/llvm/lib/Target/VE/VETargetTransformInfo.h index 5c0ddca62c761..eed3832c9f1fb 100644 --- a/llvm/lib/Target/VE/VETargetTransformInfo.h +++ b/llvm/lib/Target/VE/VETargetTransformInfo.h @@ -134,12 +134,14 @@ class VETTIImpl final : public BasicTTIImplBase { } // Load & Store { - bool isLegalMaskedLoad(Type *DataType, Align Alignment, - unsigned /*AddressSpace*/) const override { + bool + isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned /*AddressSpace*/, + TargetTransformInfo::MaskKind /*MaskKind*/) const override { return isVectorLaneType(*getLaneType(DataType)); } - bool isLegalMaskedStore(Type *DataType, Align Alignment, - unsigned /*AddressSpace*/) const override { + bool isLegalMaskedStore( + Type *DataType, Align Alignment, unsigned /*AddressSpace*/, + TargetTransformInfo::MaskKind /*MaskKind*/) const override { return isVectorLaneType(*getLaneType(DataType)); } bool isLegalMaskedGather(Type *DataType, Align Alignment) const override { diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp index 4b77bf925b2ba..10a6b654a037d 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.cpp +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.cpp @@ -6322,7 +6322,8 @@ static bool isLegalMaskedLoadStore(Type *ScalarTy, const X86Subtarget *ST) { } bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment, - unsigned AddressSpace) const { + unsigned AddressSpace, + TTI::MaskKind MaskKind) const { Type *ScalarTy = DataTy->getScalarType(); // The backend can't handle a single element vector w/o CFCMOV. @@ -6335,7 +6336,8 @@ bool X86TTIImpl::isLegalMaskedLoad(Type *DataTy, Align Alignment, } bool X86TTIImpl::isLegalMaskedStore(Type *DataTy, Align Alignment, - unsigned AddressSpace) const { + unsigned AddressSpace, + TTI::MaskKind MaskKind) const { Type *ScalarTy = DataTy->getScalarType(); // The backend can't handle a single element vector w/o CFCMOV. diff --git a/llvm/lib/Target/X86/X86TargetTransformInfo.h b/llvm/lib/Target/X86/X86TargetTransformInfo.h index df1393ce16ca1..9b326723ae385 100644 --- a/llvm/lib/Target/X86/X86TargetTransformInfo.h +++ b/llvm/lib/Target/X86/X86TargetTransformInfo.h @@ -267,10 +267,14 @@ class X86TTIImpl final : public BasicTTIImplBase { bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1, const TargetTransformInfo::LSRCost &C2) const override; bool canMacroFuseCmp() const override; - bool isLegalMaskedLoad(Type *DataType, Align Alignment, - unsigned AddressSpace) const override; - bool isLegalMaskedStore(Type *DataType, Align Alignment, - unsigned AddressSpace) const override; + bool + isLegalMaskedLoad(Type *DataType, Align Alignment, unsigned AddressSpace, + TTI::MaskKind MaskKind = + TTI::MaskKind::VariableOrConstantMask) const override; + bool + isLegalMaskedStore(Type *DataType, Align Alignment, unsigned AddressSpace, + TTI::MaskKind MaskKind = + TTI::MaskKind::VariableOrConstantMask) const override; bool isLegalNTLoad(Type *DataType, Align Alignment) const override; bool isLegalNTStore(Type *DataType, Align Alignment) const override; bool isLegalBroadcastLoad(Type *ElementTy, diff --git a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp index 146e7d1047dd0..b7b08ae61ec52 100644 --- a/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp +++ b/llvm/lib/Transforms/Scalar/ScalarizeMaskedMemIntrin.cpp @@ -1123,7 +1123,10 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, if (TTI.isLegalMaskedLoad( CI->getType(), CI->getParamAlign(0).valueOrOne(), cast(CI->getArgOperand(0)->getType()) - ->getAddressSpace())) + ->getAddressSpace(), + isConstantIntVector(CI->getArgOperand(1)) + ? TTI::MaskKind::ConstantMask + : TTI::MaskKind::VariableOrConstantMask)) return false; scalarizeMaskedLoad(DL, HasBranchDivergence, CI, DTU, ModifiedDT); return true; @@ -1132,7 +1135,10 @@ static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, CI->getArgOperand(0)->getType(), CI->getParamAlign(1).valueOrOne(), cast(CI->getArgOperand(1)->getType()) - ->getAddressSpace())) + ->getAddressSpace(), + isConstantIntVector(CI->getArgOperand(2)) + ? TTI::MaskKind::ConstantMask + : TTI::MaskKind::VariableOrConstantMask)) return false; scalarizeMaskedStore(DL, HasBranchDivergence, CI, DTU, ModifiedDT); return true; diff --git a/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir b/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir index e3b072549bc04..3158916a3195c 100644 --- a/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir +++ b/llvm/test/CodeGen/MIR/NVPTX/floating-point-immediate-operands.mir @@ -40,9 +40,9 @@ registers: - { id: 7, class: b32 } body: | bb.0.entry: - %0 = LD_i32 0, 0, 4, 2, 32, &test_param_0, 0 + %0 = LD_i32 0, 0, 4, 2, 32, -1, &test_param_0, 0 %1 = CVT_f64_f32 %0, 0 - %2 = LD_i32 0, 0, 4, 0, 32, &test_param_1, 0 + %2 = LD_i32 0, 0, 4, 0, 32, -1, &test_param_1, 0 ; CHECK: %3:b64 = FADD_rnf64ri %1, double 3.250000e+00 %3 = FADD_rnf64ri %1, double 3.250000e+00 %4 = CVT_f32_f64 %3, 5 @@ -66,9 +66,9 @@ registers: - { id: 7, class: b32 } body: | bb.0.entry: - %0 = LD_i32 0, 0, 4, 2, 32, &test2_param_0, 0 + %0 = LD_i32 0, 0, 4, 2, 32, -1, &test2_param_0, 0 %1 = CVT_f64_f32 %0, 0 - %2 = LD_i32 0, 0, 4, 0, 32, &test2_param_1, 0 + %2 = LD_i32 0, 0, 4, 0, 32, -1, &test2_param_1, 0 ; CHECK: %3:b64 = FADD_rnf64ri %1, double 0x7FF8000000000000 %3 = FADD_rnf64ri %1, double 0x7FF8000000000000 %4 = CVT_f32_f64 %3, 5 diff --git a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll index 3fac29f74125b..d219493d2b31b 100644 --- a/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll +++ b/llvm/test/CodeGen/NVPTX/ldg-invariant-256.ll @@ -346,19 +346,15 @@ define i32 @ld_global_v8i32(ptr addrspace(1) %ptr) { ; SM100-LABEL: ld_global_v8i32( ; SM100: { ; SM100-NEXT: .reg .b32 %r<16>; -; SM100-NEXT: .reg .b64 %rd<6>; +; SM100-NEXT: .reg .b64 %rd<2>; ; SM100-EMPTY: ; SM100-NEXT: // %bb.0: ; SM100-NEXT: ld.param.b64 %rd1, [ld_global_v8i32_param_0]; -; SM100-NEXT: ld.global.nc.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1]; -; SM100-NEXT: mov.b64 {%r1, %r2}, %rd5; -; SM100-NEXT: mov.b64 {%r3, %r4}, %rd4; -; SM100-NEXT: mov.b64 {%r5, %r6}, %rd3; -; SM100-NEXT: mov.b64 {%r7, %r8}, %rd2; -; SM100-NEXT: add.s32 %r9, %r7, %r8; -; SM100-NEXT: add.s32 %r10, %r5, %r6; -; SM100-NEXT: add.s32 %r11, %r3, %r4; -; SM100-NEXT: add.s32 %r12, %r1, %r2; +; SM100-NEXT: ld.global.nc.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1]; +; SM100-NEXT: add.s32 %r9, %r1, %r2; +; SM100-NEXT: add.s32 %r10, %r3, %r4; +; SM100-NEXT: add.s32 %r11, %r5, %r6; +; SM100-NEXT: add.s32 %r12, %r7, %r8; ; SM100-NEXT: add.s32 %r13, %r9, %r10; ; SM100-NEXT: add.s32 %r14, %r11, %r12; ; SM100-NEXT: add.s32 %r15, %r13, %r14; diff --git a/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir b/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir index 0b2d85600a2ef..4be91dfc60c6a 100644 --- a/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir +++ b/llvm/test/CodeGen/NVPTX/machinelicm-no-preheader.mir @@ -26,10 +26,10 @@ body: | ; CHECK: bb.0.entry: ; CHECK-NEXT: successors: %bb.2(0x30000000), %bb.3(0x50000000) ; CHECK-NEXT: {{ $}} - ; CHECK-NEXT: [[LD_i32_:%[0-9]+]]:b32 = LD_i32 0, 0, 101, 3, 32, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101) - ; CHECK-NEXT: [[LD_i64_:%[0-9]+]]:b64 = LD_i64 0, 0, 101, 3, 64, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101) + ; CHECK-NEXT: [[LD_i32_:%[0-9]+]]:b32 = LD_i32 0, 0, 101, 3, 32, -1, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101) + ; CHECK-NEXT: [[LD_i64_:%[0-9]+]]:b64 = LD_i64 0, 0, 101, 3, 64, -1, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101) ; CHECK-NEXT: [[ADD64ri:%[0-9]+]]:b64 = nuw ADD64ri killed [[LD_i64_]], 2 - ; CHECK-NEXT: [[LD_i32_1:%[0-9]+]]:b32 = LD_i32 0, 0, 1, 3, 32, [[ADD64ri]], 0 + ; CHECK-NEXT: [[LD_i32_1:%[0-9]+]]:b32 = LD_i32 0, 0, 1, 3, 32, -1, [[ADD64ri]], 0 ; CHECK-NEXT: [[SETP_i32ri:%[0-9]+]]:b1 = SETP_i32ri [[LD_i32_]], 0, 0 ; CHECK-NEXT: CBranch killed [[SETP_i32ri]], %bb.2 ; CHECK-NEXT: {{ $}} @@ -54,10 +54,10 @@ body: | bb.0.entry: successors: %bb.2(0x30000000), %bb.1(0x50000000) - %5:b32 = LD_i32 0, 0, 101, 3, 32, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101) - %6:b64 = LD_i64 0, 0, 101, 3, 64, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101) + %5:b32 = LD_i32 0, 0, 101, 3, 32, -1, &test_hoist_param_1, 0 :: (dereferenceable invariant load (s32), addrspace 101) + %6:b64 = LD_i64 0, 0, 101, 3, 64, -1, &test_hoist_param_0, 0 :: (dereferenceable invariant load (s64), addrspace 101) %0:b64 = nuw ADD64ri killed %6, 2 - %1:b32 = LD_i32 0, 0, 1, 3, 32, %0, 0 + %1:b32 = LD_i32 0, 0, 1, 3, 32, -1, %0, 0 %7:b1 = SETP_i32ri %5, 0, 0 CBranch killed %7, %bb.2 GOTO %bb.1 diff --git a/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll new file mode 100644 index 0000000000000..8617dea310d6c --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/masked-load-vectors.ll @@ -0,0 +1,366 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90 +; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %} +; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100 +; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %} + + +; Different architectures are tested in this file for the following reasons: +; - SM90 does not have 256-bit load/store instructions +; - SM90 does not have masked store instructions +; - SM90 does not support packed f32x2 instructions + +define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; SM90-LABEL: global_8xi32( +; SM90: { +; SM90-NEXT: .reg .b32 %r<9>; +; SM90-NEXT: .reg .b64 %rd<3>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0]; +; SM90-NEXT: .pragma "used_bytes_mask 0xf000"; +; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16]; +; SM90-NEXT: .pragma "used_bytes_mask 0xf0f"; +; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1]; +; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1]; +; SM90-NEXT: st.global.b32 [%rd2], %r5; +; SM90-NEXT: st.global.b32 [%rd2+8], %r7; +; SM90-NEXT: st.global.b32 [%rd2+28], %r4; +; SM90-NEXT: ret; +; +; SM100-LABEL: global_8xi32( +; SM100: { +; SM100-NEXT: .reg .b32 %r<9>; +; SM100-NEXT: .reg .b64 %rd<3>; +; SM100-EMPTY: +; SM100-NEXT: // %bb.0: +; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0]; +; SM100-NEXT: .pragma "used_bytes_mask 0xf0000f0f"; +; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1]; +; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1]; +; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8}; +; SM100-NEXT: ret; + %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 32 %a, <8 x i1> , <8 x i32> poison) + tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> ) + ret void +} + +; Masked stores are only supported for 32-bit element types, +; while masked loads are supported for all element types. +define void @global_16xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; SM90-LABEL: global_16xi16( +; SM90: { +; SM90-NEXT: .reg .b16 %rs<7>; +; SM90-NEXT: .reg .b32 %r<9>; +; SM90-NEXT: .reg .b64 %rd<3>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b64 %rd1, [global_16xi16_param_0]; +; SM90-NEXT: .pragma "used_bytes_mask 0xf000"; +; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16]; +; SM90-NEXT: mov.b32 {%rs1, %rs2}, %r4; +; SM90-NEXT: .pragma "used_bytes_mask 0xf0f"; +; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1]; +; SM90-NEXT: mov.b32 {%rs3, %rs4}, %r7; +; SM90-NEXT: mov.b32 {%rs5, %rs6}, %r5; +; SM90-NEXT: ld.param.b64 %rd2, [global_16xi16_param_1]; +; SM90-NEXT: st.global.b16 [%rd2], %rs5; +; SM90-NEXT: st.global.b16 [%rd2+2], %rs6; +; SM90-NEXT: st.global.b16 [%rd2+8], %rs3; +; SM90-NEXT: st.global.b16 [%rd2+10], %rs4; +; SM90-NEXT: st.global.b16 [%rd2+28], %rs1; +; SM90-NEXT: st.global.b16 [%rd2+30], %rs2; +; SM90-NEXT: ret; +; +; SM100-LABEL: global_16xi16( +; SM100: { +; SM100-NEXT: .reg .b16 %rs<7>; +; SM100-NEXT: .reg .b32 %r<9>; +; SM100-NEXT: .reg .b64 %rd<3>; +; SM100-EMPTY: +; SM100-NEXT: // %bb.0: +; SM100-NEXT: ld.param.b64 %rd1, [global_16xi16_param_0]; +; SM100-NEXT: .pragma "used_bytes_mask 0xf0000f0f"; +; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1]; +; SM100-NEXT: mov.b32 {%rs1, %rs2}, %r8; +; SM100-NEXT: mov.b32 {%rs3, %rs4}, %r3; +; SM100-NEXT: mov.b32 {%rs5, %rs6}, %r1; +; SM100-NEXT: ld.param.b64 %rd2, [global_16xi16_param_1]; +; SM100-NEXT: st.global.b16 [%rd2], %rs5; +; SM100-NEXT: st.global.b16 [%rd2+2], %rs6; +; SM100-NEXT: st.global.b16 [%rd2+8], %rs3; +; SM100-NEXT: st.global.b16 [%rd2+10], %rs4; +; SM100-NEXT: st.global.b16 [%rd2+28], %rs1; +; SM100-NEXT: st.global.b16 [%rd2+30], %rs2; +; SM100-NEXT: ret; + %a.load = tail call <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1) align 32 %a, <16 x i1> , <16 x i16> poison) + tail call void @llvm.masked.store.v16i16.p1(<16 x i16> %a.load, ptr addrspace(1) align 32 %b, <16 x i1> ) + ret void +} + +define void @global_8xi32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; CHECK-LABEL: global_8xi32_no_align( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<4>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [global_8xi32_no_align_param_0]; +; CHECK-NEXT: ld.global.b32 %r1, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd2, [global_8xi32_no_align_param_1]; +; CHECK-NEXT: ld.global.b32 %r2, [%rd1+8]; +; CHECK-NEXT: ld.global.b32 %r3, [%rd1+28]; +; CHECK-NEXT: st.global.b32 [%rd2], %r1; +; CHECK-NEXT: st.global.b32 [%rd2+8], %r2; +; CHECK-NEXT: st.global.b32 [%rd2+28], %r3; +; CHECK-NEXT: ret; + %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 16 %a, <8 x i1> , <8 x i32> poison) + tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 16 %b, <8 x i1> ) + ret void +} + + +define void @global_8xi32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; SM90-LABEL: global_8xi32_invariant( +; SM90: { +; SM90-NEXT: .reg .b32 %r<9>; +; SM90-NEXT: .reg .b64 %rd<3>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_invariant_param_0]; +; SM90-NEXT: .pragma "used_bytes_mask 0xf000"; +; SM90-NEXT: ld.global.nc.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16]; +; SM90-NEXT: .pragma "used_bytes_mask 0xf0f"; +; SM90-NEXT: ld.global.nc.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1]; +; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_invariant_param_1]; +; SM90-NEXT: st.global.b32 [%rd2], %r5; +; SM90-NEXT: st.global.b32 [%rd2+8], %r7; +; SM90-NEXT: st.global.b32 [%rd2+28], %r4; +; SM90-NEXT: ret; +; +; SM100-LABEL: global_8xi32_invariant( +; SM100: { +; SM100-NEXT: .reg .b32 %r<9>; +; SM100-NEXT: .reg .b64 %rd<3>; +; SM100-EMPTY: +; SM100-NEXT: // %bb.0: +; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_invariant_param_0]; +; SM100-NEXT: .pragma "used_bytes_mask 0xf0000f0f"; +; SM100-NEXT: ld.global.nc.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1]; +; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_invariant_param_1]; +; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8}; +; SM100-NEXT: ret; + %a.load = tail call <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1) align 32 %a, <8 x i1> , <8 x i32> poison), !invariant.load !0 + tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> ) + ret void +} + +define void @global_2xi16(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; CHECK-LABEL: global_2xi16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [global_2xi16_param_0]; +; CHECK-NEXT: .pragma "used_bytes_mask 0x3"; +; CHECK-NEXT: ld.global.b32 %r1, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_param_1]; +; CHECK-NEXT: mov.b32 {%rs1, _}, %r1; +; CHECK-NEXT: st.global.b16 [%rd2], %rs1; +; CHECK-NEXT: ret; + %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 4 %a, <2 x i1> , <2 x i16> poison) + tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> ) + ret void +} + +define void @global_2xi16_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; CHECK-LABEL: global_2xi16_invariant( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [global_2xi16_invariant_param_0]; +; CHECK-NEXT: .pragma "used_bytes_mask 0x3"; +; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_invariant_param_1]; +; CHECK-NEXT: mov.b32 {%rs1, _}, %r1; +; CHECK-NEXT: st.global.b16 [%rd2], %rs1; +; CHECK-NEXT: ret; + %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 4 %a, <2 x i1> , <2 x i16> poison), !invariant.load !0 + tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> ) + ret void +} + +define void @global_2xi16_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; CHECK-LABEL: global_2xi16_no_align( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [global_2xi16_no_align_param_0]; +; CHECK-NEXT: ld.global.b16 %rs1, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd2, [global_2xi16_no_align_param_1]; +; CHECK-NEXT: st.global.b16 [%rd2], %rs1; +; CHECK-NEXT: ret; + %a.load = tail call <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1) align 2 %a, <2 x i1> , <2 x i16> poison) + tail call void @llvm.masked.store.v2i16.p1(<2 x i16> %a.load, ptr addrspace(1) align 4 %b, <2 x i1> ) + ret void +} + +define void @global_4xi8(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; CHECK-LABEL: global_4xi8( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<3>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [global_4xi8_param_0]; +; CHECK-NEXT: .pragma "used_bytes_mask 0x5"; +; CHECK-NEXT: ld.global.b32 %r1, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd2, [global_4xi8_param_1]; +; CHECK-NEXT: st.global.b8 [%rd2], %r1; +; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7772U; +; CHECK-NEXT: st.global.b8 [%rd2+2], %r2; +; CHECK-NEXT: ret; + %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 4 %a, <4 x i1> , <4 x i8> poison) + tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> ) + ret void +} + +define void @global_4xi8_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; CHECK-LABEL: global_4xi8_invariant( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<3>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [global_4xi8_invariant_param_0]; +; CHECK-NEXT: .pragma "used_bytes_mask 0x5"; +; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd2, [global_4xi8_invariant_param_1]; +; CHECK-NEXT: st.global.b8 [%rd2], %r1; +; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7772U; +; CHECK-NEXT: st.global.b8 [%rd2+2], %r2; +; CHECK-NEXT: ret; + %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 4 %a, <4 x i1> , <4 x i8> poison), !invariant.load !0 + tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> ) + ret void +} + +define void @global_4xi8_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; CHECK-LABEL: global_4xi8_no_align( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [global_4xi8_no_align_param_0]; +; CHECK-NEXT: ld.global.b8 %rs1, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd2, [global_4xi8_no_align_param_1]; +; CHECK-NEXT: ld.global.b8 %rs2, [%rd1+2]; +; CHECK-NEXT: st.global.b8 [%rd2], %rs1; +; CHECK-NEXT: st.global.b8 [%rd2+2], %rs2; +; CHECK-NEXT: ret; + %a.load = tail call <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1) align 2 %a, <4 x i1> , <4 x i8> poison) + tail call void @llvm.masked.store.v4i8.p1(<4 x i8> %a.load, ptr addrspace(1) align 4 %b, <4 x i1> ) + ret void +} + +; In sm100+, we pack 2xf32 loads into a single b64 load while lowering +define void @global_2xf32(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; SM90-LABEL: global_2xf32( +; SM90: { +; SM90-NEXT: .reg .b32 %r<3>; +; SM90-NEXT: .reg .b64 %rd<3>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b64 %rd1, [global_2xf32_param_0]; +; SM90-NEXT: .pragma "used_bytes_mask 0xf"; +; SM90-NEXT: ld.global.v2.b32 {%r1, %r2}, [%rd1]; +; SM90-NEXT: ld.param.b64 %rd2, [global_2xf32_param_1]; +; SM90-NEXT: st.global.b32 [%rd2], %r1; +; SM90-NEXT: ret; +; +; SM100-LABEL: global_2xf32( +; SM100: { +; SM100-NEXT: .reg .b32 %r<2>; +; SM100-NEXT: .reg .b64 %rd<4>; +; SM100-EMPTY: +; SM100-NEXT: // %bb.0: +; SM100-NEXT: ld.param.b64 %rd1, [global_2xf32_param_0]; +; SM100-NEXT: .pragma "used_bytes_mask 0xf"; +; SM100-NEXT: ld.global.b64 %rd2, [%rd1]; +; SM100-NEXT: ld.param.b64 %rd3, [global_2xf32_param_1]; +; SM100-NEXT: mov.b64 {%r1, _}, %rd2; +; SM100-NEXT: st.global.b32 [%rd3], %r1; +; SM100-NEXT: ret; + %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 8 %a, <2 x i1> , <2 x float> poison) + tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> ) + ret void +} + +define void @global_2xf32_invariant(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; SM90-LABEL: global_2xf32_invariant( +; SM90: { +; SM90-NEXT: .reg .b32 %r<3>; +; SM90-NEXT: .reg .b64 %rd<3>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b64 %rd1, [global_2xf32_invariant_param_0]; +; SM90-NEXT: .pragma "used_bytes_mask 0xf"; +; SM90-NEXT: ld.global.nc.v2.b32 {%r1, %r2}, [%rd1]; +; SM90-NEXT: ld.param.b64 %rd2, [global_2xf32_invariant_param_1]; +; SM90-NEXT: st.global.b32 [%rd2], %r1; +; SM90-NEXT: ret; +; +; SM100-LABEL: global_2xf32_invariant( +; SM100: { +; SM100-NEXT: .reg .b32 %r<2>; +; SM100-NEXT: .reg .b64 %rd<4>; +; SM100-EMPTY: +; SM100-NEXT: // %bb.0: +; SM100-NEXT: ld.param.b64 %rd1, [global_2xf32_invariant_param_0]; +; SM100-NEXT: .pragma "used_bytes_mask 0xf"; +; SM100-NEXT: ld.global.nc.b64 %rd2, [%rd1]; +; SM100-NEXT: ld.param.b64 %rd3, [global_2xf32_invariant_param_1]; +; SM100-NEXT: mov.b64 {%r1, _}, %rd2; +; SM100-NEXT: st.global.b32 [%rd3], %r1; +; SM100-NEXT: ret; + %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 8 %a, <2 x i1> , <2 x float> poison), !invariant.load !0 + tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> ) + ret void +} + +define void @global_2xf32_no_align(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; CHECK-LABEL: global_2xf32_no_align( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<2>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [global_2xf32_no_align_param_0]; +; CHECK-NEXT: ld.global.b32 %r1, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd2, [global_2xf32_no_align_param_1]; +; CHECK-NEXT: st.global.b32 [%rd2], %r1; +; CHECK-NEXT: ret; + %a.load = tail call <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1) align 4 %a, <2 x i1> , <2 x float> poison) + tail call void @llvm.masked.store.v2f32.p1(<2 x float> %a.load, ptr addrspace(1) align 8 %b, <2 x i1> ) + ret void +} + +declare <8 x i32> @llvm.masked.load.v8i32.p1(ptr addrspace(1), <8 x i1>, <8 x i32>) +declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), <8 x i1>) +declare <16 x i16> @llvm.masked.load.v16i16.p1(ptr addrspace(1), <16 x i1>, <16 x i16>) +declare void @llvm.masked.store.v16i16.p1(<16 x i16>, ptr addrspace(1), <16 x i1>) +declare <2 x i16> @llvm.masked.load.v2i16.p1(ptr addrspace(1), <2 x i1>, <2 x i16>) +declare void @llvm.masked.store.v2i16.p1(<2 x i16>, ptr addrspace(1), <2 x i1>) +declare <4 x i8> @llvm.masked.load.v4i8.p1(ptr addrspace(1), <4 x i1>, <4 x i8>) +declare void @llvm.masked.store.v4i8.p1(<4 x i8>, ptr addrspace(1), <4 x i1>) +declare <2 x float> @llvm.masked.load.v2f32.p1(ptr addrspace(1), <2 x i1>, <2 x float>) +declare void @llvm.masked.store.v2f32.p1(<2 x float>, ptr addrspace(1), <2 x i1>) +!0 = !{} diff --git a/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll new file mode 100644 index 0000000000000..9f23acaf93bc8 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/masked-store-variable-mask.ll @@ -0,0 +1,56 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK +; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %} + +; Confirm that a masked store with a variable mask is scalarized before lowering + +define void @global_variable_mask(ptr addrspace(1) %a, ptr addrspace(1) %b, <4 x i1> %mask) { +; CHECK-LABEL: global_variable_mask( +; CHECK: { +; CHECK-NEXT: .reg .pred %p<9>; +; CHECK-NEXT: .reg .b16 %rs<9>; +; CHECK-NEXT: .reg .b64 %rd<7>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b8 %rs1, [global_variable_mask_param_2+3]; +; CHECK-NEXT: ld.param.b8 %rs3, [global_variable_mask_param_2+2]; +; CHECK-NEXT: and.b16 %rs4, %rs3, 1; +; CHECK-NEXT: ld.param.b8 %rs5, [global_variable_mask_param_2+1]; +; CHECK-NEXT: and.b16 %rs6, %rs5, 1; +; CHECK-NEXT: setp.ne.b16 %p2, %rs6, 0; +; CHECK-NEXT: ld.param.b8 %rs7, [global_variable_mask_param_2]; +; CHECK-NEXT: and.b16 %rs8, %rs7, 1; +; CHECK-NEXT: setp.ne.b16 %p1, %rs8, 0; +; CHECK-NEXT: ld.param.b64 %rd5, [global_variable_mask_param_1]; +; CHECK-NEXT: ld.param.b64 %rd6, [global_variable_mask_param_0]; +; CHECK-NEXT: ld.global.v4.b64 {%rd1, %rd2, %rd3, %rd4}, [%rd6]; +; CHECK-NEXT: not.pred %p5, %p1; +; CHECK-NEXT: @%p5 bra $L__BB0_2; +; CHECK-NEXT: // %bb.1: // %cond.store +; CHECK-NEXT: st.global.b64 [%rd5], %rd1; +; CHECK-NEXT: $L__BB0_2: // %else +; CHECK-NEXT: and.b16 %rs2, %rs1, 1; +; CHECK-NEXT: setp.ne.b16 %p3, %rs4, 0; +; CHECK-NEXT: not.pred %p6, %p2; +; CHECK-NEXT: @%p6 bra $L__BB0_4; +; CHECK-NEXT: // %bb.3: // %cond.store1 +; CHECK-NEXT: st.global.b64 [%rd5+8], %rd2; +; CHECK-NEXT: $L__BB0_4: // %else2 +; CHECK-NEXT: setp.ne.b16 %p4, %rs2, 0; +; CHECK-NEXT: not.pred %p7, %p3; +; CHECK-NEXT: @%p7 bra $L__BB0_6; +; CHECK-NEXT: // %bb.5: // %cond.store3 +; CHECK-NEXT: st.global.b64 [%rd5+16], %rd3; +; CHECK-NEXT: $L__BB0_6: // %else4 +; CHECK-NEXT: not.pred %p8, %p4; +; CHECK-NEXT: @%p8 bra $L__BB0_8; +; CHECK-NEXT: // %bb.7: // %cond.store5 +; CHECK-NEXT: st.global.b64 [%rd5+24], %rd4; +; CHECK-NEXT: $L__BB0_8: // %else6 +; CHECK-NEXT: ret; + %a.load = load <4 x i64>, ptr addrspace(1) %a + tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> %mask) + ret void +} + +declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), <4 x i1>) diff --git a/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll new file mode 100644 index 0000000000000..feb7b7e0a0b39 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/masked-store-vectors-256.ll @@ -0,0 +1,318 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5 +; RUN: llc < %s -march=nvptx64 -mcpu=sm_90 | FileCheck %s -check-prefixes=CHECK,SM90 +; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_90 | %ptxas-verify -arch=sm_90 %} +; RUN: llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | FileCheck %s -check-prefixes=CHECK,SM100 +; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_100 -mattr=+ptx88 | %ptxas-verify -arch=sm_100 %} + +; This test is based on load-store-vectors.ll, +; and contains testing for lowering 256-bit masked vector stores + +; Types we are checking: i32, i64, f32, f64 + +; Address spaces we are checking: generic, global +; - Global is the only address space that currently supports masked stores. +; - The generic stores will get legalized before the backend via scalarization, +; this file tests that even though we won't be generating them in the LSV. + +; 256-bit vector loads/stores are only legal for blackwell+, so on sm_90, the vectors will be split + +; generic address space + +define void @generic_8xi32(ptr %a, ptr %b) { +; CHECK-LABEL: generic_8xi32( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<9>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [generic_8xi32_param_0]; +; CHECK-NEXT: ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16]; +; CHECK-NEXT: ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd2, [generic_8xi32_param_1]; +; CHECK-NEXT: st.b32 [%rd2], %r5; +; CHECK-NEXT: st.b32 [%rd2+8], %r7; +; CHECK-NEXT: st.b32 [%rd2+28], %r4; +; CHECK-NEXT: ret; + %a.load = load <8 x i32>, ptr %a + tail call void @llvm.masked.store.v8i32.p0(<8 x i32> %a.load, ptr align 32 %b, <8 x i1> ) + ret void +} + +define void @generic_4xi64(ptr %a, ptr %b) { +; CHECK-LABEL: generic_4xi64( +; CHECK: { +; CHECK-NEXT: .reg .b64 %rd<7>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [generic_4xi64_param_0]; +; CHECK-NEXT: ld.v2.b64 {%rd2, %rd3}, [%rd1+16]; +; CHECK-NEXT: ld.v2.b64 {%rd4, %rd5}, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd6, [generic_4xi64_param_1]; +; CHECK-NEXT: st.b64 [%rd6], %rd4; +; CHECK-NEXT: st.b64 [%rd6+16], %rd2; +; CHECK-NEXT: ret; + %a.load = load <4 x i64>, ptr %a + tail call void @llvm.masked.store.v4i64.p0(<4 x i64> %a.load, ptr align 32 %b, <4 x i1> ) + ret void +} + +define void @generic_8xfloat(ptr %a, ptr %b) { +; CHECK-LABEL: generic_8xfloat( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<9>; +; CHECK-NEXT: .reg .b64 %rd<3>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [generic_8xfloat_param_0]; +; CHECK-NEXT: ld.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16]; +; CHECK-NEXT: ld.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd2, [generic_8xfloat_param_1]; +; CHECK-NEXT: st.b32 [%rd2], %r5; +; CHECK-NEXT: st.b32 [%rd2+8], %r7; +; CHECK-NEXT: st.b32 [%rd2+28], %r4; +; CHECK-NEXT: ret; + %a.load = load <8 x float>, ptr %a + tail call void @llvm.masked.store.v8f32.p0(<8 x float> %a.load, ptr align 32 %b, <8 x i1> ) + ret void +} + +define void @generic_4xdouble(ptr %a, ptr %b) { +; CHECK-LABEL: generic_4xdouble( +; CHECK: { +; CHECK-NEXT: .reg .b64 %rd<7>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [generic_4xdouble_param_0]; +; CHECK-NEXT: ld.v2.b64 {%rd2, %rd3}, [%rd1+16]; +; CHECK-NEXT: ld.v2.b64 {%rd4, %rd5}, [%rd1]; +; CHECK-NEXT: ld.param.b64 %rd6, [generic_4xdouble_param_1]; +; CHECK-NEXT: st.b64 [%rd6], %rd4; +; CHECK-NEXT: st.b64 [%rd6+16], %rd2; +; CHECK-NEXT: ret; + %a.load = load <4 x double>, ptr %a + tail call void @llvm.masked.store.v4f64.p0(<4 x double> %a.load, ptr align 32 %b, <4 x i1> ) + ret void +} + +; global address space + +define void @global_8xi32(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; SM90-LABEL: global_8xi32( +; SM90: { +; SM90-NEXT: .reg .b32 %r<9>; +; SM90-NEXT: .reg .b64 %rd<3>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0]; +; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16]; +; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1]; +; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1]; +; SM90-NEXT: st.global.b32 [%rd2], %r5; +; SM90-NEXT: st.global.b32 [%rd2+8], %r7; +; SM90-NEXT: st.global.b32 [%rd2+28], %r4; +; SM90-NEXT: ret; +; +; SM100-LABEL: global_8xi32( +; SM100: { +; SM100-NEXT: .reg .b32 %r<9>; +; SM100-NEXT: .reg .b64 %rd<3>; +; SM100-EMPTY: +; SM100-NEXT: // %bb.0: +; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_param_0]; +; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1]; +; SM100-NEXT: ld.param.b64 %rd2, [global_8xi32_param_1]; +; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8}; +; SM100-NEXT: ret; + %a.load = load <8 x i32>, ptr addrspace(1) %a + tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> ) + ret void +} + +define void @global_4xi64(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; SM90-LABEL: global_4xi64( +; SM90: { +; SM90-NEXT: .reg .b64 %rd<7>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b64 %rd1, [global_4xi64_param_0]; +; SM90-NEXT: ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16]; +; SM90-NEXT: ld.global.v2.b64 {%rd4, %rd5}, [%rd1]; +; SM90-NEXT: ld.param.b64 %rd6, [global_4xi64_param_1]; +; SM90-NEXT: st.global.b64 [%rd6], %rd4; +; SM90-NEXT: st.global.b64 [%rd6+16], %rd2; +; SM90-NEXT: ret; +; +; SM100-LABEL: global_4xi64( +; SM100: { +; SM100-NEXT: .reg .b64 %rd<7>; +; SM100-EMPTY: +; SM100-NEXT: // %bb.0: +; SM100-NEXT: ld.param.b64 %rd1, [global_4xi64_param_0]; +; SM100-NEXT: ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1]; +; SM100-NEXT: ld.param.b64 %rd6, [global_4xi64_param_1]; +; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _}; +; SM100-NEXT: ret; + %a.load = load <4 x i64>, ptr addrspace(1) %a + tail call void @llvm.masked.store.v4i64.p1(<4 x i64> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> ) + ret void +} + +define void @global_8xfloat(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; SM90-LABEL: global_8xfloat( +; SM90: { +; SM90-NEXT: .reg .b32 %r<9>; +; SM90-NEXT: .reg .b64 %rd<3>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b64 %rd1, [global_8xfloat_param_0]; +; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16]; +; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1]; +; SM90-NEXT: ld.param.b64 %rd2, [global_8xfloat_param_1]; +; SM90-NEXT: st.global.b32 [%rd2], %r5; +; SM90-NEXT: st.global.b32 [%rd2+8], %r7; +; SM90-NEXT: st.global.b32 [%rd2+28], %r4; +; SM90-NEXT: ret; +; +; SM100-LABEL: global_8xfloat( +; SM100: { +; SM100-NEXT: .reg .b32 %r<9>; +; SM100-NEXT: .reg .b64 %rd<3>; +; SM100-EMPTY: +; SM100-NEXT: // %bb.0: +; SM100-NEXT: ld.param.b64 %rd1, [global_8xfloat_param_0]; +; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1]; +; SM100-NEXT: ld.param.b64 %rd2, [global_8xfloat_param_1]; +; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, _, %r3, _, _, _, _, %r8}; +; SM100-NEXT: ret; + %a.load = load <8 x float>, ptr addrspace(1) %a + tail call void @llvm.masked.store.v8f32.p1(<8 x float> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> ) + ret void +} + +define void @global_4xdouble(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; SM90-LABEL: global_4xdouble( +; SM90: { +; SM90-NEXT: .reg .b64 %rd<7>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b64 %rd1, [global_4xdouble_param_0]; +; SM90-NEXT: ld.global.v2.b64 {%rd2, %rd3}, [%rd1+16]; +; SM90-NEXT: ld.global.v2.b64 {%rd4, %rd5}, [%rd1]; +; SM90-NEXT: ld.param.b64 %rd6, [global_4xdouble_param_1]; +; SM90-NEXT: st.global.b64 [%rd6], %rd4; +; SM90-NEXT: st.global.b64 [%rd6+16], %rd2; +; SM90-NEXT: ret; +; +; SM100-LABEL: global_4xdouble( +; SM100: { +; SM100-NEXT: .reg .b64 %rd<7>; +; SM100-EMPTY: +; SM100-NEXT: // %bb.0: +; SM100-NEXT: ld.param.b64 %rd1, [global_4xdouble_param_0]; +; SM100-NEXT: ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1]; +; SM100-NEXT: ld.param.b64 %rd6, [global_4xdouble_param_1]; +; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, _, %rd4, _}; +; SM100-NEXT: ret; + %a.load = load <4 x double>, ptr addrspace(1) %a + tail call void @llvm.masked.store.v4f64.p1(<4 x double> %a.load, ptr addrspace(1) align 32 %b, <4 x i1> ) + ret void +} + +; edge cases +define void @global_8xi32_all_mask_on(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; SM90-LABEL: global_8xi32_all_mask_on( +; SM90: { +; SM90-NEXT: .reg .b32 %r<9>; +; SM90-NEXT: .reg .b64 %rd<3>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0]; +; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1]; +; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1+16]; +; SM90-NEXT: ld.param.b64 %rd2, [global_8xi32_all_mask_on_param_1]; +; SM90-NEXT: st.global.v4.b32 [%rd2+16], {%r5, %r6, %r7, %r8}; +; SM90-NEXT: st.global.v4.b32 [%rd2], {%r1, %r2, %r3, %r4}; +; SM90-NEXT: ret; +; +; SM100-LABEL: global_8xi32_all_mask_on( +; SM100: { +; SM100-NEXT: .reg .b64 %rd<7>; +; SM100-EMPTY: +; SM100-NEXT: // %bb.0: +; SM100-NEXT: ld.param.b64 %rd1, [global_8xi32_all_mask_on_param_0]; +; SM100-NEXT: ld.global.v4.b64 {%rd2, %rd3, %rd4, %rd5}, [%rd1]; +; SM100-NEXT: ld.param.b64 %rd6, [global_8xi32_all_mask_on_param_1]; +; SM100-NEXT: st.global.v4.b64 [%rd6], {%rd2, %rd3, %rd4, %rd5}; +; SM100-NEXT: ret; + %a.load = load <8 x i32>, ptr addrspace(1) %a + tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> ) + ret void +} + +define void @global_8xi32_all_mask_off(ptr addrspace(1) %a, ptr addrspace(1) %b) { +; CHECK-LABEL: global_8xi32_all_mask_off( +; CHECK: { +; CHECK-EMPTY: +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ret; + %a.load = load <8 x i32>, ptr addrspace(1) %a + tail call void @llvm.masked.store.v8i32.p1(<8 x i32> %a.load, ptr addrspace(1) align 32 %b, <8 x i1> ) + ret void +} + +; This is an example pattern for the LSV's output of these masked stores +define void @vectorizerOutput(ptr addrspace(1) %in, ptr addrspace(1) %out) { +; SM90-LABEL: vectorizerOutput( +; SM90: { +; SM90-NEXT: .reg .b32 %r<9>; +; SM90-NEXT: .reg .b64 %rd<3>; +; SM90-EMPTY: +; SM90-NEXT: // %bb.0: +; SM90-NEXT: ld.param.b64 %rd1, [vectorizerOutput_param_0]; +; SM90-NEXT: ld.global.v4.b32 {%r1, %r2, %r3, %r4}, [%rd1+16]; +; SM90-NEXT: ld.global.v4.b32 {%r5, %r6, %r7, %r8}, [%rd1]; +; SM90-NEXT: ld.param.b64 %rd2, [vectorizerOutput_param_1]; +; SM90-NEXT: st.global.b32 [%rd2], %r5; +; SM90-NEXT: st.global.b32 [%rd2+4], %r6; +; SM90-NEXT: st.global.b32 [%rd2+12], %r8; +; SM90-NEXT: st.global.b32 [%rd2+16], %r1; +; SM90-NEXT: ret; +; +; SM100-LABEL: vectorizerOutput( +; SM100: { +; SM100-NEXT: .reg .b32 %r<9>; +; SM100-NEXT: .reg .b64 %rd<3>; +; SM100-EMPTY: +; SM100-NEXT: // %bb.0: +; SM100-NEXT: ld.param.b64 %rd1, [vectorizerOutput_param_0]; +; SM100-NEXT: ld.global.v8.b32 {%r1, %r2, %r3, %r4, %r5, %r6, %r7, %r8}, [%rd1]; +; SM100-NEXT: ld.param.b64 %rd2, [vectorizerOutput_param_1]; +; SM100-NEXT: st.global.v8.b32 [%rd2], {%r1, %r2, _, %r4, %r5, _, _, _}; +; SM100-NEXT: ret; + %1 = load <8 x i32>, ptr addrspace(1) %in, align 32 + %load05 = extractelement <8 x i32> %1, i32 0 + %load16 = extractelement <8 x i32> %1, i32 1 + %load38 = extractelement <8 x i32> %1, i32 3 + %load49 = extractelement <8 x i32> %1, i32 4 + %2 = insertelement <8 x i32> poison, i32 %load05, i32 0 + %3 = insertelement <8 x i32> %2, i32 %load16, i32 1 + %4 = insertelement <8 x i32> %3, i32 poison, i32 2 + %5 = insertelement <8 x i32> %4, i32 %load38, i32 3 + %6 = insertelement <8 x i32> %5, i32 %load49, i32 4 + %7 = insertelement <8 x i32> %6, i32 poison, i32 5 + %8 = insertelement <8 x i32> %7, i32 poison, i32 6 + %9 = insertelement <8 x i32> %8, i32 poison, i32 7 + call void @llvm.masked.store.v8i32.p1(<8 x i32> %9, ptr addrspace(1) align 32 %out, <8 x i1> ) + ret void +} + +declare void @llvm.masked.store.v8i32.p0(<8 x i32>, ptr, <8 x i1>) +declare void @llvm.masked.store.v4i64.p0(<4 x i64>, ptr, <4 x i1>) +declare void @llvm.masked.store.v8f32.p0(<8 x float>, ptr, <8 x i1>) +declare void @llvm.masked.store.v4f64.p0(<4 x double>, ptr, <4 x i1>) + +declare void @llvm.masked.store.v8i32.p1(<8 x i32>, ptr addrspace(1), <8 x i1>) +declare void @llvm.masked.store.v4i64.p1(<4 x i64>, ptr addrspace(1), <4 x i1>) +declare void @llvm.masked.store.v8f32.p1(<8 x float>, ptr addrspace(1), <8 x i1>) +declare void @llvm.masked.store.v4f64.p1(<4 x double>, ptr addrspace(1), <4 x i1>) diff --git a/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir b/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir index dfc84177fb0e6..a84b7fcd33836 100644 --- a/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir +++ b/llvm/test/CodeGen/NVPTX/proxy-reg-erasure.mir @@ -77,7 +77,7 @@ constants: [] machineFunctionInfo: {} body: | bb.0: - %0:b32, %1:b32, %2:b32, %3:b32 = LDV_i32_v4 0, 0, 101, 3, 32, &retval0, 0 :: (load (s128), addrspace 101) + %0:b32, %1:b32, %2:b32, %3:b32 = LDV_i32_v4 0, 0, 101, 3, 32, -1, &retval0, 0 :: (load (s128), addrspace 101) ; CHECK-NOT: ProxyReg %4:b32 = ProxyRegB32 killed %0 %5:b32 = ProxyRegB32 killed %1 @@ -86,7 +86,7 @@ body: | ; CHECK: STV_i32_v4 killed %0, killed %1, killed %2, killed %3 STV_i32_v4 killed %4, killed %5, killed %6, killed %7, 0, 0, 101, 32, &func_retval0, 0 :: (store (s128), addrspace 101) - %8:b32 = LD_i32 0, 0, 101, 3, 32, &retval0, 0 :: (load (s32), addrspace 101) + %8:b32 = LD_i32 0, 0, 101, 3, 32, -1, &retval0, 0 :: (load (s32), addrspace 101) ; CHECK-NOT: ProxyReg %9:b32 = ProxyRegB32 killed %8 %10:b32 = ProxyRegB32 killed %9