summaryrefslogtreecommitdiff
path: root/src/video_core/renderer_vulkan
diff options
context:
space:
mode:
Diffstat (limited to 'src/video_core/renderer_vulkan')
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.cpp359
1 files changed, 320 insertions, 39 deletions
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 77fc58f25..8bcd04221 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -88,6 +88,9 @@ bool IsPrecise(Operation operand) {
88 88
89} // namespace 89} // namespace
90 90
91class ASTDecompiler;
92class ExprDecompiler;
93
91class SPIRVDecompiler : public Sirit::Module { 94class SPIRVDecompiler : public Sirit::Module {
92public: 95public:
93 explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage) 96 explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage)
@@ -97,27 +100,7 @@ public:
97 AddExtension("SPV_KHR_variable_pointers"); 100 AddExtension("SPV_KHR_variable_pointers");
98 } 101 }
99 102
100 void Decompile() { 103 void DecompileBranchMode() {
101 AllocateBindings();
102 AllocateLabels();
103
104 DeclareVertex();
105 DeclareGeometry();
106 DeclareFragment();
107 DeclareRegisters();
108 DeclarePredicates();
109 DeclareLocalMemory();
110 DeclareInternalFlags();
111 DeclareInputAttributes();
112 DeclareOutputAttributes();
113 DeclareConstantBuffers();
114 DeclareGlobalBuffers();
115 DeclareSamplers();
116
117 execute_function =
118 Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
119 Emit(OpLabel());
120
121 const u32 first_address = ir.GetBasicBlocks().begin()->first; 104 const u32 first_address = ir.GetBasicBlocks().begin()->first;
122 const Id loop_label = OpLabel("loop"); 105 const Id loop_label = OpLabel("loop");
123 const Id merge_label = OpLabel("merge"); 106 const Id merge_label = OpLabel("merge");
@@ -174,6 +157,43 @@ public:
174 Emit(continue_label); 157 Emit(continue_label);
175 Emit(OpBranch(loop_label)); 158 Emit(OpBranch(loop_label));
176 Emit(merge_label); 159 Emit(merge_label);
160 }
161
162 void DecompileAST();
163
164 void Decompile() {
165 const bool is_fully_decompiled = ir.IsDecompiled();
166 AllocateBindings();
167 if (!is_fully_decompiled) {
168 AllocateLabels();
169 }
170
171 DeclareVertex();
172 DeclareGeometry();
173 DeclareFragment();
174 DeclareRegisters();
175 DeclarePredicates();
176 if (is_fully_decompiled) {
177 DeclareFlowVariables();
178 }
179 DeclareLocalMemory();
180 DeclareInternalFlags();
181 DeclareInputAttributes();
182 DeclareOutputAttributes();
183 DeclareConstantBuffers();
184 DeclareGlobalBuffers();
185 DeclareSamplers();
186
187 execute_function =
188 Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
189 Emit(OpLabel());
190
191 if (is_fully_decompiled) {
192 DecompileAST();
193 } else {
194 DecompileBranchMode();
195 }
196
177 Emit(OpReturn()); 197 Emit(OpReturn());
178 Emit(OpFunctionEnd()); 198 Emit(OpFunctionEnd());
179 } 199 }
@@ -206,6 +226,9 @@ public:
206 } 226 }
207 227
208private: 228private:
229 friend class ASTDecompiler;
230 friend class ExprDecompiler;
231
209 static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount); 232 static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
210 233
211 void AllocateBindings() { 234 void AllocateBindings() {
@@ -294,6 +317,14 @@ private:
294 } 317 }
295 } 318 }
296 319
320 void DeclareFlowVariables() {
321 for (u32 i = 0; i < ir.GetASTNumVariables(); i++) {
322 const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
323 Name(id, fmt::format("flow_var_{}", static_cast<u32>(i)));
324 flow_variables.emplace(i, AddGlobalVariable(id));
325 }
326 }
327
297 void DeclareLocalMemory() { 328 void DeclareLocalMemory() {
298 if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) { 329 if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) {
299 const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4); 330 const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4);
@@ -615,9 +646,15 @@ private:
615 Emit(OpBranchConditional(condition, true_label, skip_label)); 646 Emit(OpBranchConditional(condition, true_label, skip_label));
616 Emit(true_label); 647 Emit(true_label);
617 648
649 ++conditional_nest_count;
618 VisitBasicBlock(conditional->GetCode()); 650 VisitBasicBlock(conditional->GetCode());
651 --conditional_nest_count;
619 652
620 Emit(OpBranch(skip_label)); 653 if (inside_branch == 0) {
654 Emit(OpBranch(skip_label));
655 } else {
656 inside_branch--;
657 }
621 Emit(skip_label); 658 Emit(skip_label);
622 return {}; 659 return {};
623 660
@@ -980,7 +1017,11 @@ private:
980 UNIMPLEMENTED_IF(!target); 1017 UNIMPLEMENTED_IF(!target);
981 1018
982 Emit(OpStore(jmp_to, Constant(t_uint, target->GetValue()))); 1019 Emit(OpStore(jmp_to, Constant(t_uint, target->GetValue())));
983 BranchingOp([&]() { Emit(OpBranch(continue_label)); }); 1020 Emit(OpBranch(continue_label));
1021 inside_branch = conditional_nest_count;
1022 if (conditional_nest_count == 0) {
1023 Emit(OpLabel());
1024 }
984 return {}; 1025 return {};
985 } 1026 }
986 1027
@@ -988,7 +1029,11 @@ private:
988 const Id op_a = VisitOperand<Type::Uint>(operation, 0); 1029 const Id op_a = VisitOperand<Type::Uint>(operation, 0);
989 1030
990 Emit(OpStore(jmp_to, op_a)); 1031 Emit(OpStore(jmp_to, op_a));
991 BranchingOp([&]() { Emit(OpBranch(continue_label)); }); 1032 Emit(OpBranch(continue_label));
1033 inside_branch = conditional_nest_count;
1034 if (conditional_nest_count == 0) {
1035 Emit(OpLabel());
1036 }
992 return {}; 1037 return {};
993 } 1038 }
994 1039
@@ -1015,11 +1060,15 @@ private:
1015 1060
1016 Emit(OpStore(flow_stack_top, previous)); 1061 Emit(OpStore(flow_stack_top, previous));
1017 Emit(OpStore(jmp_to, target)); 1062 Emit(OpStore(jmp_to, target));
1018 BranchingOp([&]() { Emit(OpBranch(continue_label)); }); 1063 Emit(OpBranch(continue_label));
1064 inside_branch = conditional_nest_count;
1065 if (conditional_nest_count == 0) {
1066 Emit(OpLabel());
1067 }
1019 return {}; 1068 return {};
1020 } 1069 }
1021 1070
1022 Id Exit(Operation operation) { 1071 Id PreExit() {
1023 switch (stage) { 1072 switch (stage) {
1024 case ShaderStage::Vertex: { 1073 case ShaderStage::Vertex: {
1025 // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't 1074 // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't
@@ -1067,12 +1116,35 @@ private:
1067 } 1116 }
1068 } 1117 }
1069 1118
1070 BranchingOp([&]() { Emit(OpReturn()); }); 1119 return {};
1120 }
1121
1122 Id Exit(Operation operation) {
1123 PreExit();
1124 inside_branch = conditional_nest_count;
1125 if (conditional_nest_count > 0) {
1126 Emit(OpReturn());
1127 } else {
1128 const Id dummy = OpLabel();
1129 Emit(OpBranch(dummy));
1130 Emit(dummy);
1131 Emit(OpReturn());
1132 Emit(OpLabel());
1133 }
1071 return {}; 1134 return {};
1072 } 1135 }
1073 1136
1074 Id Discard(Operation operation) { 1137 Id Discard(Operation operation) {
1075 BranchingOp([&]() { Emit(OpKill()); }); 1138 inside_branch = conditional_nest_count;
1139 if (conditional_nest_count > 0) {
1140 Emit(OpKill());
1141 } else {
1142 const Id dummy = OpLabel();
1143 Emit(OpBranch(dummy));
1144 Emit(dummy);
1145 Emit(OpKill());
1146 Emit(OpLabel());
1147 }
1076 return {}; 1148 return {};
1077 } 1149 }
1078 1150
@@ -1267,17 +1339,6 @@ private:
1267 return {}; 1339 return {};
1268 } 1340 }
1269 1341
1270 void BranchingOp(std::function<void()> call) {
1271 const Id true_label = OpLabel();
1272 const Id skip_label = OpLabel();
1273 Emit(OpSelectionMerge(skip_label, spv::SelectionControlMask::Flatten));
1274 Emit(OpBranchConditional(v_true, true_label, skip_label, 1, 0));
1275 Emit(true_label);
1276 call();
1277
1278 Emit(skip_label);
1279 }
1280
1281 std::tuple<Id, Id> CreateFlowStack() { 1342 std::tuple<Id, Id> CreateFlowStack() {
1282 // TODO(Rodrigo): Figure out the actual depth of the flow stack, for now it seems unlikely 1343 // TODO(Rodrigo): Figure out the actual depth of the flow stack, for now it seems unlikely
1283 // that shaders will use 20 nested SSYs and PBKs. 1344 // that shaders will use 20 nested SSYs and PBKs.
@@ -1483,6 +1544,8 @@ private:
1483 const ShaderIR& ir; 1544 const ShaderIR& ir;
1484 const ShaderStage stage; 1545 const ShaderStage stage;
1485 const Tegra::Shader::Header header; 1546 const Tegra::Shader::Header header;
1547 u64 conditional_nest_count{};
1548 u64 inside_branch{};
1486 1549
1487 const Id t_void = Name(TypeVoid(), "void"); 1550 const Id t_void = Name(TypeVoid(), "void");
1488 1551
@@ -1545,6 +1608,7 @@ private:
1545 Id per_vertex{}; 1608 Id per_vertex{};
1546 std::map<u32, Id> registers; 1609 std::map<u32, Id> registers;
1547 std::map<Tegra::Shader::Pred, Id> predicates; 1610 std::map<Tegra::Shader::Pred, Id> predicates;
1611 std::map<u32, Id> flow_variables;
1548 Id local_memory{}; 1612 Id local_memory{};
1549 std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{}; 1613 std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
1550 std::map<Attribute::Index, Id> input_attributes; 1614 std::map<Attribute::Index, Id> input_attributes;
@@ -1580,6 +1644,223 @@ private:
1580 std::map<u32, Id> labels; 1644 std::map<u32, Id> labels;
1581}; 1645};
1582 1646
1647class ExprDecompiler {
1648public:
1649 explicit ExprDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
1650
1651 Id operator()(VideoCommon::Shader::ExprAnd& expr) {
1652 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1653 const Id op1 = Visit(expr.operand1);
1654 const Id op2 = Visit(expr.operand2);
1655 return decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2));
1656 }
1657
1658 Id operator()(VideoCommon::Shader::ExprOr& expr) {
1659 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1660 const Id op1 = Visit(expr.operand1);
1661 const Id op2 = Visit(expr.operand2);
1662 return decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2));
1663 }
1664
1665 Id operator()(VideoCommon::Shader::ExprNot& expr) {
1666 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1667 const Id op1 = Visit(expr.operand1);
1668 return decomp.Emit(decomp.OpLogicalNot(type_def, op1));
1669 }
1670
1671 Id operator()(VideoCommon::Shader::ExprPredicate& expr) {
1672 const auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate);
1673 return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred)));
1674 }
1675
1676 Id operator()(VideoCommon::Shader::ExprCondCode& expr) {
1677 const Node cc = decomp.ir.GetConditionCode(expr.cc);
1678 Id target;
1679
1680 if (const auto pred = std::get_if<PredicateNode>(&*cc)) {
1681 const auto index = pred->GetIndex();
1682 switch (index) {
1683 case Tegra::Shader::Pred::NeverExecute:
1684 target = decomp.v_false;
1685 case Tegra::Shader::Pred::UnusedIndex:
1686 target = decomp.v_true;
1687 default:
1688 target = decomp.predicates.at(index);
1689 }
1690 } else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) {
1691 target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag()));
1692 }
1693 return decomp.Emit(decomp.OpLoad(decomp.t_bool, target));
1694 }
1695
1696 Id operator()(VideoCommon::Shader::ExprVar& expr) {
1697 return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index)));
1698 }
1699
1700 Id operator()(VideoCommon::Shader::ExprBoolean& expr) {
1701 return expr.value ? decomp.v_true : decomp.v_false;
1702 }
1703
1704 Id Visit(VideoCommon::Shader::Expr& node) {
1705 return std::visit(*this, *node);
1706 }
1707
1708private:
1709 SPIRVDecompiler& decomp;
1710};
1711
1712class ASTDecompiler {
1713public:
1714 explicit ASTDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
1715
1716 void operator()(VideoCommon::Shader::ASTProgram& ast) {
1717 ASTNode current = ast.nodes.GetFirst();
1718 while (current) {
1719 Visit(current);
1720 current = current->GetNext();
1721 }
1722 }
1723
1724 void operator()(VideoCommon::Shader::ASTIfThen& ast) {
1725 ExprDecompiler expr_parser{decomp};
1726 const Id condition = expr_parser.Visit(ast.condition);
1727 const Id then_label = decomp.OpLabel();
1728 const Id endif_label = decomp.OpLabel();
1729 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
1730 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
1731 decomp.Emit(then_label);
1732 ASTNode current = ast.nodes.GetFirst();
1733 while (current) {
1734 Visit(current);
1735 current = current->GetNext();
1736 }
1737 decomp.Emit(decomp.OpBranch(endif_label));
1738 decomp.Emit(endif_label);
1739 }
1740
1741 void operator()(VideoCommon::Shader::ASTIfElse& ast) {
1742 UNREACHABLE();
1743 }
1744
1745 void operator()(VideoCommon::Shader::ASTBlockEncoded& ast) {
1746 UNREACHABLE();
1747 }
1748
1749 void operator()(VideoCommon::Shader::ASTBlockDecoded& ast) {
1750 decomp.VisitBasicBlock(ast.nodes);
1751 }
1752
1753 void operator()(VideoCommon::Shader::ASTVarSet& ast) {
1754 ExprDecompiler expr_parser{decomp};
1755 const Id condition = expr_parser.Visit(ast.condition);
1756 decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition));
1757 }
1758
1759 void operator()(VideoCommon::Shader::ASTLabel& ast) {
1760 // Do nothing
1761 }
1762
1763 void operator()(VideoCommon::Shader::ASTGoto& ast) {
1764 UNREACHABLE();
1765 }
1766
1767 void operator()(VideoCommon::Shader::ASTDoWhile& ast) {
1768 const Id loop_label = decomp.OpLabel();
1769 const Id endloop_label = decomp.OpLabel();
1770 const Id loop_start_block = decomp.OpLabel();
1771 const Id loop_end_block = decomp.OpLabel();
1772 current_loop_exit = endloop_label;
1773 decomp.Emit(decomp.OpBranch(loop_label));
1774 decomp.Emit(loop_label);
1775 decomp.Emit(
1776 decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone));
1777 decomp.Emit(decomp.OpBranch(loop_start_block));
1778 decomp.Emit(loop_start_block);
1779 ASTNode current = ast.nodes.GetFirst();
1780 while (current) {
1781 Visit(current);
1782 current = current->GetNext();
1783 }
1784 ExprDecompiler expr_parser{decomp};
1785 const Id condition = expr_parser.Visit(ast.condition);
1786 decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label));
1787 decomp.Emit(endloop_label);
1788 }
1789
1790 void operator()(VideoCommon::Shader::ASTReturn& ast) {
1791 if (!VideoCommon::Shader::ExprIsTrue(ast.condition)) {
1792 ExprDecompiler expr_parser{decomp};
1793 const Id condition = expr_parser.Visit(ast.condition);
1794 const Id then_label = decomp.OpLabel();
1795 const Id endif_label = decomp.OpLabel();
1796 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
1797 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
1798 decomp.Emit(then_label);
1799 if (ast.kills) {
1800 decomp.Emit(decomp.OpKill());
1801 } else {
1802 decomp.PreExit();
1803 decomp.Emit(decomp.OpReturn());
1804 }
1805 decomp.Emit(endif_label);
1806 } else {
1807 const Id next_block = decomp.OpLabel();
1808 decomp.Emit(decomp.OpBranch(next_block));
1809 decomp.Emit(next_block);
1810 if (ast.kills) {
1811 decomp.Emit(decomp.OpKill());
1812 } else {
1813 decomp.PreExit();
1814 decomp.Emit(decomp.OpReturn());
1815 }
1816 decomp.Emit(decomp.OpLabel());
1817 }
1818 }
1819
1820 void operator()(VideoCommon::Shader::ASTBreak& ast) {
1821 if (!VideoCommon::Shader::ExprIsTrue(ast.condition)) {
1822 ExprDecompiler expr_parser{decomp};
1823 const Id condition = expr_parser.Visit(ast.condition);
1824 const Id then_label = decomp.OpLabel();
1825 const Id endif_label = decomp.OpLabel();
1826 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
1827 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
1828 decomp.Emit(then_label);
1829 decomp.Emit(decomp.OpBranch(current_loop_exit));
1830 decomp.Emit(endif_label);
1831 } else {
1832 const Id next_block = decomp.OpLabel();
1833 decomp.Emit(decomp.OpBranch(next_block));
1834 decomp.Emit(next_block);
1835 decomp.Emit(decomp.OpBranch(current_loop_exit));
1836 decomp.Emit(decomp.OpLabel());
1837 }
1838 }
1839
1840 void Visit(VideoCommon::Shader::ASTNode& node) {
1841 std::visit(*this, *node->GetInnerData());
1842 }
1843
1844private:
1845 SPIRVDecompiler& decomp;
1846 Id current_loop_exit{};
1847};
1848
1849void SPIRVDecompiler::DecompileAST() {
1850 const u32 num_flow_variables = ir.GetASTNumVariables();
1851 for (u32 i = 0; i < num_flow_variables; i++) {
1852 const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
1853 Name(id, fmt::format("flow_var_{}", i));
1854 flow_variables.emplace(i, AddGlobalVariable(id));
1855 }
1856 ASTDecompiler decompiler{*this};
1857 VideoCommon::Shader::ASTNode program = ir.GetASTProgram();
1858 decompiler.Visit(program);
1859 const Id next_block = OpLabel();
1860 Emit(OpBranch(next_block));
1861 Emit(next_block);
1862}
1863
1583DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, 1864DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
1584 Maxwell::ShaderStage stage) { 1865 Maxwell::ShaderStage stage) {
1585 auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage); 1866 auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage);