1 | //===- RISCVFoldMasks.cpp - MI Vector Pseudo Mask Peepholes ---------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===---------------------------------------------------------------------===// |
8 | // |
9 | // This pass performs various peephole optimisations that fold masks into vector |
10 | // pseudo instructions after instruction selection. |
11 | // |
12 | // Currently it converts |
13 | // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew |
14 | // -> |
15 | // PseudoVMV_V_V %false, %true, %vl, %sew |
16 | // |
17 | //===---------------------------------------------------------------------===// |
18 | |
19 | #include "RISCV.h" |
20 | #include "RISCVISelDAGToDAG.h" |
21 | #include "RISCVSubtarget.h" |
22 | #include "llvm/CodeGen/MachineFunctionPass.h" |
23 | #include "llvm/CodeGen/MachineRegisterInfo.h" |
24 | #include "llvm/CodeGen/TargetInstrInfo.h" |
25 | #include "llvm/CodeGen/TargetRegisterInfo.h" |
26 | |
27 | using namespace llvm; |
28 | |
29 | #define DEBUG_TYPE "riscv-fold-masks" |
30 | |
31 | namespace { |
32 | |
33 | class RISCVFoldMasks : public MachineFunctionPass { |
34 | public: |
35 | static char ID; |
36 | const TargetInstrInfo *TII; |
37 | MachineRegisterInfo *MRI; |
38 | const TargetRegisterInfo *TRI; |
39 | RISCVFoldMasks() : MachineFunctionPass(ID) {} |
40 | |
41 | bool runOnMachineFunction(MachineFunction &MF) override; |
42 | MachineFunctionProperties getRequiredProperties() const override { |
43 | return MachineFunctionProperties().set( |
44 | MachineFunctionProperties::Property::IsSSA); |
45 | } |
46 | |
47 | StringRef getPassName() const override { return "RISC-V Fold Masks" ; } |
48 | |
49 | private: |
50 | bool convertToUnmasked(MachineInstr &MI) const; |
51 | bool convertVMergeToVMv(MachineInstr &MI) const; |
52 | |
53 | bool isAllOnesMask(const MachineInstr *MaskDef) const; |
54 | |
55 | /// Maps uses of V0 to the corresponding def of V0. |
56 | DenseMap<const MachineInstr *, const MachineInstr *> V0Defs; |
57 | }; |
58 | |
59 | } // namespace |
60 | |
61 | char RISCVFoldMasks::ID = 0; |
62 | |
63 | INITIALIZE_PASS(RISCVFoldMasks, DEBUG_TYPE, "RISC-V Fold Masks" , false, false) |
64 | |
65 | bool RISCVFoldMasks::isAllOnesMask(const MachineInstr *MaskDef) const { |
66 | assert(MaskDef && MaskDef->isCopy() && |
67 | MaskDef->getOperand(0).getReg() == RISCV::V0); |
68 | Register SrcReg = TRI->lookThruCopyLike(SrcReg: MaskDef->getOperand(i: 1).getReg(), MRI); |
69 | if (!SrcReg.isVirtual()) |
70 | return false; |
71 | MaskDef = MRI->getVRegDef(Reg: SrcReg); |
72 | if (!MaskDef) |
73 | return false; |
74 | |
75 | // TODO: Check that the VMSET is the expected bitwidth? The pseudo has |
76 | // undefined behaviour if it's the wrong bitwidth, so we could choose to |
77 | // assume that it's all-ones? Same applies to its VL. |
78 | switch (MaskDef->getOpcode()) { |
79 | case RISCV::PseudoVMSET_M_B1: |
80 | case RISCV::PseudoVMSET_M_B2: |
81 | case RISCV::PseudoVMSET_M_B4: |
82 | case RISCV::PseudoVMSET_M_B8: |
83 | case RISCV::PseudoVMSET_M_B16: |
84 | case RISCV::PseudoVMSET_M_B32: |
85 | case RISCV::PseudoVMSET_M_B64: |
86 | return true; |
87 | default: |
88 | return false; |
89 | } |
90 | } |
91 | |
92 | // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to |
93 | // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET. |
94 | bool RISCVFoldMasks::convertVMergeToVMv(MachineInstr &MI) const { |
95 | #define CASE_VMERGE_TO_VMV(lmul) \ |
96 | case RISCV::PseudoVMERGE_VVM_##lmul: \ |
97 | NewOpc = RISCV::PseudoVMV_V_V_##lmul; \ |
98 | break; |
99 | unsigned NewOpc; |
100 | switch (MI.getOpcode()) { |
101 | default: |
102 | return false; |
103 | CASE_VMERGE_TO_VMV(MF8) |
104 | CASE_VMERGE_TO_VMV(MF4) |
105 | CASE_VMERGE_TO_VMV(MF2) |
106 | CASE_VMERGE_TO_VMV(M1) |
107 | CASE_VMERGE_TO_VMV(M2) |
108 | CASE_VMERGE_TO_VMV(M4) |
109 | CASE_VMERGE_TO_VMV(M8) |
110 | } |
111 | |
112 | Register MergeReg = MI.getOperand(i: 1).getReg(); |
113 | Register FalseReg = MI.getOperand(i: 2).getReg(); |
114 | // Check merge == false (or merge == undef) |
115 | if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(SrcReg: MergeReg, MRI) != |
116 | TRI->lookThruCopyLike(SrcReg: FalseReg, MRI)) |
117 | return false; |
118 | |
119 | assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0); |
120 | if (!isAllOnesMask(MaskDef: V0Defs.lookup(Val: &MI))) |
121 | return false; |
122 | |
123 | MI.setDesc(TII->get(Opcode: NewOpc)); |
124 | MI.removeOperand(OpNo: 1); // Merge operand |
125 | MI.tieOperands(DefIdx: 0, UseIdx: 1); // Tie false to dest |
126 | MI.removeOperand(OpNo: 3); // Mask operand |
127 | MI.addOperand( |
128 | Op: MachineOperand::CreateImm(Val: RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED)); |
129 | |
130 | // vmv.v.v doesn't have a mask operand, so we may be able to inflate the |
131 | // register class for the destination and merge operands e.g. VRNoV0 -> VR |
132 | MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg()); |
133 | MRI->recomputeRegClass(Reg: MI.getOperand(i: 1).getReg()); |
134 | return true; |
135 | } |
136 | |
137 | bool RISCVFoldMasks::convertToUnmasked(MachineInstr &MI) const { |
138 | const RISCV::RISCVMaskedPseudoInfo *I = |
139 | RISCV::getMaskedPseudoInfo(MI.getOpcode()); |
140 | if (!I) |
141 | return false; |
142 | |
143 | if (!isAllOnesMask(MaskDef: V0Defs.lookup(Val: &MI))) |
144 | return false; |
145 | |
146 | // There are two classes of pseudos in the table - compares and |
147 | // everything else. See the comment on RISCVMaskedPseudo for details. |
148 | const unsigned Opc = I->UnmaskedPseudo; |
149 | const MCInstrDesc &MCID = TII->get(Opcode: Opc); |
150 | [[maybe_unused]] const bool HasPolicyOp = |
151 | RISCVII::hasVecPolicyOp(TSFlags: MCID.TSFlags); |
152 | const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(Desc: MCID); |
153 | #ifndef NDEBUG |
154 | const MCInstrDesc &MaskedMCID = TII->get(Opcode: MI.getOpcode()); |
155 | assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) == |
156 | RISCVII::hasVecPolicyOp(MCID.TSFlags) && |
157 | "Masked and unmasked pseudos are inconsistent" ); |
158 | assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure" ); |
159 | #endif |
160 | (void)HasPolicyOp; |
161 | |
162 | MI.setDesc(MCID); |
163 | |
164 | // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs? |
165 | unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs(); |
166 | MI.removeOperand(OpNo: MaskOpIdx); |
167 | |
168 | // The unmasked pseudo will no longer be constrained to the vrnov0 reg class, |
169 | // so try and relax it to vr. |
170 | MRI->recomputeRegClass(Reg: MI.getOperand(i: 0).getReg()); |
171 | unsigned PassthruOpIdx = MI.getNumExplicitDefs(); |
172 | if (HasPassthru) { |
173 | if (MI.getOperand(i: PassthruOpIdx).getReg() != RISCV::NoRegister) |
174 | MRI->recomputeRegClass(Reg: MI.getOperand(i: PassthruOpIdx).getReg()); |
175 | } else |
176 | MI.removeOperand(OpNo: PassthruOpIdx); |
177 | |
178 | return true; |
179 | } |
180 | |
181 | bool RISCVFoldMasks::runOnMachineFunction(MachineFunction &MF) { |
182 | if (skipFunction(F: MF.getFunction())) |
183 | return false; |
184 | |
185 | // Skip if the vector extension is not enabled. |
186 | const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); |
187 | if (!ST.hasVInstructions()) |
188 | return false; |
189 | |
190 | TII = ST.getInstrInfo(); |
191 | MRI = &MF.getRegInfo(); |
192 | TRI = MRI->getTargetRegisterInfo(); |
193 | |
194 | bool Changed = false; |
195 | |
196 | // Masked pseudos coming out of isel will have their mask operand in the form: |
197 | // |
198 | // $v0:vr = COPY %mask:vr |
199 | // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr |
200 | // |
201 | // Because $v0 isn't in SSA, keep track of its definition at each use so we |
202 | // can check mask operands. |
203 | for (const MachineBasicBlock &MBB : MF) { |
204 | const MachineInstr *CurrentV0Def = nullptr; |
205 | for (const MachineInstr &MI : MBB) { |
206 | if (MI.readsRegister(RISCV::Reg: V0, TRI)) |
207 | V0Defs[&MI] = CurrentV0Def; |
208 | |
209 | if (MI.definesRegister(RISCV::Reg: V0, TRI)) |
210 | CurrentV0Def = &MI; |
211 | } |
212 | } |
213 | |
214 | for (MachineBasicBlock &MBB : MF) { |
215 | for (MachineInstr &MI : MBB) { |
216 | Changed |= convertToUnmasked(MI); |
217 | Changed |= convertVMergeToVMv(MI); |
218 | } |
219 | } |
220 | |
221 | return Changed; |
222 | } |
223 | |
224 | FunctionPass *llvm::createRISCVFoldMasksPass() { return new RISCVFoldMasks(); } |
225 | |