summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Fernando Sahmkow2019-12-10 08:01:41 -0400
committerGravatar GitHub2019-12-10 08:01:41 -0400
commit6edadef96d57cb021d0131929d5a122ae102ad9e (patch)
treee7f42fb2af3b1d83665725db9034d8fb6f3d6c78 /src
parentMerge pull request #3205 from ReinUsesLisp/vk-device (diff)
parentvk_shader_decompiler: Fix build issues on old gcc versions (diff)
downloadyuzu-6edadef96d57cb021d0131929d5a122ae102ad9e.tar.gz
yuzu-6edadef96d57cb021d0131929d5a122ae102ad9e.tar.xz
yuzu-6edadef96d57cb021d0131929d5a122ae102ad9e.zip
Merge pull request #3208 from ReinUsesLisp/vk-shader-decompiler
vk_shader_decompiler: Add tessellation and misc changes
Diffstat (limited to 'src')
-rw-r--r--src/video_core/engines/shader_bytecode.h3
-rw-r--r--src/video_core/renderer_opengl/gl_shader_decompiler.cpp5
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.cpp2275
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.h74
-rw-r--r--src/video_core/shader/decode/memory.cpp34
-rw-r--r--src/video_core/shader/decode/other.cpp2
-rw-r--r--src/video_core/shader/decode/warp.cpp3
-rw-r--r--src/video_core/shader/node.h21
-rw-r--r--src/video_core/shader/shader_ir.h5
-rw-r--r--src/video_core/shader/track.cpp1
10 files changed, 1705 insertions, 718 deletions
diff --git a/src/video_core/engines/shader_bytecode.h b/src/video_core/engines/shader_bytecode.h
index 8b7dcbe9d..7703a76a3 100644
--- a/src/video_core/engines/shader_bytecode.h
+++ b/src/video_core/engines/shader_bytecode.h
@@ -98,10 +98,11 @@ union Attribute {
98 BitField<20, 10, u64> immediate; 98 BitField<20, 10, u64> immediate;
99 BitField<22, 2, u64> element; 99 BitField<22, 2, u64> element;
100 BitField<24, 6, Index> index; 100 BitField<24, 6, Index> index;
101 BitField<31, 1, u64> patch;
101 BitField<47, 3, AttributeSize> size; 102 BitField<47, 3, AttributeSize> size;
102 103
103 bool IsPhysical() const { 104 bool IsPhysical() const {
104 return element == 0 && static_cast<u64>(index.Value()) == 0; 105 return patch == 0 && element == 0 && static_cast<u64>(index.Value()) == 0;
105 } 106 }
106 } fmt20; 107 } fmt20;
107 108
diff --git a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
index 3d3cd21f3..9700c2ebe 100644
--- a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
+++ b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
@@ -1915,6 +1915,10 @@ private:
1915 return {}; 1915 return {};
1916 } 1916 }
1917 1917
1918 Expression InvocationId(Operation operation) {
1919 return {"gl_InvocationID", Type::Int};
1920 }
1921
1918 Expression YNegate(Operation operation) { 1922 Expression YNegate(Operation operation) {
1919 return {"y_direction", Type::Float}; 1923 return {"y_direction", Type::Float};
1920 } 1924 }
@@ -2153,6 +2157,7 @@ private:
2153 &GLSLDecompiler::EmitVertex, 2157 &GLSLDecompiler::EmitVertex,
2154 &GLSLDecompiler::EndPrimitive, 2158 &GLSLDecompiler::EndPrimitive,
2155 2159
2160 &GLSLDecompiler::InvocationId,
2156 &GLSLDecompiler::YNegate, 2161 &GLSLDecompiler::YNegate,
2157 &GLSLDecompiler::LocalInvocationId<0>, 2162 &GLSLDecompiler::LocalInvocationId<0>,
2158 &GLSLDecompiler::LocalInvocationId<1>, 2163 &GLSLDecompiler::LocalInvocationId<1>,
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 76894275b..8ad89b58a 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -3,8 +3,10 @@
3// Refer to the license.txt file included. 3// Refer to the license.txt file included.
4 4
5#include <functional> 5#include <functional>
6#include <limits>
6#include <map> 7#include <map>
7#include <set> 8#include <type_traits>
9#include <utility>
8 10
9#include <fmt/format.h> 11#include <fmt/format.h>
10 12
@@ -23,7 +25,9 @@
23#include "video_core/shader/node.h" 25#include "video_core/shader/node.h"
24#include "video_core/shader/shader_ir.h" 26#include "video_core/shader/shader_ir.h"
25 27
26namespace Vulkan::VKShader { 28namespace Vulkan {
29
30namespace {
27 31
28using Sirit::Id; 32using Sirit::Id;
29using Tegra::Engines::ShaderType; 33using Tegra::Engines::ShaderType;
@@ -35,22 +39,60 @@ using namespace VideoCommon::Shader;
35using Maxwell = Tegra::Engines::Maxwell3D::Regs; 39using Maxwell = Tegra::Engines::Maxwell3D::Regs;
36using Operation = const OperationNode&; 40using Operation = const OperationNode&;
37 41
42class ASTDecompiler;
43class ExprDecompiler;
44
38// TODO(Rodrigo): Use rasterizer's value 45// TODO(Rodrigo): Use rasterizer's value
39constexpr u32 MAX_CONSTBUFFER_FLOATS = 0x4000; 46constexpr u32 MaxConstBufferFloats = 0x4000;
40constexpr u32 MAX_CONSTBUFFER_ELEMENTS = MAX_CONSTBUFFER_FLOATS / 4; 47constexpr u32 MaxConstBufferElements = MaxConstBufferFloats / 4;
41constexpr u32 STAGE_BINDING_STRIDE = 0x100; 48
49constexpr u32 NumInputPatches = 32; // This value seems to be the standard
50
51enum class Type { Void, Bool, Bool2, Float, Int, Uint, HalfFloat };
52
53class Expression final {
54public:
55 Expression(Id id, Type type) : id{id}, type{type} {
56 ASSERT(type != Type::Void);
57 }
58 Expression() : type{Type::Void} {}
42 59
43enum class Type { Bool, Bool2, Float, Int, Uint, HalfFloat }; 60 Id id{};
61 Type type{};
62};
63static_assert(std::is_standard_layout_v<Expression>);
44 64
45struct SamplerImage { 65struct TexelBuffer {
46 Id image_type; 66 Id image_type{};
47 Id sampled_image_type; 67 Id image{};
48 Id sampler;
49}; 68};
50 69
51namespace { 70struct SampledImage {
71 Id image_type{};
72 Id sampled_image_type{};
73 Id sampler{};
74};
75
76struct StorageImage {
77 Id image_type{};
78 Id image{};
79};
80
81struct AttributeType {
82 Type type;
83 Id scalar;
84 Id vector;
85};
86
87struct VertexIndices {
88 std::optional<u32> position;
89 std::optional<u32> viewport;
90 std::optional<u32> point_size;
91 std::optional<u32> clip_distances;
92};
52 93
53spv::Dim GetSamplerDim(const Sampler& sampler) { 94spv::Dim GetSamplerDim(const Sampler& sampler) {
95 ASSERT(!sampler.IsBuffer());
54 switch (sampler.GetType()) { 96 switch (sampler.GetType()) {
55 case Tegra::Shader::TextureType::Texture1D: 97 case Tegra::Shader::TextureType::Texture1D:
56 return spv::Dim::Dim1D; 98 return spv::Dim::Dim1D;
@@ -66,6 +108,138 @@ spv::Dim GetSamplerDim(const Sampler& sampler) {
66 } 108 }
67} 109}
68 110
111std::pair<spv::Dim, bool> GetImageDim(const Image& image) {
112 switch (image.GetType()) {
113 case Tegra::Shader::ImageType::Texture1D:
114 return {spv::Dim::Dim1D, false};
115 case Tegra::Shader::ImageType::TextureBuffer:
116 return {spv::Dim::Buffer, false};
117 case Tegra::Shader::ImageType::Texture1DArray:
118 return {spv::Dim::Dim1D, true};
119 case Tegra::Shader::ImageType::Texture2D:
120 return {spv::Dim::Dim2D, false};
121 case Tegra::Shader::ImageType::Texture2DArray:
122 return {spv::Dim::Dim2D, true};
123 case Tegra::Shader::ImageType::Texture3D:
124 return {spv::Dim::Dim3D, false};
125 default:
126 UNIMPLEMENTED_MSG("Unimplemented image type={}", static_cast<u32>(image.GetType()));
127 return {spv::Dim::Dim2D, false};
128 }
129}
130
131/// Returns the number of vertices present in a primitive topology.
132u32 GetNumPrimitiveTopologyVertices(Maxwell::PrimitiveTopology primitive_topology) {
133 switch (primitive_topology) {
134 case Maxwell::PrimitiveTopology::Points:
135 return 1;
136 case Maxwell::PrimitiveTopology::Lines:
137 case Maxwell::PrimitiveTopology::LineLoop:
138 case Maxwell::PrimitiveTopology::LineStrip:
139 return 2;
140 case Maxwell::PrimitiveTopology::Triangles:
141 case Maxwell::PrimitiveTopology::TriangleStrip:
142 case Maxwell::PrimitiveTopology::TriangleFan:
143 return 3;
144 case Maxwell::PrimitiveTopology::LinesAdjacency:
145 case Maxwell::PrimitiveTopology::LineStripAdjacency:
146 return 4;
147 case Maxwell::PrimitiveTopology::TrianglesAdjacency:
148 case Maxwell::PrimitiveTopology::TriangleStripAdjacency:
149 return 6;
150 case Maxwell::PrimitiveTopology::Quads:
151 UNIMPLEMENTED_MSG("Quads");
152 return 3;
153 case Maxwell::PrimitiveTopology::QuadStrip:
154 UNIMPLEMENTED_MSG("QuadStrip");
155 return 3;
156 case Maxwell::PrimitiveTopology::Polygon:
157 UNIMPLEMENTED_MSG("Polygon");
158 return 3;
159 case Maxwell::PrimitiveTopology::Patches:
160 UNIMPLEMENTED_MSG("Patches");
161 return 3;
162 default:
163 UNREACHABLE();
164 return 3;
165 }
166}
167
168spv::ExecutionMode GetExecutionMode(Maxwell::TessellationPrimitive primitive) {
169 switch (primitive) {
170 case Maxwell::TessellationPrimitive::Isolines:
171 return spv::ExecutionMode::Isolines;
172 case Maxwell::TessellationPrimitive::Triangles:
173 return spv::ExecutionMode::Triangles;
174 case Maxwell::TessellationPrimitive::Quads:
175 return spv::ExecutionMode::Quads;
176 }
177 UNREACHABLE();
178 return spv::ExecutionMode::Triangles;
179}
180
181spv::ExecutionMode GetExecutionMode(Maxwell::TessellationSpacing spacing) {
182 switch (spacing) {
183 case Maxwell::TessellationSpacing::Equal:
184 return spv::ExecutionMode::SpacingEqual;
185 case Maxwell::TessellationSpacing::FractionalOdd:
186 return spv::ExecutionMode::SpacingFractionalOdd;
187 case Maxwell::TessellationSpacing::FractionalEven:
188 return spv::ExecutionMode::SpacingFractionalEven;
189 }
190 UNREACHABLE();
191 return spv::ExecutionMode::SpacingEqual;
192}
193
194spv::ExecutionMode GetExecutionMode(Maxwell::PrimitiveTopology input_topology) {
195 switch (input_topology) {
196 case Maxwell::PrimitiveTopology::Points:
197 return spv::ExecutionMode::InputPoints;
198 case Maxwell::PrimitiveTopology::Lines:
199 case Maxwell::PrimitiveTopology::LineLoop:
200 case Maxwell::PrimitiveTopology::LineStrip:
201 return spv::ExecutionMode::InputLines;
202 case Maxwell::PrimitiveTopology::Triangles:
203 case Maxwell::PrimitiveTopology::TriangleStrip:
204 case Maxwell::PrimitiveTopology::TriangleFan:
205 return spv::ExecutionMode::Triangles;
206 case Maxwell::PrimitiveTopology::LinesAdjacency:
207 case Maxwell::PrimitiveTopology::LineStripAdjacency:
208 return spv::ExecutionMode::InputLinesAdjacency;
209 case Maxwell::PrimitiveTopology::TrianglesAdjacency:
210 case Maxwell::PrimitiveTopology::TriangleStripAdjacency:
211 return spv::ExecutionMode::InputTrianglesAdjacency;
212 case Maxwell::PrimitiveTopology::Quads:
213 UNIMPLEMENTED_MSG("Quads");
214 return spv::ExecutionMode::Triangles;
215 case Maxwell::PrimitiveTopology::QuadStrip:
216 UNIMPLEMENTED_MSG("QuadStrip");
217 return spv::ExecutionMode::Triangles;
218 case Maxwell::PrimitiveTopology::Polygon:
219 UNIMPLEMENTED_MSG("Polygon");
220 return spv::ExecutionMode::Triangles;
221 case Maxwell::PrimitiveTopology::Patches:
222 UNIMPLEMENTED_MSG("Patches");
223 return spv::ExecutionMode::Triangles;
224 }
225 UNREACHABLE();
226 return spv::ExecutionMode::Triangles;
227}
228
229spv::ExecutionMode GetExecutionMode(Tegra::Shader::OutputTopology output_topology) {
230 switch (output_topology) {
231 case Tegra::Shader::OutputTopology::PointList:
232 return spv::ExecutionMode::OutputPoints;
233 case Tegra::Shader::OutputTopology::LineStrip:
234 return spv::ExecutionMode::OutputLineStrip;
235 case Tegra::Shader::OutputTopology::TriangleStrip:
236 return spv::ExecutionMode::OutputTriangleStrip;
237 default:
238 UNREACHABLE();
239 return spv::ExecutionMode::OutputPoints;
240 }
241}
242
69/// Returns true if an attribute index is one of the 32 generic attributes 243/// Returns true if an attribute index is one of the 32 generic attributes
70constexpr bool IsGenericAttribute(Attribute::Index attribute) { 244constexpr bool IsGenericAttribute(Attribute::Index attribute) {
71 return attribute >= Attribute::Index::Attribute_0 && 245 return attribute >= Attribute::Index::Attribute_0 &&
@@ -73,7 +247,7 @@ constexpr bool IsGenericAttribute(Attribute::Index attribute) {
73} 247}
74 248
75/// Returns the location of a generic attribute 249/// Returns the location of a generic attribute
76constexpr u32 GetGenericAttributeLocation(Attribute::Index attribute) { 250u32 GetGenericAttributeLocation(Attribute::Index attribute) {
77 ASSERT(IsGenericAttribute(attribute)); 251 ASSERT(IsGenericAttribute(attribute));
78 return static_cast<u32>(attribute) - static_cast<u32>(Attribute::Index::Attribute_0); 252 return static_cast<u32>(attribute) - static_cast<u32>(Attribute::Index::Attribute_0);
79} 253}
@@ -87,20 +261,146 @@ bool IsPrecise(Operation operand) {
87 return false; 261 return false;
88} 262}
89 263
90} // namespace 264class SPIRVDecompiler final : public Sirit::Module {
91
92class ASTDecompiler;
93class ExprDecompiler;
94
95class SPIRVDecompiler : public Sirit::Module {
96public: 265public:
97 explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage) 266 explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage,
98 : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()} { 267 const Specialization& specialization)
268 : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()},
269 specialization{specialization} {
99 AddCapability(spv::Capability::Shader); 270 AddCapability(spv::Capability::Shader);
271 AddCapability(spv::Capability::UniformAndStorageBuffer16BitAccess);
272 AddCapability(spv::Capability::ImageQuery);
273 AddCapability(spv::Capability::Image1D);
274 AddCapability(spv::Capability::ImageBuffer);
275 AddCapability(spv::Capability::ImageGatherExtended);
276 AddCapability(spv::Capability::SampledBuffer);
277 AddCapability(spv::Capability::StorageImageWriteWithoutFormat);
278 AddCapability(spv::Capability::SubgroupBallotKHR);
279 AddCapability(spv::Capability::SubgroupVoteKHR);
280 AddExtension("SPV_KHR_shader_ballot");
281 AddExtension("SPV_KHR_subgroup_vote");
100 AddExtension("SPV_KHR_storage_buffer_storage_class"); 282 AddExtension("SPV_KHR_storage_buffer_storage_class");
101 AddExtension("SPV_KHR_variable_pointers"); 283 AddExtension("SPV_KHR_variable_pointers");
284
285 if (ir.UsesViewportIndex()) {
286 AddCapability(spv::Capability::MultiViewport);
287 if (device.IsExtShaderViewportIndexLayerSupported()) {
288 AddExtension("SPV_EXT_shader_viewport_index_layer");
289 AddCapability(spv::Capability::ShaderViewportIndexLayerEXT);
290 }
291 }
292
293 if (device.IsFloat16Supported()) {
294 AddCapability(spv::Capability::Float16);
295 }
296 t_scalar_half = Name(TypeFloat(device.IsFloat16Supported() ? 16 : 32), "scalar_half");
297 t_half = Name(TypeVector(t_scalar_half, 2), "half");
298
299 const Id main = Decompile();
300
301 switch (stage) {
302 case ShaderType::Vertex:
303 AddEntryPoint(spv::ExecutionModel::Vertex, main, "main", interfaces);
304 break;
305 case ShaderType::TesselationControl:
306 AddCapability(spv::Capability::Tessellation);
307 AddEntryPoint(spv::ExecutionModel::TessellationControl, main, "main", interfaces);
308 AddExecutionMode(main, spv::ExecutionMode::OutputVertices,
309 header.common2.threads_per_input_primitive);
310 break;
311 case ShaderType::TesselationEval:
312 AddCapability(spv::Capability::Tessellation);
313 AddEntryPoint(spv::ExecutionModel::TessellationEvaluation, main, "main", interfaces);
314 AddExecutionMode(main, GetExecutionMode(specialization.tessellation.primitive));
315 AddExecutionMode(main, GetExecutionMode(specialization.tessellation.spacing));
316 AddExecutionMode(main, specialization.tessellation.clockwise
317 ? spv::ExecutionMode::VertexOrderCw
318 : spv::ExecutionMode::VertexOrderCcw);
319 break;
320 case ShaderType::Geometry:
321 AddCapability(spv::Capability::Geometry);
322 AddEntryPoint(spv::ExecutionModel::Geometry, main, "main", interfaces);
323 AddExecutionMode(main, GetExecutionMode(specialization.primitive_topology));
324 AddExecutionMode(main, GetExecutionMode(header.common3.output_topology));
325 AddExecutionMode(main, spv::ExecutionMode::OutputVertices,
326 header.common4.max_output_vertices);
327 // TODO(Rodrigo): Where can we get this info from?
328 AddExecutionMode(main, spv::ExecutionMode::Invocations, 1U);
329 break;
330 case ShaderType::Fragment:
331 AddEntryPoint(spv::ExecutionModel::Fragment, main, "main", interfaces);
332 AddExecutionMode(main, spv::ExecutionMode::OriginUpperLeft);
333 if (header.ps.omap.depth) {
334 AddExecutionMode(main, spv::ExecutionMode::DepthReplacing);
335 }
336 break;
337 case ShaderType::Compute:
338 const auto workgroup_size = specialization.workgroup_size;
339 AddExecutionMode(main, spv::ExecutionMode::LocalSize, workgroup_size[0],
340 workgroup_size[1], workgroup_size[2]);
341 AddEntryPoint(spv::ExecutionModel::GLCompute, main, "main", interfaces);
342 break;
343 }
102 } 344 }
103 345
346private:
347 Id Decompile() {
348 DeclareCommon();
349 DeclareVertex();
350 DeclareTessControl();
351 DeclareTessEval();
352 DeclareGeometry();
353 DeclareFragment();
354 DeclareCompute();
355 DeclareRegisters();
356 DeclarePredicates();
357 DeclareLocalMemory();
358 DeclareSharedMemory();
359 DeclareInternalFlags();
360 DeclareInputAttributes();
361 DeclareOutputAttributes();
362
363 u32 binding = specialization.base_binding;
364 binding = DeclareConstantBuffers(binding);
365 binding = DeclareGlobalBuffers(binding);
366 binding = DeclareTexelBuffers(binding);
367 binding = DeclareSamplers(binding);
368 binding = DeclareImages(binding);
369
370 const Id main = OpFunction(t_void, {}, TypeFunction(t_void));
371 AddLabel();
372
373 if (ir.IsDecompiled()) {
374 DeclareFlowVariables();
375 DecompileAST();
376 } else {
377 AllocateLabels();
378 DecompileBranchMode();
379 }
380
381 OpReturn();
382 OpFunctionEnd();
383
384 return main;
385 }
386
387 void DefinePrologue() {
388 if (stage == ShaderType::Vertex) {
389 // Clear Position to avoid reading trash on the Z conversion.
390 const auto position_index = out_indices.position.value();
391 const Id position = AccessElement(t_out_float4, out_vertex, position_index);
392 OpStore(position, v_varying_default);
393
394 if (specialization.point_size) {
395 const u32 point_size_index = out_indices.point_size.value();
396 const Id out_point_size = AccessElement(t_out_float, out_vertex, point_size_index);
397 OpStore(out_point_size, Constant(t_float, *specialization.point_size));
398 }
399 }
400 }
401
402 void DecompileAST();
403
104 void DecompileBranchMode() { 404 void DecompileBranchMode() {
105 const u32 first_address = ir.GetBasicBlocks().begin()->first; 405 const u32 first_address = ir.GetBasicBlocks().begin()->first;
106 const Id loop_label = OpLabel("loop"); 406 const Id loop_label = OpLabel("loop");
@@ -111,14 +411,15 @@ public:
111 411
112 std::vector<Sirit::Literal> literals; 412 std::vector<Sirit::Literal> literals;
113 std::vector<Id> branch_labels; 413 std::vector<Id> branch_labels;
114 for (const auto& pair : labels) { 414 for (const auto& [literal, label] : labels) {
115 const auto [literal, label] = pair;
116 literals.push_back(literal); 415 literals.push_back(literal);
117 branch_labels.push_back(label); 416 branch_labels.push_back(label);
118 } 417 }
119 418
120 jmp_to = Emit(OpVariable(TypePointer(spv::StorageClass::Function, t_uint), 419 jmp_to = OpVariable(TypePointer(spv::StorageClass::Function, t_uint),
121 spv::StorageClass::Function, Constant(t_uint, first_address))); 420 spv::StorageClass::Function, Constant(t_uint, first_address));
421 AddLocalVariable(jmp_to);
422
122 std::tie(ssy_flow_stack, ssy_flow_stack_top) = CreateFlowStack(); 423 std::tie(ssy_flow_stack, ssy_flow_stack_top) = CreateFlowStack();
123 std::tie(pbk_flow_stack, pbk_flow_stack_top) = CreateFlowStack(); 424 std::tie(pbk_flow_stack, pbk_flow_stack_top) = CreateFlowStack();
124 425
@@ -128,151 +429,118 @@ public:
128 Name(pbk_flow_stack, "pbk_flow_stack"); 429 Name(pbk_flow_stack, "pbk_flow_stack");
129 Name(pbk_flow_stack_top, "pbk_flow_stack_top"); 430 Name(pbk_flow_stack_top, "pbk_flow_stack_top");
130 431
131 Emit(OpBranch(loop_label)); 432 DefinePrologue();
132 Emit(loop_label); 433
133 Emit(OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::Unroll)); 434 OpBranch(loop_label);
134 Emit(OpBranch(dummy_label)); 435 AddLabel(loop_label);
436 OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::MaskNone);
437 OpBranch(dummy_label);
135 438
136 Emit(dummy_label); 439 AddLabel(dummy_label);
137 const Id default_branch = OpLabel(); 440 const Id default_branch = OpLabel();
138 const Id jmp_to_load = Emit(OpLoad(t_uint, jmp_to)); 441 const Id jmp_to_load = OpLoad(t_uint, jmp_to);
139 Emit(OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone)); 442 OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone);
140 Emit(OpSwitch(jmp_to_load, default_branch, literals, branch_labels)); 443 OpSwitch(jmp_to_load, default_branch, literals, branch_labels);
141 444
142 Emit(default_branch); 445 AddLabel(default_branch);
143 Emit(OpReturn()); 446 OpReturn();
144 447
145 for (const auto& pair : ir.GetBasicBlocks()) { 448 for (const auto& [address, bb] : ir.GetBasicBlocks()) {
146 const auto& [address, bb] = pair; 449 AddLabel(labels.at(address));
147 Emit(labels.at(address));
148 450
149 VisitBasicBlock(bb); 451 VisitBasicBlock(bb);
150 452
151 const auto next_it = labels.lower_bound(address + 1); 453 const auto next_it = labels.lower_bound(address + 1);
152 const Id next_label = next_it != labels.end() ? next_it->second : default_branch; 454 const Id next_label = next_it != labels.end() ? next_it->second : default_branch;
153 Emit(OpBranch(next_label)); 455 OpBranch(next_label);
154 } 456 }
155 457
156 Emit(jump_label); 458 AddLabel(jump_label);
157 Emit(OpBranch(continue_label)); 459 OpBranch(continue_label);
158 Emit(continue_label); 460 AddLabel(continue_label);
159 Emit(OpBranch(loop_label)); 461 OpBranch(loop_label);
160 Emit(merge_label); 462 AddLabel(merge_label);
161 } 463 }
162 464
163 void DecompileAST(); 465private:
466 friend class ASTDecompiler;
467 friend class ExprDecompiler;
164 468
165 void Decompile() { 469 static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
166 const bool is_fully_decompiled = ir.IsDecompiled();
167 AllocateBindings();
168 if (!is_fully_decompiled) {
169 AllocateLabels();
170 }
171 470
172 DeclareVertex(); 471 void AllocateLabels() {
173 DeclareGeometry(); 472 for (const auto& pair : ir.GetBasicBlocks()) {
174 DeclareFragment(); 473 const u32 address = pair.first;
175 DeclareRegisters(); 474 labels.emplace(address, OpLabel(fmt::format("label_0x{:x}", address)));
176 DeclarePredicates();
177 if (is_fully_decompiled) {
178 DeclareFlowVariables();
179 } 475 }
180 DeclareLocalMemory(); 476 }
181 DeclareInternalFlags();
182 DeclareInputAttributes();
183 DeclareOutputAttributes();
184 DeclareConstantBuffers();
185 DeclareGlobalBuffers();
186 DeclareSamplers();
187 477
188 execute_function = 478 void DeclareCommon() {
189 Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void))); 479 thread_id =
190 Emit(OpLabel()); 480 DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id");
481 }
191 482
192 if (is_fully_decompiled) { 483 void DeclareVertex() {
193 DecompileAST(); 484 if (stage != ShaderType::Vertex) {
194 } else { 485 return;
195 DecompileBranchMode();
196 } 486 }
487 Id out_vertex_struct;
488 std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct();
489 const Id vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct);
490 out_vertex = OpVariable(vertex_ptr, spv::StorageClass::Output);
491 interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex")));
197 492
198 Emit(OpReturn()); 493 // Declare input attributes
199 Emit(OpFunctionEnd()); 494 vertex_index = DeclareInputBuiltIn(spv::BuiltIn::VertexIndex, t_in_uint, "vertex_index");
495 instance_index =
496 DeclareInputBuiltIn(spv::BuiltIn::InstanceIndex, t_in_uint, "instance_index");
200 } 497 }
201 498
202 ShaderEntries GetShaderEntries() const { 499 void DeclareTessControl() {
203 ShaderEntries entries; 500 if (stage != ShaderType::TesselationControl) {
204 entries.const_buffers_base_binding = const_buffers_base_binding; 501 return;
205 entries.global_buffers_base_binding = global_buffers_base_binding;
206 entries.samplers_base_binding = samplers_base_binding;
207 for (const auto& cbuf : ir.GetConstantBuffers()) {
208 entries.const_buffers.emplace_back(cbuf.second, cbuf.first);
209 }
210 for (const auto& gmem_pair : ir.GetGlobalMemory()) {
211 const auto& [base, usage] = gmem_pair;
212 entries.global_buffers.emplace_back(base.cbuf_index, base.cbuf_offset);
213 }
214 for (const auto& sampler : ir.GetSamplers()) {
215 entries.samplers.emplace_back(sampler);
216 }
217 for (const auto& attribute : ir.GetInputAttributes()) {
218 if (IsGenericAttribute(attribute)) {
219 entries.attributes.insert(GetGenericAttributeLocation(attribute));
220 }
221 } 502 }
222 entries.clip_distances = ir.GetClipDistances(); 503 DeclareInputVertexArray(NumInputPatches);
223 entries.shader_length = ir.GetLength(); 504 DeclareOutputVertexArray(header.common2.threads_per_input_primitive);
224 entries.entry_function = execute_function;
225 entries.interfaces = interfaces;
226 return entries;
227 }
228
229private:
230 friend class ASTDecompiler;
231 friend class ExprDecompiler;
232
233 static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
234 505
235 void AllocateBindings() { 506 tess_level_outer = DeclareBuiltIn(
236 const u32 binding_base = static_cast<u32>(stage) * STAGE_BINDING_STRIDE; 507 spv::BuiltIn::TessLevelOuter, spv::StorageClass::Output,
237 u32 binding_iterator = binding_base; 508 TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 4U))),
509 "tess_level_outer");
510 Decorate(tess_level_outer, spv::Decoration::Patch);
238 511
239 const auto Allocate = [&binding_iterator](std::size_t count) { 512 tess_level_inner = DeclareBuiltIn(
240 const u32 current_binding = binding_iterator; 513 spv::BuiltIn::TessLevelInner, spv::StorageClass::Output,
241 binding_iterator += static_cast<u32>(count); 514 TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 2U))),
242 return current_binding; 515 "tess_level_inner");
243 }; 516 Decorate(tess_level_inner, spv::Decoration::Patch);
244 const_buffers_base_binding = Allocate(ir.GetConstantBuffers().size());
245 global_buffers_base_binding = Allocate(ir.GetGlobalMemory().size());
246 samplers_base_binding = Allocate(ir.GetSamplers().size());
247 517
248 ASSERT_MSG(binding_iterator - binding_base < STAGE_BINDING_STRIDE, 518 invocation_id = DeclareInputBuiltIn(spv::BuiltIn::InvocationId, t_in_int, "invocation_id");
249 "Stage binding stride is too small");
250 } 519 }
251 520
252 void AllocateLabels() { 521 void DeclareTessEval() {
253 for (const auto& pair : ir.GetBasicBlocks()) { 522 if (stage != ShaderType::TesselationEval) {
254 const u32 address = pair.first;
255 labels.emplace(address, OpLabel(fmt::format("label_0x{:x}", address)));
256 }
257 }
258
259 void DeclareVertex() {
260 if (stage != ShaderType::Vertex)
261 return; 523 return;
524 }
525 DeclareInputVertexArray(NumInputPatches);
526 DeclareOutputVertex();
262 527
263 DeclareVertexRedeclarations(); 528 tess_coord = DeclareInputBuiltIn(spv::BuiltIn::TessCoord, t_in_float3, "tess_coord");
264 } 529 }
265 530
266 void DeclareGeometry() { 531 void DeclareGeometry() {
267 if (stage != ShaderType::Geometry) 532 if (stage != ShaderType::Geometry) {
268 return; 533 return;
269 534 }
270 UNIMPLEMENTED(); 535 const u32 num_input = GetNumPrimitiveTopologyVertices(specialization.primitive_topology);
536 DeclareInputVertexArray(num_input);
537 DeclareOutputVertex();
271 } 538 }
272 539
273 void DeclareFragment() { 540 void DeclareFragment() {
274 if (stage != ShaderType::Fragment) 541 if (stage != ShaderType::Fragment) {
275 return; 542 return;
543 }
276 544
277 for (u32 rt = 0; rt < static_cast<u32>(frag_colors.size()); ++rt) { 545 for (u32 rt = 0; rt < static_cast<u32>(frag_colors.size()); ++rt) {
278 if (!IsRenderTargetUsed(rt)) { 546 if (!IsRenderTargetUsed(rt)) {
@@ -296,10 +564,19 @@ private:
296 interfaces.push_back(frag_depth); 564 interfaces.push_back(frag_depth);
297 } 565 }
298 566
299 frag_coord = DeclareBuiltIn(spv::BuiltIn::FragCoord, spv::StorageClass::Input, t_in_float4, 567 frag_coord = DeclareInputBuiltIn(spv::BuiltIn::FragCoord, t_in_float4, "frag_coord");
300 "frag_coord"); 568 front_facing = DeclareInputBuiltIn(spv::BuiltIn::FrontFacing, t_in_bool, "front_facing");
301 front_facing = DeclareBuiltIn(spv::BuiltIn::FrontFacing, spv::StorageClass::Input, 569 point_coord = DeclareInputBuiltIn(spv::BuiltIn::PointCoord, t_in_float2, "point_coord");
302 t_in_bool, "front_facing"); 570 }
571
572 void DeclareCompute() {
573 if (stage != ShaderType::Compute) {
574 return;
575 }
576
577 workgroup_id = DeclareInputBuiltIn(spv::BuiltIn::WorkgroupId, t_in_uint3, "workgroup_id");
578 local_invocation_id =
579 DeclareInputBuiltIn(spv::BuiltIn::LocalInvocationId, t_in_uint3, "local_invocation_id");
303 } 580 }
304 581
305 void DeclareRegisters() { 582 void DeclareRegisters() {
@@ -327,21 +604,44 @@ private:
327 } 604 }
328 605
329 void DeclareLocalMemory() { 606 void DeclareLocalMemory() {
330 if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) { 607 // TODO(Rodrigo): Unstub kernel local memory size and pass it from a register at
331 const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4); 608 // specialization time.
332 const Id type_array = TypeArray(t_float, Constant(t_uint, element_count)); 609 const u64 lmem_size = stage == ShaderType::Compute ? 0x400 : header.GetLocalMemorySize();
333 const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array); 610 if (lmem_size == 0) {
334 Name(type_pointer, "LocalMemory"); 611 return;
612 }
613 const auto element_count = static_cast<u32>(Common::AlignUp(lmem_size, 4) / 4);
614 const Id type_array = TypeArray(t_float, Constant(t_uint, element_count));
615 const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array);
616 Name(type_pointer, "LocalMemory");
617
618 local_memory =
619 OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array));
620 AddGlobalVariable(Name(local_memory, "local_memory"));
621 }
622
623 void DeclareSharedMemory() {
624 if (stage != ShaderType::Compute) {
625 return;
626 }
627 t_smem_uint = TypePointer(spv::StorageClass::Workgroup, t_uint);
335 628
336 local_memory = 629 const u32 smem_size = specialization.shared_memory_size;
337 OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array)); 630 if (smem_size == 0) {
338 AddGlobalVariable(Name(local_memory, "local_memory")); 631 // Avoid declaring an empty array.
632 return;
339 } 633 }
634 const auto element_count = static_cast<u32>(Common::AlignUp(smem_size, 4) / 4);
635 const Id type_array = TypeArray(t_uint, Constant(t_uint, element_count));
636 const Id type_pointer = TypePointer(spv::StorageClass::Workgroup, type_array);
637 Name(type_pointer, "SharedMemory");
638
639 shared_memory = OpVariable(type_pointer, spv::StorageClass::Workgroup);
640 AddGlobalVariable(Name(shared_memory, "shared_memory"));
340 } 641 }
341 642
342 void DeclareInternalFlags() { 643 void DeclareInternalFlags() {
343 constexpr std::array<const char*, INTERNAL_FLAGS_COUNT> names = {"zero", "sign", "carry", 644 constexpr std::array names = {"zero", "sign", "carry", "overflow"};
344 "overflow"};
345 for (std::size_t flag = 0; flag < INTERNAL_FLAGS_COUNT; ++flag) { 645 for (std::size_t flag = 0; flag < INTERNAL_FLAGS_COUNT; ++flag) {
346 const auto flag_code = static_cast<InternalFlag>(flag); 646 const auto flag_code = static_cast<InternalFlag>(flag);
347 const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false); 647 const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
@@ -349,17 +649,53 @@ private:
349 } 649 }
350 } 650 }
351 651
652 void DeclareInputVertexArray(u32 length) {
653 constexpr auto storage = spv::StorageClass::Input;
654 std::tie(in_indices, in_vertex) = DeclareVertexArray(storage, "in_indices", length);
655 }
656
657 void DeclareOutputVertexArray(u32 length) {
658 constexpr auto storage = spv::StorageClass::Output;
659 std::tie(out_indices, out_vertex) = DeclareVertexArray(storage, "out_indices", length);
660 }
661
662 std::tuple<VertexIndices, Id> DeclareVertexArray(spv::StorageClass storage_class,
663 std::string name, u32 length) {
664 const auto [struct_id, indices] = DeclareVertexStruct();
665 const Id vertex_array = TypeArray(struct_id, Constant(t_uint, length));
666 const Id vertex_ptr = TypePointer(storage_class, vertex_array);
667 const Id vertex = OpVariable(vertex_ptr, storage_class);
668 AddGlobalVariable(Name(vertex, std::move(name)));
669 interfaces.push_back(vertex);
670 return {indices, vertex};
671 }
672
673 void DeclareOutputVertex() {
674 Id out_vertex_struct;
675 std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct();
676 const Id out_vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct);
677 out_vertex = OpVariable(out_vertex_ptr, spv::StorageClass::Output);
678 interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex")));
679 }
680
352 void DeclareInputAttributes() { 681 void DeclareInputAttributes() {
353 for (const auto index : ir.GetInputAttributes()) { 682 for (const auto index : ir.GetInputAttributes()) {
354 if (!IsGenericAttribute(index)) { 683 if (!IsGenericAttribute(index)) {
355 continue; 684 continue;
356 } 685 }
357 686
358 UNIMPLEMENTED_IF(stage == ShaderType::Geometry);
359
360 const u32 location = GetGenericAttributeLocation(index); 687 const u32 location = GetGenericAttributeLocation(index);
361 const Id id = OpVariable(t_in_float4, spv::StorageClass::Input); 688 const auto type_descriptor = GetAttributeType(location);
362 Name(AddGlobalVariable(id), fmt::format("in_attr{}", location)); 689 Id type;
690 if (IsInputAttributeArray()) {
691 type = GetTypeVectorDefinitionLut(type_descriptor.type).at(3);
692 type = TypeArray(type, Constant(t_uint, GetNumInputVertices()));
693 type = TypePointer(spv::StorageClass::Input, type);
694 } else {
695 type = type_descriptor.vector;
696 }
697 const Id id = OpVariable(type, spv::StorageClass::Input);
698 AddGlobalVariable(Name(id, fmt::format("in_attr{}", location)));
363 input_attributes.emplace(index, id); 699 input_attributes.emplace(index, id);
364 interfaces.push_back(id); 700 interfaces.push_back(id);
365 701
@@ -389,8 +725,21 @@ private:
389 if (!IsGenericAttribute(index)) { 725 if (!IsGenericAttribute(index)) {
390 continue; 726 continue;
391 } 727 }
392 const auto location = GetGenericAttributeLocation(index); 728 const u32 location = GetGenericAttributeLocation(index);
393 const Id id = OpVariable(t_out_float4, spv::StorageClass::Output); 729 Id type = t_float4;
730 Id varying_default = v_varying_default;
731 if (IsOutputAttributeArray()) {
732 const u32 num = GetNumOutputVertices();
733 type = TypeArray(type, Constant(t_uint, num));
734 if (device.GetDriverID() != vk::DriverIdKHR::eIntelProprietaryWindows) {
735 // Intel's proprietary driver fails to setup defaults for arrayed output
736 // attributes.
737 varying_default = ConstantComposite(type, std::vector(num, varying_default));
738 }
739 }
740 type = TypePointer(spv::StorageClass::Output, type);
741
742 const Id id = OpVariable(type, spv::StorageClass::Output, varying_default);
394 Name(AddGlobalVariable(id), fmt::format("out_attr{}", location)); 743 Name(AddGlobalVariable(id), fmt::format("out_attr{}", location));
395 output_attributes.emplace(index, id); 744 output_attributes.emplace(index, id);
396 interfaces.push_back(id); 745 interfaces.push_back(id);
@@ -399,10 +748,8 @@ private:
399 } 748 }
400 } 749 }
401 750
402 void DeclareConstantBuffers() { 751 u32 DeclareConstantBuffers(u32 binding) {
403 u32 binding = const_buffers_base_binding; 752 for (const auto& [index, size] : ir.GetConstantBuffers()) {
404 for (const auto& entry : ir.GetConstantBuffers()) {
405 const auto [index, size] = entry;
406 const Id type = device.IsKhrUniformBufferStandardLayoutSupported() ? t_cbuf_scalar_ubo 753 const Id type = device.IsKhrUniformBufferStandardLayoutSupported() ? t_cbuf_scalar_ubo
407 : t_cbuf_std140_ubo; 754 : t_cbuf_std140_ubo;
408 const Id id = OpVariable(type, spv::StorageClass::Uniform); 755 const Id id = OpVariable(type, spv::StorageClass::Uniform);
@@ -412,12 +759,11 @@ private:
412 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET); 759 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
413 constant_buffers.emplace(index, id); 760 constant_buffers.emplace(index, id);
414 } 761 }
762 return binding;
415 } 763 }
416 764
417 void DeclareGlobalBuffers() { 765 u32 DeclareGlobalBuffers(u32 binding) {
418 u32 binding = global_buffers_base_binding; 766 for (const auto& [base, usage] : ir.GetGlobalMemory()) {
419 for (const auto& entry : ir.GetGlobalMemory()) {
420 const auto [base, usage] = entry;
421 const Id id = OpVariable(t_gmem_ssbo, spv::StorageClass::StorageBuffer); 767 const Id id = OpVariable(t_gmem_ssbo, spv::StorageClass::StorageBuffer);
422 AddGlobalVariable( 768 AddGlobalVariable(
423 Name(id, fmt::format("gmem_{}_{}", base.cbuf_index, base.cbuf_offset))); 769 Name(id, fmt::format("gmem_{}_{}", base.cbuf_index, base.cbuf_offset)));
@@ -426,89 +772,187 @@ private:
426 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET); 772 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
427 global_buffers.emplace(base, id); 773 global_buffers.emplace(base, id);
428 } 774 }
775 return binding;
776 }
777
778 u32 DeclareTexelBuffers(u32 binding) {
779 for (const auto& sampler : ir.GetSamplers()) {
780 if (!sampler.IsBuffer()) {
781 continue;
782 }
783 ASSERT(!sampler.IsArray());
784 ASSERT(!sampler.IsShadow());
785
786 constexpr auto dim = spv::Dim::Buffer;
787 constexpr int depth = 0;
788 constexpr int arrayed = 0;
789 constexpr bool ms = false;
790 constexpr int sampled = 1;
791 constexpr auto format = spv::ImageFormat::Unknown;
792 const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format);
793 const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type);
794 const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
795 AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex())));
796 Decorate(id, spv::Decoration::Binding, binding++);
797 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
798
799 texel_buffers.emplace(sampler.GetIndex(), TexelBuffer{image_type, id});
800 }
801 return binding;
429 } 802 }
430 803
431 void DeclareSamplers() { 804 u32 DeclareSamplers(u32 binding) {
432 u32 binding = samplers_base_binding;
433 for (const auto& sampler : ir.GetSamplers()) { 805 for (const auto& sampler : ir.GetSamplers()) {
806 if (sampler.IsBuffer()) {
807 continue;
808 }
434 const auto dim = GetSamplerDim(sampler); 809 const auto dim = GetSamplerDim(sampler);
435 const int depth = sampler.IsShadow() ? 1 : 0; 810 const int depth = sampler.IsShadow() ? 1 : 0;
436 const int arrayed = sampler.IsArray() ? 1 : 0; 811 const int arrayed = sampler.IsArray() ? 1 : 0;
437 // TODO(Rodrigo): Sampled 1 indicates that the image will be used with a sampler. When 812 constexpr bool ms = false;
438 // SULD and SUST instructions are implemented, replace this value. 813 constexpr int sampled = 1;
439 const int sampled = 1; 814 constexpr auto format = spv::ImageFormat::Unknown;
440 const Id image_type = 815 const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format);
441 TypeImage(t_float, dim, depth, arrayed, false, sampled, spv::ImageFormat::Unknown);
442 const Id sampled_image_type = TypeSampledImage(image_type); 816 const Id sampled_image_type = TypeSampledImage(image_type);
443 const Id pointer_type = 817 const Id pointer_type =
444 TypePointer(spv::StorageClass::UniformConstant, sampled_image_type); 818 TypePointer(spv::StorageClass::UniformConstant, sampled_image_type);
445 const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant); 819 const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
446 AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex()))); 820 AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex())));
821 Decorate(id, spv::Decoration::Binding, binding++);
822 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
447 823
448 sampler_images.insert( 824 sampled_images.emplace(sampler.GetIndex(),
449 {static_cast<u32>(sampler.GetIndex()), {image_type, sampled_image_type, id}}); 825 SampledImage{image_type, sampled_image_type, id});
826 }
827 return binding;
828 }
829
830 u32 DeclareImages(u32 binding) {
831 for (const auto& image : ir.GetImages()) {
832 const auto [dim, arrayed] = GetImageDim(image);
833 constexpr int depth = 0;
834 constexpr bool ms = false;
835 constexpr int sampled = 2; // This won't be accessed with a sampler
836 constexpr auto format = spv::ImageFormat::Unknown;
837 const Id image_type = TypeImage(t_uint, dim, depth, arrayed, ms, sampled, format, {});
838 const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type);
839 const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
840 AddGlobalVariable(Name(id, fmt::format("image_{}", image.GetIndex())));
450 841
451 Decorate(id, spv::Decoration::Binding, binding++); 842 Decorate(id, spv::Decoration::Binding, binding++);
452 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET); 843 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
844 if (image.IsRead() && !image.IsWritten()) {
845 Decorate(id, spv::Decoration::NonWritable);
846 } else if (image.IsWritten() && !image.IsRead()) {
847 Decorate(id, spv::Decoration::NonReadable);
848 }
849
850 images.emplace(static_cast<u32>(image.GetIndex()), StorageImage{image_type, id});
453 } 851 }
852 return binding;
853 }
854
855 bool IsInputAttributeArray() const {
856 return stage == ShaderType::TesselationControl || stage == ShaderType::TesselationEval ||
857 stage == ShaderType::Geometry;
454 } 858 }
455 859
456 void DeclareVertexRedeclarations() { 860 bool IsOutputAttributeArray() const {
457 vertex_index = DeclareBuiltIn(spv::BuiltIn::VertexIndex, spv::StorageClass::Input, 861 return stage == ShaderType::TesselationControl;
458 t_in_uint, "vertex_index"); 862 }
459 instance_index = DeclareBuiltIn(spv::BuiltIn::InstanceIndex, spv::StorageClass::Input,
460 t_in_uint, "instance_index");
461 863
462 bool is_clip_distances_declared = false; 864 u32 GetNumInputVertices() const {
463 for (const auto index : ir.GetOutputAttributes()) { 865 switch (stage) {
464 if (index == Attribute::Index::ClipDistances0123 || 866 case ShaderType::Geometry:
465 index == Attribute::Index::ClipDistances4567) { 867 return GetNumPrimitiveTopologyVertices(specialization.primitive_topology);
466 is_clip_distances_declared = true; 868 case ShaderType::TesselationControl:
467 } 869 case ShaderType::TesselationEval:
870 return NumInputPatches;
871 default:
872 UNREACHABLE();
873 return 1;
468 } 874 }
875 }
469 876
470 std::vector<Id> members; 877 u32 GetNumOutputVertices() const {
471 members.push_back(t_float4); 878 switch (stage) {
472 if (ir.UsesPointSize()) { 879 case ShaderType::TesselationControl:
473 members.push_back(t_float); 880 return header.common2.threads_per_input_primitive;
474 } 881 default:
475 if (is_clip_distances_declared) { 882 UNREACHABLE();
476 members.push_back(TypeArray(t_float, Constant(t_uint, 8))); 883 return 1;
477 } 884 }
478 885 }
479 const Id gl_per_vertex_struct = Name(TypeStruct(members), "PerVertex"); 886
480 Decorate(gl_per_vertex_struct, spv::Decoration::Block); 887 std::tuple<Id, VertexIndices> DeclareVertexStruct() {
481 888 struct BuiltIn {
482 u32 declaration_index = 0; 889 Id type;
483 const auto MemberDecorateBuiltIn = [&](spv::BuiltIn builtin, std::string name, 890 spv::BuiltIn builtin;
484 bool condition) { 891 const char* name;
485 if (!condition) 892 };
486 return u32{}; 893 std::vector<BuiltIn> members;
487 MemberName(gl_per_vertex_struct, declaration_index, name); 894 members.reserve(4);
488 MemberDecorate(gl_per_vertex_struct, declaration_index, spv::Decoration::BuiltIn, 895
489 static_cast<u32>(builtin)); 896 const auto AddBuiltIn = [&](Id type, spv::BuiltIn builtin, const char* name) {
490 return declaration_index++; 897 const auto index = static_cast<u32>(members.size());
898 members.push_back(BuiltIn{type, builtin, name});
899 return index;
491 }; 900 };
492 901
493 position_index = MemberDecorateBuiltIn(spv::BuiltIn::Position, "position", true); 902 VertexIndices indices;
494 point_size_index = 903 indices.position = AddBuiltIn(t_float4, spv::BuiltIn::Position, "position");
495 MemberDecorateBuiltIn(spv::BuiltIn::PointSize, "point_size", ir.UsesPointSize()); 904
496 clip_distances_index = MemberDecorateBuiltIn(spv::BuiltIn::ClipDistance, "clip_distances", 905 if (ir.UsesViewportIndex()) {
497 is_clip_distances_declared); 906 if (stage != ShaderType::Vertex || device.IsExtShaderViewportIndexLayerSupported()) {
907 indices.viewport = AddBuiltIn(t_int, spv::BuiltIn::ViewportIndex, "viewport_index");
908 } else {
909 LOG_ERROR(Render_Vulkan,
910 "Shader requires ViewportIndex but it's not supported on this "
911 "stage with this device.");
912 }
913 }
914
915 if (ir.UsesPointSize() || specialization.point_size) {
916 indices.point_size = AddBuiltIn(t_float, spv::BuiltIn::PointSize, "point_size");
917 }
918
919 const auto& output_attributes = ir.GetOutputAttributes();
920 const bool declare_clip_distances =
921 std::any_of(output_attributes.begin(), output_attributes.end(), [](const auto& index) {
922 return index == Attribute::Index::ClipDistances0123 ||
923 index == Attribute::Index::ClipDistances4567;
924 });
925 if (declare_clip_distances) {
926 indices.clip_distances = AddBuiltIn(TypeArray(t_float, Constant(t_uint, 8)),
927 spv::BuiltIn::ClipDistance, "clip_distances");
928 }
929
930 std::vector<Id> member_types;
931 member_types.reserve(members.size());
932 for (std::size_t i = 0; i < members.size(); ++i) {
933 member_types.push_back(members[i].type);
934 }
935 const Id per_vertex_struct = Name(TypeStruct(member_types), "PerVertex");
936 Decorate(per_vertex_struct, spv::Decoration::Block);
937
938 for (std::size_t index = 0; index < members.size(); ++index) {
939 const auto& member = members[index];
940 MemberName(per_vertex_struct, static_cast<u32>(index), member.name);
941 MemberDecorate(per_vertex_struct, static_cast<u32>(index), spv::Decoration::BuiltIn,
942 static_cast<u32>(member.builtin));
943 }
498 944
499 const Id type_pointer = TypePointer(spv::StorageClass::Output, gl_per_vertex_struct); 945 return {per_vertex_struct, indices};
500 per_vertex = OpVariable(type_pointer, spv::StorageClass::Output);
501 AddGlobalVariable(Name(per_vertex, "per_vertex"));
502 interfaces.push_back(per_vertex);
503 } 946 }
504 947
505 void VisitBasicBlock(const NodeBlock& bb) { 948 void VisitBasicBlock(const NodeBlock& bb) {
506 for (const auto& node : bb) { 949 for (const auto& node : bb) {
507 static_cast<void>(Visit(node)); 950 [[maybe_unused]] const Type type = Visit(node).type;
951 ASSERT(type == Type::Void);
508 } 952 }
509 } 953 }
510 954
511 Id Visit(const Node& node) { 955 Expression Visit(const Node& node) {
512 if (const auto operation = std::get_if<OperationNode>(&*node)) { 956 if (const auto operation = std::get_if<OperationNode>(&*node)) {
513 const auto operation_index = static_cast<std::size_t>(operation->GetCode()); 957 const auto operation_index = static_cast<std::size_t>(operation->GetCode());
514 const auto decompiler = operation_decompilers[operation_index]; 958 const auto decompiler = operation_decompilers[operation_index];
@@ -516,18 +960,21 @@ private:
516 UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index); 960 UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index);
517 } 961 }
518 return (this->*decompiler)(*operation); 962 return (this->*decompiler)(*operation);
963 }
519 964
520 } else if (const auto gpr = std::get_if<GprNode>(&*node)) { 965 if (const auto gpr = std::get_if<GprNode>(&*node)) {
521 const u32 index = gpr->GetIndex(); 966 const u32 index = gpr->GetIndex();
522 if (index == Register::ZeroIndex) { 967 if (index == Register::ZeroIndex) {
523 return Constant(t_float, 0.0f); 968 return {v_float_zero, Type::Float};
524 } 969 }
525 return Emit(OpLoad(t_float, registers.at(index))); 970 return {OpLoad(t_float, registers.at(index)), Type::Float};
971 }
526 972
527 } else if (const auto immediate = std::get_if<ImmediateNode>(&*node)) { 973 if (const auto immediate = std::get_if<ImmediateNode>(&*node)) {
528 return BitcastTo<Type::Float>(Constant(t_uint, immediate->GetValue())); 974 return {Constant(t_uint, immediate->GetValue()), Type::Uint};
975 }
529 976
530 } else if (const auto predicate = std::get_if<PredicateNode>(&*node)) { 977 if (const auto predicate = std::get_if<PredicateNode>(&*node)) {
531 const auto value = [&]() -> Id { 978 const auto value = [&]() -> Id {
532 switch (const auto index = predicate->GetIndex(); index) { 979 switch (const auto index = predicate->GetIndex(); index) {
533 case Tegra::Shader::Pred::UnusedIndex: 980 case Tegra::Shader::Pred::UnusedIndex:
@@ -535,74 +982,108 @@ private:
535 case Tegra::Shader::Pred::NeverExecute: 982 case Tegra::Shader::Pred::NeverExecute:
536 return v_false; 983 return v_false;
537 default: 984 default:
538 return Emit(OpLoad(t_bool, predicates.at(index))); 985 return OpLoad(t_bool, predicates.at(index));
539 } 986 }
540 }(); 987 }();
541 if (predicate->IsNegated()) { 988 if (predicate->IsNegated()) {
542 return Emit(OpLogicalNot(t_bool, value)); 989 return {OpLogicalNot(t_bool, value), Type::Bool};
543 } 990 }
544 return value; 991 return {value, Type::Bool};
992 }
545 993
546 } else if (const auto abuf = std::get_if<AbufNode>(&*node)) { 994 if (const auto abuf = std::get_if<AbufNode>(&*node)) {
547 const auto attribute = abuf->GetIndex(); 995 const auto attribute = abuf->GetIndex();
548 const auto element = abuf->GetElement(); 996 const u32 element = abuf->GetElement();
997 const auto& buffer = abuf->GetBuffer();
998
999 const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) {
1000 std::vector<Id> members;
1001 members.reserve(std::size(indices) + 1);
1002
1003 if (buffer && IsInputAttributeArray()) {
1004 members.push_back(AsUint(Visit(buffer)));
1005 }
1006 for (const u32 index : indices) {
1007 members.push_back(Constant(t_uint, index));
1008 }
1009 return OpAccessChain(pointer_type, composite, members);
1010 };
549 1011
550 switch (attribute) { 1012 switch (attribute) {
551 case Attribute::Index::Position: 1013 case Attribute::Index::Position: {
552 if (stage != ShaderType::Fragment) { 1014 if (stage == ShaderType::Fragment) {
553 UNIMPLEMENTED();
554 break;
555 } else {
556 if (element == 3) { 1015 if (element == 3) {
557 return Constant(t_float, 1.0f); 1016 return {Constant(t_float, 1.0f), Type::Float};
558 } 1017 }
559 return Emit(OpLoad(t_float, AccessElement(t_in_float, frag_coord, element))); 1018 return {OpLoad(t_float, AccessElement(t_in_float, frag_coord, element)),
1019 Type::Float};
1020 }
1021 const std::vector elements = {in_indices.position.value(), element};
1022 return {OpLoad(t_float, ArrayPass(t_in_float, in_vertex, elements)), Type::Float};
1023 }
1024 case Attribute::Index::PointCoord: {
1025 switch (element) {
1026 case 0:
1027 case 1:
1028 return {OpCompositeExtract(t_float, OpLoad(t_float2, point_coord), element),
1029 Type::Float};
560 } 1030 }
1031 UNIMPLEMENTED_MSG("Unimplemented point coord element={}", element);
1032 return {v_float_zero, Type::Float};
1033 }
561 case Attribute::Index::TessCoordInstanceIDVertexID: 1034 case Attribute::Index::TessCoordInstanceIDVertexID:
562 // TODO(Subv): Find out what the values are for the first two elements when inside a 1035 // TODO(Subv): Find out what the values are for the first two elements when inside a
563 // vertex shader, and what's the value of the fourth element when inside a Tess Eval 1036 // vertex shader, and what's the value of the fourth element when inside a Tess Eval
564 // shader. 1037 // shader.
565 ASSERT(stage == ShaderType::Vertex);
566 switch (element) { 1038 switch (element) {
1039 case 0:
1040 case 1:
1041 return {OpLoad(t_float, AccessElement(t_in_float, tess_coord, element)),
1042 Type::Float};
567 case 2: 1043 case 2:
568 return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, instance_index))); 1044 return {OpLoad(t_uint, instance_index), Type::Uint};
569 case 3: 1045 case 3:
570 return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, vertex_index))); 1046 return {OpLoad(t_uint, vertex_index), Type::Uint};
571 } 1047 }
572 UNIMPLEMENTED_MSG("Unmanaged TessCoordInstanceIDVertexID element={}", element); 1048 UNIMPLEMENTED_MSG("Unmanaged TessCoordInstanceIDVertexID element={}", element);
573 return Constant(t_float, 0); 1049 return {Constant(t_uint, 0U), Type::Uint};
574 case Attribute::Index::FrontFacing: 1050 case Attribute::Index::FrontFacing:
575 // TODO(Subv): Find out what the values are for the other elements. 1051 // TODO(Subv): Find out what the values are for the other elements.
576 ASSERT(stage == ShaderType::Fragment); 1052 ASSERT(stage == ShaderType::Fragment);
577 if (element == 3) { 1053 if (element == 3) {
578 const Id is_front_facing = Emit(OpLoad(t_bool, front_facing)); 1054 const Id is_front_facing = OpLoad(t_bool, front_facing);
579 const Id true_value = 1055 const Id true_value = Constant(t_int, static_cast<s32>(-1));
580 BitcastTo<Type::Float>(Constant(t_int, static_cast<s32>(-1))); 1056 const Id false_value = Constant(t_int, 0);
581 const Id false_value = BitcastTo<Type::Float>(Constant(t_int, 0)); 1057 return {OpSelect(t_int, is_front_facing, true_value, false_value), Type::Int};
582 return Emit(OpSelect(t_float, is_front_facing, true_value, false_value));
583 } 1058 }
584 UNIMPLEMENTED_MSG("Unmanaged FrontFacing element={}", element); 1059 UNIMPLEMENTED_MSG("Unmanaged FrontFacing element={}", element);
585 return Constant(t_float, 0.0f); 1060 return {v_float_zero, Type::Float};
586 default: 1061 default:
587 if (IsGenericAttribute(attribute)) { 1062 if (IsGenericAttribute(attribute)) {
588 const Id pointer = 1063 const u32 location = GetGenericAttributeLocation(attribute);
589 AccessElement(t_in_float, input_attributes.at(attribute), element); 1064 const auto type_descriptor = GetAttributeType(location);
590 return Emit(OpLoad(t_float, pointer)); 1065 const Type type = type_descriptor.type;
1066 const Id attribute_id = input_attributes.at(attribute);
1067 const std::vector elements = {element};
1068 const Id pointer = ArrayPass(type_descriptor.scalar, attribute_id, elements);
1069 return {OpLoad(GetTypeDefinition(type), pointer), type};
591 } 1070 }
592 break; 1071 break;
593 } 1072 }
594 UNIMPLEMENTED_MSG("Unhandled input attribute: {}", static_cast<u32>(attribute)); 1073 UNIMPLEMENTED_MSG("Unhandled input attribute: {}", static_cast<u32>(attribute));
1074 return {v_float_zero, Type::Float};
1075 }
595 1076
596 } else if (const auto cbuf = std::get_if<CbufNode>(&*node)) { 1077 if (const auto cbuf = std::get_if<CbufNode>(&*node)) {
597 const Node& offset = cbuf->GetOffset(); 1078 const Node& offset = cbuf->GetOffset();
598 const Id buffer_id = constant_buffers.at(cbuf->GetIndex()); 1079 const Id buffer_id = constant_buffers.at(cbuf->GetIndex());
599 1080
600 Id pointer{}; 1081 Id pointer{};
601 if (device.IsKhrUniformBufferStandardLayoutSupported()) { 1082 if (device.IsKhrUniformBufferStandardLayoutSupported()) {
602 const Id buffer_offset = Emit(OpShiftRightLogical( 1083 const Id buffer_offset =
603 t_uint, BitcastTo<Type::Uint>(Visit(offset)), Constant(t_uint, 2u))); 1084 OpShiftRightLogical(t_uint, AsUint(Visit(offset)), Constant(t_uint, 2U));
604 pointer = Emit( 1085 pointer =
605 OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0u), buffer_offset)); 1086 OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0U), buffer_offset);
606 } else { 1087 } else {
607 Id buffer_index{}; 1088 Id buffer_index{};
608 Id buffer_element{}; 1089 Id buffer_element{};
@@ -614,53 +1095,76 @@ private:
614 buffer_element = Constant(t_uint, (offset_imm / 4) % 4); 1095 buffer_element = Constant(t_uint, (offset_imm / 4) % 4);
615 } else if (std::holds_alternative<OperationNode>(*offset)) { 1096 } else if (std::holds_alternative<OperationNode>(*offset)) {
616 // Indirect access 1097 // Indirect access
617 const Id offset_id = BitcastTo<Type::Uint>(Visit(offset)); 1098 const Id offset_id = AsUint(Visit(offset));
618 const Id unsafe_offset = Emit(OpUDiv(t_uint, offset_id, Constant(t_uint, 4))); 1099 const Id unsafe_offset = OpUDiv(t_uint, offset_id, Constant(t_uint, 4));
619 const Id final_offset = Emit(OpUMod( 1100 const Id final_offset =
620 t_uint, unsafe_offset, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS - 1))); 1101 OpUMod(t_uint, unsafe_offset, Constant(t_uint, MaxConstBufferElements - 1));
621 buffer_index = Emit(OpUDiv(t_uint, final_offset, Constant(t_uint, 4))); 1102 buffer_index = OpUDiv(t_uint, final_offset, Constant(t_uint, 4));
622 buffer_element = Emit(OpUMod(t_uint, final_offset, Constant(t_uint, 4))); 1103 buffer_element = OpUMod(t_uint, final_offset, Constant(t_uint, 4));
623 } else { 1104 } else {
624 UNREACHABLE_MSG("Unmanaged offset node type"); 1105 UNREACHABLE_MSG("Unmanaged offset node type");
625 } 1106 }
626 pointer = Emit(OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0), 1107 pointer = OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0), buffer_index,
627 buffer_index, buffer_element)); 1108 buffer_element);
628 } 1109 }
629 return Emit(OpLoad(t_float, pointer)); 1110 return {OpLoad(t_float, pointer), Type::Float};
1111 }
630 1112
631 } else if (const auto gmem = std::get_if<GmemNode>(&*node)) { 1113 if (const auto gmem = std::get_if<GmemNode>(&*node)) {
632 const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor()); 1114 const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor());
633 const Id real = BitcastTo<Type::Uint>(Visit(gmem->GetRealAddress())); 1115 const Id real = AsUint(Visit(gmem->GetRealAddress()));
634 const Id base = BitcastTo<Type::Uint>(Visit(gmem->GetBaseAddress())); 1116 const Id base = AsUint(Visit(gmem->GetBaseAddress()));
1117
1118 Id offset = OpISub(t_uint, real, base);
1119 offset = OpUDiv(t_uint, offset, Constant(t_uint, 4U));
1120 return {OpLoad(t_float,
1121 OpAccessChain(t_gmem_float, gmem_buffer, Constant(t_uint, 0U), offset)),
1122 Type::Float};
1123 }
635 1124
636 Id offset = Emit(OpISub(t_uint, real, base)); 1125 if (const auto lmem = std::get_if<LmemNode>(&*node)) {
637 offset = Emit(OpUDiv(t_uint, offset, Constant(t_uint, 4u))); 1126 Id address = AsUint(Visit(lmem->GetAddress()));
638 return Emit(OpLoad(t_float, Emit(OpAccessChain(t_gmem_float, gmem_buffer, 1127 address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
639 Constant(t_uint, 0u), offset)))); 1128 const Id pointer = OpAccessChain(t_prv_float, local_memory, address);
1129 return {OpLoad(t_float, pointer), Type::Float};
1130 }
1131
1132 if (const auto smem = std::get_if<SmemNode>(&*node)) {
1133 Id address = AsUint(Visit(smem->GetAddress()));
1134 address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
1135 const Id pointer = OpAccessChain(t_smem_uint, shared_memory, address);
1136 return {OpLoad(t_uint, pointer), Type::Uint};
1137 }
640 1138
641 } else if (const auto conditional = std::get_if<ConditionalNode>(&*node)) { 1139 if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) {
1140 const Id flag = internal_flags.at(static_cast<std::size_t>(internal_flag->GetFlag()));
1141 return {OpLoad(t_bool, flag), Type::Bool};
1142 }
1143
1144 if (const auto conditional = std::get_if<ConditionalNode>(&*node)) {
642 // It's invalid to call conditional on nested nodes, use an operation instead 1145 // It's invalid to call conditional on nested nodes, use an operation instead
643 const Id true_label = OpLabel(); 1146 const Id true_label = OpLabel();
644 const Id skip_label = OpLabel(); 1147 const Id skip_label = OpLabel();
645 const Id condition = Visit(conditional->GetCondition()); 1148 const Id condition = AsBool(Visit(conditional->GetCondition()));
646 Emit(OpSelectionMerge(skip_label, spv::SelectionControlMask::MaskNone)); 1149 OpSelectionMerge(skip_label, spv::SelectionControlMask::MaskNone);
647 Emit(OpBranchConditional(condition, true_label, skip_label)); 1150 OpBranchConditional(condition, true_label, skip_label);
648 Emit(true_label); 1151 AddLabel(true_label);
649 1152
650 ++conditional_nest_count; 1153 conditional_branch_set = true;
1154 inside_branch = false;
651 VisitBasicBlock(conditional->GetCode()); 1155 VisitBasicBlock(conditional->GetCode());
652 --conditional_nest_count; 1156 conditional_branch_set = false;
653 1157 if (!inside_branch) {
654 if (inside_branch == 0) { 1158 OpBranch(skip_label);
655 Emit(OpBranch(skip_label));
656 } else { 1159 } else {
657 inside_branch--; 1160 inside_branch = false;
658 } 1161 }
659 Emit(skip_label); 1162 AddLabel(skip_label);
660 return {}; 1163 return {};
1164 }
661 1165
662 } else if (const auto comment = std::get_if<CommentNode>(&*node)) { 1166 if (const auto comment = std::get_if<CommentNode>(&*node)) {
663 Name(Emit(OpUndef(t_void)), comment->GetText()); 1167 Name(OpUndef(t_void), comment->GetText());
664 return {}; 1168 return {};
665 } 1169 }
666 1170
@@ -669,94 +1173,126 @@ private:
669 } 1173 }
670 1174
671 template <Id (Module::*func)(Id, Id), Type result_type, Type type_a = result_type> 1175 template <Id (Module::*func)(Id, Id), Type result_type, Type type_a = result_type>
672 Id Unary(Operation operation) { 1176 Expression Unary(Operation operation) {
673 const Id type_def = GetTypeDefinition(result_type); 1177 const Id type_def = GetTypeDefinition(result_type);
674 const Id op_a = VisitOperand<type_a>(operation, 0); 1178 const Id op_a = As(Visit(operation[0]), type_a);
675 1179
676 const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a))); 1180 const Id value = (this->*func)(type_def, op_a);
677 if (IsPrecise(operation)) { 1181 if (IsPrecise(operation)) {
678 Decorate(value, spv::Decoration::NoContraction); 1182 Decorate(value, spv::Decoration::NoContraction);
679 } 1183 }
680 return value; 1184 return {value, result_type};
681 } 1185 }
682 1186
683 template <Id (Module::*func)(Id, Id, Id), Type result_type, Type type_a = result_type, 1187 template <Id (Module::*func)(Id, Id, Id), Type result_type, Type type_a = result_type,
684 Type type_b = type_a> 1188 Type type_b = type_a>
685 Id Binary(Operation operation) { 1189 Expression Binary(Operation operation) {
686 const Id type_def = GetTypeDefinition(result_type); 1190 const Id type_def = GetTypeDefinition(result_type);
687 const Id op_a = VisitOperand<type_a>(operation, 0); 1191 const Id op_a = As(Visit(operation[0]), type_a);
688 const Id op_b = VisitOperand<type_b>(operation, 1); 1192 const Id op_b = As(Visit(operation[1]), type_b);
689 1193
690 const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b))); 1194 const Id value = (this->*func)(type_def, op_a, op_b);
691 if (IsPrecise(operation)) { 1195 if (IsPrecise(operation)) {
692 Decorate(value, spv::Decoration::NoContraction); 1196 Decorate(value, spv::Decoration::NoContraction);
693 } 1197 }
694 return value; 1198 return {value, result_type};
695 } 1199 }
696 1200
697 template <Id (Module::*func)(Id, Id, Id, Id), Type result_type, Type type_a = result_type, 1201 template <Id (Module::*func)(Id, Id, Id, Id), Type result_type, Type type_a = result_type,
698 Type type_b = type_a, Type type_c = type_b> 1202 Type type_b = type_a, Type type_c = type_b>
699 Id Ternary(Operation operation) { 1203 Expression Ternary(Operation operation) {
700 const Id type_def = GetTypeDefinition(result_type); 1204 const Id type_def = GetTypeDefinition(result_type);
701 const Id op_a = VisitOperand<type_a>(operation, 0); 1205 const Id op_a = As(Visit(operation[0]), type_a);
702 const Id op_b = VisitOperand<type_b>(operation, 1); 1206 const Id op_b = As(Visit(operation[1]), type_b);
703 const Id op_c = VisitOperand<type_c>(operation, 2); 1207 const Id op_c = As(Visit(operation[2]), type_c);
704 1208
705 const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c))); 1209 const Id value = (this->*func)(type_def, op_a, op_b, op_c);
706 if (IsPrecise(operation)) { 1210 if (IsPrecise(operation)) {
707 Decorate(value, spv::Decoration::NoContraction); 1211 Decorate(value, spv::Decoration::NoContraction);
708 } 1212 }
709 return value; 1213 return {value, result_type};
710 } 1214 }
711 1215
712 template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type, Type type_a = result_type, 1216 template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type, Type type_a = result_type,
713 Type type_b = type_a, Type type_c = type_b, Type type_d = type_c> 1217 Type type_b = type_a, Type type_c = type_b, Type type_d = type_c>
714 Id Quaternary(Operation operation) { 1218 Expression Quaternary(Operation operation) {
715 const Id type_def = GetTypeDefinition(result_type); 1219 const Id type_def = GetTypeDefinition(result_type);
716 const Id op_a = VisitOperand<type_a>(operation, 0); 1220 const Id op_a = As(Visit(operation[0]), type_a);
717 const Id op_b = VisitOperand<type_b>(operation, 1); 1221 const Id op_b = As(Visit(operation[1]), type_b);
718 const Id op_c = VisitOperand<type_c>(operation, 2); 1222 const Id op_c = As(Visit(operation[2]), type_c);
719 const Id op_d = VisitOperand<type_d>(operation, 3); 1223 const Id op_d = As(Visit(operation[3]), type_d);
720 1224
721 const Id value = 1225 const Id value = (this->*func)(type_def, op_a, op_b, op_c, op_d);
722 BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c, op_d)));
723 if (IsPrecise(operation)) { 1226 if (IsPrecise(operation)) {
724 Decorate(value, spv::Decoration::NoContraction); 1227 Decorate(value, spv::Decoration::NoContraction);
725 } 1228 }
726 return value; 1229 return {value, result_type};
727 } 1230 }
728 1231
729 Id Assign(Operation operation) { 1232 Expression Assign(Operation operation) {
730 const Node& dest = operation[0]; 1233 const Node& dest = operation[0];
731 const Node& src = operation[1]; 1234 const Node& src = operation[1];
732 1235
733 Id target{}; 1236 Expression target{};
734 if (const auto gpr = std::get_if<GprNode>(&*dest)) { 1237 if (const auto gpr = std::get_if<GprNode>(&*dest)) {
735 if (gpr->GetIndex() == Register::ZeroIndex) { 1238 if (gpr->GetIndex() == Register::ZeroIndex) {
736 // Writing to Register::ZeroIndex is a no op 1239 // Writing to Register::ZeroIndex is a no op
737 return {}; 1240 return {};
738 } 1241 }
739 target = registers.at(gpr->GetIndex()); 1242 target = {registers.at(gpr->GetIndex()), Type::Float};
740 1243
741 } else if (const auto abuf = std::get_if<AbufNode>(&*dest)) { 1244 } else if (const auto abuf = std::get_if<AbufNode>(&*dest)) {
742 target = [&]() -> Id { 1245 const auto& buffer = abuf->GetBuffer();
1246 const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) {
1247 std::vector<Id> members;
1248 members.reserve(std::size(indices) + 1);
1249
1250 if (buffer && IsOutputAttributeArray()) {
1251 members.push_back(AsUint(Visit(buffer)));
1252 }
1253 for (const u32 index : indices) {
1254 members.push_back(Constant(t_uint, index));
1255 }
1256 return OpAccessChain(pointer_type, composite, members);
1257 };
1258
1259 target = [&]() -> Expression {
1260 const u32 element = abuf->GetElement();
743 switch (const auto attribute = abuf->GetIndex(); attribute) { 1261 switch (const auto attribute = abuf->GetIndex(); attribute) {
744 case Attribute::Index::Position: 1262 case Attribute::Index::Position: {
745 return AccessElement(t_out_float, per_vertex, position_index, 1263 const u32 index = out_indices.position.value();
746 abuf->GetElement()); 1264 return {ArrayPass(t_out_float, out_vertex, {index, element}), Type::Float};
1265 }
747 case Attribute::Index::LayerViewportPointSize: 1266 case Attribute::Index::LayerViewportPointSize:
748 UNIMPLEMENTED_IF(abuf->GetElement() != 3); 1267 switch (element) {
749 return AccessElement(t_out_float, per_vertex, point_size_index); 1268 case 2: {
750 case Attribute::Index::ClipDistances0123: 1269 if (!out_indices.viewport) {
751 return AccessElement(t_out_float, per_vertex, clip_distances_index, 1270 return {};
752 abuf->GetElement()); 1271 }
753 case Attribute::Index::ClipDistances4567: 1272 const u32 index = out_indices.viewport.value();
754 return AccessElement(t_out_float, per_vertex, clip_distances_index, 1273 return {AccessElement(t_out_int, out_vertex, index), Type::Int};
755 abuf->GetElement() + 4); 1274 }
1275 case 3: {
1276 const auto index = out_indices.point_size.value();
1277 return {AccessElement(t_out_float, out_vertex, index), Type::Float};
1278 }
1279 default:
1280 UNIMPLEMENTED_MSG("LayerViewportPoint element={}", abuf->GetElement());
1281 return {};
1282 }
1283 case Attribute::Index::ClipDistances0123: {
1284 const u32 index = out_indices.clip_distances.value();
1285 return {AccessElement(t_out_float, out_vertex, index, element), Type::Float};
1286 }
1287 case Attribute::Index::ClipDistances4567: {
1288 const u32 index = out_indices.clip_distances.value();
1289 return {AccessElement(t_out_float, out_vertex, index, element + 4),
1290 Type::Float};
1291 }
756 default: 1292 default:
757 if (IsGenericAttribute(attribute)) { 1293 if (IsGenericAttribute(attribute)) {
758 return AccessElement(t_out_float, output_attributes.at(attribute), 1294 const Id composite = output_attributes.at(attribute);
759 abuf->GetElement()); 1295 return {ArrayPass(t_out_float, composite, {element}), Type::Float};
760 } 1296 }
761 UNIMPLEMENTED_MSG("Unhandled output attribute: {}", 1297 UNIMPLEMENTED_MSG("Unhandled output attribute: {}",
762 static_cast<u32>(attribute)); 1298 static_cast<u32>(attribute));
@@ -764,72 +1300,154 @@ private:
764 } 1300 }
765 }(); 1301 }();
766 1302
1303 } else if (const auto patch = std::get_if<PatchNode>(&*dest)) {
1304 target = [&]() -> Expression {
1305 const u32 offset = patch->GetOffset();
1306 switch (offset) {
1307 case 0:
1308 case 1:
1309 case 2:
1310 case 3:
1311 return {AccessElement(t_out_float, tess_level_outer, offset % 4), Type::Float};
1312 case 4:
1313 case 5:
1314 return {AccessElement(t_out_float, tess_level_inner, offset % 4), Type::Float};
1315 }
1316 UNIMPLEMENTED_MSG("Unhandled patch output offset: {}", offset);
1317 return {};
1318 }();
1319
767 } else if (const auto lmem = std::get_if<LmemNode>(&*dest)) { 1320 } else if (const auto lmem = std::get_if<LmemNode>(&*dest)) {
768 Id address = BitcastTo<Type::Uint>(Visit(lmem->GetAddress())); 1321 Id address = AsUint(Visit(lmem->GetAddress()));
769 address = Emit(OpUDiv(t_uint, address, Constant(t_uint, 4))); 1322 address = OpUDiv(t_uint, address, Constant(t_uint, 4));
770 target = Emit(OpAccessChain(t_prv_float, local_memory, {address})); 1323 target = {OpAccessChain(t_prv_float, local_memory, address), Type::Float};
1324
1325 } else if (const auto smem = std::get_if<SmemNode>(&*dest)) {
1326 ASSERT(stage == ShaderType::Compute);
1327 Id address = AsUint(Visit(smem->GetAddress()));
1328 address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
1329 target = {OpAccessChain(t_smem_uint, shared_memory, address), Type::Uint};
1330
1331 } else if (const auto gmem = std::get_if<GmemNode>(&*dest)) {
1332 const Id real = AsUint(Visit(gmem->GetRealAddress()));
1333 const Id base = AsUint(Visit(gmem->GetBaseAddress()));
1334 const Id diff = OpISub(t_uint, real, base);
1335 const Id offset = OpShiftRightLogical(t_uint, diff, Constant(t_uint, 2));
1336
1337 const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor());
1338 target = {OpAccessChain(t_gmem_float, gmem_buffer, Constant(t_uint, 0), offset),
1339 Type::Float};
1340
1341 } else {
1342 UNIMPLEMENTED();
771 } 1343 }
772 1344
773 Emit(OpStore(target, Visit(src))); 1345 OpStore(target.id, As(Visit(src), target.type));
774 return {}; 1346 return {};
775 } 1347 }
776 1348
777 Id FCastHalf0(Operation operation) { 1349 template <u32 offset>
778 UNIMPLEMENTED(); 1350 Expression FCastHalf(Operation operation) {
779 return {}; 1351 const Id value = AsHalfFloat(Visit(operation[0]));
1352 return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, offset)),
1353 Type::Float};
780 } 1354 }
781 1355
782 Id FCastHalf1(Operation operation) { 1356 Expression FSwizzleAdd(Operation operation) {
783 UNIMPLEMENTED(); 1357 const Id minus = Constant(t_float, -1.0f);
784 return {}; 1358 const Id plus = v_float_one;
785 } 1359 const Id zero = v_float_zero;
1360 const Id lut_a = ConstantComposite(t_float4, minus, plus, minus, zero);
1361 const Id lut_b = ConstantComposite(t_float4, minus, minus, plus, minus);
786 1362
787 Id FSwizzleAdd(Operation operation) { 1363 Id mask = OpLoad(t_uint, thread_id);
788 UNIMPLEMENTED(); 1364 mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3));
789 return {}; 1365 mask = OpShiftLeftLogical(t_uint, mask, Constant(t_uint, 1));
790 } 1366 mask = OpShiftRightLogical(t_uint, AsUint(Visit(operation[2])), mask);
1367 mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3));
791 1368
792 Id HNegate(Operation operation) { 1369 const Id modifier_a = OpVectorExtractDynamic(t_float, lut_a, mask);
793 UNIMPLEMENTED(); 1370 const Id modifier_b = OpVectorExtractDynamic(t_float, lut_b, mask);
794 return {}; 1371
1372 const Id op_a = OpFMul(t_float, AsFloat(Visit(operation[0])), modifier_a);
1373 const Id op_b = OpFMul(t_float, AsFloat(Visit(operation[1])), modifier_b);
1374 return {OpFAdd(t_float, op_a, op_b), Type::Float};
795 } 1375 }
796 1376
797 Id HClamp(Operation operation) { 1377 Expression HNegate(Operation operation) {
798 UNIMPLEMENTED(); 1378 const bool is_f16 = device.IsFloat16Supported();
799 return {}; 1379 const Id minus_one = Constant(t_scalar_half, is_f16 ? 0xbc00 : 0xbf800000);
1380 const Id one = Constant(t_scalar_half, is_f16 ? 0x3c00 : 0x3f800000);
1381 const auto GetNegate = [&](std::size_t index) {
1382 return OpSelect(t_scalar_half, AsBool(Visit(operation[index])), minus_one, one);
1383 };
1384 const Id negation = OpCompositeConstruct(t_half, GetNegate(1), GetNegate(2));
1385 return {OpFMul(t_half, AsHalfFloat(Visit(operation[0])), negation), Type::HalfFloat};
800 } 1386 }
801 1387
802 Id HCastFloat(Operation operation) { 1388 Expression HClamp(Operation operation) {
803 UNIMPLEMENTED(); 1389 const auto Pack = [&](std::size_t index) {
804 return {}; 1390 const Id scalar = GetHalfScalarFromFloat(AsFloat(Visit(operation[index])));
1391 return OpCompositeConstruct(t_half, scalar, scalar);
1392 };
1393 const Id value = AsHalfFloat(Visit(operation[0]));
1394 const Id min = Pack(1);
1395 const Id max = Pack(2);
1396
1397 const Id clamped = OpFClamp(t_half, value, min, max);
1398 if (IsPrecise(operation)) {
1399 Decorate(clamped, spv::Decoration::NoContraction);
1400 }
1401 return {clamped, Type::HalfFloat};
805 } 1402 }
806 1403
807 Id HUnpack(Operation operation) { 1404 Expression HCastFloat(Operation operation) {
808 UNIMPLEMENTED(); 1405 const Id value = GetHalfScalarFromFloat(AsFloat(Visit(operation[0])));
809 return {}; 1406 return {OpCompositeConstruct(t_half, value, Constant(t_scalar_half, 0)), Type::HalfFloat};
810 } 1407 }
811 1408
812 Id HMergeF32(Operation operation) { 1409 Expression HUnpack(Operation operation) {
813 UNIMPLEMENTED(); 1410 Expression operand = Visit(operation[0]);
814 return {}; 1411 const auto type = std::get<Tegra::Shader::HalfType>(operation.GetMeta());
1412 if (type == Tegra::Shader::HalfType::H0_H1) {
1413 return operand;
1414 }
1415 const auto value = [&] {
1416 switch (std::get<Tegra::Shader::HalfType>(operation.GetMeta())) {
1417 case Tegra::Shader::HalfType::F32:
1418 return GetHalfScalarFromFloat(AsFloat(operand));
1419 case Tegra::Shader::HalfType::H0_H0:
1420 return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 0);
1421 case Tegra::Shader::HalfType::H1_H1:
1422 return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 1);
1423 default:
1424 UNREACHABLE();
1425 return ConstantNull(t_half);
1426 }
1427 }();
1428 return {OpCompositeConstruct(t_half, value, value), Type::HalfFloat};
815 } 1429 }
816 1430
817 Id HMergeH0(Operation operation) { 1431 Expression HMergeF32(Operation operation) {
818 UNIMPLEMENTED(); 1432 const Id value = AsHalfFloat(Visit(operation[0]));
819 return {}; 1433 return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, 0)), Type::Float};
820 } 1434 }
821 1435
822 Id HMergeH1(Operation operation) { 1436 template <u32 offset>
823 UNIMPLEMENTED(); 1437 Expression HMergeHN(Operation operation) {
824 return {}; 1438 const Id target = AsHalfFloat(Visit(operation[0]));
1439 const Id source = AsHalfFloat(Visit(operation[1]));
1440 const Id object = OpCompositeExtract(t_scalar_half, source, offset);
1441 return {OpCompositeInsert(t_half, object, target, offset), Type::HalfFloat};
825 } 1442 }
826 1443
827 Id HPack2(Operation operation) { 1444 Expression HPack2(Operation operation) {
828 UNIMPLEMENTED(); 1445 const Id low = GetHalfScalarFromFloat(AsFloat(Visit(operation[0])));
829 return {}; 1446 const Id high = GetHalfScalarFromFloat(AsFloat(Visit(operation[1])));
1447 return {OpCompositeConstruct(t_half, low, high), Type::HalfFloat};
830 } 1448 }
831 1449
832 Id LogicalAssign(Operation operation) { 1450 Expression LogicalAssign(Operation operation) {
833 const Node& dest = operation[0]; 1451 const Node& dest = operation[0];
834 const Node& src = operation[1]; 1452 const Node& src = operation[1];
835 1453
@@ -850,106 +1468,190 @@ private:
850 target = internal_flags.at(static_cast<u32>(flag->GetFlag())); 1468 target = internal_flags.at(static_cast<u32>(flag->GetFlag()));
851 } 1469 }
852 1470
853 Emit(OpStore(target, Visit(src))); 1471 OpStore(target, AsBool(Visit(src)));
854 return {}; 1472 return {};
855 } 1473 }
856 1474
857 Id LogicalPick2(Operation operation) { 1475 Id GetTextureSampler(Operation operation) {
858 UNIMPLEMENTED(); 1476 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
859 return {}; 1477 ASSERT(!meta.sampler.IsBuffer());
1478
1479 const auto& entry = sampled_images.at(meta.sampler.GetIndex());
1480 return OpLoad(entry.sampled_image_type, entry.sampler);
860 } 1481 }
861 1482
862 Id LogicalAnd2(Operation operation) { 1483 Id GetTextureImage(Operation operation) {
863 UNIMPLEMENTED(); 1484 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
864 return {}; 1485 const u32 index = meta.sampler.GetIndex();
1486 if (meta.sampler.IsBuffer()) {
1487 const auto& entry = texel_buffers.at(index);
1488 return OpLoad(entry.image_type, entry.image);
1489 } else {
1490 const auto& entry = sampled_images.at(index);
1491 return OpImage(entry.image_type, GetTextureSampler(operation));
1492 }
865 } 1493 }
866 1494
867 Id GetTextureSampler(Operation operation) { 1495 Id GetImage(Operation operation) {
868 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1496 const auto& meta = std::get<MetaImage>(operation.GetMeta());
869 const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex())); 1497 const auto entry = images.at(meta.image.GetIndex());
870 return Emit(OpLoad(entry.sampled_image_type, entry.sampler)); 1498 return OpLoad(entry.image_type, entry.image);
871 } 1499 }
872 1500
873 Id GetTextureImage(Operation operation) { 1501 Id AssembleVector(const std::vector<Id>& coords, Type type) {
874 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1502 const Id coords_type = GetTypeVectorDefinitionLut(type).at(coords.size() - 1);
875 const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex())); 1503 return coords.size() == 1 ? coords[0] : OpCompositeConstruct(coords_type, coords);
876 return Emit(OpImage(entry.image_type, GetTextureSampler(operation)));
877 } 1504 }
878 1505
879 Id GetTextureCoordinates(Operation operation) { 1506 Id GetCoordinates(Operation operation, Type type) {
880 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
881 std::vector<Id> coords; 1507 std::vector<Id> coords;
882 for (std::size_t i = 0; i < operation.GetOperandsCount(); ++i) { 1508 for (std::size_t i = 0; i < operation.GetOperandsCount(); ++i) {
883 coords.push_back(Visit(operation[i])); 1509 coords.push_back(As(Visit(operation[i]), type));
884 } 1510 }
885 if (meta->sampler.IsArray()) { 1511 if (const auto meta = std::get_if<MetaTexture>(&operation.GetMeta())) {
886 const Id array_integer = BitcastTo<Type::Int>(Visit(meta->array)); 1512 // Add array coordinate for textures
887 coords.push_back(Emit(OpConvertSToF(t_float, array_integer))); 1513 if (meta->sampler.IsArray()) {
1514 Id array = AsInt(Visit(meta->array));
1515 if (type == Type::Float) {
1516 array = OpConvertSToF(t_float, array);
1517 }
1518 coords.push_back(array);
1519 }
888 } 1520 }
889 if (meta->sampler.IsShadow()) { 1521 return AssembleVector(coords, type);
890 coords.push_back(Visit(meta->depth_compare)); 1522 }
1523
1524 Id GetOffsetCoordinates(Operation operation) {
1525 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
1526 std::vector<Id> coords;
1527 coords.reserve(meta.aoffi.size());
1528 for (const auto& coord : meta.aoffi) {
1529 coords.push_back(AsInt(Visit(coord)));
891 } 1530 }
1531 return AssembleVector(coords, Type::Int);
1532 }
1533
1534 std::pair<Id, Id> GetDerivatives(Operation operation) {
1535 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
1536 const auto& derivatives = meta.derivates;
1537 ASSERT(derivatives.size() % 2 == 0);
892 1538
893 const std::array<Id, 4> t_float_lut = {nullptr, t_float2, t_float3, t_float4}; 1539 const std::size_t components = derivatives.size() / 2;
894 return coords.size() == 1 1540 std::vector<Id> dx, dy;
895 ? coords[0] 1541 dx.reserve(components);
896 : Emit(OpCompositeConstruct(t_float_lut.at(coords.size() - 1), coords)); 1542 dy.reserve(components);
1543 for (std::size_t index = 0; index < components; ++index) {
1544 dx.push_back(AsFloat(Visit(derivatives.at(index * 2 + 0))));
1545 dy.push_back(AsFloat(Visit(derivatives.at(index * 2 + 1))));
1546 }
1547 return {AssembleVector(dx, Type::Float), AssembleVector(dy, Type::Float)};
897 } 1548 }
898 1549
899 Id GetTextureElement(Operation operation, Id sample_value) { 1550 Expression GetTextureElement(Operation operation, Id sample_value, Type type) {
900 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1551 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
901 ASSERT(meta); 1552 const auto type_def = GetTypeDefinition(type);
902 return Emit(OpCompositeExtract(t_float, sample_value, meta->element)); 1553 return {OpCompositeExtract(type_def, sample_value, meta.element), type};
903 } 1554 }
904 1555
905 Id Texture(Operation operation) { 1556 Expression Texture(Operation operation) {
906 const Id texture = Emit(OpImageSampleImplicitLod(t_float4, GetTextureSampler(operation), 1557 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
907 GetTextureCoordinates(operation))); 1558 UNIMPLEMENTED_IF(!meta.aoffi.empty());
908 return GetTextureElement(operation, texture); 1559
1560 const bool can_implicit = stage == ShaderType::Fragment;
1561 const Id sampler = GetTextureSampler(operation);
1562 const Id coords = GetCoordinates(operation, Type::Float);
1563
1564 if (meta.depth_compare) {
1565 // Depth sampling
1566 UNIMPLEMENTED_IF(meta.bias);
1567 const Id dref = AsFloat(Visit(meta.depth_compare));
1568 if (can_implicit) {
1569 return {OpImageSampleDrefImplicitLod(t_float, sampler, coords, dref, {}),
1570 Type::Float};
1571 } else {
1572 return {OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref,
1573 spv::ImageOperandsMask::Lod, v_float_zero),
1574 Type::Float};
1575 }
1576 }
1577
1578 std::vector<Id> operands;
1579 spv::ImageOperandsMask mask{};
1580 if (meta.bias) {
1581 mask = mask | spv::ImageOperandsMask::Bias;
1582 operands.push_back(AsFloat(Visit(meta.bias)));
1583 }
1584
1585 Id texture;
1586 if (can_implicit) {
1587 texture = OpImageSampleImplicitLod(t_float4, sampler, coords, mask, operands);
1588 } else {
1589 texture = OpImageSampleExplicitLod(t_float4, sampler, coords,
1590 mask | spv::ImageOperandsMask::Lod, v_float_zero,
1591 operands);
1592 }
1593 return GetTextureElement(operation, texture, Type::Float);
909 } 1594 }
910 1595
911 Id TextureLod(Operation operation) { 1596 Expression TextureLod(Operation operation) {
912 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1597 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
913 const Id texture = Emit(OpImageSampleExplicitLod( 1598
914 t_float4, GetTextureSampler(operation), GetTextureCoordinates(operation), 1599 const Id sampler = GetTextureSampler(operation);
915 spv::ImageOperandsMask::Lod, Visit(meta->lod))); 1600 const Id coords = GetCoordinates(operation, Type::Float);
916 return GetTextureElement(operation, texture); 1601 const Id lod = AsFloat(Visit(meta.lod));
1602
1603 spv::ImageOperandsMask mask = spv::ImageOperandsMask::Lod;
1604 std::vector<Id> operands;
1605 if (!meta.aoffi.empty()) {
1606 mask = mask | spv::ImageOperandsMask::Offset;
1607 operands.push_back(GetOffsetCoordinates(operation));
1608 }
1609
1610 if (meta.sampler.IsShadow()) {
1611 const Id dref = AsFloat(Visit(meta.depth_compare));
1612 return {
1613 OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref, mask, lod, operands),
1614 Type::Float};
1615 }
1616 const Id texture = OpImageSampleExplicitLod(t_float4, sampler, coords, mask, lod, operands);
1617 return GetTextureElement(operation, texture, Type::Float);
917 } 1618 }
918 1619
919 Id TextureGather(Operation operation) { 1620 Expression TextureGather(Operation operation) {
920 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1621 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
921 const auto coords = GetTextureCoordinates(operation); 1622 UNIMPLEMENTED_IF(!meta.aoffi.empty());
922 1623
923 Id texture; 1624 const Id coords = GetCoordinates(operation, Type::Float);
924 if (meta->sampler.IsShadow()) { 1625 Id texture{};
925 texture = Emit(OpImageDrefGather(t_float4, GetTextureSampler(operation), coords, 1626 if (meta.sampler.IsShadow()) {
926 Visit(meta->component))); 1627 texture = OpImageDrefGather(t_float4, GetTextureSampler(operation), coords,
1628 AsFloat(Visit(meta.depth_compare)));
927 } else { 1629 } else {
928 u32 component_value = 0; 1630 u32 component_value = 0;
929 if (meta->component) { 1631 if (meta.component) {
930 const auto component = std::get_if<ImmediateNode>(&*meta->component); 1632 const auto component = std::get_if<ImmediateNode>(&*meta.component);
931 ASSERT_MSG(component, "Component is not an immediate value"); 1633 ASSERT_MSG(component, "Component is not an immediate value");
932 component_value = component->GetValue(); 1634 component_value = component->GetValue();
933 } 1635 }
934 texture = Emit(OpImageGather(t_float4, GetTextureSampler(operation), coords, 1636 texture = OpImageGather(t_float4, GetTextureSampler(operation), coords,
935 Constant(t_uint, component_value))); 1637 Constant(t_uint, component_value));
936 } 1638 }
937 1639 return GetTextureElement(operation, texture, Type::Float);
938 return GetTextureElement(operation, texture);
939 } 1640 }
940 1641
941 Id TextureQueryDimensions(Operation operation) { 1642 Expression TextureQueryDimensions(Operation operation) {
942 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1643 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
943 const auto image_id = GetTextureImage(operation); 1644 UNIMPLEMENTED_IF(!meta.aoffi.empty());
944 AddCapability(spv::Capability::ImageQuery); 1645 UNIMPLEMENTED_IF(meta.depth_compare);
945 1646
946 if (meta->element == 3) { 1647 const auto image_id = GetTextureImage(operation);
947 return BitcastTo<Type::Float>(Emit(OpImageQueryLevels(t_int, image_id))); 1648 if (meta.element == 3) {
1649 return {OpImageQueryLevels(t_int, image_id), Type::Int};
948 } 1650 }
949 1651
950 const Id lod = VisitOperand<Type::Uint>(operation, 0); 1652 const Id lod = AsUint(Visit(operation[0]));
951 const std::size_t coords_count = [&]() { 1653 const std::size_t coords_count = [&]() {
952 switch (const auto type = meta->sampler.GetType(); type) { 1654 switch (const auto type = meta.sampler.GetType(); type) {
953 case Tegra::Shader::TextureType::Texture1D: 1655 case Tegra::Shader::TextureType::Texture1D:
954 return 1; 1656 return 1;
955 case Tegra::Shader::TextureType::Texture2D: 1657 case Tegra::Shader::TextureType::Texture2D:
@@ -963,141 +1665,190 @@ private:
963 } 1665 }
964 }(); 1666 }();
965 1667
966 if (meta->element >= coords_count) { 1668 if (meta.element >= coords_count) {
967 return Constant(t_float, 0.0f); 1669 return {v_float_zero, Type::Float};
968 } 1670 }
969 1671
970 const std::array<Id, 3> types = {t_int, t_int2, t_int3}; 1672 const std::array<Id, 3> types = {t_int, t_int2, t_int3};
971 const Id sizes = Emit(OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod)); 1673 const Id sizes = OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod);
972 const Id size = Emit(OpCompositeExtract(t_int, sizes, meta->element)); 1674 const Id size = OpCompositeExtract(t_int, sizes, meta.element);
973 return BitcastTo<Type::Float>(size); 1675 return {size, Type::Int};
974 } 1676 }
975 1677
976 Id TextureQueryLod(Operation operation) { 1678 Expression TextureQueryLod(Operation operation) {
977 UNIMPLEMENTED(); 1679 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
978 return {}; 1680 UNIMPLEMENTED_IF(!meta.aoffi.empty());
1681 UNIMPLEMENTED_IF(meta.depth_compare);
1682
1683 if (meta.element >= 2) {
1684 UNREACHABLE_MSG("Invalid element");
1685 return {v_float_zero, Type::Float};
1686 }
1687 const auto sampler_id = GetTextureSampler(operation);
1688
1689 const Id multiplier = Constant(t_float, 256.0f);
1690 const Id multipliers = ConstantComposite(t_float2, multiplier, multiplier);
1691
1692 const Id coords = GetCoordinates(operation, Type::Float);
1693 Id size = OpImageQueryLod(t_float2, sampler_id, coords);
1694 size = OpFMul(t_float2, size, multipliers);
1695 size = OpConvertFToS(t_int2, size);
1696 return GetTextureElement(operation, size, Type::Int);
979 } 1697 }
980 1698
981 Id TexelFetch(Operation operation) { 1699 Expression TexelFetch(Operation operation) {
1700 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
1701 UNIMPLEMENTED_IF(meta.depth_compare);
1702
1703 const Id image = GetTextureImage(operation);
1704 const Id coords = GetCoordinates(operation, Type::Int);
1705 Id fetch;
1706 if (meta.lod && !meta.sampler.IsBuffer()) {
1707 fetch = OpImageFetch(t_float4, image, coords, spv::ImageOperandsMask::Lod,
1708 AsInt(Visit(meta.lod)));
1709 } else {
1710 fetch = OpImageFetch(t_float4, image, coords);
1711 }
1712 return GetTextureElement(operation, fetch, Type::Float);
1713 }
1714
1715 Expression TextureGradient(Operation operation) {
1716 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
1717 UNIMPLEMENTED_IF(!meta.aoffi.empty());
1718
1719 const Id sampler = GetTextureSampler(operation);
1720 const Id coords = GetCoordinates(operation, Type::Float);
1721 const auto [dx, dy] = GetDerivatives(operation);
1722 const std::vector grad = {dx, dy};
1723
1724 static constexpr auto mask = spv::ImageOperandsMask::Grad;
1725 const Id texture = OpImageSampleImplicitLod(t_float4, sampler, coords, mask, grad);
1726 return GetTextureElement(operation, texture, Type::Float);
1727 }
1728
1729 Expression ImageLoad(Operation operation) {
982 UNIMPLEMENTED(); 1730 UNIMPLEMENTED();
983 return {}; 1731 return {};
984 } 1732 }
985 1733
986 Id TextureGradient(Operation operation) { 1734 Expression ImageStore(Operation operation) {
987 UNIMPLEMENTED(); 1735 const auto meta{std::get<MetaImage>(operation.GetMeta())};
1736 std::vector<Id> colors;
1737 for (const auto& value : meta.values) {
1738 colors.push_back(AsUint(Visit(value)));
1739 }
1740
1741 const Id coords = GetCoordinates(operation, Type::Int);
1742 const Id texel = OpCompositeConstruct(t_uint4, colors);
1743
1744 OpImageWrite(GetImage(operation), coords, texel, {});
988 return {}; 1745 return {};
989 } 1746 }
990 1747
991 Id ImageLoad(Operation operation) { 1748 Expression AtomicImageAdd(Operation operation) {
992 UNIMPLEMENTED(); 1749 UNIMPLEMENTED();
993 return {}; 1750 return {};
994 } 1751 }
995 1752
996 Id ImageStore(Operation operation) { 1753 Expression AtomicImageMin(Operation operation) {
997 UNIMPLEMENTED(); 1754 UNIMPLEMENTED();
998 return {}; 1755 return {};
999 } 1756 }
1000 1757
1001 Id AtomicImageAdd(Operation operation) { 1758 Expression AtomicImageMax(Operation operation) {
1002 UNIMPLEMENTED(); 1759 UNIMPLEMENTED();
1003 return {}; 1760 return {};
1004 } 1761 }
1005 1762
1006 Id AtomicImageAnd(Operation operation) { 1763 Expression AtomicImageAnd(Operation operation) {
1007 UNIMPLEMENTED(); 1764 UNIMPLEMENTED();
1008 return {}; 1765 return {};
1009 } 1766 }
1010 1767
1011 Id AtomicImageOr(Operation operation) { 1768 Expression AtomicImageOr(Operation operation) {
1012 UNIMPLEMENTED(); 1769 UNIMPLEMENTED();
1013 return {}; 1770 return {};
1014 } 1771 }
1015 1772
1016 Id AtomicImageXor(Operation operation) { 1773 Expression AtomicImageXor(Operation operation) {
1017 UNIMPLEMENTED(); 1774 UNIMPLEMENTED();
1018 return {}; 1775 return {};
1019 } 1776 }
1020 1777
1021 Id AtomicImageExchange(Operation operation) { 1778 Expression AtomicImageExchange(Operation operation) {
1022 UNIMPLEMENTED(); 1779 UNIMPLEMENTED();
1023 return {}; 1780 return {};
1024 } 1781 }
1025 1782
1026 Id Branch(Operation operation) { 1783 Expression Branch(Operation operation) {
1027 const auto target = std::get_if<ImmediateNode>(&*operation[0]); 1784 const auto& target = std::get<ImmediateNode>(*operation[0]);
1028 UNIMPLEMENTED_IF(!target); 1785 OpStore(jmp_to, Constant(t_uint, target.GetValue()));
1029 1786 OpBranch(continue_label);
1030 Emit(OpStore(jmp_to, Constant(t_uint, target->GetValue()))); 1787 inside_branch = true;
1031 Emit(OpBranch(continue_label)); 1788 if (!conditional_branch_set) {
1032 inside_branch = conditional_nest_count; 1789 AddLabel();
1033 if (conditional_nest_count == 0) {
1034 Emit(OpLabel());
1035 } 1790 }
1036 return {}; 1791 return {};
1037 } 1792 }
1038 1793
1039 Id BranchIndirect(Operation operation) { 1794 Expression BranchIndirect(Operation operation) {
1040 const Id op_a = VisitOperand<Type::Uint>(operation, 0); 1795 const Id op_a = AsUint(Visit(operation[0]));
1041 1796
1042 Emit(OpStore(jmp_to, op_a)); 1797 OpStore(jmp_to, op_a);
1043 Emit(OpBranch(continue_label)); 1798 OpBranch(continue_label);
1044 inside_branch = conditional_nest_count; 1799 inside_branch = true;
1045 if (conditional_nest_count == 0) { 1800 if (!conditional_branch_set) {
1046 Emit(OpLabel()); 1801 AddLabel();
1047 } 1802 }
1048 return {}; 1803 return {};
1049 } 1804 }
1050 1805
1051 Id PushFlowStack(Operation operation) { 1806 Expression PushFlowStack(Operation operation) {
1052 const auto target = std::get_if<ImmediateNode>(&*operation[0]); 1807 const auto& target = std::get<ImmediateNode>(*operation[0]);
1053 ASSERT(target);
1054
1055 const auto [flow_stack, flow_stack_top] = GetFlowStack(operation); 1808 const auto [flow_stack, flow_stack_top] = GetFlowStack(operation);
1056 const Id current = Emit(OpLoad(t_uint, flow_stack_top)); 1809 const Id current = OpLoad(t_uint, flow_stack_top);
1057 const Id next = Emit(OpIAdd(t_uint, current, Constant(t_uint, 1))); 1810 const Id next = OpIAdd(t_uint, current, Constant(t_uint, 1));
1058 const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, current)); 1811 const Id access = OpAccessChain(t_func_uint, flow_stack, current);
1059 1812
1060 Emit(OpStore(access, Constant(t_uint, target->GetValue()))); 1813 OpStore(access, Constant(t_uint, target.GetValue()));
1061 Emit(OpStore(flow_stack_top, next)); 1814 OpStore(flow_stack_top, next);
1062 return {}; 1815 return {};
1063 } 1816 }
1064 1817
1065 Id PopFlowStack(Operation operation) { 1818 Expression PopFlowStack(Operation operation) {
1066 const auto [flow_stack, flow_stack_top] = GetFlowStack(operation); 1819 const auto [flow_stack, flow_stack_top] = GetFlowStack(operation);
1067 const Id current = Emit(OpLoad(t_uint, flow_stack_top)); 1820 const Id current = OpLoad(t_uint, flow_stack_top);
1068 const Id previous = Emit(OpISub(t_uint, current, Constant(t_uint, 1))); 1821 const Id previous = OpISub(t_uint, current, Constant(t_uint, 1));
1069 const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, previous)); 1822 const Id access = OpAccessChain(t_func_uint, flow_stack, previous);
1070 const Id target = Emit(OpLoad(t_uint, access)); 1823 const Id target = OpLoad(t_uint, access);
1071 1824
1072 Emit(OpStore(flow_stack_top, previous)); 1825 OpStore(flow_stack_top, previous);
1073 Emit(OpStore(jmp_to, target)); 1826 OpStore(jmp_to, target);
1074 Emit(OpBranch(continue_label)); 1827 OpBranch(continue_label);
1075 inside_branch = conditional_nest_count; 1828 inside_branch = true;
1076 if (conditional_nest_count == 0) { 1829 if (!conditional_branch_set) {
1077 Emit(OpLabel()); 1830 AddLabel();
1078 } 1831 }
1079 return {}; 1832 return {};
1080 } 1833 }
1081 1834
1082 Id PreExit() { 1835 void PreExit() {
1083 switch (stage) { 1836 if (stage == ShaderType::Vertex) {
1084 case ShaderType::Vertex: { 1837 const u32 position_index = out_indices.position.value();
1085 // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't 1838 const Id z_pointer = AccessElement(t_out_float, out_vertex, position_index, 2U);
1086 // seem to be working on Nvidia's drivers and Intel (mesa and blob) doesn't support it. 1839 const Id w_pointer = AccessElement(t_out_float, out_vertex, position_index, 3U);
1087 const Id z_pointer = AccessElement(t_out_float, per_vertex, position_index, 2u); 1840 Id depth = OpLoad(t_float, z_pointer);
1088 Id depth = Emit(OpLoad(t_float, z_pointer)); 1841 depth = OpFAdd(t_float, depth, OpLoad(t_float, w_pointer));
1089 depth = Emit(OpFAdd(t_float, depth, Constant(t_float, 1.0f))); 1842 depth = OpFMul(t_float, depth, Constant(t_float, 0.5f));
1090 depth = Emit(OpFMul(t_float, depth, Constant(t_float, 0.5f))); 1843 OpStore(z_pointer, depth);
1091 Emit(OpStore(z_pointer, depth));
1092 break;
1093 } 1844 }
1094 case ShaderType::Fragment: { 1845 if (stage == ShaderType::Fragment) {
1095 const auto SafeGetRegister = [&](u32 reg) { 1846 const auto SafeGetRegister = [&](u32 reg) {
1096 // TODO(Rodrigo): Replace with contains once C++20 releases 1847 // TODO(Rodrigo): Replace with contains once C++20 releases
1097 if (const auto it = registers.find(reg); it != registers.end()) { 1848 if (const auto it = registers.find(reg); it != registers.end()) {
1098 return Emit(OpLoad(t_float, it->second)); 1849 return OpLoad(t_float, it->second);
1099 } 1850 }
1100 return Constant(t_float, 0.0f); 1851 return v_float_zero;
1101 }; 1852 };
1102 1853
1103 UNIMPLEMENTED_IF_MSG(header.ps.omap.sample_mask != 0, 1854 UNIMPLEMENTED_IF_MSG(header.ps.omap.sample_mask != 0,
@@ -1112,8 +1863,8 @@ private:
1112 // TODO(Subv): Figure out how dual-source blending is configured in the Switch. 1863 // TODO(Subv): Figure out how dual-source blending is configured in the Switch.
1113 for (u32 component = 0; component < 4; ++component) { 1864 for (u32 component = 0; component < 4; ++component) {
1114 if (header.ps.IsColorComponentOutputEnabled(rt, component)) { 1865 if (header.ps.IsColorComponentOutputEnabled(rt, component)) {
1115 Emit(OpStore(AccessElement(t_out_float, frag_colors.at(rt), component), 1866 OpStore(AccessElement(t_out_float, frag_colors.at(rt), component),
1116 SafeGetRegister(current_reg))); 1867 SafeGetRegister(current_reg));
1117 ++current_reg; 1868 ++current_reg;
1118 } 1869 }
1119 } 1870 }
@@ -1121,110 +1872,117 @@ private:
1121 if (header.ps.omap.depth) { 1872 if (header.ps.omap.depth) {
1122 // The depth output is always 2 registers after the last color output, and 1873 // The depth output is always 2 registers after the last color output, and
1123 // current_reg already contains one past the last color register. 1874 // current_reg already contains one past the last color register.
1124 Emit(OpStore(frag_depth, SafeGetRegister(current_reg + 1))); 1875 OpStore(frag_depth, SafeGetRegister(current_reg + 1));
1125 } 1876 }
1126 break;
1127 } 1877 }
1128 }
1129
1130 return {};
1131 } 1878 }
1132 1879
1133 Id Exit(Operation operation) { 1880 Expression Exit(Operation operation) {
1134 PreExit(); 1881 PreExit();
1135 inside_branch = conditional_nest_count; 1882 inside_branch = true;
1136 if (conditional_nest_count > 0) { 1883 if (conditional_branch_set) {
1137 Emit(OpReturn()); 1884 OpReturn();
1138 } else { 1885 } else {
1139 const Id dummy = OpLabel(); 1886 const Id dummy = OpLabel();
1140 Emit(OpBranch(dummy)); 1887 OpBranch(dummy);
1141 Emit(dummy); 1888 AddLabel(dummy);
1142 Emit(OpReturn()); 1889 OpReturn();
1143 Emit(OpLabel()); 1890 AddLabel();
1144 } 1891 }
1145 return {}; 1892 return {};
1146 } 1893 }
1147 1894
1148 Id Discard(Operation operation) { 1895 Expression Discard(Operation operation) {
1149 inside_branch = conditional_nest_count; 1896 inside_branch = true;
1150 if (conditional_nest_count > 0) { 1897 if (conditional_branch_set) {
1151 Emit(OpKill()); 1898 OpKill();
1152 } else { 1899 } else {
1153 const Id dummy = OpLabel(); 1900 const Id dummy = OpLabel();
1154 Emit(OpBranch(dummy)); 1901 OpBranch(dummy);
1155 Emit(dummy); 1902 AddLabel(dummy);
1156 Emit(OpKill()); 1903 OpKill();
1157 Emit(OpLabel()); 1904 AddLabel();
1158 } 1905 }
1159 return {}; 1906 return {};
1160 } 1907 }
1161 1908
1162 Id EmitVertex(Operation operation) { 1909 Expression EmitVertex(Operation) {
1163 UNIMPLEMENTED(); 1910 OpEmitVertex();
1164 return {}; 1911 return {};
1165 } 1912 }
1166 1913
1167 Id EndPrimitive(Operation operation) { 1914 Expression EndPrimitive(Operation operation) {
1168 UNIMPLEMENTED(); 1915 OpEndPrimitive();
1169 return {}; 1916 return {};
1170 } 1917 }
1171 1918
1172 Id YNegate(Operation operation) { 1919 Expression InvocationId(Operation) {
1173 UNIMPLEMENTED(); 1920 return {OpLoad(t_int, invocation_id), Type::Int};
1174 return {};
1175 } 1921 }
1176 1922
1177 template <u32 element> 1923 Expression YNegate(Operation) {
1178 Id LocalInvocationId(Operation) { 1924 LOG_WARNING(Render_Vulkan, "(STUBBED)");
1179 UNIMPLEMENTED(); 1925 return {Constant(t_float, 1.0f), Type::Float};
1180 return {};
1181 } 1926 }
1182 1927
1183 template <u32 element> 1928 template <u32 element>
1184 Id WorkGroupId(Operation) { 1929 Expression LocalInvocationId(Operation) {
1185 UNIMPLEMENTED(); 1930 const Id id = OpLoad(t_uint3, local_invocation_id);
1186 return {}; 1931 return {OpCompositeExtract(t_uint, id, element), Type::Uint};
1187 } 1932 }
1188 1933
1189 Id BallotThread(Operation) { 1934 template <u32 element>
1190 UNIMPLEMENTED(); 1935 Expression WorkGroupId(Operation operation) {
1191 return {}; 1936 const Id id = OpLoad(t_uint3, workgroup_id);
1937 return {OpCompositeExtract(t_uint, id, element), Type::Uint};
1192 } 1938 }
1193 1939
1194 Id VoteAll(Operation) { 1940 Expression BallotThread(Operation operation) {
1195 UNIMPLEMENTED(); 1941 const Id predicate = AsBool(Visit(operation[0]));
1196 return {}; 1942 const Id ballot = OpSubgroupBallotKHR(t_uint4, predicate);
1197 }
1198 1943
1199 Id VoteAny(Operation) { 1944 if (!device.IsWarpSizePotentiallyBiggerThanGuest()) {
1200 UNIMPLEMENTED(); 1945 // Guest-like devices can just return the first index.
1201 return {}; 1946 return {OpCompositeExtract(t_uint, ballot, 0U), Type::Uint};
1947 }
1948
1949 // The others will have to return what is local to the current thread.
1950 // For instance a device with a warp size of 64 will return the upper uint when the current
1951 // thread is 38.
1952 const Id tid = OpLoad(t_uint, thread_id);
1953 const Id thread_index = OpShiftRightLogical(t_uint, tid, Constant(t_uint, 5));
1954 return {OpVectorExtractDynamic(t_uint, ballot, thread_index), Type::Uint};
1202 } 1955 }
1203 1956
1204 Id VoteEqual(Operation) { 1957 template <Id (Module::*func)(Id, Id)>
1205 UNIMPLEMENTED(); 1958 Expression Vote(Operation operation) {
1206 return {}; 1959 // TODO(Rodrigo): Handle devices with different warp sizes
1960 const Id predicate = AsBool(Visit(operation[0]));
1961 return {(this->*func)(t_bool, predicate), Type::Bool};
1207 } 1962 }
1208 1963
1209 Id ThreadId(Operation) { 1964 Expression ThreadId(Operation) {
1210 UNIMPLEMENTED(); 1965 return {OpLoad(t_uint, thread_id), Type::Uint};
1211 return {};
1212 } 1966 }
1213 1967
1214 Id ShuffleIndexed(Operation) { 1968 Expression ShuffleIndexed(Operation operation) {
1215 UNIMPLEMENTED(); 1969 const Id value = AsFloat(Visit(operation[0]));
1216 return {}; 1970 const Id index = AsUint(Visit(operation[1]));
1971 return {OpSubgroupReadInvocationKHR(t_float, value, index), Type::Float};
1217 } 1972 }
1218 1973
1219 Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type, 1974 Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type, std::string name) {
1220 const std::string& name) {
1221 const Id id = OpVariable(type, storage); 1975 const Id id = OpVariable(type, storage);
1222 Decorate(id, spv::Decoration::BuiltIn, static_cast<u32>(builtin)); 1976 Decorate(id, spv::Decoration::BuiltIn, static_cast<u32>(builtin));
1223 AddGlobalVariable(Name(id, name)); 1977 AddGlobalVariable(Name(id, std::move(name)));
1224 interfaces.push_back(id); 1978 interfaces.push_back(id);
1225 return id; 1979 return id;
1226 } 1980 }
1227 1981
1982 Id DeclareInputBuiltIn(spv::BuiltIn builtin, Id type, std::string name) {
1983 return DeclareBuiltIn(builtin, spv::StorageClass::Input, type, std::move(name));
1984 }
1985
1228 bool IsRenderTargetUsed(u32 rt) const { 1986 bool IsRenderTargetUsed(u32 rt) const {
1229 for (u32 component = 0; component < 4; ++component) { 1987 for (u32 component = 0; component < 4; ++component) {
1230 if (header.ps.IsColorComponentOutputEnabled(rt, component)) { 1988 if (header.ps.IsColorComponentOutputEnabled(rt, component)) {
@@ -1242,66 +2000,148 @@ private:
1242 members.push_back(Constant(t_uint, element)); 2000 members.push_back(Constant(t_uint, element));
1243 } 2001 }
1244 2002
1245 return Emit(OpAccessChain(pointer_type, composite, members)); 2003 return OpAccessChain(pointer_type, composite, members);
1246 } 2004 }
1247 2005
1248 template <Type type> 2006 Id As(Expression expr, Type wanted_type) {
1249 Id VisitOperand(Operation operation, std::size_t operand_index) { 2007 switch (wanted_type) {
1250 const Id value = Visit(operation[operand_index]);
1251
1252 switch (type) {
1253 case Type::Bool: 2008 case Type::Bool:
2009 return AsBool(expr);
1254 case Type::Bool2: 2010 case Type::Bool2:
2011 return AsBool2(expr);
1255 case Type::Float: 2012 case Type::Float:
1256 return value; 2013 return AsFloat(expr);
1257 case Type::Int: 2014 case Type::Int:
1258 return Emit(OpBitcast(t_int, value)); 2015 return AsInt(expr);
1259 case Type::Uint: 2016 case Type::Uint:
1260 return Emit(OpBitcast(t_uint, value)); 2017 return AsUint(expr);
1261 case Type::HalfFloat: 2018 case Type::HalfFloat:
1262 UNIMPLEMENTED(); 2019 return AsHalfFloat(expr);
2020 default:
2021 UNREACHABLE();
2022 return expr.id;
1263 } 2023 }
1264 UNREACHABLE();
1265 return value;
1266 } 2024 }
1267 2025
1268 template <Type type> 2026 Id AsBool(Expression expr) {
1269 Id BitcastFrom(Id value) { 2027 ASSERT(expr.type == Type::Bool);
1270 switch (type) { 2028 return expr.id;
1271 case Type::Bool: 2029 }
1272 case Type::Bool2: 2030
2031 Id AsBool2(Expression expr) {
2032 ASSERT(expr.type == Type::Bool2);
2033 return expr.id;
2034 }
2035
2036 Id AsFloat(Expression expr) {
2037 switch (expr.type) {
1273 case Type::Float: 2038 case Type::Float:
1274 return value; 2039 return expr.id;
1275 case Type::Int: 2040 case Type::Int:
1276 case Type::Uint: 2041 case Type::Uint:
1277 return Emit(OpBitcast(t_float, value)); 2042 return OpBitcast(t_float, expr.id);
1278 case Type::HalfFloat: 2043 case Type::HalfFloat:
1279 UNIMPLEMENTED(); 2044 if (device.IsFloat16Supported()) {
2045 return OpBitcast(t_float, expr.id);
2046 }
2047 return OpBitcast(t_float, OpPackHalf2x16(t_uint, expr.id));
2048 default:
2049 UNREACHABLE();
2050 return expr.id;
1280 } 2051 }
1281 UNREACHABLE();
1282 return value;
1283 } 2052 }
1284 2053
1285 template <Type type> 2054 Id AsInt(Expression expr) {
1286 Id BitcastTo(Id value) { 2055 switch (expr.type) {
1287 switch (type) { 2056 case Type::Int:
1288 case Type::Bool: 2057 return expr.id;
1289 case Type::Bool2: 2058 case Type::Float:
2059 case Type::Uint:
2060 return OpBitcast(t_int, expr.id);
2061 case Type::HalfFloat:
2062 if (device.IsFloat16Supported()) {
2063 return OpBitcast(t_int, expr.id);
2064 }
2065 return OpPackHalf2x16(t_int, expr.id);
2066 default:
1290 UNREACHABLE(); 2067 UNREACHABLE();
2068 return expr.id;
2069 }
2070 }
2071
2072 Id AsUint(Expression expr) {
2073 switch (expr.type) {
2074 case Type::Uint:
2075 return expr.id;
1291 case Type::Float: 2076 case Type::Float:
1292 return Emit(OpBitcast(t_float, value));
1293 case Type::Int: 2077 case Type::Int:
1294 return Emit(OpBitcast(t_int, value)); 2078 return OpBitcast(t_uint, expr.id);
1295 case Type::Uint:
1296 return Emit(OpBitcast(t_uint, value));
1297 case Type::HalfFloat: 2079 case Type::HalfFloat:
1298 UNIMPLEMENTED(); 2080 if (device.IsFloat16Supported()) {
2081 return OpBitcast(t_uint, expr.id);
2082 }
2083 return OpPackHalf2x16(t_uint, expr.id);
2084 default:
2085 UNREACHABLE();
2086 return expr.id;
2087 }
2088 }
2089
2090 Id AsHalfFloat(Expression expr) {
2091 switch (expr.type) {
2092 case Type::HalfFloat:
2093 return expr.id;
2094 case Type::Float:
2095 case Type::Int:
2096 case Type::Uint:
2097 if (device.IsFloat16Supported()) {
2098 return OpBitcast(t_half, expr.id);
2099 }
2100 return OpUnpackHalf2x16(t_half, AsUint(expr));
2101 default:
2102 UNREACHABLE();
2103 return expr.id;
2104 }
2105 }
2106
2107 Id GetHalfScalarFromFloat(Id value) {
2108 if (device.IsFloat16Supported()) {
2109 return OpFConvert(t_scalar_half, value);
1299 } 2110 }
1300 UNREACHABLE();
1301 return value; 2111 return value;
1302 } 2112 }
1303 2113
1304 Id GetTypeDefinition(Type type) { 2114 Id GetFloatFromHalfScalar(Id value) {
2115 if (device.IsFloat16Supported()) {
2116 return OpFConvert(t_float, value);
2117 }
2118 return value;
2119 }
2120
2121 AttributeType GetAttributeType(u32 location) const {
2122 if (stage != ShaderType::Vertex) {
2123 return {Type::Float, t_in_float, t_in_float4};
2124 }
2125 switch (specialization.attribute_types.at(location)) {
2126 case Maxwell::VertexAttribute::Type::SignedNorm:
2127 case Maxwell::VertexAttribute::Type::UnsignedNorm:
2128 case Maxwell::VertexAttribute::Type::Float:
2129 return {Type::Float, t_in_float, t_in_float4};
2130 case Maxwell::VertexAttribute::Type::SignedInt:
2131 return {Type::Int, t_in_int, t_in_int4};
2132 case Maxwell::VertexAttribute::Type::UnsignedInt:
2133 return {Type::Uint, t_in_uint, t_in_uint4};
2134 case Maxwell::VertexAttribute::Type::UnsignedScaled:
2135 case Maxwell::VertexAttribute::Type::SignedScaled:
2136 UNIMPLEMENTED();
2137 return {Type::Float, t_in_float, t_in_float4};
2138 default:
2139 UNREACHABLE();
2140 return {Type::Float, t_in_float, t_in_float4};
2141 }
2142 }
2143
2144 Id GetTypeDefinition(Type type) const {
1305 switch (type) { 2145 switch (type) {
1306 case Type::Bool: 2146 case Type::Bool:
1307 return t_bool; 2147 return t_bool;
@@ -1314,10 +2154,25 @@ private:
1314 case Type::Uint: 2154 case Type::Uint:
1315 return t_uint; 2155 return t_uint;
1316 case Type::HalfFloat: 2156 case Type::HalfFloat:
2157 return t_half;
2158 default:
2159 UNREACHABLE();
2160 return {};
2161 }
2162 }
2163
2164 std::array<Id, 4> GetTypeVectorDefinitionLut(Type type) const {
2165 switch (type) {
2166 case Type::Float:
2167 return {nullptr, t_float2, t_float3, t_float4};
2168 case Type::Int:
2169 return {nullptr, t_int2, t_int3, t_int4};
2170 case Type::Uint:
2171 return {nullptr, t_uint2, t_uint3, t_uint4};
2172 default:
1317 UNIMPLEMENTED(); 2173 UNIMPLEMENTED();
2174 return {};
1318 } 2175 }
1319 UNREACHABLE();
1320 return {};
1321 } 2176 }
1322 2177
1323 std::tuple<Id, Id> CreateFlowStack() { 2178 std::tuple<Id, Id> CreateFlowStack() {
@@ -1327,9 +2182,11 @@ private:
1327 constexpr auto storage_class = spv::StorageClass::Function; 2182 constexpr auto storage_class = spv::StorageClass::Function;
1328 2183
1329 const Id flow_stack_type = TypeArray(t_uint, Constant(t_uint, FLOW_STACK_SIZE)); 2184 const Id flow_stack_type = TypeArray(t_uint, Constant(t_uint, FLOW_STACK_SIZE));
1330 const Id stack = Emit(OpVariable(TypePointer(storage_class, flow_stack_type), storage_class, 2185 const Id stack = OpVariable(TypePointer(storage_class, flow_stack_type), storage_class,
1331 ConstantNull(flow_stack_type))); 2186 ConstantNull(flow_stack_type));
1332 const Id top = Emit(OpVariable(t_func_uint, storage_class, Constant(t_uint, 0))); 2187 const Id top = OpVariable(t_func_uint, storage_class, Constant(t_uint, 0));
2188 AddLocalVariable(stack);
2189 AddLocalVariable(top);
1333 return std::tie(stack, top); 2190 return std::tie(stack, top);
1334 } 2191 }
1335 2192
@@ -1358,8 +2215,8 @@ private:
1358 &SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>, 2215 &SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>,
1359 &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>, 2216 &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>,
1360 &SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>, 2217 &SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>,
1361 &SPIRVDecompiler::FCastHalf0, 2218 &SPIRVDecompiler::FCastHalf<0>,
1362 &SPIRVDecompiler::FCastHalf1, 2219 &SPIRVDecompiler::FCastHalf<1>,
1363 &SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>, 2220 &SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>,
1364 &SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>, 2221 &SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>,
1365 &SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>, 2222 &SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>,
@@ -1407,7 +2264,7 @@ private:
1407 &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>, 2264 &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>,
1408 &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>, 2265 &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>,
1409 &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>, 2266 &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
1410 &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Uint>, 2267 &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
1411 &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>, 2268 &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>,
1412 &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>, 2269 &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>,
1413 &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>, 2270 &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>,
@@ -1426,8 +2283,8 @@ private:
1426 &SPIRVDecompiler::HCastFloat, 2283 &SPIRVDecompiler::HCastFloat,
1427 &SPIRVDecompiler::HUnpack, 2284 &SPIRVDecompiler::HUnpack,
1428 &SPIRVDecompiler::HMergeF32, 2285 &SPIRVDecompiler::HMergeF32,
1429 &SPIRVDecompiler::HMergeH0, 2286 &SPIRVDecompiler::HMergeHN<0>,
1430 &SPIRVDecompiler::HMergeH1, 2287 &SPIRVDecompiler::HMergeHN<1>,
1431 &SPIRVDecompiler::HPack2, 2288 &SPIRVDecompiler::HPack2,
1432 2289
1433 &SPIRVDecompiler::LogicalAssign, 2290 &SPIRVDecompiler::LogicalAssign,
@@ -1435,8 +2292,9 @@ private:
1435 &SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>, 2292 &SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>,
1436 &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>, 2293 &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>,
1437 &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>, 2294 &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>,
1438 &SPIRVDecompiler::LogicalPick2, 2295 &SPIRVDecompiler::Binary<&Module::OpVectorExtractDynamic, Type::Bool, Type::Bool2,
1439 &SPIRVDecompiler::LogicalAnd2, 2296 Type::Uint>,
2297 &SPIRVDecompiler::Unary<&Module::OpAll, Type::Bool, Type::Bool2>,
1440 2298
1441 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>, 2299 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>,
1442 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>, 2300 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>,
@@ -1444,7 +2302,7 @@ private:
1444 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>, 2302 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>,
1445 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>, 2303 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>,
1446 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>, 2304 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>,
1447 &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool>, 2305 &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool, Type::Float>,
1448 2306
1449 &SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>, 2307 &SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>,
1450 &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>, 2308 &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>,
@@ -1460,19 +2318,19 @@ private:
1460 &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>, 2318 &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>,
1461 &SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>, 2319 &SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>,
1462 2320
1463 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>, 2321 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>,
1464 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>, 2322 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>,
1465 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>, 2323 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>,
1466 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>, 2324 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>,
1467 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>, 2325 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>,
1468 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>, 2326 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>,
1469 // TODO(Rodrigo): Should these use the OpFUnord* variants? 2327 // TODO(Rodrigo): Should these use the OpFUnord* variants?
1470 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>, 2328 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>,
1471 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>, 2329 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>,
1472 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>, 2330 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>,
1473 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>, 2331 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>,
1474 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>, 2332 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>,
1475 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>, 2333 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>,
1476 2334
1477 &SPIRVDecompiler::Texture, 2335 &SPIRVDecompiler::Texture,
1478 &SPIRVDecompiler::TextureLod, 2336 &SPIRVDecompiler::TextureLod,
@@ -1500,6 +2358,7 @@ private:
1500 &SPIRVDecompiler::EmitVertex, 2358 &SPIRVDecompiler::EmitVertex,
1501 &SPIRVDecompiler::EndPrimitive, 2359 &SPIRVDecompiler::EndPrimitive,
1502 2360
2361 &SPIRVDecompiler::InvocationId,
1503 &SPIRVDecompiler::YNegate, 2362 &SPIRVDecompiler::YNegate,
1504 &SPIRVDecompiler::LocalInvocationId<0>, 2363 &SPIRVDecompiler::LocalInvocationId<0>,
1505 &SPIRVDecompiler::LocalInvocationId<1>, 2364 &SPIRVDecompiler::LocalInvocationId<1>,
@@ -1509,9 +2368,9 @@ private:
1509 &SPIRVDecompiler::WorkGroupId<2>, 2368 &SPIRVDecompiler::WorkGroupId<2>,
1510 2369
1511 &SPIRVDecompiler::BallotThread, 2370 &SPIRVDecompiler::BallotThread,
1512 &SPIRVDecompiler::VoteAll, 2371 &SPIRVDecompiler::Vote<&Module::OpSubgroupAllKHR>,
1513 &SPIRVDecompiler::VoteAny, 2372 &SPIRVDecompiler::Vote<&Module::OpSubgroupAnyKHR>,
1514 &SPIRVDecompiler::VoteEqual, 2373 &SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>,
1515 2374
1516 &SPIRVDecompiler::ThreadId, 2375 &SPIRVDecompiler::ThreadId,
1517 &SPIRVDecompiler::ShuffleIndexed, 2376 &SPIRVDecompiler::ShuffleIndexed,
@@ -1522,8 +2381,7 @@ private:
1522 const ShaderIR& ir; 2381 const ShaderIR& ir;
1523 const ShaderType stage; 2382 const ShaderType stage;
1524 const Tegra::Shader::Header header; 2383 const Tegra::Shader::Header header;
1525 u64 conditional_nest_count{}; 2384 const Specialization& specialization;
1526 u64 inside_branch{};
1527 2385
1528 const Id t_void = Name(TypeVoid(), "void"); 2386 const Id t_void = Name(TypeVoid(), "void");
1529 2387
@@ -1551,20 +2409,28 @@ private:
1551 const Id t_func_uint = Name(TypePointer(spv::StorageClass::Function, t_uint), "func_uint"); 2409 const Id t_func_uint = Name(TypePointer(spv::StorageClass::Function, t_uint), "func_uint");
1552 2410
1553 const Id t_in_bool = Name(TypePointer(spv::StorageClass::Input, t_bool), "in_bool"); 2411 const Id t_in_bool = Name(TypePointer(spv::StorageClass::Input, t_bool), "in_bool");
2412 const Id t_in_int = Name(TypePointer(spv::StorageClass::Input, t_int), "in_int");
2413 const Id t_in_int4 = Name(TypePointer(spv::StorageClass::Input, t_int4), "in_int4");
1554 const Id t_in_uint = Name(TypePointer(spv::StorageClass::Input, t_uint), "in_uint"); 2414 const Id t_in_uint = Name(TypePointer(spv::StorageClass::Input, t_uint), "in_uint");
2415 const Id t_in_uint3 = Name(TypePointer(spv::StorageClass::Input, t_uint3), "in_uint3");
2416 const Id t_in_uint4 = Name(TypePointer(spv::StorageClass::Input, t_uint4), "in_uint4");
1555 const Id t_in_float = Name(TypePointer(spv::StorageClass::Input, t_float), "in_float"); 2417 const Id t_in_float = Name(TypePointer(spv::StorageClass::Input, t_float), "in_float");
2418 const Id t_in_float2 = Name(TypePointer(spv::StorageClass::Input, t_float2), "in_float2");
2419 const Id t_in_float3 = Name(TypePointer(spv::StorageClass::Input, t_float3), "in_float3");
1556 const Id t_in_float4 = Name(TypePointer(spv::StorageClass::Input, t_float4), "in_float4"); 2420 const Id t_in_float4 = Name(TypePointer(spv::StorageClass::Input, t_float4), "in_float4");
1557 2421
2422 const Id t_out_int = Name(TypePointer(spv::StorageClass::Output, t_int), "out_int");
2423
1558 const Id t_out_float = Name(TypePointer(spv::StorageClass::Output, t_float), "out_float"); 2424 const Id t_out_float = Name(TypePointer(spv::StorageClass::Output, t_float), "out_float");
1559 const Id t_out_float4 = Name(TypePointer(spv::StorageClass::Output, t_float4), "out_float4"); 2425 const Id t_out_float4 = Name(TypePointer(spv::StorageClass::Output, t_float4), "out_float4");
1560 2426
1561 const Id t_cbuf_float = TypePointer(spv::StorageClass::Uniform, t_float); 2427 const Id t_cbuf_float = TypePointer(spv::StorageClass::Uniform, t_float);
1562 const Id t_cbuf_std140 = Decorate( 2428 const Id t_cbuf_std140 = Decorate(
1563 Name(TypeArray(t_float4, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS)), "CbufStd140Array"), 2429 Name(TypeArray(t_float4, Constant(t_uint, MaxConstBufferElements)), "CbufStd140Array"),
1564 spv::Decoration::ArrayStride, 16u); 2430 spv::Decoration::ArrayStride, 16U);
1565 const Id t_cbuf_scalar = Decorate( 2431 const Id t_cbuf_scalar = Decorate(
1566 Name(TypeArray(t_float, Constant(t_uint, MAX_CONSTBUFFER_FLOATS)), "CbufScalarArray"), 2432 Name(TypeArray(t_float, Constant(t_uint, MaxConstBufferFloats)), "CbufScalarArray"),
1567 spv::Decoration::ArrayStride, 4u); 2433 spv::Decoration::ArrayStride, 4U);
1568 const Id t_cbuf_std140_struct = MemberDecorate( 2434 const Id t_cbuf_std140_struct = MemberDecorate(
1569 Decorate(TypeStruct(t_cbuf_std140), spv::Decoration::Block), 0, spv::Decoration::Offset, 0); 2435 Decorate(TypeStruct(t_cbuf_std140), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
1570 const Id t_cbuf_scalar_struct = MemberDecorate( 2436 const Id t_cbuf_scalar_struct = MemberDecorate(
@@ -1572,28 +2438,43 @@ private:
1572 const Id t_cbuf_std140_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_std140_struct); 2438 const Id t_cbuf_std140_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_std140_struct);
1573 const Id t_cbuf_scalar_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_scalar_struct); 2439 const Id t_cbuf_scalar_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_scalar_struct);
1574 2440
2441 Id t_smem_uint{};
2442
1575 const Id t_gmem_float = TypePointer(spv::StorageClass::StorageBuffer, t_float); 2443 const Id t_gmem_float = TypePointer(spv::StorageClass::StorageBuffer, t_float);
1576 const Id t_gmem_array = 2444 const Id t_gmem_array =
1577 Name(Decorate(TypeRuntimeArray(t_float), spv::Decoration::ArrayStride, 4u), "GmemArray"); 2445 Name(Decorate(TypeRuntimeArray(t_float), spv::Decoration::ArrayStride, 4U), "GmemArray");
1578 const Id t_gmem_struct = MemberDecorate( 2446 const Id t_gmem_struct = MemberDecorate(
1579 Decorate(TypeStruct(t_gmem_array), spv::Decoration::Block), 0, spv::Decoration::Offset, 0); 2447 Decorate(TypeStruct(t_gmem_array), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
1580 const Id t_gmem_ssbo = TypePointer(spv::StorageClass::StorageBuffer, t_gmem_struct); 2448 const Id t_gmem_ssbo = TypePointer(spv::StorageClass::StorageBuffer, t_gmem_struct);
1581 2449
1582 const Id v_float_zero = Constant(t_float, 0.0f); 2450 const Id v_float_zero = Constant(t_float, 0.0f);
2451 const Id v_float_one = Constant(t_float, 1.0f);
2452
2453 // Nvidia uses these defaults for varyings (e.g. position and generic attributes)
2454 const Id v_varying_default =
2455 ConstantComposite(t_float4, v_float_zero, v_float_zero, v_float_zero, v_float_one);
2456
1583 const Id v_true = ConstantTrue(t_bool); 2457 const Id v_true = ConstantTrue(t_bool);
1584 const Id v_false = ConstantFalse(t_bool); 2458 const Id v_false = ConstantFalse(t_bool);
1585 2459
1586 Id per_vertex{}; 2460 Id t_scalar_half{};
2461 Id t_half{};
2462
2463 Id out_vertex{};
2464 Id in_vertex{};
1587 std::map<u32, Id> registers; 2465 std::map<u32, Id> registers;
1588 std::map<Tegra::Shader::Pred, Id> predicates; 2466 std::map<Tegra::Shader::Pred, Id> predicates;
1589 std::map<u32, Id> flow_variables; 2467 std::map<u32, Id> flow_variables;
1590 Id local_memory{}; 2468 Id local_memory{};
2469 Id shared_memory{};
1591 std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{}; 2470 std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
1592 std::map<Attribute::Index, Id> input_attributes; 2471 std::map<Attribute::Index, Id> input_attributes;
1593 std::map<Attribute::Index, Id> output_attributes; 2472 std::map<Attribute::Index, Id> output_attributes;
1594 std::map<u32, Id> constant_buffers; 2473 std::map<u32, Id> constant_buffers;
1595 std::map<GlobalMemoryBase, Id> global_buffers; 2474 std::map<GlobalMemoryBase, Id> global_buffers;
1596 std::map<u32, SamplerImage> sampler_images; 2475 std::map<u32, TexelBuffer> texel_buffers;
2476 std::map<u32, SampledImage> sampled_images;
2477 std::map<u32, StorageImage> images;
1597 2478
1598 Id instance_index{}; 2479 Id instance_index{};
1599 Id vertex_index{}; 2480 Id vertex_index{};
@@ -1601,18 +2482,20 @@ private:
1601 Id frag_depth{}; 2482 Id frag_depth{};
1602 Id frag_coord{}; 2483 Id frag_coord{};
1603 Id front_facing{}; 2484 Id front_facing{};
1604 2485 Id point_coord{};
1605 u32 position_index{}; 2486 Id tess_level_outer{};
1606 u32 point_size_index{}; 2487 Id tess_level_inner{};
1607 u32 clip_distances_index{}; 2488 Id tess_coord{};
2489 Id invocation_id{};
2490 Id workgroup_id{};
2491 Id local_invocation_id{};
2492 Id thread_id{};
2493
2494 VertexIndices in_indices;
2495 VertexIndices out_indices;
1608 2496
1609 std::vector<Id> interfaces; 2497 std::vector<Id> interfaces;
1610 2498
1611 u32 const_buffers_base_binding{};
1612 u32 global_buffers_base_binding{};
1613 u32 samplers_base_binding{};
1614
1615 Id execute_function{};
1616 Id jmp_to{}; 2499 Id jmp_to{};
1617 Id ssy_flow_stack_top{}; 2500 Id ssy_flow_stack_top{};
1618 Id pbk_flow_stack_top{}; 2501 Id pbk_flow_stack_top{};
@@ -1620,6 +2503,9 @@ private:
1620 Id pbk_flow_stack{}; 2503 Id pbk_flow_stack{};
1621 Id continue_label{}; 2504 Id continue_label{};
1622 std::map<u32, Id> labels; 2505 std::map<u32, Id> labels;
2506
2507 bool conditional_branch_set{};
2508 bool inside_branch{};
1623}; 2509};
1624 2510
1625class ExprDecompiler { 2511class ExprDecompiler {
@@ -1630,25 +2516,25 @@ public:
1630 const Id type_def = decomp.GetTypeDefinition(Type::Bool); 2516 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1631 const Id op1 = Visit(expr.operand1); 2517 const Id op1 = Visit(expr.operand1);
1632 const Id op2 = Visit(expr.operand2); 2518 const Id op2 = Visit(expr.operand2);
1633 return decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2)); 2519 return decomp.OpLogicalAnd(type_def, op1, op2);
1634 } 2520 }
1635 2521
1636 Id operator()(const ExprOr& expr) { 2522 Id operator()(const ExprOr& expr) {
1637 const Id type_def = decomp.GetTypeDefinition(Type::Bool); 2523 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1638 const Id op1 = Visit(expr.operand1); 2524 const Id op1 = Visit(expr.operand1);
1639 const Id op2 = Visit(expr.operand2); 2525 const Id op2 = Visit(expr.operand2);
1640 return decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2)); 2526 return decomp.OpLogicalOr(type_def, op1, op2);
1641 } 2527 }
1642 2528
1643 Id operator()(const ExprNot& expr) { 2529 Id operator()(const ExprNot& expr) {
1644 const Id type_def = decomp.GetTypeDefinition(Type::Bool); 2530 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1645 const Id op1 = Visit(expr.operand1); 2531 const Id op1 = Visit(expr.operand1);
1646 return decomp.Emit(decomp.OpLogicalNot(type_def, op1)); 2532 return decomp.OpLogicalNot(type_def, op1);
1647 } 2533 }
1648 2534
1649 Id operator()(const ExprPredicate& expr) { 2535 Id operator()(const ExprPredicate& expr) {
1650 const auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate); 2536 const auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate);
1651 return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred))); 2537 return decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred));
1652 } 2538 }
1653 2539
1654 Id operator()(const ExprCondCode& expr) { 2540 Id operator()(const ExprCondCode& expr) {
@@ -1670,12 +2556,15 @@ public:
1670 } 2556 }
1671 } else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) { 2557 } else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) {
1672 target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag())); 2558 target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag()));
2559 } else {
2560 UNREACHABLE();
1673 } 2561 }
1674 return decomp.Emit(decomp.OpLoad(decomp.t_bool, target)); 2562
2563 return decomp.OpLoad(decomp.t_bool, target);
1675 } 2564 }
1676 2565
1677 Id operator()(const ExprVar& expr) { 2566 Id operator()(const ExprVar& expr) {
1678 return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index))); 2567 return decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index));
1679 } 2568 }
1680 2569
1681 Id operator()(const ExprBoolean& expr) { 2570 Id operator()(const ExprBoolean& expr) {
@@ -1684,9 +2573,9 @@ public:
1684 2573
1685 Id operator()(const ExprGprEqual& expr) { 2574 Id operator()(const ExprGprEqual& expr) {
1686 const Id target = decomp.Constant(decomp.t_uint, expr.value); 2575 const Id target = decomp.Constant(decomp.t_uint, expr.value);
1687 const Id gpr = decomp.BitcastTo<Type::Uint>( 2576 Id gpr = decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr));
1688 decomp.Emit(decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr)))); 2577 gpr = decomp.OpBitcast(decomp.t_uint, gpr);
1689 return decomp.Emit(decomp.OpLogicalEqual(decomp.t_uint, gpr, target)); 2578 return decomp.OpLogicalEqual(decomp.t_uint, gpr, target);
1690 } 2579 }
1691 2580
1692 Id Visit(const Expr& node) { 2581 Id Visit(const Expr& node) {
@@ -1714,16 +2603,16 @@ public:
1714 const Id condition = expr_parser.Visit(ast.condition); 2603 const Id condition = expr_parser.Visit(ast.condition);
1715 const Id then_label = decomp.OpLabel(); 2604 const Id then_label = decomp.OpLabel();
1716 const Id endif_label = decomp.OpLabel(); 2605 const Id endif_label = decomp.OpLabel();
1717 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); 2606 decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
1718 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); 2607 decomp.OpBranchConditional(condition, then_label, endif_label);
1719 decomp.Emit(then_label); 2608 decomp.AddLabel(then_label);
1720 ASTNode current = ast.nodes.GetFirst(); 2609 ASTNode current = ast.nodes.GetFirst();
1721 while (current) { 2610 while (current) {
1722 Visit(current); 2611 Visit(current);
1723 current = current->GetNext(); 2612 current = current->GetNext();
1724 } 2613 }
1725 decomp.Emit(decomp.OpBranch(endif_label)); 2614 decomp.OpBranch(endif_label);
1726 decomp.Emit(endif_label); 2615 decomp.AddLabel(endif_label);
1727 } 2616 }
1728 2617
1729 void operator()([[maybe_unused]] const ASTIfElse& ast) { 2618 void operator()([[maybe_unused]] const ASTIfElse& ast) {
@@ -1741,7 +2630,7 @@ public:
1741 void operator()(const ASTVarSet& ast) { 2630 void operator()(const ASTVarSet& ast) {
1742 ExprDecompiler expr_parser{decomp}; 2631 ExprDecompiler expr_parser{decomp};
1743 const Id condition = expr_parser.Visit(ast.condition); 2632 const Id condition = expr_parser.Visit(ast.condition);
1744 decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition)); 2633 decomp.OpStore(decomp.flow_variables.at(ast.index), condition);
1745 } 2634 }
1746 2635
1747 void operator()([[maybe_unused]] const ASTLabel& ast) { 2636 void operator()([[maybe_unused]] const ASTLabel& ast) {
@@ -1758,12 +2647,11 @@ public:
1758 const Id loop_start_block = decomp.OpLabel(); 2647 const Id loop_start_block = decomp.OpLabel();
1759 const Id loop_end_block = decomp.OpLabel(); 2648 const Id loop_end_block = decomp.OpLabel();
1760 current_loop_exit = endloop_label; 2649 current_loop_exit = endloop_label;
1761 decomp.Emit(decomp.OpBranch(loop_label)); 2650 decomp.OpBranch(loop_label);
1762 decomp.Emit(loop_label); 2651 decomp.AddLabel(loop_label);
1763 decomp.Emit( 2652 decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone);
1764 decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone)); 2653 decomp.OpBranch(loop_start_block);
1765 decomp.Emit(decomp.OpBranch(loop_start_block)); 2654 decomp.AddLabel(loop_start_block);
1766 decomp.Emit(loop_start_block);
1767 ASTNode current = ast.nodes.GetFirst(); 2655 ASTNode current = ast.nodes.GetFirst();
1768 while (current) { 2656 while (current) {
1769 Visit(current); 2657 Visit(current);
@@ -1771,8 +2659,8 @@ public:
1771 } 2659 }
1772 ExprDecompiler expr_parser{decomp}; 2660 ExprDecompiler expr_parser{decomp};
1773 const Id condition = expr_parser.Visit(ast.condition); 2661 const Id condition = expr_parser.Visit(ast.condition);
1774 decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label)); 2662 decomp.OpBranchConditional(condition, loop_label, endloop_label);
1775 decomp.Emit(endloop_label); 2663 decomp.AddLabel(endloop_label);
1776 } 2664 }
1777 2665
1778 void operator()(const ASTReturn& ast) { 2666 void operator()(const ASTReturn& ast) {
@@ -1781,27 +2669,27 @@ public:
1781 const Id condition = expr_parser.Visit(ast.condition); 2669 const Id condition = expr_parser.Visit(ast.condition);
1782 const Id then_label = decomp.OpLabel(); 2670 const Id then_label = decomp.OpLabel();
1783 const Id endif_label = decomp.OpLabel(); 2671 const Id endif_label = decomp.OpLabel();
1784 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); 2672 decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
1785 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); 2673 decomp.OpBranchConditional(condition, then_label, endif_label);
1786 decomp.Emit(then_label); 2674 decomp.AddLabel(then_label);
1787 if (ast.kills) { 2675 if (ast.kills) {
1788 decomp.Emit(decomp.OpKill()); 2676 decomp.OpKill();
1789 } else { 2677 } else {
1790 decomp.PreExit(); 2678 decomp.PreExit();
1791 decomp.Emit(decomp.OpReturn()); 2679 decomp.OpReturn();
1792 } 2680 }
1793 decomp.Emit(endif_label); 2681 decomp.AddLabel(endif_label);
1794 } else { 2682 } else {
1795 const Id next_block = decomp.OpLabel(); 2683 const Id next_block = decomp.OpLabel();
1796 decomp.Emit(decomp.OpBranch(next_block)); 2684 decomp.OpBranch(next_block);
1797 decomp.Emit(next_block); 2685 decomp.AddLabel(next_block);
1798 if (ast.kills) { 2686 if (ast.kills) {
1799 decomp.Emit(decomp.OpKill()); 2687 decomp.OpKill();
1800 } else { 2688 } else {
1801 decomp.PreExit(); 2689 decomp.PreExit();
1802 decomp.Emit(decomp.OpReturn()); 2690 decomp.OpReturn();
1803 } 2691 }
1804 decomp.Emit(decomp.OpLabel()); 2692 decomp.AddLabel(decomp.OpLabel());
1805 } 2693 }
1806 } 2694 }
1807 2695
@@ -1811,17 +2699,17 @@ public:
1811 const Id condition = expr_parser.Visit(ast.condition); 2699 const Id condition = expr_parser.Visit(ast.condition);
1812 const Id then_label = decomp.OpLabel(); 2700 const Id then_label = decomp.OpLabel();
1813 const Id endif_label = decomp.OpLabel(); 2701 const Id endif_label = decomp.OpLabel();
1814 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); 2702 decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
1815 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); 2703 decomp.OpBranchConditional(condition, then_label, endif_label);
1816 decomp.Emit(then_label); 2704 decomp.AddLabel(then_label);
1817 decomp.Emit(decomp.OpBranch(current_loop_exit)); 2705 decomp.OpBranch(current_loop_exit);
1818 decomp.Emit(endif_label); 2706 decomp.AddLabel(endif_label);
1819 } else { 2707 } else {
1820 const Id next_block = decomp.OpLabel(); 2708 const Id next_block = decomp.OpLabel();
1821 decomp.Emit(decomp.OpBranch(next_block)); 2709 decomp.OpBranch(next_block);
1822 decomp.Emit(next_block); 2710 decomp.AddLabel(next_block);
1823 decomp.Emit(decomp.OpBranch(current_loop_exit)); 2711 decomp.OpBranch(current_loop_exit);
1824 decomp.Emit(decomp.OpLabel()); 2712 decomp.AddLabel(decomp.OpLabel());
1825 } 2713 }
1826 } 2714 }
1827 2715
@@ -1842,20 +2730,51 @@ void SPIRVDecompiler::DecompileAST() {
1842 flow_variables.emplace(i, AddGlobalVariable(id)); 2730 flow_variables.emplace(i, AddGlobalVariable(id));
1843 } 2731 }
1844 2732
2733 DefinePrologue();
2734
1845 const ASTNode program = ir.GetASTProgram(); 2735 const ASTNode program = ir.GetASTProgram();
1846 ASTDecompiler decompiler{*this}; 2736 ASTDecompiler decompiler{*this};
1847 decompiler.Visit(program); 2737 decompiler.Visit(program);
1848 2738
1849 const Id next_block = OpLabel(); 2739 const Id next_block = OpLabel();
1850 Emit(OpBranch(next_block)); 2740 OpBranch(next_block);
1851 Emit(next_block); 2741 AddLabel(next_block);
2742}
2743
2744} // Anonymous namespace
2745
2746ShaderEntries GenerateShaderEntries(const VideoCommon::Shader::ShaderIR& ir) {
2747 ShaderEntries entries;
2748 for (const auto& cbuf : ir.GetConstantBuffers()) {
2749 entries.const_buffers.emplace_back(cbuf.second, cbuf.first);
2750 }
2751 for (const auto& [base, usage] : ir.GetGlobalMemory()) {
2752 entries.global_buffers.emplace_back(base.cbuf_index, base.cbuf_offset, usage.is_written);
2753 }
2754 for (const auto& sampler : ir.GetSamplers()) {
2755 if (sampler.IsBuffer()) {
2756 entries.texel_buffers.emplace_back(sampler);
2757 } else {
2758 entries.samplers.emplace_back(sampler);
2759 }
2760 }
2761 for (const auto& image : ir.GetImages()) {
2762 entries.images.emplace_back(image);
2763 }
2764 for (const auto& attribute : ir.GetInputAttributes()) {
2765 if (IsGenericAttribute(attribute)) {
2766 entries.attributes.insert(GetGenericAttributeLocation(attribute));
2767 }
2768 }
2769 entries.clip_distances = ir.GetClipDistances();
2770 entries.shader_length = ir.GetLength();
2771 entries.uses_warps = ir.UsesWarps();
2772 return entries;
1852} 2773}
1853 2774
1854DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, 2775std::vector<u32> Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
1855 ShaderType stage) { 2776 ShaderType stage, const Specialization& specialization) {
1856 auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage); 2777 return SPIRVDecompiler(device, ir, stage, specialization).Assemble();
1857 decompiler->Decompile();
1858 return {std::move(decompiler), decompiler->GetShaderEntries()};
1859} 2778}
1860 2779
1861} // namespace Vulkan::VKShader 2780} // namespace Vulkan
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.h b/src/video_core/renderer_vulkan/vk_shader_decompiler.h
index 203fc00d0..2b01321b6 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.h
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.h
@@ -5,29 +5,28 @@
5#pragma once 5#pragma once
6 6
7#include <array> 7#include <array>
8#include <bitset>
8#include <memory> 9#include <memory>
9#include <set> 10#include <set>
11#include <type_traits>
10#include <utility> 12#include <utility>
11#include <vector> 13#include <vector>
12 14
13#include <sirit/sirit.h>
14
15#include "common/common_types.h" 15#include "common/common_types.h"
16#include "video_core/engines/maxwell_3d.h" 16#include "video_core/engines/maxwell_3d.h"
17#include "video_core/engines/shader_type.h"
17#include "video_core/shader/shader_ir.h" 18#include "video_core/shader/shader_ir.h"
18 19
19namespace VideoCommon::Shader {
20class ShaderIR;
21}
22
23namespace Vulkan { 20namespace Vulkan {
24class VKDevice; 21class VKDevice;
25} 22}
26 23
27namespace Vulkan::VKShader { 24namespace Vulkan {
28 25
29using Maxwell = Tegra::Engines::Maxwell3D::Regs; 26using Maxwell = Tegra::Engines::Maxwell3D::Regs;
27using TexelBufferEntry = VideoCommon::Shader::Sampler;
30using SamplerEntry = VideoCommon::Shader::Sampler; 28using SamplerEntry = VideoCommon::Shader::Sampler;
29using ImageEntry = VideoCommon::Shader::Image;
31 30
32constexpr u32 DESCRIPTOR_SET = 0; 31constexpr u32 DESCRIPTOR_SET = 0;
33 32
@@ -46,39 +45,74 @@ private:
46 45
47class GlobalBufferEntry { 46class GlobalBufferEntry {
48public: 47public:
49 explicit GlobalBufferEntry(u32 cbuf_index, u32 cbuf_offset) 48 constexpr explicit GlobalBufferEntry(u32 cbuf_index, u32 cbuf_offset, bool is_written)
50 : cbuf_index{cbuf_index}, cbuf_offset{cbuf_offset} {} 49 : cbuf_index{cbuf_index}, cbuf_offset{cbuf_offset}, is_written{is_written} {}
51 50
52 u32 GetCbufIndex() const { 51 constexpr u32 GetCbufIndex() const {
53 return cbuf_index; 52 return cbuf_index;
54 } 53 }
55 54
56 u32 GetCbufOffset() const { 55 constexpr u32 GetCbufOffset() const {
57 return cbuf_offset; 56 return cbuf_offset;
58 } 57 }
59 58
59 constexpr bool IsWritten() const {
60 return is_written;
61 }
62
60private: 63private:
61 u32 cbuf_index{}; 64 u32 cbuf_index{};
62 u32 cbuf_offset{}; 65 u32 cbuf_offset{};
66 bool is_written{};
63}; 67};
64 68
65struct ShaderEntries { 69struct ShaderEntries {
66 u32 const_buffers_base_binding{}; 70 u32 NumBindings() const {
67 u32 global_buffers_base_binding{}; 71 return static_cast<u32>(const_buffers.size() + global_buffers.size() +
68 u32 samplers_base_binding{}; 72 texel_buffers.size() + samplers.size() + images.size());
73 }
74
69 std::vector<ConstBufferEntry> const_buffers; 75 std::vector<ConstBufferEntry> const_buffers;
70 std::vector<GlobalBufferEntry> global_buffers; 76 std::vector<GlobalBufferEntry> global_buffers;
77 std::vector<TexelBufferEntry> texel_buffers;
71 std::vector<SamplerEntry> samplers; 78 std::vector<SamplerEntry> samplers;
79 std::vector<ImageEntry> images;
72 std::set<u32> attributes; 80 std::set<u32> attributes;
73 std::array<bool, Maxwell::NumClipDistances> clip_distances{}; 81 std::array<bool, Maxwell::NumClipDistances> clip_distances{};
74 std::size_t shader_length{}; 82 std::size_t shader_length{};
75 Sirit::Id entry_function{}; 83 bool uses_warps{};
76 std::vector<Sirit::Id> interfaces; 84};
85
86struct Specialization final {
87 u32 base_binding{};
88
89 // Compute specific
90 std::array<u32, 3> workgroup_size{};
91 u32 shared_memory_size{};
92
93 // Graphics specific
94 Maxwell::PrimitiveTopology primitive_topology{};
95 std::optional<float> point_size{};
96 std::array<Maxwell::VertexAttribute::Type, Maxwell::NumVertexAttributes> attribute_types{};
97
98 // Tessellation specific
99 struct {
100 Maxwell::TessellationPrimitive primitive{};
101 Maxwell::TessellationSpacing spacing{};
102 bool clockwise{};
103 } tessellation;
104};
105// Old gcc versions don't consider this trivially copyable.
106// static_assert(std::is_trivially_copyable_v<Specialization>);
107
108struct SPIRVShader {
109 std::vector<u32> code;
110 ShaderEntries entries;
77}; 111};
78 112
79using DecompilerResult = std::pair<std::unique_ptr<Sirit::Module>, ShaderEntries>; 113ShaderEntries GenerateShaderEntries(const VideoCommon::Shader::ShaderIR& ir);
80 114
81DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, 115std::vector<u32> Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
82 Tegra::Engines::ShaderType stage); 116 Tegra::Engines::ShaderType stage, const Specialization& specialization);
83 117
84} // namespace Vulkan::VKShader 118} // namespace Vulkan
diff --git a/src/video_core/shader/decode/memory.cpp b/src/video_core/shader/decode/memory.cpp
index 335d78146..78e92f52e 100644
--- a/src/video_core/shader/decode/memory.cpp
+++ b/src/video_core/shader/decode/memory.cpp
@@ -21,6 +21,7 @@ using Tegra::Shader::OpCode;
21using Tegra::Shader::Register; 21using Tegra::Shader::Register;
22 22
23namespace { 23namespace {
24
24u32 GetUniformTypeElementsCount(Tegra::Shader::UniformType uniform_type) { 25u32 GetUniformTypeElementsCount(Tegra::Shader::UniformType uniform_type) {
25 switch (uniform_type) { 26 switch (uniform_type) {
26 case Tegra::Shader::UniformType::Single: 27 case Tegra::Shader::UniformType::Single:
@@ -35,6 +36,7 @@ u32 GetUniformTypeElementsCount(Tegra::Shader::UniformType uniform_type) {
35 return 1; 36 return 1;
36 } 37 }
37} 38}
39
38} // Anonymous namespace 40} // Anonymous namespace
39 41
40u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) { 42u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) {
@@ -196,28 +198,28 @@ u32 ShaderIR::DecodeMemory(NodeBlock& bb, u32 pc) {
196 UNIMPLEMENTED_IF_MSG((instr.attribute.fmt20.immediate.Value() % sizeof(u32)) != 0, 198 UNIMPLEMENTED_IF_MSG((instr.attribute.fmt20.immediate.Value() % sizeof(u32)) != 0,
197 "Unaligned attribute loads are not supported"); 199 "Unaligned attribute loads are not supported");
198 200
199 u64 next_element = instr.attribute.fmt20.element; 201 u64 element = instr.attribute.fmt20.element;
200 auto next_index = static_cast<u64>(instr.attribute.fmt20.index.Value()); 202 auto index = static_cast<u64>(instr.attribute.fmt20.index.Value());
201 203
202 const auto StoreNextElement = [&](u32 reg_offset) { 204 const u32 num_words = static_cast<u32>(instr.attribute.fmt20.size.Value()) + 1;
203 const auto dest = GetOutputAttribute(static_cast<Attribute::Index>(next_index), 205 for (u32 reg_offset = 0; reg_offset < num_words; ++reg_offset) {
204 next_element, GetRegister(instr.gpr39)); 206 Node dest;
207 if (instr.attribute.fmt20.patch) {
208 const u32 offset = static_cast<u32>(index) * 4 + static_cast<u32>(element);
209 dest = MakeNode<PatchNode>(offset);
210 } else {
211 dest = GetOutputAttribute(static_cast<Attribute::Index>(index), element,
212 GetRegister(instr.gpr39));
213 }
205 const auto src = GetRegister(instr.gpr0.Value() + reg_offset); 214 const auto src = GetRegister(instr.gpr0.Value() + reg_offset);
206 215
207 bb.push_back(Operation(OperationCode::Assign, dest, src)); 216 bb.push_back(Operation(OperationCode::Assign, dest, src));
208 217
209 // Load the next attribute element into the following register. If the element 218 // Load the next attribute element into the following register. If the element to load
210 // to load goes beyond the vec4 size, load the first element of the next 219 // goes beyond the vec4 size, load the first element of the next attribute.
211 // attribute. 220 element = (element + 1) % 4;
212 next_element = (next_element + 1) % 4; 221 index = index + (element == 0 ? 1 : 0);
213 next_index = next_index + (next_element == 0 ? 1 : 0);
214 };
215
216 const u32 num_words = static_cast<u32>(instr.attribute.fmt20.size.Value()) + 1;
217 for (u32 reg_offset = 0; reg_offset < num_words; ++reg_offset) {
218 StoreNextElement(reg_offset);
219 } 222 }
220
221 break; 223 break;
222 } 224 }
223 case OpCode::Id::ST_L: 225 case OpCode::Id::ST_L:
diff --git a/src/video_core/shader/decode/other.cpp b/src/video_core/shader/decode/other.cpp
index 17cd45d3c..5c802886b 100644
--- a/src/video_core/shader/decode/other.cpp
+++ b/src/video_core/shader/decode/other.cpp
@@ -69,6 +69,8 @@ u32 ShaderIR::DecodeOther(NodeBlock& bb, u32 pc) {
69 case OpCode::Id::MOV_SYS: { 69 case OpCode::Id::MOV_SYS: {
70 const Node value = [this, instr] { 70 const Node value = [this, instr] {
71 switch (instr.sys20) { 71 switch (instr.sys20) {
72 case SystemVariable::InvocationId:
73 return Operation(OperationCode::InvocationId);
72 case SystemVariable::Ydirection: 74 case SystemVariable::Ydirection:
73 return Operation(OperationCode::YNegate); 75 return Operation(OperationCode::YNegate);
74 case SystemVariable::InvocationInfo: 76 case SystemVariable::InvocationInfo:
diff --git a/src/video_core/shader/decode/warp.cpp b/src/video_core/shader/decode/warp.cpp
index d98d0e1dd..11b77f795 100644
--- a/src/video_core/shader/decode/warp.cpp
+++ b/src/video_core/shader/decode/warp.cpp
@@ -38,6 +38,9 @@ u32 ShaderIR::DecodeWarp(NodeBlock& bb, u32 pc) {
38 const Instruction instr = {program_code[pc]}; 38 const Instruction instr = {program_code[pc]};
39 const auto opcode = OpCode::Decode(instr); 39 const auto opcode = OpCode::Decode(instr);
40 40
41 // Signal the backend that this shader uses warp instructions.
42 uses_warps = true;
43
41 switch (opcode->get().GetId()) { 44 switch (opcode->get().GetId()) {
42 case OpCode::Id::VOTE: { 45 case OpCode::Id::VOTE: {
43 const Node value = GetPredicate(instr.vote.value, instr.vote.negate_value != 0); 46 const Node value = GetPredicate(instr.vote.value, instr.vote.negate_value != 0);
diff --git a/src/video_core/shader/node.h b/src/video_core/shader/node.h
index b2576bdd6..1a4d28ae9 100644
--- a/src/video_core/shader/node.h
+++ b/src/video_core/shader/node.h
@@ -172,6 +172,7 @@ enum class OperationCode {
172 EmitVertex, /// () -> void 172 EmitVertex, /// () -> void
173 EndPrimitive, /// () -> void 173 EndPrimitive, /// () -> void
174 174
175 InvocationId, /// () -> int
175 YNegate, /// () -> float 176 YNegate, /// () -> float
176 LocalInvocationIdX, /// () -> uint 177 LocalInvocationIdX, /// () -> uint
177 LocalInvocationIdY, /// () -> uint 178 LocalInvocationIdY, /// () -> uint
@@ -213,13 +214,14 @@ class PredicateNode;
213class AbufNode; 214class AbufNode;
214class CbufNode; 215class CbufNode;
215class LmemNode; 216class LmemNode;
217class PatchNode;
216class SmemNode; 218class SmemNode;
217class GmemNode; 219class GmemNode;
218class CommentNode; 220class CommentNode;
219 221
220using NodeData = 222using NodeData = std::variant<OperationNode, ConditionalNode, GprNode, ImmediateNode,
221 std::variant<OperationNode, ConditionalNode, GprNode, ImmediateNode, InternalFlagNode, 223 InternalFlagNode, PredicateNode, AbufNode, PatchNode, CbufNode,
222 PredicateNode, AbufNode, CbufNode, LmemNode, SmemNode, GmemNode, CommentNode>; 224 LmemNode, SmemNode, GmemNode, CommentNode>;
223using Node = std::shared_ptr<NodeData>; 225using Node = std::shared_ptr<NodeData>;
224using Node4 = std::array<Node, 4>; 226using Node4 = std::array<Node, 4>;
225using NodeBlock = std::vector<Node>; 227using NodeBlock = std::vector<Node>;
@@ -542,6 +544,19 @@ private:
542 u32 element{}; 544 u32 element{};
543}; 545};
544 546
547/// Patch memory (used to communicate tessellation stages).
548class PatchNode final {
549public:
550 explicit PatchNode(u32 offset) : offset{offset} {}
551
552 u32 GetOffset() const {
553 return offset;
554 }
555
556private:
557 u32 offset{};
558};
559
545/// Constant buffer node, usually mapped to uniform buffers in GLSL 560/// Constant buffer node, usually mapped to uniform buffers in GLSL
546class CbufNode final { 561class CbufNode final {
547public: 562public:
diff --git a/src/video_core/shader/shader_ir.h b/src/video_core/shader/shader_ir.h
index 2f71a50d2..580f84fcb 100644
--- a/src/video_core/shader/shader_ir.h
+++ b/src/video_core/shader/shader_ir.h
@@ -137,6 +137,10 @@ public:
137 return uses_vertex_id; 137 return uses_vertex_id;
138 } 138 }
139 139
140 bool UsesWarps() const {
141 return uses_warps;
142 }
143
140 bool HasPhysicalAttributes() const { 144 bool HasPhysicalAttributes() const {
141 return uses_physical_attributes; 145 return uses_physical_attributes;
142 } 146 }
@@ -415,6 +419,7 @@ private:
415 bool uses_physical_attributes{}; // Shader uses AL2P or physical attribute read/writes 419 bool uses_physical_attributes{}; // Shader uses AL2P or physical attribute read/writes
416 bool uses_instance_id{}; 420 bool uses_instance_id{};
417 bool uses_vertex_id{}; 421 bool uses_vertex_id{};
422 bool uses_warps{};
418 423
419 Tegra::Shader::Header header; 424 Tegra::Shader::Header header;
420}; 425};
diff --git a/src/video_core/shader/track.cpp b/src/video_core/shader/track.cpp
index 55f5949e4..165c79330 100644
--- a/src/video_core/shader/track.cpp
+++ b/src/video_core/shader/track.cpp
@@ -7,6 +7,7 @@
7#include <variant> 7#include <variant>
8 8
9#include "common/common_types.h" 9#include "common/common_types.h"
10#include "video_core/shader/node.h"
10#include "video_core/shader/shader_ir.h" 11#include "video_core/shader/shader_ir.h"
11 12
12namespace VideoCommon::Shader { 13namespace VideoCommon::Shader {