summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/shader_recompiler/ir_opt/constant_propagation_pass.cpp82
1 files changed, 77 insertions, 5 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index f1ad16d60..9eb61b54c 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -9,6 +9,7 @@
9#include "common/bit_cast.h" 9#include "common/bit_cast.h"
10#include "common/bit_util.h" 10#include "common/bit_util.h"
11#include "shader_recompiler/exception.h" 11#include "shader_recompiler/exception.h"
12#include "shader_recompiler/frontend/ir/ir_emitter.h"
12#include "shader_recompiler/frontend/ir/microinstruction.h" 13#include "shader_recompiler/frontend/ir/microinstruction.h"
13#include "shader_recompiler/ir_opt/passes.h" 14#include "shader_recompiler/ir_opt/passes.h"
14 15
@@ -99,8 +100,71 @@ void FoldGetPred(IR::Inst& inst) {
99 } 100 }
100} 101}
101 102
103/// Replaces the pattern generated by two XMAD multiplications
104bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
105 /*
106 * We are looking for this pattern:
107 * %rhs_bfe = BitFieldUExtract %factor_a, #0, #16 (uses: 1)
108 * %rhs_mul = IMul32 %rhs_bfe, %factor_b (uses: 1)
109 * %lhs_bfe = BitFieldUExtract %factor_a, #16, #16 (uses: 1)
110 * %rhs_mul = IMul32 %lhs_bfe, %factor_b (uses: 1)
111 * %lhs_shl = ShiftLeftLogical32 %rhs_mul, #16 (uses: 1)
112 * %result = IAdd32 %lhs_shl, %rhs_mul (uses: 10)
113 *
114 * And replacing it with
115 * %result = IMul32 %factor_a, %factor_b
116 *
117 * This optimization has been proven safe by LLVM and MSVC.
118 */
119 const IR::Value lhs_arg{inst.Arg(0)};
120 const IR::Value rhs_arg{inst.Arg(1)};
121 if (lhs_arg.IsImmediate() || rhs_arg.IsImmediate()) {
122 return false;
123 }
124 IR::Inst* const lhs_shl{lhs_arg.InstRecursive()};
125 if (lhs_shl->Opcode() != IR::Opcode::ShiftLeftLogical32 || lhs_shl->Arg(1) != IR::Value{16U}) {
126 return false;
127 }
128 if (lhs_shl->Arg(0).IsImmediate()) {
129 return false;
130 }
131 IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()};
132 IR::Inst* const rhs_mul{rhs_arg.InstRecursive()};
133 if (lhs_mul->Opcode() != IR::Opcode::IMul32 || rhs_mul->Opcode() != IR::Opcode::IMul32) {
134 return false;
135 }
136 if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) {
137 return false;
138 }
139 const IR::U32 factor_b{lhs_mul->Arg(1)};
140 if (lhs_mul->Arg(0).IsImmediate() || rhs_mul->Arg(0).IsImmediate()) {
141 return false;
142 }
143 IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()};
144 IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()};
145 if (lhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) {
146 return false;
147 }
148 if (rhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) {
149 return false;
150 }
151 if (lhs_bfe->Arg(1) != IR::Value{16U} || lhs_bfe->Arg(2) != IR::Value{16U}) {
152 return false;
153 }
154 if (rhs_bfe->Arg(1) != IR::Value{0U} || rhs_bfe->Arg(2) != IR::Value{16U}) {
155 return false;
156 }
157 if (lhs_bfe->Arg(0).Resolve() != rhs_bfe->Arg(0).Resolve()) {
158 return false;
159 }
160 const IR::U32 factor_a{lhs_bfe->Arg(0)};
161 IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
162 inst.ReplaceUsesWith(ir.IMul(factor_a, factor_b));
163 return true;
164}
165
102template <typename T> 166template <typename T>
103void FoldAdd(IR::Inst& inst) { 167void FoldAdd(IR::Block& block, IR::Inst& inst) {
104 if (inst.HasAssociatedPseudoOperation()) { 168 if (inst.HasAssociatedPseudoOperation()) {
105 return; 169 return;
106 } 170 }
@@ -110,6 +174,12 @@ void FoldAdd(IR::Inst& inst) {
110 const IR::Value rhs{inst.Arg(1)}; 174 const IR::Value rhs{inst.Arg(1)};
111 if (rhs.IsImmediate() && Arg<T>(rhs) == 0) { 175 if (rhs.IsImmediate() && Arg<T>(rhs) == 0) {
112 inst.ReplaceUsesWith(inst.Arg(0)); 176 inst.ReplaceUsesWith(inst.Arg(0));
177 return;
178 }
179 if constexpr (std::is_same_v<T, u32>) {
180 if (FoldXmadMultiply(block, inst)) {
181 return;
182 }
113 } 183 }
114} 184}
115 185
@@ -244,14 +314,14 @@ void FoldBranchConditional(IR::Inst& inst) {
244 } 314 }
245} 315}
246 316
247void ConstantPropagation(IR::Inst& inst) { 317void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
248 switch (inst.Opcode()) { 318 switch (inst.Opcode()) {
249 case IR::Opcode::GetRegister: 319 case IR::Opcode::GetRegister:
250 return FoldGetRegister(inst); 320 return FoldGetRegister(inst);
251 case IR::Opcode::GetPred: 321 case IR::Opcode::GetPred:
252 return FoldGetPred(inst); 322 return FoldGetPred(inst);
253 case IR::Opcode::IAdd32: 323 case IR::Opcode::IAdd32:
254 return FoldAdd<u32>(inst); 324 return FoldAdd<u32>(block, inst);
255 case IR::Opcode::ISub32: 325 case IR::Opcode::ISub32:
256 return FoldISub32(inst); 326 return FoldISub32(inst);
257 case IR::Opcode::BitCastF32U32: 327 case IR::Opcode::BitCastF32U32:
@@ -259,7 +329,7 @@ void ConstantPropagation(IR::Inst& inst) {
259 case IR::Opcode::BitCastU32F32: 329 case IR::Opcode::BitCastU32F32:
260 return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32); 330 return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32);
261 case IR::Opcode::IAdd64: 331 case IR::Opcode::IAdd64:
262 return FoldAdd<u64>(inst); 332 return FoldAdd<u64>(block, inst);
263 case IR::Opcode::Select32: 333 case IR::Opcode::Select32:
264 return FoldSelect<u32>(inst); 334 return FoldSelect<u32>(inst);
265 case IR::Opcode::LogicalAnd: 335 case IR::Opcode::LogicalAnd:
@@ -292,7 +362,9 @@ void ConstantPropagation(IR::Inst& inst) {
292} // Anonymous namespace 362} // Anonymous namespace
293 363
294void ConstantPropagationPass(IR::Block& block) { 364void ConstantPropagationPass(IR::Block& block) {
295 std::ranges::for_each(block, ConstantPropagation); 365 for (IR::Inst& inst : block) {
366 ConstantPropagation(block, inst);
367 }
296} 368}
297 369
298} // namespace Shader::Optimization 370} // namespace Shader::Optimization