summaryrefslogtreecommitdiff
path: root/src/shader_recompiler
diff options
context:
space:
mode:
Diffstat (limited to 'src/shader_recompiler')
-rw-r--r--src/shader_recompiler/frontend/ir/basic_block.cpp2
-rw-r--r--src/shader_recompiler/ir_opt/constant_propagation_pass.cpp59
2 files changed, 48 insertions, 13 deletions
diff --git a/src/shader_recompiler/frontend/ir/basic_block.cpp b/src/shader_recompiler/frontend/ir/basic_block.cpp
index 50c6a83cd..da33ff6f1 100644
--- a/src/shader_recompiler/frontend/ir/basic_block.cpp
+++ b/src/shader_recompiler/frontend/ir/basic_block.cpp
@@ -87,7 +87,7 @@ static std::string ArgToIndex(const std::map<const Block*, size_t>& block_to_ind
87 } 87 }
88 switch (arg.Type()) { 88 switch (arg.Type()) {
89 case Type::U1: 89 case Type::U1:
90 return fmt::format("#{}", arg.U1() ? '1' : '0'); 90 return fmt::format("#{}", arg.U1() ? "true" : "false");
91 case Type::U8: 91 case Type::U8:
92 return fmt::format("#{}", arg.U8()); 92 return fmt::format("#{}", arg.U8());
93 case Type::U16: 93 case Type::U16:
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index 7fb3192d8..f1170c61e 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -3,6 +3,7 @@
3// Refer to the license.txt file included. 3// Refer to the license.txt file included.
4 4
5#include <algorithm> 5#include <algorithm>
6#include <tuple>
6#include <type_traits> 7#include <type_traits>
7 8
8#include "common/bit_cast.h" 9#include "common/bit_cast.h"
@@ -13,12 +14,17 @@
13 14
14namespace Shader::Optimization { 15namespace Shader::Optimization {
15namespace { 16namespace {
16[[nodiscard]] u32 BitFieldUExtract(u32 base, u32 shift, u32 count) { 17// Metaprogramming stuff to get arguments information out of a lambda
17 if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) { 18template <typename Func>
18 throw LogicError("Undefined result in BitFieldUExtract({}, {}, {})", base, shift, count); 19struct LambdaTraits : LambdaTraits<decltype(&std::remove_reference_t<Func>::operator())> {};
19 } 20
20 return (base >> shift) & ((1U << count) - 1); 21template <typename ReturnType, typename LambdaType, typename... Args>
21} 22struct LambdaTraits<ReturnType (LambdaType::*)(Args...) const> {
23 template <size_t I>
24 using ArgType = std::tuple_element_t<I, std::tuple<Args...>>;
25
26 static constexpr size_t NUM_ARGS{sizeof...(Args)};
27};
22 28
23template <typename T> 29template <typename T>
24[[nodiscard]] T Arg(const IR::Value& value) { 30[[nodiscard]] T Arg(const IR::Value& value) {
@@ -104,6 +110,14 @@ void FoldAdd(IR::Inst& inst) {
104 } 110 }
105} 111}
106 112
113template <typename T>
114void FoldSelect(IR::Inst& inst) {
115 const IR::Value cond{inst.Arg(0)};
116 if (cond.IsImmediate()) {
117 inst.ReplaceUsesWith(cond.U1() ? inst.Arg(1) : inst.Arg(2));
118 }
119}
120
107void FoldLogicalAnd(IR::Inst& inst) { 121void FoldLogicalAnd(IR::Inst& inst) {
108 if (!FoldCommutative(inst, [](bool a, bool b) { return a && b; })) { 122 if (!FoldCommutative(inst, [](bool a, bool b) { return a && b; })) {
109 return; 123 return;
@@ -131,6 +145,21 @@ void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) {
131 } 145 }
132} 146}
133 147
148template <typename Func, size_t... I>
149IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<I...>) {
150 using Traits = LambdaTraits<decltype(func)>;
151 return IR::Value{func(Arg<Traits::ArgType<I>>(inst.Arg(I))...)};
152}
153
154template <typename Func>
155void FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
156 if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) {
157 return;
158 }
159 using Indices = std::make_index_sequence<LambdaTraits<decltype(func)>::NUM_ARGS>;
160 inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{}));
161}
162
134void ConstantPropagation(IR::Inst& inst) { 163void ConstantPropagation(IR::Inst& inst) {
135 switch (inst.Opcode()) { 164 switch (inst.Opcode()) {
136 case IR::Opcode::GetRegister: 165 case IR::Opcode::GetRegister:
@@ -145,14 +174,20 @@ void ConstantPropagation(IR::Inst& inst) {
145 return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32); 174 return FoldBitCast<u32, f32>(inst, IR::Opcode::BitCastF32U32);
146 case IR::Opcode::IAdd64: 175 case IR::Opcode::IAdd64:
147 return FoldAdd<u64>(inst); 176 return FoldAdd<u64>(inst);
148 case IR::Opcode::BitFieldUExtract: 177 case IR::Opcode::Select32:
149 if (inst.AreAllArgsImmediates() && !inst.HasAssociatedPseudoOperation()) { 178 return FoldSelect<u32>(inst);
150 inst.ReplaceUsesWith(IR::Value{
151 BitFieldUExtract(inst.Arg(0).U32(), inst.Arg(1).U32(), inst.Arg(2).U32())});
152 }
153 break;
154 case IR::Opcode::LogicalAnd: 179 case IR::Opcode::LogicalAnd:
155 return FoldLogicalAnd(inst); 180 return FoldLogicalAnd(inst);
181 case IR::Opcode::ULessThan:
182 return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; });
183 case IR::Opcode::BitFieldUExtract:
184 return FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) {
185 if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) {
186 throw LogicError("Undefined result in {}({}, {}, {})", IR::Opcode::BitFieldUExtract,
187 base, shift, count);
188 }
189 return (base >> shift) & ((1U << count) - 1);
190 });
156 default: 191 default:
157 break; 192 break;
158 } 193 }