@@ -13527,6 +13527,40 @@ static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
1352713527 return true;
1352813528}
1352913529
13530+ /// Match the index vector of a scatter or gather node as the shuffle mask
13531+ /// which performs the rearrangement if possible. Will only match if
13532+ /// all lanes are touched, and thus replacing the scatter or gather with
13533+ /// a unit strided access and shuffle is legal.
13534+ static bool matchIndexAsShuffle(EVT VT, SDValue Index, SDValue Mask,
13535+ SmallVector<int> &ShuffleMask) {
13536+ if (!ISD::isConstantSplatVectorAllOnes(Mask.getNode()))
13537+ return false;
13538+ if (!ISD::isBuildVectorOfConstantSDNodes(Index.getNode()))
13539+ return false;
13540+
13541+ const unsigned ElementSize = VT.getScalarStoreSize();
13542+ const unsigned NumElems = VT.getVectorNumElements();
13543+
13544+ // Create the shuffle mask and check all bits active
13545+ assert(ShuffleMask.empty());
13546+ BitVector ActiveLanes(NumElems);
13547+ for (unsigned i = 0; i < Index->getNumOperands(); i++) {
13548+ // TODO: We've found an active bit of UB, and could be
13549+ // more aggressive here if desired.
13550+ if (Index->getOperand(i)->isUndef())
13551+ return false;
13552+ uint64_t C = Index->getConstantOperandVal(i);
13553+ if (C % ElementSize != 0)
13554+ return false;
13555+ C = C / ElementSize;
13556+ if (C >= NumElems)
13557+ return false;
13558+ ShuffleMask.push_back(C);
13559+ ActiveLanes.set(C);
13560+ }
13561+ return ActiveLanes.all();
13562+ }
13563+
1353013564SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1353113565 DAGCombinerInfo &DCI) const {
1353213566 SelectionDAG &DAG = DCI.DAG;
@@ -13874,6 +13908,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1387413908 }
1387513909 case ISD::MGATHER: {
1387613910 const auto *MGN = dyn_cast<MaskedGatherSDNode>(N);
13911+ const EVT VT = N->getValueType(0);
1387713912 SDValue Index = MGN->getIndex();
1387813913 SDValue ScaleOp = MGN->getScale();
1387913914 ISD::MemIndexType IndexType = MGN->getIndexType();
@@ -13894,6 +13929,19 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1389413929 {MGN->getChain(), MGN->getPassThru(), MGN->getMask(),
1389513930 MGN->getBasePtr(), Index, ScaleOp},
1389613931 MGN->getMemOperand(), IndexType, MGN->getExtensionType());
13932+
13933+ SmallVector<int> ShuffleMask;
13934+ if (MGN->getExtensionType() == ISD::NON_EXTLOAD &&
13935+ matchIndexAsShuffle(VT, Index, MGN->getMask(), ShuffleMask)) {
13936+ SDValue Load = DAG.getMaskedLoad(VT, DL, MGN->getChain(),
13937+ MGN->getBasePtr(), DAG.getUNDEF(XLenVT),
13938+ MGN->getMask(), DAG.getUNDEF(VT),
13939+ MGN->getMemoryVT(), MGN->getMemOperand(),
13940+ ISD::UNINDEXED, ISD::NON_EXTLOAD);
13941+ SDValue Shuffle =
13942+ DAG.getVectorShuffle(VT, DL, Load, DAG.getUNDEF(VT), ShuffleMask);
13943+ return DAG.getMergeValues({Shuffle, Load.getValue(1)}, DL);
13944+ }
1389713945 break;
1389813946 }
1389913947 case ISD::MSCATTER:{
@@ -13918,6 +13966,18 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1391813966 {MSN->getChain(), MSN->getValue(), MSN->getMask(), MSN->getBasePtr(),
1391913967 Index, ScaleOp},
1392013968 MSN->getMemOperand(), IndexType, MSN->isTruncatingStore());
13969+
13970+ EVT VT = MSN->getValue()->getValueType(0);
13971+ SmallVector<int> ShuffleMask;
13972+ if (!MSN->isTruncatingStore() &&
13973+ matchIndexAsShuffle(VT, Index, MSN->getMask(), ShuffleMask)) {
13974+ SDValue Shuffle = DAG.getVectorShuffle(VT, DL, MSN->getValue(),
13975+ DAG.getUNDEF(VT), ShuffleMask);
13976+ return DAG.getMaskedStore(MSN->getChain(), DL, Shuffle, MSN->getBasePtr(),
13977+ DAG.getUNDEF(XLenVT), MSN->getMask(),
13978+ MSN->getMemoryVT(), MSN->getMemOperand(),
13979+ ISD::UNINDEXED, false);
13980+ }
1392113981 break;
1392213982 }
1392313983 case ISD::VP_GATHER: {
0 commit comments