summaryrefslogtreecommitdiff
path: root/src/shader_recompiler/ir_opt
diff options
context:
space:
mode:
authorGravatar ReinUsesLisp2021-02-11 16:39:06 -0300
committerGravatar ameerj2021-07-22 21:51:22 -0400
commit9170200a11715d131645d1ffb92e86e6ef0d7e88 (patch)
tree6c6f84c38a9b59d023ecb09c0737ea56da166b64 /src/shader_recompiler/ir_opt
parentspirv: Initial SPIR-V support (diff)
downloadyuzu-9170200a11715d131645d1ffb92e86e6ef0d7e88.tar.gz
yuzu-9170200a11715d131645d1ffb92e86e6ef0d7e88.tar.xz
yuzu-9170200a11715d131645d1ffb92e86e6ef0d7e88.zip
shader: Initial implementation of an AST
Diffstat (limited to 'src/shader_recompiler/ir_opt')
-rw-r--r--src/shader_recompiler/ir_opt/constant_propagation_pass.cpp50
-rw-r--r--src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp24
-rw-r--r--src/shader_recompiler/ir_opt/verification_pass.cpp4
3 files changed, 75 insertions, 3 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index f1170c61e..9fba6ac23 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -132,6 +132,32 @@ void FoldLogicalAnd(IR::Inst& inst) {
132 } 132 }
133} 133}
134 134
135void FoldLogicalOr(IR::Inst& inst) {
136 if (!FoldCommutative(inst, [](bool a, bool b) { return a || b; })) {
137 return;
138 }
139 const IR::Value rhs{inst.Arg(1)};
140 if (rhs.IsImmediate()) {
141 if (rhs.U1()) {
142 inst.ReplaceUsesWith(IR::Value{true});
143 } else {
144 inst.ReplaceUsesWith(inst.Arg(0));
145 }
146 }
147}
148
149void FoldLogicalNot(IR::Inst& inst) {
150 const IR::U1 value{inst.Arg(0)};
151 if (value.IsImmediate()) {
152 inst.ReplaceUsesWith(IR::Value{!value.U1()});
153 return;
154 }
155 IR::Inst* const arg{value.InstRecursive()};
156 if (arg->Opcode() == IR::Opcode::LogicalNot) {
157 inst.ReplaceUsesWith(arg->Arg(0));
158 }
159}
160
135template <typename Dest, typename Source> 161template <typename Dest, typename Source>
136void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) { 162void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) {
137 const IR::Value value{inst.Arg(0)}; 163 const IR::Value value{inst.Arg(0)};
@@ -160,6 +186,24 @@ void FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
160 inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{})); 186 inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{}));
161} 187}
162 188
189void FoldBranchConditional(IR::Inst& inst) {
190 const IR::U1 cond{inst.Arg(0)};
191 if (cond.IsImmediate()) {
192 // TODO: Convert to Branch
193 return;
194 }
195 const IR::Inst* cond_inst{cond.InstRecursive()};
196 if (cond_inst->Opcode() == IR::Opcode::LogicalNot) {
197 const IR::Value true_label{inst.Arg(1)};
198 const IR::Value false_label{inst.Arg(2)};
199 // Remove negation on the conditional (take the parameter out of LogicalNot) and swap
200 // the branches
201 inst.SetArg(0, cond_inst->Arg(0));
202 inst.SetArg(1, false_label);
203 inst.SetArg(2, true_label);
204 }
205}
206
163void ConstantPropagation(IR::Inst& inst) { 207void ConstantPropagation(IR::Inst& inst) {
164 switch (inst.Opcode()) { 208 switch (inst.Opcode()) {
165 case IR::Opcode::GetRegister: 209 case IR::Opcode::GetRegister:
@@ -178,6 +222,10 @@ void ConstantPropagation(IR::Inst& inst) {
178 return FoldSelect<u32>(inst); 222 return FoldSelect<u32>(inst);
179 case IR::Opcode::LogicalAnd: 223 case IR::Opcode::LogicalAnd:
180 return FoldLogicalAnd(inst); 224 return FoldLogicalAnd(inst);
225 case IR::Opcode::LogicalOr:
226 return FoldLogicalOr(inst);
227 case IR::Opcode::LogicalNot:
228 return FoldLogicalNot(inst);
181 case IR::Opcode::ULessThan: 229 case IR::Opcode::ULessThan:
182 return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); 230 return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; });
183 case IR::Opcode::BitFieldUExtract: 231 case IR::Opcode::BitFieldUExtract:
@@ -188,6 +236,8 @@ void ConstantPropagation(IR::Inst& inst) {
188 } 236 }
189 return (base >> shift) & ((1U << count) - 1); 237 return (base >> shift) & ((1U << count) - 1);
190 }); 238 });
239 case IR::Opcode::BranchConditional:
240 return FoldBranchConditional(inst);
191 default: 241 default:
192 break; 242 break;
193 } 243 }
diff --git a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
index 15a9db90a..8ca996e93 100644
--- a/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
+++ b/src/shader_recompiler/ir_opt/ssa_rewrite_pass.cpp
@@ -34,6 +34,13 @@ struct SignFlagTag : FlagTag {};
34struct CarryFlagTag : FlagTag {}; 34struct CarryFlagTag : FlagTag {};
35struct OverflowFlagTag : FlagTag {}; 35struct OverflowFlagTag : FlagTag {};
36 36
37struct GotoVariable : FlagTag {
38 GotoVariable() = default;
39 explicit GotoVariable(u32 index_) : index{index_} {}
40
41 u32 index;
42};
43
37struct DefTable { 44struct DefTable {
38 [[nodiscard]] ValueMap& operator[](IR::Reg variable) noexcept { 45 [[nodiscard]] ValueMap& operator[](IR::Reg variable) noexcept {
39 return regs[IR::RegIndex(variable)]; 46 return regs[IR::RegIndex(variable)];
@@ -43,6 +50,10 @@ struct DefTable {
43 return preds[IR::PredIndex(variable)]; 50 return preds[IR::PredIndex(variable)];
44 } 51 }
45 52
53 [[nodiscard]] ValueMap& operator[](GotoVariable goto_variable) {
54 return goto_vars[goto_variable.index];
55 }
56
46 [[nodiscard]] ValueMap& operator[](ZeroFlagTag) noexcept { 57 [[nodiscard]] ValueMap& operator[](ZeroFlagTag) noexcept {
47 return zero_flag; 58 return zero_flag;
48 } 59 }
@@ -61,6 +72,7 @@ struct DefTable {
61 72
62 std::array<ValueMap, IR::NUM_USER_REGS> regs; 73 std::array<ValueMap, IR::NUM_USER_REGS> regs;
63 std::array<ValueMap, IR::NUM_USER_PREDS> preds; 74 std::array<ValueMap, IR::NUM_USER_PREDS> preds;
75 boost::container::flat_map<u32, ValueMap> goto_vars;
64 ValueMap zero_flag; 76 ValueMap zero_flag;
65 ValueMap sign_flag; 77 ValueMap sign_flag;
66 ValueMap carry_flag; 78 ValueMap carry_flag;
@@ -68,15 +80,15 @@ struct DefTable {
68}; 80};
69 81
70IR::Opcode UndefOpcode(IR::Reg) noexcept { 82IR::Opcode UndefOpcode(IR::Reg) noexcept {
71 return IR::Opcode::Undef32; 83 return IR::Opcode::UndefU32;
72} 84}
73 85
74IR::Opcode UndefOpcode(IR::Pred) noexcept { 86IR::Opcode UndefOpcode(IR::Pred) noexcept {
75 return IR::Opcode::Undef1; 87 return IR::Opcode::UndefU1;
76} 88}
77 89
78IR::Opcode UndefOpcode(const FlagTag&) noexcept { 90IR::Opcode UndefOpcode(const FlagTag&) noexcept {
79 return IR::Opcode::Undef1; 91 return IR::Opcode::UndefU1;
80} 92}
81 93
82[[nodiscard]] bool IsPhi(const IR::Inst& inst) noexcept { 94[[nodiscard]] bool IsPhi(const IR::Inst& inst) noexcept {
@@ -165,6 +177,9 @@ void SsaRewritePass(IR::Function& function) {
165 pass.WriteVariable(pred, block, inst.Arg(1)); 177 pass.WriteVariable(pred, block, inst.Arg(1));
166 } 178 }
167 break; 179 break;
180 case IR::Opcode::SetGotoVariable:
181 pass.WriteVariable(GotoVariable{inst.Arg(0).U32()}, block, inst.Arg(1));
182 break;
168 case IR::Opcode::SetZFlag: 183 case IR::Opcode::SetZFlag:
169 pass.WriteVariable(ZeroFlagTag{}, block, inst.Arg(0)); 184 pass.WriteVariable(ZeroFlagTag{}, block, inst.Arg(0));
170 break; 185 break;
@@ -187,6 +202,9 @@ void SsaRewritePass(IR::Function& function) {
187 inst.ReplaceUsesWith(pass.ReadVariable(pred, block)); 202 inst.ReplaceUsesWith(pass.ReadVariable(pred, block));
188 } 203 }
189 break; 204 break;
205 case IR::Opcode::GetGotoVariable:
206 inst.ReplaceUsesWith(pass.ReadVariable(GotoVariable{inst.Arg(0).U32()}, block));
207 break;
190 case IR::Opcode::GetZFlag: 208 case IR::Opcode::GetZFlag:
191 inst.ReplaceUsesWith(pass.ReadVariable(ZeroFlagTag{}, block)); 209 inst.ReplaceUsesWith(pass.ReadVariable(ZeroFlagTag{}, block));
192 break; 210 break;
diff --git a/src/shader_recompiler/ir_opt/verification_pass.cpp b/src/shader_recompiler/ir_opt/verification_pass.cpp
index 8a5adf5a2..32b56eb57 100644
--- a/src/shader_recompiler/ir_opt/verification_pass.cpp
+++ b/src/shader_recompiler/ir_opt/verification_pass.cpp
@@ -14,6 +14,10 @@ namespace Shader::Optimization {
14static void ValidateTypes(const IR::Function& function) { 14static void ValidateTypes(const IR::Function& function) {
15 for (const auto& block : function.blocks) { 15 for (const auto& block : function.blocks) {
16 for (const IR::Inst& inst : *block) { 16 for (const IR::Inst& inst : *block) {
17 if (inst.Opcode() == IR::Opcode::Phi) {
18 // Skip validation on phi nodes
19 continue;
20 }
17 const size_t num_args{inst.NumArgs()}; 21 const size_t num_args{inst.NumArgs()};
18 for (size_t i = 0; i < num_args; ++i) { 22 for (size_t i = 0; i < num_args; ++i) {
19 const IR::Type t1{inst.Arg(i).Type()}; 23 const IR::Type t1{inst.Arg(i).Type()};