diff options
| -rw-r--r-- | src/shader_recompiler/ir_opt/constant_propagation_pass.cpp | 82 |
1 files changed, 77 insertions, 5 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp index f1ad16d60..9eb61b54c 100644 --- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp | |||
| @@ -9,6 +9,7 @@ | |||
| 9 | #include "common/bit_cast.h" | 9 | #include "common/bit_cast.h" |
| 10 | #include "common/bit_util.h" | 10 | #include "common/bit_util.h" |
| 11 | #include "shader_recompiler/exception.h" | 11 | #include "shader_recompiler/exception.h" |
| 12 | #include "shader_recompiler/frontend/ir/ir_emitter.h" | ||
| 12 | #include "shader_recompiler/frontend/ir/microinstruction.h" | 13 | #include "shader_recompiler/frontend/ir/microinstruction.h" |
| 13 | #include "shader_recompiler/ir_opt/passes.h" | 14 | #include "shader_recompiler/ir_opt/passes.h" |
| 14 | 15 | ||
| @@ -99,8 +100,71 @@ void FoldGetPred(IR::Inst& inst) { | |||
| 99 | } | 100 | } |
| 100 | } | 101 | } |
| 101 | 102 | ||
| 103 | /// Replaces the pattern generated by two XMAD multiplications | ||
| 104 | bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { | ||
| 105 | /* | ||
| 106 | * We are looking for this pattern: | ||
| 107 | * %rhs_bfe = BitFieldUExtract %factor_a, #0, #16 (uses: 1) | ||
| 108 | * %rhs_mul = IMul32 %rhs_bfe, %factor_b (uses: 1) | ||
| 109 | * %lhs_bfe = BitFieldUExtract %factor_a, #16, #16 (uses: 1) | ||
| 110 | * %rhs_mul = IMul32 %lhs_bfe, %factor_b (uses: 1) | ||
| 111 | * %lhs_shl = ShiftLeftLogical32 %rhs_mul, #16 (uses: 1) | ||
| 112 | * %result = IAdd32 %lhs_shl, %rhs_mul (uses: 10) | ||
| 113 | * | ||
| 114 | * And replacing it with | ||
| 115 | * %result = IMul32 %factor_a, %factor_b | ||
| 116 | * | ||
| 117 | * This optimization has been proven safe by LLVM and MSVC. | ||
| 118 | */ | ||
| 119 | const IR::Value lhs_arg{inst.Arg(0)}; | ||
| 120 | const IR::Value rhs_arg{inst.Arg(1)}; | ||
| 121 | if (lhs_arg.IsImmediate() || rhs_arg.IsImmediate()) { | ||
| 122 | return false; | ||
| 123 | } | ||
| 124 | IR::Inst* const lhs_shl{lhs_arg.InstRecursive()}; | ||
| 125 | if (lhs_shl->Opcode() != IR::Opcode::ShiftLeftLogical32 || lhs_shl->Arg(1) != IR::Value{16U}) { | ||
| 126 | return false; | ||
| 127 | } | ||
| 128 | if (lhs_shl->Arg(0).IsImmediate()) { | ||
| 129 | return false; | ||
| 130 | } | ||
| 131 | IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()}; | ||
| 132 | IR::Inst* const rhs_mul{rhs_arg.InstRecursive()}; | ||
| 133 | if (lhs_mul->Opcode() != IR::Opcode::IMul32 || rhs_mul->Opcode() != IR::Opcode::IMul32) { | ||
| 134 | return false; | ||
| 135 | } | ||
| 136 | if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) { | ||
| 137 | return false; | ||
| 138 | } | ||
| 139 | const IR::U32 factor_b{lhs_mul->Arg(1)}; | ||
| 140 | if (lhs_mul->Arg(0).IsImmediate() || rhs_mul->Arg(0).IsImmediate()) { | ||
| 141 | return false; | ||
| 142 | } | ||
| 143 | IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()}; | ||
| 144 | IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()}; | ||
| 145 | if (lhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) { | ||
| 146 | return false; | ||
| 147 | } | ||
| 148 | if (rhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) { | ||
| 149 | return false; | ||
| 150 | } | ||
| 151 | if (lhs_bfe->Arg(1) != IR::Value{16U} || lhs_bfe->Arg(2) != IR::Value{16U}) { | ||
| 152 | return false; | ||
| 153 | } | ||
| 154 | if (rhs_bfe->Arg(1) != IR::Value{0U} || rhs_bfe->Arg(2) != IR::Value{16U}) { | ||
| 155 | return false; | ||
| 156 | } | ||
| 157 | if (lhs_bfe->Arg(0).Resolve() != rhs_bfe->Arg(0).Resolve()) { | ||
| 158 | return false; | ||
| 159 | } | ||
| 160 | const IR::U32 factor_a{lhs_bfe->Arg(0)}; | ||
| 161 | IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)}; | ||
| 162 | inst.ReplaceUsesWith(ir.IMul(factor_a, factor_b)); | ||
| 163 | return true; | ||
| 164 | } | ||
| 165 | |||
| 102 | template <typename T> | 166 | template <typename T> |
| 103 | void FoldAdd(IR::Inst& inst) { | 167 | void FoldAdd(IR::Block& block, IR::Inst& inst) { |
| 104 | if (inst.HasAssociatedPseudoOperation()) { | 168 | if (inst.HasAssociatedPseudoOperation()) { |
| 105 | return; | 169 | return; |
| 106 | } | 170 | } |
| @@ -110,6 +174,12 @@ void FoldAdd(IR::Inst& inst) { | |||
| 110 | const IR::Value rhs{inst.Arg(1)}; | 174 | const IR::Value rhs{inst.Arg(1)}; |
| 111 | if (rhs.IsImmediate() && Arg<T>(rhs) == 0) { | 175 | if (rhs.IsImmediate() && Arg<T>(rhs) == 0) { |
| 112 | inst.ReplaceUsesWith(inst.Arg(0)); | 176 | inst.ReplaceUsesWith(inst.Arg(0)); |
| 177 | return; | ||
| 178 | } | ||
| 179 | if constexpr (std::is_same_v<T, u32>) { | ||
| 180 | if (FoldXmadMultiply(block, inst)) { | ||
| 181 | return; | ||
| 182 | } | ||
| 113 | } | 183 | } |
| 114 | } | 184 | } |
| 115 | 185 | ||
| @@ -244,14 +314,14 @@ void FoldBranchConditional(IR::Inst& inst) { | |||
| 244 | } | 314 | } |
| 245 | } | 315 | } |
| 246 | 316 | ||
| 247 | void ConstantPropagation(IR::Inst& inst) { | 317 | void ConstantPropagation(IR::Block& block, IR::Inst& inst) { |
| 248 | switch (inst.Opcode()) { | 318 | switch (inst.Opcode()) { |
| 249 | case IR::Opcode::GetRegister: | 319 | case IR::Opcode::GetRegister: |
| 250 | return FoldGetRegister(inst); | 320 | return FoldGetRegister(inst); |
| 251 | case IR::Opcode::GetPred: | 321 | case IR::Opcode::GetPred: |
| 252 | return FoldGetPred(inst); | 322 | return FoldGetPred(inst); |
| 253 | case IR::Opcode::IAdd32: | 323 | case IR::Opcode::IAdd32: |
| 254 | return FoldAdd<u32>(inst); | 324 | return FoldAdd<u32>(block, inst); |
| 255 | case IR::Opcode::ISub32: | 325 | case IR::Opcode::ISub32: |
| 256 | return FoldISub32(inst); | 326 | return FoldISub32(inst); |
| 257 | case IR::Opcode::BitCastF32U32: | 327 | case IR::Opcode::BitCastF32U32: |
| @@ -259,7 +329,7 @@ void ConstantPropagation(IR::Inst& inst) { | |||
| 259 | case IR::Opcode::BitCastU32F32: | 329 | case IR::Opcode::BitCastU32F32: |
| 260 | return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32); | 330 | return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32); |
| 261 | case IR::Opcode::IAdd64: | 331 | case IR::Opcode::IAdd64: |
| 262 | return FoldAdd<u64>(inst); | 332 | return FoldAdd<u64>(block, inst); |
| 263 | case IR::Opcode::Select32: | 333 | case IR::Opcode::Select32: |
| 264 | return FoldSelect<u32>(inst); | 334 | return FoldSelect<u32>(inst); |
| 265 | case IR::Opcode::LogicalAnd: | 335 | case IR::Opcode::LogicalAnd: |
| @@ -292,7 +362,9 @@ void ConstantPropagation(IR::Inst& inst) { | |||
| 292 | } // Anonymous namespace | 362 | } // Anonymous namespace |
| 293 | 363 | ||
| 294 | void ConstantPropagationPass(IR::Block& block) { | 364 | void ConstantPropagationPass(IR::Block& block) { |
| 295 | std::ranges::for_each(block, ConstantPropagation); | 365 | for (IR::Inst& inst : block) { |
| 366 | ConstantPropagation(block, inst); | ||
| 367 | } | ||
| 296 | } | 368 | } |
| 297 | 369 | ||
| 298 | } // namespace Shader::Optimization | 370 | } // namespace Shader::Optimization |