diff options
Diffstat (limited to 'src/shader_recompiler/ir_opt/constant_propagation_pass.cpp')
| -rw-r--r-- | src/shader_recompiler/ir_opt/constant_propagation_pass.cpp | 49 |
1 files changed, 26 insertions, 23 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp index 1720d7a09..61fbbe04c 100644 --- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp | |||
| @@ -58,7 +58,7 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { | |||
| 58 | } | 58 | } |
| 59 | if (is_lhs_immediate && !is_rhs_immediate) { | 59 | if (is_lhs_immediate && !is_rhs_immediate) { |
| 60 | IR::Inst* const rhs_inst{rhs.InstRecursive()}; | 60 | IR::Inst* const rhs_inst{rhs.InstRecursive()}; |
| 61 | if (rhs_inst->Opcode() == inst.Opcode() && rhs_inst->Arg(1).IsImmediate()) { | 61 | if (rhs_inst->GetOpcode() == inst.GetOpcode() && rhs_inst->Arg(1).IsImmediate()) { |
| 62 | const auto combined{imm_fn(Arg<T>(lhs), Arg<T>(rhs_inst->Arg(1)))}; | 62 | const auto combined{imm_fn(Arg<T>(lhs), Arg<T>(rhs_inst->Arg(1)))}; |
| 63 | inst.SetArg(0, rhs_inst->Arg(0)); | 63 | inst.SetArg(0, rhs_inst->Arg(0)); |
| 64 | inst.SetArg(1, IR::Value{combined}); | 64 | inst.SetArg(1, IR::Value{combined}); |
| @@ -70,7 +70,7 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) { | |||
| 70 | } | 70 | } |
| 71 | if (!is_lhs_immediate && is_rhs_immediate) { | 71 | if (!is_lhs_immediate && is_rhs_immediate) { |
| 72 | const IR::Inst* const lhs_inst{lhs.InstRecursive()}; | 72 | const IR::Inst* const lhs_inst{lhs.InstRecursive()}; |
| 73 | if (lhs_inst->Opcode() == inst.Opcode() && lhs_inst->Arg(1).IsImmediate()) { | 73 | if (lhs_inst->GetOpcode() == inst.GetOpcode() && lhs_inst->Arg(1).IsImmediate()) { |
| 74 | const auto combined{imm_fn(Arg<T>(rhs), Arg<T>(lhs_inst->Arg(1)))}; | 74 | const auto combined{imm_fn(Arg<T>(rhs), Arg<T>(lhs_inst->Arg(1)))}; |
| 75 | inst.SetArg(0, lhs_inst->Arg(0)); | 75 | inst.SetArg(0, lhs_inst->Arg(0)); |
| 76 | inst.SetArg(1, IR::Value{combined}); | 76 | inst.SetArg(1, IR::Value{combined}); |
| @@ -123,7 +123,8 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { | |||
| 123 | return false; | 123 | return false; |
| 124 | } | 124 | } |
| 125 | IR::Inst* const lhs_shl{lhs_arg.InstRecursive()}; | 125 | IR::Inst* const lhs_shl{lhs_arg.InstRecursive()}; |
| 126 | if (lhs_shl->Opcode() != IR::Opcode::ShiftLeftLogical32 || lhs_shl->Arg(1) != IR::Value{16U}) { | 126 | if (lhs_shl->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || |
| 127 | lhs_shl->Arg(1) != IR::Value{16U}) { | ||
| 127 | return false; | 128 | return false; |
| 128 | } | 129 | } |
| 129 | if (lhs_shl->Arg(0).IsImmediate()) { | 130 | if (lhs_shl->Arg(0).IsImmediate()) { |
| @@ -131,7 +132,7 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { | |||
| 131 | } | 132 | } |
| 132 | IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()}; | 133 | IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()}; |
| 133 | IR::Inst* const rhs_mul{rhs_arg.InstRecursive()}; | 134 | IR::Inst* const rhs_mul{rhs_arg.InstRecursive()}; |
| 134 | if (lhs_mul->Opcode() != IR::Opcode::IMul32 || rhs_mul->Opcode() != IR::Opcode::IMul32) { | 135 | if (lhs_mul->GetOpcode() != IR::Opcode::IMul32 || rhs_mul->GetOpcode() != IR::Opcode::IMul32) { |
| 135 | return false; | 136 | return false; |
| 136 | } | 137 | } |
| 137 | if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) { | 138 | if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) { |
| @@ -143,10 +144,10 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { | |||
| 143 | } | 144 | } |
| 144 | IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()}; | 145 | IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()}; |
| 145 | IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()}; | 146 | IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()}; |
| 146 | if (lhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) { | 147 | if (lhs_bfe->GetOpcode() != IR::Opcode::BitFieldUExtract) { |
| 147 | return false; | 148 | return false; |
| 148 | } | 149 | } |
| 149 | if (rhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) { | 150 | if (rhs_bfe->GetOpcode() != IR::Opcode::BitFieldUExtract) { |
| 150 | return false; | 151 | return false; |
| 151 | } | 152 | } |
| 152 | if (lhs_bfe->Arg(1) != IR::Value{16U} || lhs_bfe->Arg(2) != IR::Value{16U}) { | 153 | if (lhs_bfe->Arg(1) != IR::Value{16U} || lhs_bfe->Arg(2) != IR::Value{16U}) { |
| @@ -194,8 +195,9 @@ void FoldISub32(IR::Inst& inst) { | |||
| 194 | // ISub32 is generally used to subtract two constant buffers, compare and replace this with | 195 | // ISub32 is generally used to subtract two constant buffers, compare and replace this with |
| 195 | // zero if they equal. | 196 | // zero if they equal. |
| 196 | const auto equal_cbuf{[](IR::Inst* a, IR::Inst* b) { | 197 | const auto equal_cbuf{[](IR::Inst* a, IR::Inst* b) { |
| 197 | return a->Opcode() == IR::Opcode::GetCbufU32 && b->Opcode() == IR::Opcode::GetCbufU32 && | 198 | return a->GetOpcode() == IR::Opcode::GetCbufU32 && |
| 198 | a->Arg(0) == b->Arg(0) && a->Arg(1) == b->Arg(1); | 199 | b->GetOpcode() == IR::Opcode::GetCbufU32 && a->Arg(0) == b->Arg(0) && |
| 200 | a->Arg(1) == b->Arg(1); | ||
| 199 | }}; | 201 | }}; |
| 200 | IR::Inst* op_a{inst.Arg(0).InstRecursive()}; | 202 | IR::Inst* op_a{inst.Arg(0).InstRecursive()}; |
| 201 | IR::Inst* op_b{inst.Arg(1).InstRecursive()}; | 203 | IR::Inst* op_b{inst.Arg(1).InstRecursive()}; |
| @@ -204,15 +206,15 @@ void FoldISub32(IR::Inst& inst) { | |||
| 204 | return; | 206 | return; |
| 205 | } | 207 | } |
| 206 | // It's also possible a value is being added to a cbuf and then subtracted | 208 | // It's also possible a value is being added to a cbuf and then subtracted |
| 207 | if (op_b->Opcode() == IR::Opcode::IAdd32) { | 209 | if (op_b->GetOpcode() == IR::Opcode::IAdd32) { |
| 208 | // Canonicalize local variables to simplify the following logic | 210 | // Canonicalize local variables to simplify the following logic |
| 209 | std::swap(op_a, op_b); | 211 | std::swap(op_a, op_b); |
| 210 | } | 212 | } |
| 211 | if (op_b->Opcode() != IR::Opcode::GetCbufU32) { | 213 | if (op_b->GetOpcode() != IR::Opcode::GetCbufU32) { |
| 212 | return; | 214 | return; |
| 213 | } | 215 | } |
| 214 | IR::Inst* const inst_cbuf{op_b}; | 216 | IR::Inst* const inst_cbuf{op_b}; |
| 215 | if (op_a->Opcode() != IR::Opcode::IAdd32) { | 217 | if (op_a->GetOpcode() != IR::Opcode::IAdd32) { |
| 216 | return; | 218 | return; |
| 217 | } | 219 | } |
| 218 | IR::Value add_op_a{op_a->Arg(0)}; | 220 | IR::Value add_op_a{op_a->Arg(0)}; |
| @@ -250,7 +252,8 @@ void FoldFPMul32(IR::Inst& inst) { | |||
| 250 | } | 252 | } |
| 251 | IR::Inst* const lhs_op{lhs_value.InstRecursive()}; | 253 | IR::Inst* const lhs_op{lhs_value.InstRecursive()}; |
| 252 | IR::Inst* const rhs_op{rhs_value.InstRecursive()}; | 254 | IR::Inst* const rhs_op{rhs_value.InstRecursive()}; |
| 253 | if (lhs_op->Opcode() != IR::Opcode::FPMul32 || rhs_op->Opcode() != IR::Opcode::FPRecip32) { | 255 | if (lhs_op->GetOpcode() != IR::Opcode::FPMul32 || |
| 256 | rhs_op->GetOpcode() != IR::Opcode::FPRecip32) { | ||
| 254 | return; | 257 | return; |
| 255 | } | 258 | } |
| 256 | const IR::Value recip_source{rhs_op->Arg(0)}; | 259 | const IR::Value recip_source{rhs_op->Arg(0)}; |
| @@ -260,8 +263,8 @@ void FoldFPMul32(IR::Inst& inst) { | |||
| 260 | } | 263 | } |
| 261 | IR::Inst* const attr_a{recip_source.InstRecursive()}; | 264 | IR::Inst* const attr_a{recip_source.InstRecursive()}; |
| 262 | IR::Inst* const attr_b{lhs_mul_source.InstRecursive()}; | 265 | IR::Inst* const attr_b{lhs_mul_source.InstRecursive()}; |
| 263 | if (attr_a->Opcode() != IR::Opcode::GetAttribute || | 266 | if (attr_a->GetOpcode() != IR::Opcode::GetAttribute || |
| 264 | attr_b->Opcode() != IR::Opcode::GetAttribute) { | 267 | attr_b->GetOpcode() != IR::Opcode::GetAttribute) { |
| 265 | return; | 268 | return; |
| 266 | } | 269 | } |
| 267 | if (attr_a->Arg(0).Attribute() == attr_b->Arg(0).Attribute()) { | 270 | if (attr_a->Arg(0).Attribute() == attr_b->Arg(0).Attribute()) { |
| @@ -304,7 +307,7 @@ void FoldLogicalNot(IR::Inst& inst) { | |||
| 304 | return; | 307 | return; |
| 305 | } | 308 | } |
| 306 | IR::Inst* const arg{value.InstRecursive()}; | 309 | IR::Inst* const arg{value.InstRecursive()}; |
| 307 | if (arg->Opcode() == IR::Opcode::LogicalNot) { | 310 | if (arg->GetOpcode() == IR::Opcode::LogicalNot) { |
| 308 | inst.ReplaceUsesWith(arg->Arg(0)); | 311 | inst.ReplaceUsesWith(arg->Arg(0)); |
| 309 | } | 312 | } |
| 310 | } | 313 | } |
| @@ -317,12 +320,12 @@ void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) { | |||
| 317 | return; | 320 | return; |
| 318 | } | 321 | } |
| 319 | IR::Inst* const arg_inst{value.InstRecursive()}; | 322 | IR::Inst* const arg_inst{value.InstRecursive()}; |
| 320 | if (arg_inst->Opcode() == reverse) { | 323 | if (arg_inst->GetOpcode() == reverse) { |
| 321 | inst.ReplaceUsesWith(arg_inst->Arg(0)); | 324 | inst.ReplaceUsesWith(arg_inst->Arg(0)); |
| 322 | return; | 325 | return; |
| 323 | } | 326 | } |
| 324 | if constexpr (op == IR::Opcode::BitCastF32U32) { | 327 | if constexpr (op == IR::Opcode::BitCastF32U32) { |
| 325 | if (arg_inst->Opcode() == IR::Opcode::GetCbufU32) { | 328 | if (arg_inst->GetOpcode() == IR::Opcode::GetCbufU32) { |
| 326 | // Replace the bitcast with a typed constant buffer read | 329 | // Replace the bitcast with a typed constant buffer read |
| 327 | inst.ReplaceOpcode(IR::Opcode::GetCbufF32); | 330 | inst.ReplaceOpcode(IR::Opcode::GetCbufF32); |
| 328 | inst.SetArg(0, arg_inst->Arg(0)); | 331 | inst.SetArg(0, arg_inst->Arg(0)); |
| @@ -338,7 +341,7 @@ void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) { | |||
| 338 | return; | 341 | return; |
| 339 | } | 342 | } |
| 340 | IR::Inst* const arg_inst{value.InstRecursive()}; | 343 | IR::Inst* const arg_inst{value.InstRecursive()}; |
| 341 | if (arg_inst->Opcode() == reverse) { | 344 | if (arg_inst->GetOpcode() == reverse) { |
| 342 | inst.ReplaceUsesWith(arg_inst->Arg(0)); | 345 | inst.ReplaceUsesWith(arg_inst->Arg(0)); |
| 343 | return; | 346 | return; |
| 344 | } | 347 | } |
| @@ -347,7 +350,7 @@ void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) { | |||
| 347 | template <typename Func, size_t... I> | 350 | template <typename Func, size_t... I> |
| 348 | IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<I...>) { | 351 | IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<I...>) { |
| 349 | using Traits = LambdaTraits<decltype(func)>; | 352 | using Traits = LambdaTraits<decltype(func)>; |
| 350 | return IR::Value{func(Arg<Traits::ArgType<I>>(inst.Arg(I))...)}; | 353 | return IR::Value{func(Arg<typename Traits::template ArgType<I>>(inst.Arg(I))...)}; |
| 351 | } | 354 | } |
| 352 | 355 | ||
| 353 | void FoldBranchConditional(IR::Inst& inst) { | 356 | void FoldBranchConditional(IR::Inst& inst) { |
| @@ -357,7 +360,7 @@ void FoldBranchConditional(IR::Inst& inst) { | |||
| 357 | return; | 360 | return; |
| 358 | } | 361 | } |
| 359 | const IR::Inst* cond_inst{cond.InstRecursive()}; | 362 | const IR::Inst* cond_inst{cond.InstRecursive()}; |
| 360 | if (cond_inst->Opcode() == IR::Opcode::LogicalNot) { | 363 | if (cond_inst->GetOpcode() == IR::Opcode::LogicalNot) { |
| 361 | const IR::Value true_label{inst.Arg(1)}; | 364 | const IR::Value true_label{inst.Arg(1)}; |
| 362 | const IR::Value false_label{inst.Arg(2)}; | 365 | const IR::Value false_label{inst.Arg(2)}; |
| 363 | // Remove negation on the conditional (take the parameter out of LogicalNot) and swap | 366 | // Remove negation on the conditional (take the parameter out of LogicalNot) and swap |
| @@ -371,10 +374,10 @@ void FoldBranchConditional(IR::Inst& inst) { | |||
| 371 | std::optional<IR::Value> FoldCompositeExtractImpl(IR::Value inst_value, IR::Opcode insert, | 374 | std::optional<IR::Value> FoldCompositeExtractImpl(IR::Value inst_value, IR::Opcode insert, |
| 372 | IR::Opcode construct, u32 first_index) { | 375 | IR::Opcode construct, u32 first_index) { |
| 373 | IR::Inst* const inst{inst_value.InstRecursive()}; | 376 | IR::Inst* const inst{inst_value.InstRecursive()}; |
| 374 | if (inst->Opcode() == construct) { | 377 | if (inst->GetOpcode() == construct) { |
| 375 | return inst->Arg(first_index); | 378 | return inst->Arg(first_index); |
| 376 | } | 379 | } |
| 377 | if (inst->Opcode() != insert) { | 380 | if (inst->GetOpcode() != insert) { |
| 378 | return std::nullopt; | 381 | return std::nullopt; |
| 379 | } | 382 | } |
| 380 | IR::Value value_index{inst->Arg(2)}; | 383 | IR::Value value_index{inst->Arg(2)}; |
| @@ -410,7 +413,7 @@ void FoldCompositeExtract(IR::Inst& inst, IR::Opcode construct, IR::Opcode inser | |||
| 410 | } | 413 | } |
| 411 | 414 | ||
| 412 | void ConstantPropagation(IR::Block& block, IR::Inst& inst) { | 415 | void ConstantPropagation(IR::Block& block, IR::Inst& inst) { |
| 413 | switch (inst.Opcode()) { | 416 | switch (inst.GetOpcode()) { |
| 414 | case IR::Opcode::GetRegister: | 417 | case IR::Opcode::GetRegister: |
| 415 | return FoldGetRegister(inst); | 418 | return FoldGetRegister(inst); |
| 416 | case IR::Opcode::GetPred: | 419 | case IR::Opcode::GetPred: |