summaryrefslogtreecommitdiff
path: root/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
diff options
context:
space:
mode:
authorGravatar ReinUsesLisp2021-02-16 04:10:22 -0300
committerGravatar ameerj2021-07-22 21:51:22 -0400
commitb5d7279d878211654b4abb165d94af763a365f47 (patch)
tree9b3a7b6e9d7d2b8945fe87d27ff75f1712ef06aa /src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
parentshader: Improve object pool (diff)
downloadyuzu-b5d7279d878211654b4abb165d94af763a365f47.tar.gz
yuzu-b5d7279d878211654b4abb165d94af763a365f47.tar.xz
yuzu-b5d7279d878211654b4abb165d94af763a365f47.zip
spirv: Initial bindings support
Diffstat (limited to 'src/shader_recompiler/ir_opt/constant_propagation_pass.cpp')
-rw-r--r--src/shader_recompiler/ir_opt/constant_propagation_pass.cpp76
1 files changed, 64 insertions, 12 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index cbde65b9b..f1ad16d60 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -77,6 +77,16 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) {
77 return true; 77 return true;
78} 78}
79 79
80template <typename Func>
81bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
82 if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) {
83 return false;
84 }
85 using Indices = std::make_index_sequence<LambdaTraits<decltype(func)>::NUM_ARGS>;
86 inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{}));
87 return true;
88}
89
80void FoldGetRegister(IR::Inst& inst) { 90void FoldGetRegister(IR::Inst& inst) {
81 if (inst.Arg(0).Reg() == IR::Reg::RZ) { 91 if (inst.Arg(0).Reg() == IR::Reg::RZ) {
82 inst.ReplaceUsesWith(IR::Value{u32{0}}); 92 inst.ReplaceUsesWith(IR::Value{u32{0}});
@@ -103,6 +113,52 @@ void FoldAdd(IR::Inst& inst) {
103 } 113 }
104} 114}
105 115
116void FoldISub32(IR::Inst& inst) {
117 if (FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a - b; })) {
118 return;
119 }
120 if (inst.Arg(0).IsImmediate() || inst.Arg(1).IsImmediate()) {
121 return;
122 }
123 // ISub32 is generally used to subtract two constant buffers, compare and replace this with
124 // zero if they equal.
125 const auto equal_cbuf{[](IR::Inst* a, IR::Inst* b) {
126 return a->Opcode() == IR::Opcode::GetCbuf && b->Opcode() == IR::Opcode::GetCbuf &&
127 a->Arg(0) == b->Arg(0) && a->Arg(1) == b->Arg(1);
128 }};
129 IR::Inst* op_a{inst.Arg(0).InstRecursive()};
130 IR::Inst* op_b{inst.Arg(1).InstRecursive()};
131 if (equal_cbuf(op_a, op_b)) {
132 inst.ReplaceUsesWith(IR::Value{u32{0}});
133 return;
134 }
135 // It's also possible a value is being added to a cbuf and then subtracted
136 if (op_b->Opcode() == IR::Opcode::IAdd32) {
137 // Canonicalize local variables to simplify the following logic
138 std::swap(op_a, op_b);
139 }
140 if (op_b->Opcode() != IR::Opcode::GetCbuf) {
141 return;
142 }
143 IR::Inst* const inst_cbuf{op_b};
144 if (op_a->Opcode() != IR::Opcode::IAdd32) {
145 return;
146 }
147 IR::Value add_op_a{op_a->Arg(0)};
148 IR::Value add_op_b{op_a->Arg(1)};
149 if (add_op_b.IsImmediate()) {
150 // Canonicalize
151 std::swap(add_op_a, add_op_b);
152 }
153 if (add_op_b.IsImmediate()) {
154 return;
155 }
156 IR::Inst* const add_cbuf{add_op_b.InstRecursive()};
157 if (equal_cbuf(add_cbuf, inst_cbuf)) {
158 inst.ReplaceUsesWith(add_op_a);
159 }
160}
161
106template <typename T> 162template <typename T>
107void FoldSelect(IR::Inst& inst) { 163void FoldSelect(IR::Inst& inst) {
108 const IR::Value cond{inst.Arg(0)}; 164 const IR::Value cond{inst.Arg(0)};
@@ -170,15 +226,6 @@ IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<
170 return IR::Value{func(Arg<Traits::ArgType<I>>(inst.Arg(I))...)}; 226 return IR::Value{func(Arg<Traits::ArgType<I>>(inst.Arg(I))...)};
171} 227}
172 228
173template <typename Func>
174void FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
175 if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) {
176 return;
177 }
178 using Indices = std::make_index_sequence<LambdaTraits<decltype(func)>::NUM_ARGS>;
179 inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{}));
180}
181
182void FoldBranchConditional(IR::Inst& inst) { 229void FoldBranchConditional(IR::Inst& inst) {
183 const IR::U1 cond{inst.Arg(0)}; 230 const IR::U1 cond{inst.Arg(0)};
184 if (cond.IsImmediate()) { 231 if (cond.IsImmediate()) {
@@ -205,6 +252,8 @@ void ConstantPropagation(IR::Inst& inst) {
205 return FoldGetPred(inst); 252 return FoldGetPred(inst);
206 case IR::Opcode::IAdd32: 253 case IR::Opcode::IAdd32:
207 return FoldAdd<u32>(inst); 254 return FoldAdd<u32>(inst);
255 case IR::Opcode::ISub32:
256 return FoldISub32(inst);
208 case IR::Opcode::BitCastF32U32: 257 case IR::Opcode::BitCastF32U32:
209 return FoldBitCast<f32, u32>(inst, IR::Opcode::BitCastU32F32); 258 return FoldBitCast<f32, u32>(inst, IR::Opcode::BitCastU32F32);
210 case IR::Opcode::BitCastU32F32: 259 case IR::Opcode::BitCastU32F32:
@@ -220,17 +269,20 @@ void ConstantPropagation(IR::Inst& inst) {
220 case IR::Opcode::LogicalNot: 269 case IR::Opcode::LogicalNot:
221 return FoldLogicalNot(inst); 270 return FoldLogicalNot(inst);
222 case IR::Opcode::SLessThan: 271 case IR::Opcode::SLessThan:
223 return FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; }); 272 FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; });
273 return;
224 case IR::Opcode::ULessThan: 274 case IR::Opcode::ULessThan:
225 return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); 275 FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; });
276 return;
226 case IR::Opcode::BitFieldUExtract: 277 case IR::Opcode::BitFieldUExtract:
227 return FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) { 278 FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) {
228 if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) { 279 if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) {
229 throw LogicError("Undefined result in {}({}, {}, {})", IR::Opcode::BitFieldUExtract, 280 throw LogicError("Undefined result in {}({}, {}, {})", IR::Opcode::BitFieldUExtract,
230 base, shift, count); 281 base, shift, count);
231 } 282 }
232 return (base >> shift) & ((1U << count) - 1); 283 return (base >> shift) & ((1U << count) - 1);
233 }); 284 });
285 return;
234 case IR::Opcode::BranchConditional: 286 case IR::Opcode::BranchConditional:
235 return FoldBranchConditional(inst); 287 return FoldBranchConditional(inst);
236 default: 288 default: