summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Fernando Sahmkow2019-08-25 15:32:00 -0400
committerGravatar FernandoS272019-10-04 18:52:51 -0400
commitca9901867e91cd0be0cc75094ee8ea2fb2767c47 (patch)
tree1d46ef166d33459447a006f3da3dd663fab4fcaa /src
parentShader_IR: mark labels as unused for partial decompile. (diff)
downloadyuzu-ca9901867e91cd0be0cc75094ee8ea2fb2767c47.tar.gz
yuzu-ca9901867e91cd0be0cc75094ee8ea2fb2767c47.tar.xz
yuzu-ca9901867e91cd0be0cc75094ee8ea2fb2767c47.zip
vk_shader_compiler: Implement the decompiler in SPIR-V
Diffstat (limited to 'src')
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.cpp298
-rw-r--r--src/video_core/shader/ast.h22
-rw-r--r--src/video_core/shader/shader_ir.h4
3 files changed, 301 insertions, 23 deletions
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 77fc58f25..505e49570 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -88,6 +88,9 @@ bool IsPrecise(Operation operand) {
88 88
89} // namespace 89} // namespace
90 90
91class ASTDecompiler;
92class ExprDecompiler;
93
91class SPIRVDecompiler : public Sirit::Module { 94class SPIRVDecompiler : public Sirit::Module {
92public: 95public:
93 explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage) 96 explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderStage stage)
@@ -97,27 +100,7 @@ public:
97 AddExtension("SPV_KHR_variable_pointers"); 100 AddExtension("SPV_KHR_variable_pointers");
98 } 101 }
99 102
100 void Decompile() { 103 void DecompileBranchMode() {
101 AllocateBindings();
102 AllocateLabels();
103
104 DeclareVertex();
105 DeclareGeometry();
106 DeclareFragment();
107 DeclareRegisters();
108 DeclarePredicates();
109 DeclareLocalMemory();
110 DeclareInternalFlags();
111 DeclareInputAttributes();
112 DeclareOutputAttributes();
113 DeclareConstantBuffers();
114 DeclareGlobalBuffers();
115 DeclareSamplers();
116
117 execute_function =
118 Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
119 Emit(OpLabel());
120
121 const u32 first_address = ir.GetBasicBlocks().begin()->first; 104 const u32 first_address = ir.GetBasicBlocks().begin()->first;
122 const Id loop_label = OpLabel("loop"); 105 const Id loop_label = OpLabel("loop");
123 const Id merge_label = OpLabel("merge"); 106 const Id merge_label = OpLabel("merge");
@@ -174,6 +157,43 @@ public:
174 Emit(continue_label); 157 Emit(continue_label);
175 Emit(OpBranch(loop_label)); 158 Emit(OpBranch(loop_label));
176 Emit(merge_label); 159 Emit(merge_label);
160 }
161
162 void DecompileAST();
163
164 void Decompile() {
165 const bool is_fully_decompiled = ir.IsDecompiled();
166 AllocateBindings();
167 if (!is_fully_decompiled) {
168 AllocateLabels();
169 }
170
171 DeclareVertex();
172 DeclareGeometry();
173 DeclareFragment();
174 DeclareRegisters();
175 DeclarePredicates();
176 if (is_fully_decompiled) {
177 DeclareFlowVariables();
178 }
179 DeclareLocalMemory();
180 DeclareInternalFlags();
181 DeclareInputAttributes();
182 DeclareOutputAttributes();
183 DeclareConstantBuffers();
184 DeclareGlobalBuffers();
185 DeclareSamplers();
186
187 execute_function =
188 Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
189 Emit(OpLabel());
190
191 if (is_fully_decompiled) {
192 DecompileAST();
193 } else {
194 DecompileBranchMode();
195 }
196
177 Emit(OpReturn()); 197 Emit(OpReturn());
178 Emit(OpFunctionEnd()); 198 Emit(OpFunctionEnd());
179 } 199 }
@@ -206,6 +226,9 @@ public:
206 } 226 }
207 227
208private: 228private:
229 friend class ASTDecompiler;
230 friend class ExprDecompiler;
231
209 static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount); 232 static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
210 233
211 void AllocateBindings() { 234 void AllocateBindings() {
@@ -294,6 +317,14 @@ private:
294 } 317 }
295 } 318 }
296 319
320 void DeclareFlowVariables() {
321 for (u32 i = 0; i < ir.GetASTNumVariables(); i++) {
322 const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
323 Name(id, fmt::format("flow_var_{}", static_cast<u32>(i)));
324 flow_variables.emplace(i, AddGlobalVariable(id));
325 }
326 }
327
297 void DeclareLocalMemory() { 328 void DeclareLocalMemory() {
298 if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) { 329 if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) {
299 const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4); 330 const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4);
@@ -1019,7 +1050,7 @@ private:
1019 return {}; 1050 return {};
1020 } 1051 }
1021 1052
1022 Id Exit(Operation operation) { 1053 Id PreExit() {
1023 switch (stage) { 1054 switch (stage) {
1024 case ShaderStage::Vertex: { 1055 case ShaderStage::Vertex: {
1025 // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't 1056 // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't
@@ -1067,6 +1098,11 @@ private:
1067 } 1098 }
1068 } 1099 }
1069 1100
1101 return {};
1102 }
1103
1104 Id Exit(Operation operation) {
1105 PreExit();
1070 BranchingOp([&]() { Emit(OpReturn()); }); 1106 BranchingOp([&]() { Emit(OpReturn()); });
1071 return {}; 1107 return {};
1072 } 1108 }
@@ -1545,6 +1581,7 @@ private:
1545 Id per_vertex{}; 1581 Id per_vertex{};
1546 std::map<u32, Id> registers; 1582 std::map<u32, Id> registers;
1547 std::map<Tegra::Shader::Pred, Id> predicates; 1583 std::map<Tegra::Shader::Pred, Id> predicates;
1584 std::map<u32, Id> flow_variables;
1548 Id local_memory{}; 1585 Id local_memory{};
1549 std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{}; 1586 std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
1550 std::map<Attribute::Index, Id> input_attributes; 1587 std::map<Attribute::Index, Id> input_attributes;
@@ -1580,6 +1617,223 @@ private:
1580 std::map<u32, Id> labels; 1617 std::map<u32, Id> labels;
1581}; 1618};
1582 1619
1620class ExprDecompiler {
1621public:
1622 ExprDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
1623
1624 void operator()(VideoCommon::Shader::ExprAnd& expr) {
1625 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1626 const Id op1 = Visit(expr.operand1);
1627 const Id op2 = Visit(expr.operand2);
1628 current_id = decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2));
1629 }
1630
1631 void operator()(VideoCommon::Shader::ExprOr& expr) {
1632 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1633 const Id op1 = Visit(expr.operand1);
1634 const Id op2 = Visit(expr.operand2);
1635 current_id = decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2));
1636 }
1637
1638 void operator()(VideoCommon::Shader::ExprNot& expr) {
1639 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1640 const Id op1 = Visit(expr.operand1);
1641 current_id = decomp.Emit(decomp.OpLogicalNot(type_def, op1));
1642 }
1643
1644 void operator()(VideoCommon::Shader::ExprPredicate& expr) {
1645 auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate);
1646 current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred)));
1647 }
1648
1649 void operator()(VideoCommon::Shader::ExprCondCode& expr) {
1650 Node cc = decomp.ir.GetConditionCode(expr.cc);
1651 Id target;
1652
1653 if (const auto pred = std::get_if<PredicateNode>(&*cc)) {
1654 const auto index = pred->GetIndex();
1655 switch (index) {
1656 case Tegra::Shader::Pred::NeverExecute:
1657 target = decomp.v_false;
1658 case Tegra::Shader::Pred::UnusedIndex:
1659 target = decomp.v_true;
1660 default:
1661 target = decomp.predicates.at(index);
1662 }
1663 } else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) {
1664 target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag()));
1665 }
1666 current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, target));
1667 }
1668
1669 void operator()(VideoCommon::Shader::ExprVar& expr) {
1670 current_id = decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index)));
1671 }
1672
1673 void operator()(VideoCommon::Shader::ExprBoolean& expr) {
1674 current_id = expr.value ? decomp.v_true : decomp.v_false;
1675 }
1676
1677 Id GetResult() {
1678 return current_id;
1679 }
1680
1681 Id Visit(VideoCommon::Shader::Expr& node) {
1682 std::visit(*this, *node);
1683 return current_id;
1684 }
1685
1686private:
1687 Id current_id;
1688 SPIRVDecompiler& decomp;
1689};
1690
1691class ASTDecompiler {
1692public:
1693 ASTDecompiler(SPIRVDecompiler& decomp) : decomp{decomp} {}
1694
1695 void operator()(VideoCommon::Shader::ASTProgram& ast) {
1696 ASTNode current = ast.nodes.GetFirst();
1697 while (current) {
1698 Visit(current);
1699 current = current->GetNext();
1700 }
1701 }
1702
1703 void operator()(VideoCommon::Shader::ASTIfThen& ast) {
1704 ExprDecompiler expr_parser{decomp};
1705 const Id condition = expr_parser.Visit(ast.condition);
1706 const Id then_label = decomp.OpLabel();
1707 const Id endif_label = decomp.OpLabel();
1708 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
1709 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
1710 decomp.Emit(then_label);
1711 ASTNode current = ast.nodes.GetFirst();
1712 while (current) {
1713 Visit(current);
1714 current = current->GetNext();
1715 }
1716 decomp.Emit(endif_label);
1717 }
1718
1719 void operator()(VideoCommon::Shader::ASTIfElse& ast) {
1720 UNREACHABLE();
1721 }
1722
1723 void operator()(VideoCommon::Shader::ASTBlockEncoded& ast) {
1724 UNREACHABLE();
1725 }
1726
1727 void operator()(VideoCommon::Shader::ASTBlockDecoded& ast) {
1728 decomp.VisitBasicBlock(ast.nodes);
1729 }
1730
1731 void operator()(VideoCommon::Shader::ASTVarSet& ast) {
1732 ExprDecompiler expr_parser{decomp};
1733 const Id condition = expr_parser.Visit(ast.condition);
1734 decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition));
1735 }
1736
1737 void operator()(VideoCommon::Shader::ASTLabel& ast) {
1738 // Do nothing
1739 }
1740
1741 void operator()(VideoCommon::Shader::ASTGoto& ast) {
1742 UNREACHABLE();
1743 }
1744
1745 void operator()(VideoCommon::Shader::ASTDoWhile& ast) {
1746 const Id loop_label = decomp.OpLabel();
1747 const Id endloop_label = decomp.OpLabel();
1748 const Id loop_start_block = decomp.OpLabel();
1749 const Id loop_end_block = decomp.OpLabel();
1750 current_loop_exit = endloop_label;
1751 decomp.Emit(loop_label);
1752 decomp.Emit(decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone));
1753 decomp.Emit(decomp.OpBranch(loop_start_block));
1754 decomp.Emit(loop_start_block);
1755 ASTNode current = ast.nodes.GetFirst();
1756 while (current) {
1757 Visit(current);
1758 current = current->GetNext();
1759 }
1760 decomp.Emit(decomp.OpBranch(loop_end_block));
1761 decomp.Emit(loop_end_block);
1762 ExprDecompiler expr_parser{decomp};
1763 const Id condition = expr_parser.Visit(ast.condition);
1764 decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label));
1765 decomp.Emit(endloop_label);
1766 }
1767
1768 void operator()(VideoCommon::Shader::ASTReturn& ast) {
1769 bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition);
1770 if (!is_true) {
1771 ExprDecompiler expr_parser{decomp};
1772 const Id condition = expr_parser.Visit(ast.condition);
1773 const Id then_label = decomp.OpLabel();
1774 const Id endif_label = decomp.OpLabel();
1775 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
1776 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
1777 decomp.Emit(then_label);
1778 if (ast.kills) {
1779 decomp.Emit(decomp.OpKill());
1780 } else {
1781 decomp.PreExit();
1782 decomp.Emit(decomp.OpReturn());
1783 }
1784 decomp.Emit(endif_label);
1785 } else {
1786 decomp.Emit(decomp.OpLabel());
1787 if (ast.kills) {
1788 decomp.Emit(decomp.OpKill());
1789 } else {
1790 decomp.PreExit();
1791 decomp.Emit(decomp.OpReturn());
1792 }
1793 decomp.Emit(decomp.OpLabel());
1794 }
1795 }
1796
1797 void operator()(VideoCommon::Shader::ASTBreak& ast) {
1798 bool is_true = VideoCommon::Shader::ExprIsTrue(ast.condition);
1799 if (!is_true) {
1800 ExprDecompiler expr_parser{decomp};
1801 const Id condition = expr_parser.Visit(ast.condition);
1802 const Id then_label = decomp.OpLabel();
1803 const Id endif_label = decomp.OpLabel();
1804 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone));
1805 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label));
1806 decomp.Emit(then_label);
1807 decomp.Emit(decomp.OpBranch(current_loop_exit));
1808 decomp.Emit(endif_label);
1809 } else {
1810 decomp.Emit(decomp.OpLabel());
1811 decomp.Emit(decomp.OpBranch(current_loop_exit));
1812 decomp.Emit(decomp.OpLabel());
1813 }
1814 }
1815
1816 void Visit(VideoCommon::Shader::ASTNode& node) {
1817 std::visit(*this, *node->GetInnerData());
1818 }
1819
1820private:
1821 SPIRVDecompiler& decomp;
1822 Id current_loop_exit;
1823};
1824
1825void SPIRVDecompiler::DecompileAST() {
1826 u32 num_flow_variables = ir.GetASTNumVariables();
1827 for (u32 i = 0; i < num_flow_variables; i++) {
1828 const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
1829 Name(id, fmt::format("flow_var_{}", i));
1830 flow_variables.emplace(i, AddGlobalVariable(id));
1831 }
1832 ASTDecompiler decompiler{*this};
1833 VideoCommon::Shader::ASTNode program = ir.GetASTProgram();
1834 decompiler.Visit(program);
1835}
1836
1583DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, 1837DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
1584 Maxwell::ShaderStage stage) { 1838 Maxwell::ShaderStage stage) {
1585 auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage); 1839 auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage);
diff --git a/src/video_core/shader/ast.h b/src/video_core/shader/ast.h
index 07deb58e4..12db336df 100644
--- a/src/video_core/shader/ast.h
+++ b/src/video_core/shader/ast.h
@@ -205,13 +205,29 @@ public:
205 return nullptr; 205 return nullptr;
206 } 206 }
207 207
208 void MarkLabelUnused() const { 208 void MarkLabelUnused() {
209 auto inner = std::get_if<ASTLabel>(&data); 209 auto inner = std::get_if<ASTLabel>(&data);
210 if (inner) { 210 if (inner) {
211 inner->unused = true; 211 inner->unused = true;
212 } 212 }
213 } 213 }
214 214
215 bool IsLabelUnused() const {
216 auto inner = std::get_if<ASTLabel>(&data);
217 if (inner) {
218 return inner->unused;
219 }
220 return true;
221 }
222
223 u32 GetLabelIndex() const {
224 auto inner = std::get_if<ASTLabel>(&data);
225 if (inner) {
226 return inner->index;
227 }
228 return -1;
229 }
230
215 Expr GetIfCondition() const { 231 Expr GetIfCondition() const {
216 auto inner = std::get_if<ASTIfThen>(&data); 232 auto inner = std::get_if<ASTIfThen>(&data);
217 if (inner) { 233 if (inner) {
@@ -336,6 +352,10 @@ public:
336 return variables; 352 return variables;
337 } 353 }
338 354
355 const std::vector<ASTNode>& GetLabels() const {
356 return labels;
357 }
358
339private: 359private:
340 bool IsBackwardsJump(ASTNode goto_node, ASTNode label_node) const; 360 bool IsBackwardsJump(ASTNode goto_node, ASTNode label_node) const;
341 361
diff --git a/src/video_core/shader/shader_ir.h b/src/video_core/shader/shader_ir.h
index 7a91c9bb6..105981d67 100644
--- a/src/video_core/shader/shader_ir.h
+++ b/src/video_core/shader/shader_ir.h
@@ -151,6 +151,10 @@ public:
151 return decompiled; 151 return decompiled;
152 } 152 }
153 153
154 const ASTManager& GetASTManager() const {
155 return program_manager;
156 }
157
154 ASTNode GetASTProgram() const { 158 ASTNode GetASTProgram() const {
155 return program_manager.GetProgram(); 159 return program_manager.GetProgram();
156 } 160 }