summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/shader_recompiler/ir_opt/constant_propagation_pass.cpp175
1 files changed, 175 insertions, 0 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index 08a06da02..c403a5fae 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -3,6 +3,7 @@
3// Refer to the license.txt file included. 3// Refer to the license.txt file included.
4 4
5#include <algorithm> 5#include <algorithm>
6#include <functional>
6#include <tuple> 7#include <tuple>
7#include <type_traits> 8#include <type_traits>
8 9
@@ -88,6 +89,26 @@ bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
88 return true; 89 return true;
89} 90}
90 91
92/// Return true when all values in a range are equal
93template <typename Range>
94bool AreEqual(const Range& range) {
95 auto resolver{[](const auto& value) { return value.Resolve(); }};
96 auto equal{[](const IR::Value& lhs, const IR::Value& rhs) {
97 if (lhs == rhs) {
98 return true;
99 }
100 // Not equal, but try to match if they read the same constant buffer
101 if (!lhs.IsImmediate() && !rhs.IsImmediate() &&
102 lhs.Inst()->GetOpcode() == IR::Opcode::GetCbufU32 &&
103 rhs.Inst()->GetOpcode() == IR::Opcode::GetCbufU32 &&
104 lhs.Inst()->Arg(0) == rhs.Inst()->Arg(0) && lhs.Inst()->Arg(1) == rhs.Inst()->Arg(1)) {
105 return true;
106 }
107 return false;
108 }};
109 return std::ranges::adjacent_find(range, std::not_fn(equal), resolver) == std::end(range);
110}
111
91void FoldGetRegister(IR::Inst& inst) { 112void FoldGetRegister(IR::Inst& inst) {
92 if (inst.Arg(0).Reg() == IR::Reg::RZ) { 113 if (inst.Arg(0).Reg() == IR::Reg::RZ) {
93 inst.ReplaceUsesWith(IR::Value{u32{0}}); 114 inst.ReplaceUsesWith(IR::Value{u32{0}});
@@ -100,6 +121,157 @@ void FoldGetPred(IR::Inst& inst) {
100 } 121 }
101} 122}
102 123
124/// Replaces the XMAD pattern generated by an integer FMA
125bool FoldXmadMultiplyAdd(IR::Block& block, IR::Inst& inst) {
126 /*
127 * We are looking for this specific pattern:
128 * %6 = BitFieldUExtract %op_b, #0, #16
129 * %7 = BitFieldUExtract %op_a', #16, #16
130 * %8 = IMul32 %6, %7
131 * %10 = BitFieldUExtract %op_a', #0, #16
132 * %11 = BitFieldInsert %8, %10, #16, #16
133 * %15 = BitFieldUExtract %op_b, #0, #16
134 * %16 = BitFieldUExtract %op_a, #0, #16
135 * %17 = IMul32 %15, %16
136 * %18 = IAdd32 %17, %op_c
137 * %22 = BitFieldUExtract %op_b, #16, #16
138 * %23 = BitFieldUExtract %11, #16, #16
139 * %24 = IMul32 %22, %23
140 * %25 = ShiftLeftLogical32 %24, #16
141 * %26 = ShiftLeftLogical32 %11, #16
142 * %27 = IAdd32 %26, %18
143 * %result = IAdd32 %25, %27
144 *
145 * And replace it with:
146 * %temp = IMul32 %op_a, %op_b
147 * %result = IAdd32 %temp, %op_c
148 *
149 * This optimization has been proven safe by Nvidia's compiler logic being reversed.
150 * (If Nvidia generates this code from 'fma(a, b, c)', we can do the same in the reverse order.)
151 */
152 const IR::Value zero{0u};
153 const IR::Value sixteen{16u};
154 IR::Inst* const _25{inst.Arg(0).TryInstRecursive()};
155 IR::Inst* const _27{inst.Arg(1).TryInstRecursive()};
156 if (!_25 || !_27) {
157 return false;
158 }
159 if (_27->GetOpcode() != IR::Opcode::IAdd32) {
160 return false;
161 }
162 if (_25->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || _25->Arg(1) != sixteen) {
163 return false;
164 }
165 IR::Inst* const _24{_25->Arg(0).TryInstRecursive()};
166 if (!_24 || _24->GetOpcode() != IR::Opcode::IMul32) {
167 return false;
168 }
169 IR::Inst* const _22{_24->Arg(0).TryInstRecursive()};
170 IR::Inst* const _23{_24->Arg(1).TryInstRecursive()};
171 if (!_22 || !_23) {
172 return false;
173 }
174 if (_22->GetOpcode() != IR::Opcode::BitFieldUExtract) {
175 return false;
176 }
177 if (_23->GetOpcode() != IR::Opcode::BitFieldUExtract) {
178 return false;
179 }
180 if (_22->Arg(1) != sixteen || _22->Arg(2) != sixteen) {
181 return false;
182 }
183 if (_23->Arg(1) != sixteen || _23->Arg(2) != sixteen) {
184 return false;
185 }
186 IR::Inst* const _11{_23->Arg(0).TryInstRecursive()};
187 if (!_11 || _11->GetOpcode() != IR::Opcode::BitFieldInsert) {
188 return false;
189 }
190 if (_11->Arg(2) != sixteen || _11->Arg(3) != sixteen) {
191 return false;
192 }
193 IR::Inst* const _8{_11->Arg(0).TryInstRecursive()};
194 IR::Inst* const _10{_11->Arg(1).TryInstRecursive()};
195 if (!_8 || !_10) {
196 return false;
197 }
198 if (_8->GetOpcode() != IR::Opcode::IMul32) {
199 return false;
200 }
201 if (_10->GetOpcode() != IR::Opcode::BitFieldUExtract) {
202 return false;
203 }
204 IR::Inst* const _6{_8->Arg(0).TryInstRecursive()};
205 IR::Inst* const _7{_8->Arg(1).TryInstRecursive()};
206 if (!_6 || !_7) {
207 return false;
208 }
209 if (_6->GetOpcode() != IR::Opcode::BitFieldUExtract) {
210 return false;
211 }
212 if (_7->GetOpcode() != IR::Opcode::BitFieldUExtract) {
213 return false;
214 }
215 if (_6->Arg(1) != zero || _6->Arg(2) != sixteen) {
216 return false;
217 }
218 if (_7->Arg(1) != sixteen || _7->Arg(2) != sixteen) {
219 return false;
220 }
221 IR::Inst* const _26{_27->Arg(0).TryInstRecursive()};
222 IR::Inst* const _18{_27->Arg(1).TryInstRecursive()};
223 if (!_26 || !_18) {
224 return false;
225 }
226 if (_26->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || _26->Arg(1) != sixteen) {
227 return false;
228 }
229 if (_26->Arg(0).InstRecursive() != _11) {
230 return false;
231 }
232 if (_18->GetOpcode() != IR::Opcode::IAdd32) {
233 return false;
234 }
235 IR::Inst* const _17{_18->Arg(0).TryInstRecursive()};
236 if (!_17 || _17->GetOpcode() != IR::Opcode::IMul32) {
237 return false;
238 }
239 IR::Inst* const _15{_17->Arg(0).TryInstRecursive()};
240 IR::Inst* const _16{_17->Arg(1).TryInstRecursive()};
241 if (!_15 || !_16) {
242 return false;
243 }
244 if (_15->GetOpcode() != IR::Opcode::BitFieldUExtract) {
245 return false;
246 }
247 if (_16->GetOpcode() != IR::Opcode::BitFieldUExtract) {
248 return false;
249 }
250 if (_15->Arg(1) != zero || _16->Arg(1) != zero || _10->Arg(1) != zero) {
251 return false;
252 }
253 if (_15->Arg(2) != sixteen || _16->Arg(2) != sixteen || _10->Arg(2) != sixteen) {
254 return false;
255 }
256 const std::array<IR::Value, 3> op_as{
257 _7->Arg(0).Resolve(),
258 _16->Arg(0).Resolve(),
259 _10->Arg(0).Resolve(),
260 };
261 const std::array<IR::Value, 3> op_bs{
262 _22->Arg(0).Resolve(),
263 _6->Arg(0).Resolve(),
264 _15->Arg(0).Resolve(),
265 };
266 const IR::U32 op_c{_18->Arg(1)};
267 if (!AreEqual(op_as) || !AreEqual(op_bs)) {
268 return false;
269 }
270 IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
271 inst.ReplaceUsesWith(ir.IAdd(ir.IMul(IR::U32{op_as[0]}, IR::U32{op_bs[1]}), op_c));
272 return true;
273}
274
103/// Replaces the pattern generated by two XMAD multiplications 275/// Replaces the pattern generated by two XMAD multiplications
104bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { 276bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) {
105 /* 277 /*
@@ -179,6 +351,9 @@ void FoldAdd(IR::Block& block, IR::Inst& inst) {
179 if (FoldXmadMultiply(block, inst)) { 351 if (FoldXmadMultiply(block, inst)) {
180 return; 352 return;
181 } 353 }
354 if (FoldXmadMultiplyAdd(block, inst)) {
355 return;
356 }
182 } 357 }
183} 358}
184 359