summaryrefslogtreecommitdiff
path: root/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/shader_recompiler/ir_opt/constant_propagation_pass.cpp')
-rw-r--r--src/shader_recompiler/ir_opt/constant_propagation_pass.cpp49
1 files changed, 26 insertions, 23 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index 1720d7a09..61fbbe04c 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -58,7 +58,7 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) {
58 } 58 }
59 if (is_lhs_immediate && !is_rhs_immediate) { 59 if (is_lhs_immediate && !is_rhs_immediate) {
60 IR::Inst* const rhs_inst{rhs.InstRecursive()}; 60 IR::Inst* const rhs_inst{rhs.InstRecursive()};
61 if (rhs_inst->Opcode() == inst.Opcode() && rhs_inst->Arg(1).IsImmediate()) { 61 if (rhs_inst->GetOpcode() == inst.GetOpcode() && rhs_inst->Arg(1).IsImmediate()) {
62 const auto combined{imm_fn(Arg<T>(lhs), Arg<T>(rhs_inst->Arg(1)))}; 62 const auto combined{imm_fn(Arg<T>(lhs), Arg<T>(rhs_inst->Arg(1)))};
63 inst.SetArg(0, rhs_inst->Arg(0)); 63 inst.SetArg(0, rhs_inst->Arg(0));
64 inst.SetArg(1, IR::Value{combined}); 64 inst.SetArg(1, IR::Value{combined});
@@ -70,7 +70,7 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) {
70 } 70 }
71 if (!is_lhs_immediate && is_rhs_immediate) { 71 if (!is_lhs_immediate && is_rhs_immediate) {
72 const IR::Inst* const lhs_inst{lhs.InstRecursive()}; 72 const IR::Inst* const lhs_inst{lhs.InstRecursive()};
73 if (lhs_inst->Opcode() == inst.Opcode() && lhs_inst->Arg(1).IsImmediate()) { 73 if (lhs_inst->GetOpcode() == inst.GetOpcode() && lhs_inst->Arg(1).IsImmediate()) {
74 const auto combined{imm_fn(Arg<T>(rhs), Arg<T>(lhs_inst->Arg(1)))}; 74 const auto combined{imm_fn(Arg<T>(rhs), Arg<T>(lhs_inst->Arg(1)))};
75 inst.SetArg(0, lhs_inst->Arg(0)); 75 inst.SetArg(0, lhs_inst->Arg(0));
76 inst.SetArg(1, IR::Value{combined}); 76 inst.SetArg(1, IR::Value{combined});
@@ -123,7 +123,8 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
123 return false; 123 return false;
124 } 124 }
125 IR::Inst* const lhs_shl{lhs_arg.InstRecursive()}; 125 IR::Inst* const lhs_shl{lhs_arg.InstRecursive()};
126 if (lhs_shl->Opcode() != IR::Opcode::ShiftLeftLogical32 || lhs_shl->Arg(1) != IR::Value{16U}) { 126 if (lhs_shl->GetOpcode() != IR::Opcode::ShiftLeftLogical32 ||
127 lhs_shl->Arg(1) != IR::Value{16U}) {
127 return false; 128 return false;
128 } 129 }
129 if (lhs_shl->Arg(0).IsImmediate()) { 130 if (lhs_shl->Arg(0).IsImmediate()) {
@@ -131,7 +132,7 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
131 } 132 }
132 IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()}; 133 IR::Inst* const lhs_mul{lhs_shl->Arg(0).InstRecursive()};
133 IR::Inst* const rhs_mul{rhs_arg.InstRecursive()}; 134 IR::Inst* const rhs_mul{rhs_arg.InstRecursive()};
134 if (lhs_mul->Opcode() != IR::Opcode::IMul32 || rhs_mul->Opcode() != IR::Opcode::IMul32) { 135 if (lhs_mul->GetOpcode() != IR::Opcode::IMul32 || rhs_mul->GetOpcode() != IR::Opcode::IMul32) {
135 return false; 136 return false;
136 } 137 }
137 if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) { 138 if (lhs_mul->Arg(1).Resolve() != rhs_mul->Arg(1).Resolve()) {
@@ -143,10 +144,10 @@ bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
143 } 144 }
144 IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()}; 145 IR::Inst* const lhs_bfe{lhs_mul->Arg(0).InstRecursive()};
145 IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()}; 146 IR::Inst* const rhs_bfe{rhs_mul->Arg(0).InstRecursive()};
146 if (lhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) { 147 if (lhs_bfe->GetOpcode() != IR::Opcode::BitFieldUExtract) {
147 return false; 148 return false;
148 } 149 }
149 if (rhs_bfe->Opcode() != IR::Opcode::BitFieldUExtract) { 150 if (rhs_bfe->GetOpcode() != IR::Opcode::BitFieldUExtract) {
150 return false; 151 return false;
151 } 152 }
152 if (lhs_bfe->Arg(1) != IR::Value{16U} || lhs_bfe->Arg(2) != IR::Value{16U}) { 153 if (lhs_bfe->Arg(1) != IR::Value{16U} || lhs_bfe->Arg(2) != IR::Value{16U}) {
@@ -194,8 +195,9 @@ void FoldISub32(IR::Inst& inst) {
194 // ISub32 is generally used to subtract two constant buffers, compare and replace this with 195 // ISub32 is generally used to subtract two constant buffers, compare and replace this with
195 // zero if they equal. 196 // zero if they equal.
196 const auto equal_cbuf{[](IR::Inst* a, IR::Inst* b) { 197 const auto equal_cbuf{[](IR::Inst* a, IR::Inst* b) {
197 return a->Opcode() == IR::Opcode::GetCbufU32 && b->Opcode() == IR::Opcode::GetCbufU32 && 198 return a->GetOpcode() == IR::Opcode::GetCbufU32 &&
198 a->Arg(0) == b->Arg(0) && a->Arg(1) == b->Arg(1); 199 b->GetOpcode() == IR::Opcode::GetCbufU32 && a->Arg(0) == b->Arg(0) &&
200 a->Arg(1) == b->Arg(1);
199 }}; 201 }};
200 IR::Inst* op_a{inst.Arg(0).InstRecursive()}; 202 IR::Inst* op_a{inst.Arg(0).InstRecursive()};
201 IR::Inst* op_b{inst.Arg(1).InstRecursive()}; 203 IR::Inst* op_b{inst.Arg(1).InstRecursive()};
@@ -204,15 +206,15 @@ void FoldISub32(IR::Inst& inst) {
204 return; 206 return;
205 } 207 }
206 // It's also possible a value is being added to a cbuf and then subtracted 208 // It's also possible a value is being added to a cbuf and then subtracted
207 if (op_b->Opcode() == IR::Opcode::IAdd32) { 209 if (op_b->GetOpcode() == IR::Opcode::IAdd32) {
208 // Canonicalize local variables to simplify the following logic 210 // Canonicalize local variables to simplify the following logic
209 std::swap(op_a, op_b); 211 std::swap(op_a, op_b);
210 } 212 }
211 if (op_b->Opcode() != IR::Opcode::GetCbufU32) { 213 if (op_b->GetOpcode() != IR::Opcode::GetCbufU32) {
212 return; 214 return;
213 } 215 }
214 IR::Inst* const inst_cbuf{op_b}; 216 IR::Inst* const inst_cbuf{op_b};
215 if (op_a->Opcode() != IR::Opcode::IAdd32) { 217 if (op_a->GetOpcode() != IR::Opcode::IAdd32) {
216 return; 218 return;
217 } 219 }
218 IR::Value add_op_a{op_a->Arg(0)}; 220 IR::Value add_op_a{op_a->Arg(0)};
@@ -250,7 +252,8 @@ void FoldFPMul32(IR::Inst& inst) {
250 } 252 }
251 IR::Inst* const lhs_op{lhs_value.InstRecursive()}; 253 IR::Inst* const lhs_op{lhs_value.InstRecursive()};
252 IR::Inst* const rhs_op{rhs_value.InstRecursive()}; 254 IR::Inst* const rhs_op{rhs_value.InstRecursive()};
253 if (lhs_op->Opcode() != IR::Opcode::FPMul32 || rhs_op->Opcode() != IR::Opcode::FPRecip32) { 255 if (lhs_op->GetOpcode() != IR::Opcode::FPMul32 ||
256 rhs_op->GetOpcode() != IR::Opcode::FPRecip32) {
254 return; 257 return;
255 } 258 }
256 const IR::Value recip_source{rhs_op->Arg(0)}; 259 const IR::Value recip_source{rhs_op->Arg(0)};
@@ -260,8 +263,8 @@ void FoldFPMul32(IR::Inst& inst) {
260 } 263 }
261 IR::Inst* const attr_a{recip_source.InstRecursive()}; 264 IR::Inst* const attr_a{recip_source.InstRecursive()};
262 IR::Inst* const attr_b{lhs_mul_source.InstRecursive()}; 265 IR::Inst* const attr_b{lhs_mul_source.InstRecursive()};
263 if (attr_a->Opcode() != IR::Opcode::GetAttribute || 266 if (attr_a->GetOpcode() != IR::Opcode::GetAttribute ||
264 attr_b->Opcode() != IR::Opcode::GetAttribute) { 267 attr_b->GetOpcode() != IR::Opcode::GetAttribute) {
265 return; 268 return;
266 } 269 }
267 if (attr_a->Arg(0).Attribute() == attr_b->Arg(0).Attribute()) { 270 if (attr_a->Arg(0).Attribute() == attr_b->Arg(0).Attribute()) {
@@ -304,7 +307,7 @@ void FoldLogicalNot(IR::Inst& inst) {
304 return; 307 return;
305 } 308 }
306 IR::Inst* const arg{value.InstRecursive()}; 309 IR::Inst* const arg{value.InstRecursive()};
307 if (arg->Opcode() == IR::Opcode::LogicalNot) { 310 if (arg->GetOpcode() == IR::Opcode::LogicalNot) {
308 inst.ReplaceUsesWith(arg->Arg(0)); 311 inst.ReplaceUsesWith(arg->Arg(0));
309 } 312 }
310} 313}
@@ -317,12 +320,12 @@ void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) {
317 return; 320 return;
318 } 321 }
319 IR::Inst* const arg_inst{value.InstRecursive()}; 322 IR::Inst* const arg_inst{value.InstRecursive()};
320 if (arg_inst->Opcode() == reverse) { 323 if (arg_inst->GetOpcode() == reverse) {
321 inst.ReplaceUsesWith(arg_inst->Arg(0)); 324 inst.ReplaceUsesWith(arg_inst->Arg(0));
322 return; 325 return;
323 } 326 }
324 if constexpr (op == IR::Opcode::BitCastF32U32) { 327 if constexpr (op == IR::Opcode::BitCastF32U32) {
325 if (arg_inst->Opcode() == IR::Opcode::GetCbufU32) { 328 if (arg_inst->GetOpcode() == IR::Opcode::GetCbufU32) {
326 // Replace the bitcast with a typed constant buffer read 329 // Replace the bitcast with a typed constant buffer read
327 inst.ReplaceOpcode(IR::Opcode::GetCbufF32); 330 inst.ReplaceOpcode(IR::Opcode::GetCbufF32);
328 inst.SetArg(0, arg_inst->Arg(0)); 331 inst.SetArg(0, arg_inst->Arg(0));
@@ -338,7 +341,7 @@ void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) {
338 return; 341 return;
339 } 342 }
340 IR::Inst* const arg_inst{value.InstRecursive()}; 343 IR::Inst* const arg_inst{value.InstRecursive()};
341 if (arg_inst->Opcode() == reverse) { 344 if (arg_inst->GetOpcode() == reverse) {
342 inst.ReplaceUsesWith(arg_inst->Arg(0)); 345 inst.ReplaceUsesWith(arg_inst->Arg(0));
343 return; 346 return;
344 } 347 }
@@ -347,7 +350,7 @@ void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) {
347template <typename Func, size_t... I> 350template <typename Func, size_t... I>
348IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<I...>) { 351IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<I...>) {
349 using Traits = LambdaTraits<decltype(func)>; 352 using Traits = LambdaTraits<decltype(func)>;
350 return IR::Value{func(Arg<Traits::ArgType<I>>(inst.Arg(I))...)}; 353 return IR::Value{func(Arg<typename Traits::template ArgType<I>>(inst.Arg(I))...)};
351} 354}
352 355
353void FoldBranchConditional(IR::Inst& inst) { 356void FoldBranchConditional(IR::Inst& inst) {
@@ -357,7 +360,7 @@ void FoldBranchConditional(IR::Inst& inst) {
357 return; 360 return;
358 } 361 }
359 const IR::Inst* cond_inst{cond.InstRecursive()}; 362 const IR::Inst* cond_inst{cond.InstRecursive()};
360 if (cond_inst->Opcode() == IR::Opcode::LogicalNot) { 363 if (cond_inst->GetOpcode() == IR::Opcode::LogicalNot) {
361 const IR::Value true_label{inst.Arg(1)}; 364 const IR::Value true_label{inst.Arg(1)};
362 const IR::Value false_label{inst.Arg(2)}; 365 const IR::Value false_label{inst.Arg(2)};
363 // Remove negation on the conditional (take the parameter out of LogicalNot) and swap 366 // Remove negation on the conditional (take the parameter out of LogicalNot) and swap
@@ -371,10 +374,10 @@ void FoldBranchConditional(IR::Inst& inst) {
371std::optional<IR::Value> FoldCompositeExtractImpl(IR::Value inst_value, IR::Opcode insert, 374std::optional<IR::Value> FoldCompositeExtractImpl(IR::Value inst_value, IR::Opcode insert,
372 IR::Opcode construct, u32 first_index) { 375 IR::Opcode construct, u32 first_index) {
373 IR::Inst* const inst{inst_value.InstRecursive()}; 376 IR::Inst* const inst{inst_value.InstRecursive()};
374 if (inst->Opcode() == construct) { 377 if (inst->GetOpcode() == construct) {
375 return inst->Arg(first_index); 378 return inst->Arg(first_index);
376 } 379 }
377 if (inst->Opcode() != insert) { 380 if (inst->GetOpcode() != insert) {
378 return std::nullopt; 381 return std::nullopt;
379 } 382 }
380 IR::Value value_index{inst->Arg(2)}; 383 IR::Value value_index{inst->Arg(2)};
@@ -410,7 +413,7 @@ void FoldCompositeExtract(IR::Inst& inst, IR::Opcode construct, IR::Opcode inser
410} 413}
411 414
412void ConstantPropagation(IR::Block& block, IR::Inst& inst) { 415void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
413 switch (inst.Opcode()) { 416 switch (inst.GetOpcode()) {
414 case IR::Opcode::GetRegister: 417 case IR::Opcode::GetRegister:
415 return FoldGetRegister(inst); 418 return FoldGetRegister(inst);
416 case IR::Opcode::GetPred: 419 case IR::Opcode::GetPred: