diff options
| -rw-r--r-- | src/shader_recompiler/ir_opt/constant_propagation_pass.cpp | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp index 08a06da02..c403a5fae 100644 --- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp +++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp | |||
| @@ -3,6 +3,7 @@ | |||
| 3 | // Refer to the license.txt file included. | 3 | // Refer to the license.txt file included. |
| 4 | 4 | ||
| 5 | #include <algorithm> | 5 | #include <algorithm> |
| 6 | #include <functional> | ||
| 6 | #include <tuple> | 7 | #include <tuple> |
| 7 | #include <type_traits> | 8 | #include <type_traits> |
| 8 | 9 | ||
| @@ -88,6 +89,26 @@ bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) { | |||
| 88 | return true; | 89 | return true; |
| 89 | } | 90 | } |
| 90 | 91 | ||
| 92 | /// Return true when all values in a range are equal | ||
| 93 | template <typename Range> | ||
| 94 | bool AreEqual(const Range& range) { | ||
| 95 | auto resolver{[](const auto& value) { return value.Resolve(); }}; | ||
| 96 | auto equal{[](const IR::Value& lhs, const IR::Value& rhs) { | ||
| 97 | if (lhs == rhs) { | ||
| 98 | return true; | ||
| 99 | } | ||
| 100 | // Not equal, but try to match if they read the same constant buffer | ||
| 101 | if (!lhs.IsImmediate() && !rhs.IsImmediate() && | ||
| 102 | lhs.Inst()->GetOpcode() == IR::Opcode::GetCbufU32 && | ||
| 103 | rhs.Inst()->GetOpcode() == IR::Opcode::GetCbufU32 && | ||
| 104 | lhs.Inst()->Arg(0) == rhs.Inst()->Arg(0) && lhs.Inst()->Arg(1) == rhs.Inst()->Arg(1)) { | ||
| 105 | return true; | ||
| 106 | } | ||
| 107 | return false; | ||
| 108 | }}; | ||
| 109 | return std::ranges::adjacent_find(range, std::not_fn(equal), resolver) == std::end(range); | ||
| 110 | } | ||
| 111 | |||
| 91 | void FoldGetRegister(IR::Inst& inst) { | 112 | void FoldGetRegister(IR::Inst& inst) { |
| 92 | if (inst.Arg(0).Reg() == IR::Reg::RZ) { | 113 | if (inst.Arg(0).Reg() == IR::Reg::RZ) { |
| 93 | inst.ReplaceUsesWith(IR::Value{u32{0}}); | 114 | inst.ReplaceUsesWith(IR::Value{u32{0}}); |
| @@ -100,6 +121,157 @@ void FoldGetPred(IR::Inst& inst) { | |||
| 100 | } | 121 | } |
| 101 | } | 122 | } |
| 102 | 123 | ||
| 124 | /// Replaces the XMAD pattern generated by an integer FMA | ||
| 125 | bool FoldXmadMultiplyAdd(IR::Block& block, IR::Inst& inst) { | ||
| 126 | /* | ||
| 127 | * We are looking for this specific pattern: | ||
| 128 | * %6 = BitFieldUExtract %op_b, #0, #16 | ||
| 129 | * %7 = BitFieldUExtract %op_a', #16, #16 | ||
| 130 | * %8 = IMul32 %6, %7 | ||
| 131 | * %10 = BitFieldUExtract %op_a', #0, #16 | ||
| 132 | * %11 = BitFieldInsert %8, %10, #16, #16 | ||
| 133 | * %15 = BitFieldUExtract %op_b, #0, #16 | ||
| 134 | * %16 = BitFieldUExtract %op_a, #0, #16 | ||
| 135 | * %17 = IMul32 %15, %16 | ||
| 136 | * %18 = IAdd32 %17, %op_c | ||
| 137 | * %22 = BitFieldUExtract %op_b, #16, #16 | ||
| 138 | * %23 = BitFieldUExtract %11, #16, #16 | ||
| 139 | * %24 = IMul32 %22, %23 | ||
| 140 | * %25 = ShiftLeftLogical32 %24, #16 | ||
| 141 | * %26 = ShiftLeftLogical32 %11, #16 | ||
| 142 | * %27 = IAdd32 %26, %18 | ||
| 143 | * %result = IAdd32 %25, %27 | ||
| 144 | * | ||
| 145 | * And replace it with: | ||
| 146 | * %temp = IMul32 %op_a, %op_b | ||
| 147 | * %result = IAdd32 %temp, %op_c | ||
| 148 | * | ||
| 149 | * This optimization has been proven safe by Nvidia's compiler logic being reversed. | ||
| 150 | * (If Nvidia generates this code from 'fma(a, b, c)', we can do the same in the reverse order.) | ||
| 151 | */ | ||
| 152 | const IR::Value zero{0u}; | ||
| 153 | const IR::Value sixteen{16u}; | ||
| 154 | IR::Inst* const _25{inst.Arg(0).TryInstRecursive()}; | ||
| 155 | IR::Inst* const _27{inst.Arg(1).TryInstRecursive()}; | ||
| 156 | if (!_25 || !_27) { | ||
| 157 | return false; | ||
| 158 | } | ||
| 159 | if (_27->GetOpcode() != IR::Opcode::IAdd32) { | ||
| 160 | return false; | ||
| 161 | } | ||
| 162 | if (_25->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || _25->Arg(1) != sixteen) { | ||
| 163 | return false; | ||
| 164 | } | ||
| 165 | IR::Inst* const _24{_25->Arg(0).TryInstRecursive()}; | ||
| 166 | if (!_24 || _24->GetOpcode() != IR::Opcode::IMul32) { | ||
| 167 | return false; | ||
| 168 | } | ||
| 169 | IR::Inst* const _22{_24->Arg(0).TryInstRecursive()}; | ||
| 170 | IR::Inst* const _23{_24->Arg(1).TryInstRecursive()}; | ||
| 171 | if (!_22 || !_23) { | ||
| 172 | return false; | ||
| 173 | } | ||
| 174 | if (_22->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||
| 175 | return false; | ||
| 176 | } | ||
| 177 | if (_23->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||
| 178 | return false; | ||
| 179 | } | ||
| 180 | if (_22->Arg(1) != sixteen || _22->Arg(2) != sixteen) { | ||
| 181 | return false; | ||
| 182 | } | ||
| 183 | if (_23->Arg(1) != sixteen || _23->Arg(2) != sixteen) { | ||
| 184 | return false; | ||
| 185 | } | ||
| 186 | IR::Inst* const _11{_23->Arg(0).TryInstRecursive()}; | ||
| 187 | if (!_11 || _11->GetOpcode() != IR::Opcode::BitFieldInsert) { | ||
| 188 | return false; | ||
| 189 | } | ||
| 190 | if (_11->Arg(2) != sixteen || _11->Arg(3) != sixteen) { | ||
| 191 | return false; | ||
| 192 | } | ||
| 193 | IR::Inst* const _8{_11->Arg(0).TryInstRecursive()}; | ||
| 194 | IR::Inst* const _10{_11->Arg(1).TryInstRecursive()}; | ||
| 195 | if (!_8 || !_10) { | ||
| 196 | return false; | ||
| 197 | } | ||
| 198 | if (_8->GetOpcode() != IR::Opcode::IMul32) { | ||
| 199 | return false; | ||
| 200 | } | ||
| 201 | if (_10->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||
| 202 | return false; | ||
| 203 | } | ||
| 204 | IR::Inst* const _6{_8->Arg(0).TryInstRecursive()}; | ||
| 205 | IR::Inst* const _7{_8->Arg(1).TryInstRecursive()}; | ||
| 206 | if (!_6 || !_7) { | ||
| 207 | return false; | ||
| 208 | } | ||
| 209 | if (_6->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||
| 210 | return false; | ||
| 211 | } | ||
| 212 | if (_7->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||
| 213 | return false; | ||
| 214 | } | ||
| 215 | if (_6->Arg(1) != zero || _6->Arg(2) != sixteen) { | ||
| 216 | return false; | ||
| 217 | } | ||
| 218 | if (_7->Arg(1) != sixteen || _7->Arg(2) != sixteen) { | ||
| 219 | return false; | ||
| 220 | } | ||
| 221 | IR::Inst* const _26{_27->Arg(0).TryInstRecursive()}; | ||
| 222 | IR::Inst* const _18{_27->Arg(1).TryInstRecursive()}; | ||
| 223 | if (!_26 || !_18) { | ||
| 224 | return false; | ||
| 225 | } | ||
| 226 | if (_26->GetOpcode() != IR::Opcode::ShiftLeftLogical32 || _26->Arg(1) != sixteen) { | ||
| 227 | return false; | ||
| 228 | } | ||
| 229 | if (_26->Arg(0).InstRecursive() != _11) { | ||
| 230 | return false; | ||
| 231 | } | ||
| 232 | if (_18->GetOpcode() != IR::Opcode::IAdd32) { | ||
| 233 | return false; | ||
| 234 | } | ||
| 235 | IR::Inst* const _17{_18->Arg(0).TryInstRecursive()}; | ||
| 236 | if (!_17 || _17->GetOpcode() != IR::Opcode::IMul32) { | ||
| 237 | return false; | ||
| 238 | } | ||
| 239 | IR::Inst* const _15{_17->Arg(0).TryInstRecursive()}; | ||
| 240 | IR::Inst* const _16{_17->Arg(1).TryInstRecursive()}; | ||
| 241 | if (!_15 || !_16) { | ||
| 242 | return false; | ||
| 243 | } | ||
| 244 | if (_15->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||
| 245 | return false; | ||
| 246 | } | ||
| 247 | if (_16->GetOpcode() != IR::Opcode::BitFieldUExtract) { | ||
| 248 | return false; | ||
| 249 | } | ||
| 250 | if (_15->Arg(1) != zero || _16->Arg(1) != zero || _10->Arg(1) != zero) { | ||
| 251 | return false; | ||
| 252 | } | ||
| 253 | if (_15->Arg(2) != sixteen || _16->Arg(2) != sixteen || _10->Arg(2) != sixteen) { | ||
| 254 | return false; | ||
| 255 | } | ||
| 256 | const std::array<IR::Value, 3> op_as{ | ||
| 257 | _7->Arg(0).Resolve(), | ||
| 258 | _16->Arg(0).Resolve(), | ||
| 259 | _10->Arg(0).Resolve(), | ||
| 260 | }; | ||
| 261 | const std::array<IR::Value, 3> op_bs{ | ||
| 262 | _22->Arg(0).Resolve(), | ||
| 263 | _6->Arg(0).Resolve(), | ||
| 264 | _15->Arg(0).Resolve(), | ||
| 265 | }; | ||
| 266 | const IR::U32 op_c{_18->Arg(1)}; | ||
| 267 | if (!AreEqual(op_as) || !AreEqual(op_bs)) { | ||
| 268 | return false; | ||
| 269 | } | ||
| 270 | IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)}; | ||
| 271 | inst.ReplaceUsesWith(ir.IAdd(ir.IMul(IR::U32{op_as[0]}, IR::U32{op_bs[1]}), op_c)); | ||
| 272 | return true; | ||
| 273 | } | ||
| 274 | |||
| 103 | /// Replaces the pattern generated by two XMAD multiplications | 275 | /// Replaces the pattern generated by two XMAD multiplications |
| 104 | bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { | 276 | bool FoldXmadMultiply(IR::Block& block, IR::Inst& inst) { |
| 105 | /* | 277 | /* |
| @@ -179,6 +351,9 @@ void FoldAdd(IR::Block& block, IR::Inst& inst) { | |||
| 179 | if (FoldXmadMultiply(block, inst)) { | 351 | if (FoldXmadMultiply(block, inst)) { |
| 180 | return; | 352 | return; |
| 181 | } | 353 | } |
| 354 | if (FoldXmadMultiplyAdd(block, inst)) { | ||
| 355 | return; | ||
| 356 | } | ||
| 182 | } | 357 | } |
| 183 | } | 358 | } |
| 184 | 359 | ||