summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar ReinUsesLisp2019-12-09 23:44:29 -0300
committerGravatar ReinUsesLisp2019-12-09 23:51:57 -0300
commitecbfa416f0ab7e6c830e3df0d83158865de49ad8 (patch)
treee4fbd918616b6fe611162d0c7b07399835d1d9fa
parentshader: Keep track of shaders using warp instructions (diff)
downloadyuzu-ecbfa416f0ab7e6c830e3df0d83158865de49ad8.tar.gz
yuzu-ecbfa416f0ab7e6c830e3df0d83158865de49ad8.tar.xz
yuzu-ecbfa416f0ab7e6c830e3df0d83158865de49ad8.zip
vk_shader_decompiler: Misc changes
Update Sirit and its usage in vk_shader_decompiler. Highlights: - Implement tessellation shaders - Implement geometry shaders - Implement some missing features - Use native half float instructions when available.
m---------externals/sirit0
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.cpp2271
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.h74
3 files changed, 1648 insertions, 697 deletions
diff --git a/externals/sirit b/externals/sirit
Subproject f7c4b07a7e14edb1dcd93bc9879c823423705c2 Subproject e1a6729df7f11e33f6dc0939b18995a57c8bf3d
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 76894275b..8f517bdc1 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -3,8 +3,10 @@
3// Refer to the license.txt file included. 3// Refer to the license.txt file included.
4 4
5#include <functional> 5#include <functional>
6#include <limits>
6#include <map> 7#include <map>
7#include <set> 8#include <type_traits>
9#include <utility>
8 10
9#include <fmt/format.h> 11#include <fmt/format.h>
10 12
@@ -23,7 +25,9 @@
23#include "video_core/shader/node.h" 25#include "video_core/shader/node.h"
24#include "video_core/shader/shader_ir.h" 26#include "video_core/shader/shader_ir.h"
25 27
26namespace Vulkan::VKShader { 28namespace Vulkan {
29
30namespace {
27 31
28using Sirit::Id; 32using Sirit::Id;
29using Tegra::Engines::ShaderType; 33using Tegra::Engines::ShaderType;
@@ -35,22 +39,60 @@ using namespace VideoCommon::Shader;
35using Maxwell = Tegra::Engines::Maxwell3D::Regs; 39using Maxwell = Tegra::Engines::Maxwell3D::Regs;
36using Operation = const OperationNode&; 40using Operation = const OperationNode&;
37 41
42class ASTDecompiler;
43class ExprDecompiler;
44
38// TODO(Rodrigo): Use rasterizer's value 45// TODO(Rodrigo): Use rasterizer's value
39constexpr u32 MAX_CONSTBUFFER_FLOATS = 0x4000; 46constexpr u32 MaxConstBufferFloats = 0x4000;
40constexpr u32 MAX_CONSTBUFFER_ELEMENTS = MAX_CONSTBUFFER_FLOATS / 4; 47constexpr u32 MaxConstBufferElements = MaxConstBufferFloats / 4;
41constexpr u32 STAGE_BINDING_STRIDE = 0x100; 48
49constexpr u32 NumInputPatches = 32; // This value seems to be the standard
50
51enum class Type { Void, Bool, Bool2, Float, Int, Uint, HalfFloat };
52
53class Expression final {
54public:
55 Expression(Id id, Type type) : id{id}, type{type} {
56 ASSERT(type != Type::Void);
57 }
58 Expression() : type{Type::Void} {}
42 59
43enum class Type { Bool, Bool2, Float, Int, Uint, HalfFloat }; 60 Id id{};
61 Type type{};
62};
63static_assert(std::is_standard_layout_v<Expression>);
44 64
45struct SamplerImage { 65struct TexelBuffer {
46 Id image_type; 66 Id image_type{};
47 Id sampled_image_type; 67 Id image{};
48 Id sampler;
49}; 68};
50 69
51namespace { 70struct SampledImage {
71 Id image_type{};
72 Id sampled_image_type{};
73 Id sampler{};
74};
75
76struct StorageImage {
77 Id image_type{};
78 Id image{};
79};
80
81struct AttributeType {
82 Type type;
83 Id scalar;
84 Id vector;
85};
86
87struct VertexIndices {
88 std::optional<u32> position;
89 std::optional<u32> viewport;
90 std::optional<u32> point_size;
91 std::optional<u32> clip_distances;
92};
52 93
53spv::Dim GetSamplerDim(const Sampler& sampler) { 94spv::Dim GetSamplerDim(const Sampler& sampler) {
95 ASSERT(!sampler.IsBuffer());
54 switch (sampler.GetType()) { 96 switch (sampler.GetType()) {
55 case Tegra::Shader::TextureType::Texture1D: 97 case Tegra::Shader::TextureType::Texture1D:
56 return spv::Dim::Dim1D; 98 return spv::Dim::Dim1D;
@@ -66,6 +108,138 @@ spv::Dim GetSamplerDim(const Sampler& sampler) {
66 } 108 }
67} 109}
68 110
111std::pair<spv::Dim, bool> GetImageDim(const Image& image) {
112 switch (image.GetType()) {
113 case Tegra::Shader::ImageType::Texture1D:
114 return {spv::Dim::Dim1D, false};
115 case Tegra::Shader::ImageType::TextureBuffer:
116 return {spv::Dim::Buffer, false};
117 case Tegra::Shader::ImageType::Texture1DArray:
118 return {spv::Dim::Dim1D, true};
119 case Tegra::Shader::ImageType::Texture2D:
120 return {spv::Dim::Dim2D, false};
121 case Tegra::Shader::ImageType::Texture2DArray:
122 return {spv::Dim::Dim2D, true};
123 case Tegra::Shader::ImageType::Texture3D:
124 return {spv::Dim::Dim3D, false};
125 default:
126 UNIMPLEMENTED_MSG("Unimplemented image type={}", static_cast<u32>(image.GetType()));
127 return {spv::Dim::Dim2D, false};
128 }
129}
130
131/// Returns the number of vertices present in a primitive topology.
132u32 GetNumPrimitiveTopologyVertices(Maxwell::PrimitiveTopology primitive_topology) {
133 switch (primitive_topology) {
134 case Maxwell::PrimitiveTopology::Points:
135 return 1;
136 case Maxwell::PrimitiveTopology::Lines:
137 case Maxwell::PrimitiveTopology::LineLoop:
138 case Maxwell::PrimitiveTopology::LineStrip:
139 return 2;
140 case Maxwell::PrimitiveTopology::Triangles:
141 case Maxwell::PrimitiveTopology::TriangleStrip:
142 case Maxwell::PrimitiveTopology::TriangleFan:
143 return 3;
144 case Maxwell::PrimitiveTopology::LinesAdjacency:
145 case Maxwell::PrimitiveTopology::LineStripAdjacency:
146 return 4;
147 case Maxwell::PrimitiveTopology::TrianglesAdjacency:
148 case Maxwell::PrimitiveTopology::TriangleStripAdjacency:
149 return 6;
150 case Maxwell::PrimitiveTopology::Quads:
151 UNIMPLEMENTED_MSG("Quads");
152 return 3;
153 case Maxwell::PrimitiveTopology::QuadStrip:
154 UNIMPLEMENTED_MSG("QuadStrip");
155 return 3;
156 case Maxwell::PrimitiveTopology::Polygon:
157 UNIMPLEMENTED_MSG("Polygon");
158 return 3;
159 case Maxwell::PrimitiveTopology::Patches:
160 UNIMPLEMENTED_MSG("Patches");
161 return 3;
162 default:
163 UNREACHABLE();
164 return 3;
165 }
166}
167
168spv::ExecutionMode GetExecutionMode(Maxwell::TessellationPrimitive primitive) {
169 switch (primitive) {
170 case Maxwell::TessellationPrimitive::Isolines:
171 return spv::ExecutionMode::Isolines;
172 case Maxwell::TessellationPrimitive::Triangles:
173 return spv::ExecutionMode::Triangles;
174 case Maxwell::TessellationPrimitive::Quads:
175 return spv::ExecutionMode::Quads;
176 }
177 UNREACHABLE();
178 return spv::ExecutionMode::Triangles;
179}
180
181spv::ExecutionMode GetExecutionMode(Maxwell::TessellationSpacing spacing) {
182 switch (spacing) {
183 case Maxwell::TessellationSpacing::Equal:
184 return spv::ExecutionMode::SpacingEqual;
185 case Maxwell::TessellationSpacing::FractionalOdd:
186 return spv::ExecutionMode::SpacingFractionalOdd;
187 case Maxwell::TessellationSpacing::FractionalEven:
188 return spv::ExecutionMode::SpacingFractionalEven;
189 }
190 UNREACHABLE();
191 return spv::ExecutionMode::SpacingEqual;
192}
193
194spv::ExecutionMode GetExecutionMode(Maxwell::PrimitiveTopology input_topology) {
195 switch (input_topology) {
196 case Maxwell::PrimitiveTopology::Points:
197 return spv::ExecutionMode::InputPoints;
198 case Maxwell::PrimitiveTopology::Lines:
199 case Maxwell::PrimitiveTopology::LineLoop:
200 case Maxwell::PrimitiveTopology::LineStrip:
201 return spv::ExecutionMode::InputLines;
202 case Maxwell::PrimitiveTopology::Triangles:
203 case Maxwell::PrimitiveTopology::TriangleStrip:
204 case Maxwell::PrimitiveTopology::TriangleFan:
205 return spv::ExecutionMode::Triangles;
206 case Maxwell::PrimitiveTopology::LinesAdjacency:
207 case Maxwell::PrimitiveTopology::LineStripAdjacency:
208 return spv::ExecutionMode::InputLinesAdjacency;
209 case Maxwell::PrimitiveTopology::TrianglesAdjacency:
210 case Maxwell::PrimitiveTopology::TriangleStripAdjacency:
211 return spv::ExecutionMode::InputTrianglesAdjacency;
212 case Maxwell::PrimitiveTopology::Quads:
213 UNIMPLEMENTED_MSG("Quads");
214 return spv::ExecutionMode::Triangles;
215 case Maxwell::PrimitiveTopology::QuadStrip:
216 UNIMPLEMENTED_MSG("QuadStrip");
217 return spv::ExecutionMode::Triangles;
218 case Maxwell::PrimitiveTopology::Polygon:
219 UNIMPLEMENTED_MSG("Polygon");
220 return spv::ExecutionMode::Triangles;
221 case Maxwell::PrimitiveTopology::Patches:
222 UNIMPLEMENTED_MSG("Patches");
223 return spv::ExecutionMode::Triangles;
224 }
225 UNREACHABLE();
226 return spv::ExecutionMode::Triangles;
227}
228
229spv::ExecutionMode GetExecutionMode(Tegra::Shader::OutputTopology output_topology) {
230 switch (output_topology) {
231 case Tegra::Shader::OutputTopology::PointList:
232 return spv::ExecutionMode::OutputPoints;
233 case Tegra::Shader::OutputTopology::LineStrip:
234 return spv::ExecutionMode::OutputLineStrip;
235 case Tegra::Shader::OutputTopology::TriangleStrip:
236 return spv::ExecutionMode::OutputTriangleStrip;
237 default:
238 UNREACHABLE();
239 return spv::ExecutionMode::OutputPoints;
240 }
241}
242
69/// Returns true if an attribute index is one of the 32 generic attributes 243/// Returns true if an attribute index is one of the 32 generic attributes
70constexpr bool IsGenericAttribute(Attribute::Index attribute) { 244constexpr bool IsGenericAttribute(Attribute::Index attribute) {
71 return attribute >= Attribute::Index::Attribute_0 && 245 return attribute >= Attribute::Index::Attribute_0 &&
@@ -73,7 +247,7 @@ constexpr bool IsGenericAttribute(Attribute::Index attribute) {
73} 247}
74 248
75/// Returns the location of a generic attribute 249/// Returns the location of a generic attribute
76constexpr u32 GetGenericAttributeLocation(Attribute::Index attribute) { 250u32 GetGenericAttributeLocation(Attribute::Index attribute) {
77 ASSERT(IsGenericAttribute(attribute)); 251 ASSERT(IsGenericAttribute(attribute));
78 return static_cast<u32>(attribute) - static_cast<u32>(Attribute::Index::Attribute_0); 252 return static_cast<u32>(attribute) - static_cast<u32>(Attribute::Index::Attribute_0);
79} 253}
@@ -87,20 +261,146 @@ bool IsPrecise(Operation operand) {
87 return false; 261 return false;
88} 262}
89 263
90} // namespace 264class SPIRVDecompiler final : public Sirit::Module {
91
92class ASTDecompiler;
93class ExprDecompiler;
94
95class SPIRVDecompiler : public Sirit::Module {
96public: 265public:
97 explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage) 266 explicit SPIRVDecompiler(const VKDevice& device, const ShaderIR& ir, ShaderType stage,
98 : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()} { 267 const Specialization& specialization)
268 : Module(0x00010300), device{device}, ir{ir}, stage{stage}, header{ir.GetHeader()},
269 specialization{specialization} {
99 AddCapability(spv::Capability::Shader); 270 AddCapability(spv::Capability::Shader);
271 AddCapability(spv::Capability::UniformAndStorageBuffer16BitAccess);
272 AddCapability(spv::Capability::ImageQuery);
273 AddCapability(spv::Capability::Image1D);
274 AddCapability(spv::Capability::ImageBuffer);
275 AddCapability(spv::Capability::ImageGatherExtended);
276 AddCapability(spv::Capability::SampledBuffer);
277 AddCapability(spv::Capability::StorageImageWriteWithoutFormat);
278 AddCapability(spv::Capability::SubgroupBallotKHR);
279 AddCapability(spv::Capability::SubgroupVoteKHR);
280 AddExtension("SPV_KHR_shader_ballot");
281 AddExtension("SPV_KHR_subgroup_vote");
100 AddExtension("SPV_KHR_storage_buffer_storage_class"); 282 AddExtension("SPV_KHR_storage_buffer_storage_class");
101 AddExtension("SPV_KHR_variable_pointers"); 283 AddExtension("SPV_KHR_variable_pointers");
284
285 if (ir.UsesViewportIndex()) {
286 AddCapability(spv::Capability::MultiViewport);
287 if (device.IsExtShaderViewportIndexLayerSupported()) {
288 AddExtension("SPV_EXT_shader_viewport_index_layer");
289 AddCapability(spv::Capability::ShaderViewportIndexLayerEXT);
290 }
291 }
292
293 if (device.IsFloat16Supported()) {
294 AddCapability(spv::Capability::Float16);
295 }
296 t_scalar_half = Name(TypeFloat(device.IsFloat16Supported() ? 16 : 32), "scalar_half");
297 t_half = Name(TypeVector(t_scalar_half, 2), "half");
298
299 const Id main = Decompile();
300
301 switch (stage) {
302 case ShaderType::Vertex:
303 AddEntryPoint(spv::ExecutionModel::Vertex, main, "main", interfaces);
304 break;
305 case ShaderType::TesselationControl:
306 AddCapability(spv::Capability::Tessellation);
307 AddEntryPoint(spv::ExecutionModel::TessellationControl, main, "main", interfaces);
308 AddExecutionMode(main, spv::ExecutionMode::OutputVertices,
309 header.common2.threads_per_input_primitive);
310 break;
311 case ShaderType::TesselationEval:
312 AddCapability(spv::Capability::Tessellation);
313 AddEntryPoint(spv::ExecutionModel::TessellationEvaluation, main, "main", interfaces);
314 AddExecutionMode(main, GetExecutionMode(specialization.tessellation.primitive));
315 AddExecutionMode(main, GetExecutionMode(specialization.tessellation.spacing));
316 AddExecutionMode(main, specialization.tessellation.clockwise
317 ? spv::ExecutionMode::VertexOrderCw
318 : spv::ExecutionMode::VertexOrderCcw);
319 break;
320 case ShaderType::Geometry:
321 AddCapability(spv::Capability::Geometry);
322 AddEntryPoint(spv::ExecutionModel::Geometry, main, "main", interfaces);
323 AddExecutionMode(main, GetExecutionMode(specialization.primitive_topology));
324 AddExecutionMode(main, GetExecutionMode(header.common3.output_topology));
325 AddExecutionMode(main, spv::ExecutionMode::OutputVertices,
326 header.common4.max_output_vertices);
327 // TODO(Rodrigo): Where can we get this info from?
328 AddExecutionMode(main, spv::ExecutionMode::Invocations, 1U);
329 break;
330 case ShaderType::Fragment:
331 AddEntryPoint(spv::ExecutionModel::Fragment, main, "main", interfaces);
332 AddExecutionMode(main, spv::ExecutionMode::OriginUpperLeft);
333 if (header.ps.omap.depth) {
334 AddExecutionMode(main, spv::ExecutionMode::DepthReplacing);
335 }
336 break;
337 case ShaderType::Compute:
338 const auto workgroup_size = specialization.workgroup_size;
339 AddExecutionMode(main, spv::ExecutionMode::LocalSize, workgroup_size[0],
340 workgroup_size[1], workgroup_size[2]);
341 AddEntryPoint(spv::ExecutionModel::GLCompute, main, "main", interfaces);
342 break;
343 }
102 } 344 }
103 345
346private:
347 Id Decompile() {
348 DeclareCommon();
349 DeclareVertex();
350 DeclareTessControl();
351 DeclareTessEval();
352 DeclareGeometry();
353 DeclareFragment();
354 DeclareCompute();
355 DeclareRegisters();
356 DeclarePredicates();
357 DeclareLocalMemory();
358 DeclareSharedMemory();
359 DeclareInternalFlags();
360 DeclareInputAttributes();
361 DeclareOutputAttributes();
362
363 u32 binding = specialization.base_binding;
364 binding = DeclareConstantBuffers(binding);
365 binding = DeclareGlobalBuffers(binding);
366 binding = DeclareTexelBuffers(binding);
367 binding = DeclareSamplers(binding);
368 binding = DeclareImages(binding);
369
370 const Id main = OpFunction(t_void, {}, TypeFunction(t_void));
371 AddLabel();
372
373 if (ir.IsDecompiled()) {
374 DeclareFlowVariables();
375 DecompileAST();
376 } else {
377 AllocateLabels();
378 DecompileBranchMode();
379 }
380
381 OpReturn();
382 OpFunctionEnd();
383
384 return main;
385 }
386
387 void DefinePrologue() {
388 if (stage == ShaderType::Vertex) {
389 // Clear Position to avoid reading trash on the Z conversion.
390 const auto position_index = out_indices.position.value();
391 const Id position = AccessElement(t_out_float4, out_vertex, position_index);
392 OpStore(position, v_varying_default);
393
394 if (specialization.point_size) {
395 const u32 point_size_index = out_indices.point_size.value();
396 const Id out_point_size = AccessElement(t_out_float, out_vertex, point_size_index);
397 OpStore(out_point_size, Constant(t_float, *specialization.point_size));
398 }
399 }
400 }
401
402 void DecompileAST();
403
104 void DecompileBranchMode() { 404 void DecompileBranchMode() {
105 const u32 first_address = ir.GetBasicBlocks().begin()->first; 405 const u32 first_address = ir.GetBasicBlocks().begin()->first;
106 const Id loop_label = OpLabel("loop"); 406 const Id loop_label = OpLabel("loop");
@@ -111,14 +411,15 @@ public:
111 411
112 std::vector<Sirit::Literal> literals; 412 std::vector<Sirit::Literal> literals;
113 std::vector<Id> branch_labels; 413 std::vector<Id> branch_labels;
114 for (const auto& pair : labels) { 414 for (const auto& [literal, label] : labels) {
115 const auto [literal, label] = pair;
116 literals.push_back(literal); 415 literals.push_back(literal);
117 branch_labels.push_back(label); 416 branch_labels.push_back(label);
118 } 417 }
119 418
120 jmp_to = Emit(OpVariable(TypePointer(spv::StorageClass::Function, t_uint), 419 jmp_to = OpVariable(TypePointer(spv::StorageClass::Function, t_uint),
121 spv::StorageClass::Function, Constant(t_uint, first_address))); 420 spv::StorageClass::Function, Constant(t_uint, first_address));
421 AddLocalVariable(jmp_to);
422
122 std::tie(ssy_flow_stack, ssy_flow_stack_top) = CreateFlowStack(); 423 std::tie(ssy_flow_stack, ssy_flow_stack_top) = CreateFlowStack();
123 std::tie(pbk_flow_stack, pbk_flow_stack_top) = CreateFlowStack(); 424 std::tie(pbk_flow_stack, pbk_flow_stack_top) = CreateFlowStack();
124 425
@@ -128,151 +429,118 @@ public:
128 Name(pbk_flow_stack, "pbk_flow_stack"); 429 Name(pbk_flow_stack, "pbk_flow_stack");
129 Name(pbk_flow_stack_top, "pbk_flow_stack_top"); 430 Name(pbk_flow_stack_top, "pbk_flow_stack_top");
130 431
131 Emit(OpBranch(loop_label)); 432 DefinePrologue();
132 Emit(loop_label); 433
133 Emit(OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::Unroll)); 434 OpBranch(loop_label);
134 Emit(OpBranch(dummy_label)); 435 AddLabel(loop_label);
436 OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::MaskNone);
437 OpBranch(dummy_label);
135 438
136 Emit(dummy_label); 439 AddLabel(dummy_label);
137 const Id default_branch = OpLabel(); 440 const Id default_branch = OpLabel();
138 const Id jmp_to_load = Emit(OpLoad(t_uint, jmp_to)); 441 const Id jmp_to_load = OpLoad(t_uint, jmp_to);
139 Emit(OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone)); 442 OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone);
140 Emit(OpSwitch(jmp_to_load, default_branch, literals, branch_labels)); 443 OpSwitch(jmp_to_load, default_branch, literals, branch_labels);
141 444
142 Emit(default_branch); 445 AddLabel(default_branch);
143 Emit(OpReturn()); 446 OpReturn();
144 447
145 for (const auto& pair : ir.GetBasicBlocks()) { 448 for (const auto& [address, bb] : ir.GetBasicBlocks()) {
146 const auto& [address, bb] = pair; 449 AddLabel(labels.at(address));
147 Emit(labels.at(address));
148 450
149 VisitBasicBlock(bb); 451 VisitBasicBlock(bb);
150 452
151 const auto next_it = labels.lower_bound(address + 1); 453 const auto next_it = labels.lower_bound(address + 1);
152 const Id next_label = next_it != labels.end() ? next_it->second : default_branch; 454 const Id next_label = next_it != labels.end() ? next_it->second : default_branch;
153 Emit(OpBranch(next_label)); 455 OpBranch(next_label);
154 } 456 }
155 457
156 Emit(jump_label); 458 AddLabel(jump_label);
157 Emit(OpBranch(continue_label)); 459 OpBranch(continue_label);
158 Emit(continue_label); 460 AddLabel(continue_label);
159 Emit(OpBranch(loop_label)); 461 OpBranch(loop_label);
160 Emit(merge_label); 462 AddLabel(merge_label);
161 } 463 }
162 464
163 void DecompileAST(); 465private:
466 friend class ASTDecompiler;
467 friend class ExprDecompiler;
164 468
165 void Decompile() { 469 static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
166 const bool is_fully_decompiled = ir.IsDecompiled();
167 AllocateBindings();
168 if (!is_fully_decompiled) {
169 AllocateLabels();
170 }
171 470
172 DeclareVertex(); 471 void AllocateLabels() {
173 DeclareGeometry(); 472 for (const auto& pair : ir.GetBasicBlocks()) {
174 DeclareFragment(); 473 const u32 address = pair.first;
175 DeclareRegisters(); 474 labels.emplace(address, OpLabel(fmt::format("label_0x{:x}", address)));
176 DeclarePredicates();
177 if (is_fully_decompiled) {
178 DeclareFlowVariables();
179 } 475 }
180 DeclareLocalMemory(); 476 }
181 DeclareInternalFlags();
182 DeclareInputAttributes();
183 DeclareOutputAttributes();
184 DeclareConstantBuffers();
185 DeclareGlobalBuffers();
186 DeclareSamplers();
187 477
188 execute_function = 478 void DeclareCommon() {
189 Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void))); 479 thread_id =
190 Emit(OpLabel()); 480 DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id");
481 }
191 482
192 if (is_fully_decompiled) { 483 void DeclareVertex() {
193 DecompileAST(); 484 if (stage != ShaderType::Vertex) {
194 } else { 485 return;
195 DecompileBranchMode();
196 } 486 }
487 Id out_vertex_struct;
488 std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct();
489 const Id vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct);
490 out_vertex = OpVariable(vertex_ptr, spv::StorageClass::Output);
491 interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex")));
197 492
198 Emit(OpReturn()); 493 // Declare input attributes
199 Emit(OpFunctionEnd()); 494 vertex_index = DeclareInputBuiltIn(spv::BuiltIn::VertexIndex, t_in_uint, "vertex_index");
495 instance_index =
496 DeclareInputBuiltIn(spv::BuiltIn::InstanceIndex, t_in_uint, "instance_index");
200 } 497 }
201 498
202 ShaderEntries GetShaderEntries() const { 499 void DeclareTessControl() {
203 ShaderEntries entries; 500 if (stage != ShaderType::TesselationControl) {
204 entries.const_buffers_base_binding = const_buffers_base_binding; 501 return;
205 entries.global_buffers_base_binding = global_buffers_base_binding;
206 entries.samplers_base_binding = samplers_base_binding;
207 for (const auto& cbuf : ir.GetConstantBuffers()) {
208 entries.const_buffers.emplace_back(cbuf.second, cbuf.first);
209 }
210 for (const auto& gmem_pair : ir.GetGlobalMemory()) {
211 const auto& [base, usage] = gmem_pair;
212 entries.global_buffers.emplace_back(base.cbuf_index, base.cbuf_offset);
213 }
214 for (const auto& sampler : ir.GetSamplers()) {
215 entries.samplers.emplace_back(sampler);
216 }
217 for (const auto& attribute : ir.GetInputAttributes()) {
218 if (IsGenericAttribute(attribute)) {
219 entries.attributes.insert(GetGenericAttributeLocation(attribute));
220 }
221 } 502 }
222 entries.clip_distances = ir.GetClipDistances(); 503 DeclareInputVertexArray(NumInputPatches);
223 entries.shader_length = ir.GetLength(); 504 DeclareOutputVertexArray(header.common2.threads_per_input_primitive);
224 entries.entry_function = execute_function;
225 entries.interfaces = interfaces;
226 return entries;
227 }
228
229private:
230 friend class ASTDecompiler;
231 friend class ExprDecompiler;
232
233 static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
234 505
235 void AllocateBindings() { 506 tess_level_outer = DeclareBuiltIn(
236 const u32 binding_base = static_cast<u32>(stage) * STAGE_BINDING_STRIDE; 507 spv::BuiltIn::TessLevelOuter, spv::StorageClass::Output,
237 u32 binding_iterator = binding_base; 508 TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 4U))),
509 "tess_level_outer");
510 Decorate(tess_level_outer, spv::Decoration::Patch);
238 511
239 const auto Allocate = [&binding_iterator](std::size_t count) { 512 tess_level_inner = DeclareBuiltIn(
240 const u32 current_binding = binding_iterator; 513 spv::BuiltIn::TessLevelInner, spv::StorageClass::Output,
241 binding_iterator += static_cast<u32>(count); 514 TypePointer(spv::StorageClass::Output, TypeArray(t_float, Constant(t_uint, 2U))),
242 return current_binding; 515 "tess_level_inner");
243 }; 516 Decorate(tess_level_inner, spv::Decoration::Patch);
244 const_buffers_base_binding = Allocate(ir.GetConstantBuffers().size());
245 global_buffers_base_binding = Allocate(ir.GetGlobalMemory().size());
246 samplers_base_binding = Allocate(ir.GetSamplers().size());
247 517
248 ASSERT_MSG(binding_iterator - binding_base < STAGE_BINDING_STRIDE, 518 invocation_id = DeclareInputBuiltIn(spv::BuiltIn::InvocationId, t_in_int, "invocation_id");
249 "Stage binding stride is too small");
250 } 519 }
251 520
252 void AllocateLabels() { 521 void DeclareTessEval() {
253 for (const auto& pair : ir.GetBasicBlocks()) { 522 if (stage != ShaderType::TesselationEval) {
254 const u32 address = pair.first;
255 labels.emplace(address, OpLabel(fmt::format("label_0x{:x}", address)));
256 }
257 }
258
259 void DeclareVertex() {
260 if (stage != ShaderType::Vertex)
261 return; 523 return;
524 }
525 DeclareInputVertexArray(NumInputPatches);
526 DeclareOutputVertex();
262 527
263 DeclareVertexRedeclarations(); 528 tess_coord = DeclareInputBuiltIn(spv::BuiltIn::TessCoord, t_in_float3, "tess_coord");
264 } 529 }
265 530
266 void DeclareGeometry() { 531 void DeclareGeometry() {
267 if (stage != ShaderType::Geometry) 532 if (stage != ShaderType::Geometry) {
268 return; 533 return;
269 534 }
270 UNIMPLEMENTED(); 535 const u32 num_input = GetNumPrimitiveTopologyVertices(specialization.primitive_topology);
536 DeclareInputVertexArray(num_input);
537 DeclareOutputVertex();
271 } 538 }
272 539
273 void DeclareFragment() { 540 void DeclareFragment() {
274 if (stage != ShaderType::Fragment) 541 if (stage != ShaderType::Fragment) {
275 return; 542 return;
543 }
276 544
277 for (u32 rt = 0; rt < static_cast<u32>(frag_colors.size()); ++rt) { 545 for (u32 rt = 0; rt < static_cast<u32>(frag_colors.size()); ++rt) {
278 if (!IsRenderTargetUsed(rt)) { 546 if (!IsRenderTargetUsed(rt)) {
@@ -296,10 +564,19 @@ private:
296 interfaces.push_back(frag_depth); 564 interfaces.push_back(frag_depth);
297 } 565 }
298 566
299 frag_coord = DeclareBuiltIn(spv::BuiltIn::FragCoord, spv::StorageClass::Input, t_in_float4, 567 frag_coord = DeclareInputBuiltIn(spv::BuiltIn::FragCoord, t_in_float4, "frag_coord");
300 "frag_coord"); 568 front_facing = DeclareInputBuiltIn(spv::BuiltIn::FrontFacing, t_in_bool, "front_facing");
301 front_facing = DeclareBuiltIn(spv::BuiltIn::FrontFacing, spv::StorageClass::Input, 569 point_coord = DeclareInputBuiltIn(spv::BuiltIn::PointCoord, t_in_float2, "point_coord");
302 t_in_bool, "front_facing"); 570 }
571
572 void DeclareCompute() {
573 if (stage != ShaderType::Compute) {
574 return;
575 }
576
577 workgroup_id = DeclareInputBuiltIn(spv::BuiltIn::WorkgroupId, t_in_uint3, "workgroup_id");
578 local_invocation_id =
579 DeclareInputBuiltIn(spv::BuiltIn::LocalInvocationId, t_in_uint3, "local_invocation_id");
303 } 580 }
304 581
305 void DeclareRegisters() { 582 void DeclareRegisters() {
@@ -327,21 +604,44 @@ private:
327 } 604 }
328 605
329 void DeclareLocalMemory() { 606 void DeclareLocalMemory() {
330 if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) { 607 // TODO(Rodrigo): Unstub kernel local memory size and pass it from a register at
331 const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4); 608 // specialization time.
332 const Id type_array = TypeArray(t_float, Constant(t_uint, element_count)); 609 const u64 lmem_size = stage == ShaderType::Compute ? 0x400 : header.GetLocalMemorySize();
333 const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array); 610 if (lmem_size == 0) {
334 Name(type_pointer, "LocalMemory"); 611 return;
612 }
613 const auto element_count = static_cast<u32>(Common::AlignUp(lmem_size, 4) / 4);
614 const Id type_array = TypeArray(t_float, Constant(t_uint, element_count));
615 const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array);
616 Name(type_pointer, "LocalMemory");
617
618 local_memory =
619 OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array));
620 AddGlobalVariable(Name(local_memory, "local_memory"));
621 }
622
623 void DeclareSharedMemory() {
624 if (stage != ShaderType::Compute) {
625 return;
626 }
627 t_smem_uint = TypePointer(spv::StorageClass::Workgroup, t_uint);
335 628
336 local_memory = 629 const u32 smem_size = specialization.shared_memory_size;
337 OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array)); 630 if (smem_size == 0) {
338 AddGlobalVariable(Name(local_memory, "local_memory")); 631 // Avoid declaring an empty array.
632 return;
339 } 633 }
634 const auto element_count = static_cast<u32>(Common::AlignUp(smem_size, 4) / 4);
635 const Id type_array = TypeArray(t_uint, Constant(t_uint, element_count));
636 const Id type_pointer = TypePointer(spv::StorageClass::Workgroup, type_array);
637 Name(type_pointer, "SharedMemory");
638
639 shared_memory = OpVariable(type_pointer, spv::StorageClass::Workgroup);
640 AddGlobalVariable(Name(shared_memory, "shared_memory"));
340 } 641 }
341 642
342 void DeclareInternalFlags() { 643 void DeclareInternalFlags() {
343 constexpr std::array<const char*, INTERNAL_FLAGS_COUNT> names = {"zero", "sign", "carry", 644 constexpr std::array names = {"zero", "sign", "carry", "overflow"};
344 "overflow"};
345 for (std::size_t flag = 0; flag < INTERNAL_FLAGS_COUNT; ++flag) { 645 for (std::size_t flag = 0; flag < INTERNAL_FLAGS_COUNT; ++flag) {
346 const auto flag_code = static_cast<InternalFlag>(flag); 646 const auto flag_code = static_cast<InternalFlag>(flag);
347 const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false); 647 const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
@@ -349,17 +649,53 @@ private:
349 } 649 }
350 } 650 }
351 651
652 void DeclareInputVertexArray(u32 length) {
653 constexpr auto storage = spv::StorageClass::Input;
654 std::tie(in_indices, in_vertex) = DeclareVertexArray(storage, "in_indices", length);
655 }
656
657 void DeclareOutputVertexArray(u32 length) {
658 constexpr auto storage = spv::StorageClass::Output;
659 std::tie(out_indices, out_vertex) = DeclareVertexArray(storage, "out_indices", length);
660 }
661
662 std::tuple<VertexIndices, Id> DeclareVertexArray(spv::StorageClass storage_class,
663 std::string name, u32 length) {
664 const auto [struct_id, indices] = DeclareVertexStruct();
665 const Id vertex_array = TypeArray(struct_id, Constant(t_uint, length));
666 const Id vertex_ptr = TypePointer(storage_class, vertex_array);
667 const Id vertex = OpVariable(vertex_ptr, storage_class);
668 AddGlobalVariable(Name(vertex, std::move(name)));
669 interfaces.push_back(vertex);
670 return {indices, vertex};
671 }
672
673 void DeclareOutputVertex() {
674 Id out_vertex_struct;
675 std::tie(out_vertex_struct, out_indices) = DeclareVertexStruct();
676 const Id out_vertex_ptr = TypePointer(spv::StorageClass::Output, out_vertex_struct);
677 out_vertex = OpVariable(out_vertex_ptr, spv::StorageClass::Output);
678 interfaces.push_back(AddGlobalVariable(Name(out_vertex, "out_vertex")));
679 }
680
352 void DeclareInputAttributes() { 681 void DeclareInputAttributes() {
353 for (const auto index : ir.GetInputAttributes()) { 682 for (const auto index : ir.GetInputAttributes()) {
354 if (!IsGenericAttribute(index)) { 683 if (!IsGenericAttribute(index)) {
355 continue; 684 continue;
356 } 685 }
357 686
358 UNIMPLEMENTED_IF(stage == ShaderType::Geometry);
359
360 const u32 location = GetGenericAttributeLocation(index); 687 const u32 location = GetGenericAttributeLocation(index);
361 const Id id = OpVariable(t_in_float4, spv::StorageClass::Input); 688 const auto type_descriptor = GetAttributeType(location);
362 Name(AddGlobalVariable(id), fmt::format("in_attr{}", location)); 689 Id type;
690 if (IsInputAttributeArray()) {
691 type = GetTypeVectorDefinitionLut(type_descriptor.type).at(3);
692 type = TypeArray(type, Constant(t_uint, GetNumInputVertices()));
693 type = TypePointer(spv::StorageClass::Input, type);
694 } else {
695 type = type_descriptor.vector;
696 }
697 const Id id = OpVariable(type, spv::StorageClass::Input);
698 AddGlobalVariable(Name(id, fmt::format("in_attr{}", location)));
363 input_attributes.emplace(index, id); 699 input_attributes.emplace(index, id);
364 interfaces.push_back(id); 700 interfaces.push_back(id);
365 701
@@ -389,8 +725,21 @@ private:
389 if (!IsGenericAttribute(index)) { 725 if (!IsGenericAttribute(index)) {
390 continue; 726 continue;
391 } 727 }
392 const auto location = GetGenericAttributeLocation(index); 728 const u32 location = GetGenericAttributeLocation(index);
393 const Id id = OpVariable(t_out_float4, spv::StorageClass::Output); 729 Id type = t_float4;
730 Id varying_default = v_varying_default;
731 if (IsOutputAttributeArray()) {
732 const u32 num = GetNumOutputVertices();
733 type = TypeArray(type, Constant(t_uint, num));
734 if (device.GetDriverID() != vk::DriverIdKHR::eIntelProprietaryWindows) {
735 // Intel's proprietary driver fails to setup defaults for arrayed output
736 // attributes.
737 varying_default = ConstantComposite(type, std::vector(num, varying_default));
738 }
739 }
740 type = TypePointer(spv::StorageClass::Output, type);
741
742 const Id id = OpVariable(type, spv::StorageClass::Output, varying_default);
394 Name(AddGlobalVariable(id), fmt::format("out_attr{}", location)); 743 Name(AddGlobalVariable(id), fmt::format("out_attr{}", location));
395 output_attributes.emplace(index, id); 744 output_attributes.emplace(index, id);
396 interfaces.push_back(id); 745 interfaces.push_back(id);
@@ -399,10 +748,8 @@ private:
399 } 748 }
400 } 749 }
401 750
402 void DeclareConstantBuffers() { 751 u32 DeclareConstantBuffers(u32 binding) {
403 u32 binding = const_buffers_base_binding; 752 for (const auto& [index, size] : ir.GetConstantBuffers()) {
404 for (const auto& entry : ir.GetConstantBuffers()) {
405 const auto [index, size] = entry;
406 const Id type = device.IsKhrUniformBufferStandardLayoutSupported() ? t_cbuf_scalar_ubo 753 const Id type = device.IsKhrUniformBufferStandardLayoutSupported() ? t_cbuf_scalar_ubo
407 : t_cbuf_std140_ubo; 754 : t_cbuf_std140_ubo;
408 const Id id = OpVariable(type, spv::StorageClass::Uniform); 755 const Id id = OpVariable(type, spv::StorageClass::Uniform);
@@ -412,12 +759,11 @@ private:
412 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET); 759 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
413 constant_buffers.emplace(index, id); 760 constant_buffers.emplace(index, id);
414 } 761 }
762 return binding;
415 } 763 }
416 764
417 void DeclareGlobalBuffers() { 765 u32 DeclareGlobalBuffers(u32 binding) {
418 u32 binding = global_buffers_base_binding; 766 for (const auto& [base, usage] : ir.GetGlobalMemory()) {
419 for (const auto& entry : ir.GetGlobalMemory()) {
420 const auto [base, usage] = entry;
421 const Id id = OpVariable(t_gmem_ssbo, spv::StorageClass::StorageBuffer); 767 const Id id = OpVariable(t_gmem_ssbo, spv::StorageClass::StorageBuffer);
422 AddGlobalVariable( 768 AddGlobalVariable(
423 Name(id, fmt::format("gmem_{}_{}", base.cbuf_index, base.cbuf_offset))); 769 Name(id, fmt::format("gmem_{}_{}", base.cbuf_index, base.cbuf_offset)));
@@ -426,89 +772,187 @@ private:
426 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET); 772 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
427 global_buffers.emplace(base, id); 773 global_buffers.emplace(base, id);
428 } 774 }
775 return binding;
776 }
777
778 u32 DeclareTexelBuffers(u32 binding) {
779 for (const auto& sampler : ir.GetSamplers()) {
780 if (!sampler.IsBuffer()) {
781 continue;
782 }
783 ASSERT(!sampler.IsArray());
784 ASSERT(!sampler.IsShadow());
785
786 constexpr auto dim = spv::Dim::Buffer;
787 constexpr int depth = 0;
788 constexpr int arrayed = 0;
789 constexpr bool ms = false;
790 constexpr int sampled = 1;
791 constexpr auto format = spv::ImageFormat::Unknown;
792 const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format);
793 const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type);
794 const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
795 AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex())));
796 Decorate(id, spv::Decoration::Binding, binding++);
797 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
798
799 texel_buffers.emplace(sampler.GetIndex(), TexelBuffer{image_type, id});
800 }
801 return binding;
429 } 802 }
430 803
431 void DeclareSamplers() { 804 u32 DeclareSamplers(u32 binding) {
432 u32 binding = samplers_base_binding;
433 for (const auto& sampler : ir.GetSamplers()) { 805 for (const auto& sampler : ir.GetSamplers()) {
806 if (sampler.IsBuffer()) {
807 continue;
808 }
434 const auto dim = GetSamplerDim(sampler); 809 const auto dim = GetSamplerDim(sampler);
435 const int depth = sampler.IsShadow() ? 1 : 0; 810 const int depth = sampler.IsShadow() ? 1 : 0;
436 const int arrayed = sampler.IsArray() ? 1 : 0; 811 const int arrayed = sampler.IsArray() ? 1 : 0;
437 // TODO(Rodrigo): Sampled 1 indicates that the image will be used with a sampler. When 812 constexpr bool ms = false;
438 // SULD and SUST instructions are implemented, replace this value. 813 constexpr int sampled = 1;
439 const int sampled = 1; 814 constexpr auto format = spv::ImageFormat::Unknown;
440 const Id image_type = 815 const Id image_type = TypeImage(t_float, dim, depth, arrayed, ms, sampled, format);
441 TypeImage(t_float, dim, depth, arrayed, false, sampled, spv::ImageFormat::Unknown);
442 const Id sampled_image_type = TypeSampledImage(image_type); 816 const Id sampled_image_type = TypeSampledImage(image_type);
443 const Id pointer_type = 817 const Id pointer_type =
444 TypePointer(spv::StorageClass::UniformConstant, sampled_image_type); 818 TypePointer(spv::StorageClass::UniformConstant, sampled_image_type);
445 const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant); 819 const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
446 AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex()))); 820 AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex())));
821 Decorate(id, spv::Decoration::Binding, binding++);
822 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
447 823
448 sampler_images.insert( 824 sampled_images.emplace(sampler.GetIndex(),
449 {static_cast<u32>(sampler.GetIndex()), {image_type, sampled_image_type, id}}); 825 SampledImage{image_type, sampled_image_type, id});
826 }
827 return binding;
828 }
829
830 u32 DeclareImages(u32 binding) {
831 for (const auto& image : ir.GetImages()) {
832 const auto [dim, arrayed] = GetImageDim(image);
833 constexpr int depth = 0;
834 constexpr bool ms = false;
835 constexpr int sampled = 2; // This won't be accessed with a sampler
836 constexpr auto format = spv::ImageFormat::Unknown;
837 const Id image_type = TypeImage(t_uint, dim, depth, arrayed, ms, sampled, format, {});
838 const Id pointer_type = TypePointer(spv::StorageClass::UniformConstant, image_type);
839 const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
840 AddGlobalVariable(Name(id, fmt::format("image_{}", image.GetIndex())));
450 841
451 Decorate(id, spv::Decoration::Binding, binding++); 842 Decorate(id, spv::Decoration::Binding, binding++);
452 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET); 843 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
844 if (image.IsRead() && !image.IsWritten()) {
845 Decorate(id, spv::Decoration::NonWritable);
846 } else if (image.IsWritten() && !image.IsRead()) {
847 Decorate(id, spv::Decoration::NonReadable);
848 }
849
850 images.emplace(static_cast<u32>(image.GetIndex()), StorageImage{image_type, id});
453 } 851 }
852 return binding;
853 }
854
855 bool IsInputAttributeArray() const {
856 return stage == ShaderType::TesselationControl || stage == ShaderType::TesselationEval ||
857 stage == ShaderType::Geometry;
454 } 858 }
455 859
456 void DeclareVertexRedeclarations() { 860 bool IsOutputAttributeArray() const {
457 vertex_index = DeclareBuiltIn(spv::BuiltIn::VertexIndex, spv::StorageClass::Input, 861 return stage == ShaderType::TesselationControl;
458 t_in_uint, "vertex_index"); 862 }
459 instance_index = DeclareBuiltIn(spv::BuiltIn::InstanceIndex, spv::StorageClass::Input,
460 t_in_uint, "instance_index");
461 863
462 bool is_clip_distances_declared = false; 864 u32 GetNumInputVertices() const {
463 for (const auto index : ir.GetOutputAttributes()) { 865 switch (stage) {
464 if (index == Attribute::Index::ClipDistances0123 || 866 case ShaderType::Geometry:
465 index == Attribute::Index::ClipDistances4567) { 867 return GetNumPrimitiveTopologyVertices(specialization.primitive_topology);
466 is_clip_distances_declared = true; 868 case ShaderType::TesselationControl:
467 } 869 case ShaderType::TesselationEval:
870 return NumInputPatches;
871 default:
872 UNREACHABLE();
873 return 1;
468 } 874 }
875 }
469 876
470 std::vector<Id> members; 877 u32 GetNumOutputVertices() const {
471 members.push_back(t_float4); 878 switch (stage) {
472 if (ir.UsesPointSize()) { 879 case ShaderType::TesselationControl:
473 members.push_back(t_float); 880 return header.common2.threads_per_input_primitive;
474 } 881 default:
475 if (is_clip_distances_declared) { 882 UNREACHABLE();
476 members.push_back(TypeArray(t_float, Constant(t_uint, 8))); 883 return 1;
477 } 884 }
478 885 }
479 const Id gl_per_vertex_struct = Name(TypeStruct(members), "PerVertex"); 886
480 Decorate(gl_per_vertex_struct, spv::Decoration::Block); 887 std::tuple<Id, VertexIndices> DeclareVertexStruct() {
481 888 struct BuiltIn {
482 u32 declaration_index = 0; 889 Id type;
483 const auto MemberDecorateBuiltIn = [&](spv::BuiltIn builtin, std::string name, 890 spv::BuiltIn builtin;
484 bool condition) { 891 const char* name;
485 if (!condition) 892 };
486 return u32{}; 893 std::vector<BuiltIn> members;
487 MemberName(gl_per_vertex_struct, declaration_index, name); 894 members.reserve(4);
488 MemberDecorate(gl_per_vertex_struct, declaration_index, spv::Decoration::BuiltIn, 895
489 static_cast<u32>(builtin)); 896 const auto AddBuiltIn = [&](Id type, spv::BuiltIn builtin, const char* name) {
490 return declaration_index++; 897 const auto index = static_cast<u32>(members.size());
898 members.push_back(BuiltIn{type, builtin, name});
899 return index;
491 }; 900 };
492 901
493 position_index = MemberDecorateBuiltIn(spv::BuiltIn::Position, "position", true); 902 VertexIndices indices;
494 point_size_index = 903 indices.position = AddBuiltIn(t_float4, spv::BuiltIn::Position, "position");
495 MemberDecorateBuiltIn(spv::BuiltIn::PointSize, "point_size", ir.UsesPointSize()); 904
496 clip_distances_index = MemberDecorateBuiltIn(spv::BuiltIn::ClipDistance, "clip_distances", 905 if (ir.UsesViewportIndex()) {
497 is_clip_distances_declared); 906 if (stage != ShaderType::Vertex || device.IsExtShaderViewportIndexLayerSupported()) {
907 indices.viewport = AddBuiltIn(t_int, spv::BuiltIn::ViewportIndex, "viewport_index");
908 } else {
909 LOG_ERROR(Render_Vulkan,
910 "Shader requires ViewportIndex but it's not supported on this "
911 "stage with this device.");
912 }
913 }
914
915 if (ir.UsesPointSize() || specialization.point_size) {
916 indices.point_size = AddBuiltIn(t_float, spv::BuiltIn::PointSize, "point_size");
917 }
918
919 const auto& output_attributes = ir.GetOutputAttributes();
920 const bool declare_clip_distances =
921 std::any_of(output_attributes.begin(), output_attributes.end(), [](const auto& index) {
922 return index == Attribute::Index::ClipDistances0123 ||
923 index == Attribute::Index::ClipDistances4567;
924 });
925 if (declare_clip_distances) {
926 indices.clip_distances = AddBuiltIn(TypeArray(t_float, Constant(t_uint, 8)),
927 spv::BuiltIn::ClipDistance, "clip_distances");
928 }
929
930 std::vector<Id> member_types;
931 member_types.reserve(members.size());
932 for (std::size_t i = 0; i < members.size(); ++i) {
933 member_types.push_back(members[i].type);
934 }
935 const Id per_vertex_struct = Name(TypeStruct(member_types), "PerVertex");
936 Decorate(per_vertex_struct, spv::Decoration::Block);
937
938 for (std::size_t index = 0; index < members.size(); ++index) {
939 const auto& member = members[index];
940 MemberName(per_vertex_struct, static_cast<u32>(index), member.name);
941 MemberDecorate(per_vertex_struct, static_cast<u32>(index), spv::Decoration::BuiltIn,
942 static_cast<u32>(member.builtin));
943 }
498 944
499 const Id type_pointer = TypePointer(spv::StorageClass::Output, gl_per_vertex_struct); 945 return {per_vertex_struct, indices};
500 per_vertex = OpVariable(type_pointer, spv::StorageClass::Output);
501 AddGlobalVariable(Name(per_vertex, "per_vertex"));
502 interfaces.push_back(per_vertex);
503 } 946 }
504 947
505 void VisitBasicBlock(const NodeBlock& bb) { 948 void VisitBasicBlock(const NodeBlock& bb) {
506 for (const auto& node : bb) { 949 for (const auto& node : bb) {
507 static_cast<void>(Visit(node)); 950 [[maybe_unused]] const Type type = Visit(node).type;
951 ASSERT(type == Type::Void);
508 } 952 }
509 } 953 }
510 954
511 Id Visit(const Node& node) { 955 Expression Visit(const Node& node) {
512 if (const auto operation = std::get_if<OperationNode>(&*node)) { 956 if (const auto operation = std::get_if<OperationNode>(&*node)) {
513 const auto operation_index = static_cast<std::size_t>(operation->GetCode()); 957 const auto operation_index = static_cast<std::size_t>(operation->GetCode());
514 const auto decompiler = operation_decompilers[operation_index]; 958 const auto decompiler = operation_decompilers[operation_index];
@@ -516,18 +960,21 @@ private:
516 UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index); 960 UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index);
517 } 961 }
518 return (this->*decompiler)(*operation); 962 return (this->*decompiler)(*operation);
963 }
519 964
520 } else if (const auto gpr = std::get_if<GprNode>(&*node)) { 965 if (const auto gpr = std::get_if<GprNode>(&*node)) {
521 const u32 index = gpr->GetIndex(); 966 const u32 index = gpr->GetIndex();
522 if (index == Register::ZeroIndex) { 967 if (index == Register::ZeroIndex) {
523 return Constant(t_float, 0.0f); 968 return {v_float_zero, Type::Float};
524 } 969 }
525 return Emit(OpLoad(t_float, registers.at(index))); 970 return {OpLoad(t_float, registers.at(index)), Type::Float};
971 }
526 972
527 } else if (const auto immediate = std::get_if<ImmediateNode>(&*node)) { 973 if (const auto immediate = std::get_if<ImmediateNode>(&*node)) {
528 return BitcastTo<Type::Float>(Constant(t_uint, immediate->GetValue())); 974 return {Constant(t_uint, immediate->GetValue()), Type::Uint};
975 }
529 976
530 } else if (const auto predicate = std::get_if<PredicateNode>(&*node)) { 977 if (const auto predicate = std::get_if<PredicateNode>(&*node)) {
531 const auto value = [&]() -> Id { 978 const auto value = [&]() -> Id {
532 switch (const auto index = predicate->GetIndex(); index) { 979 switch (const auto index = predicate->GetIndex(); index) {
533 case Tegra::Shader::Pred::UnusedIndex: 980 case Tegra::Shader::Pred::UnusedIndex:
@@ -535,74 +982,107 @@ private:
535 case Tegra::Shader::Pred::NeverExecute: 982 case Tegra::Shader::Pred::NeverExecute:
536 return v_false; 983 return v_false;
537 default: 984 default:
538 return Emit(OpLoad(t_bool, predicates.at(index))); 985 return OpLoad(t_bool, predicates.at(index));
539 } 986 }
540 }(); 987 }();
541 if (predicate->IsNegated()) { 988 if (predicate->IsNegated()) {
542 return Emit(OpLogicalNot(t_bool, value)); 989 return {OpLogicalNot(t_bool, value), Type::Bool};
543 } 990 }
544 return value; 991 return {value, Type::Bool};
992 }
545 993
546 } else if (const auto abuf = std::get_if<AbufNode>(&*node)) { 994 if (const auto abuf = std::get_if<AbufNode>(&*node)) {
547 const auto attribute = abuf->GetIndex(); 995 const auto attribute = abuf->GetIndex();
548 const auto element = abuf->GetElement(); 996 const u32 element = abuf->GetElement();
997 const auto& buffer = abuf->GetBuffer();
998
999 const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) {
1000 std::vector<Id> members;
1001 members.reserve(std::size(indices) + 1);
1002
1003 if (buffer && IsInputAttributeArray()) {
1004 members.push_back(AsUint(Visit(buffer)));
1005 }
1006 for (const u32 index : indices) {
1007 members.push_back(Constant(t_uint, index));
1008 }
1009 return OpAccessChain(pointer_type, composite, members);
1010 };
549 1011
550 switch (attribute) { 1012 switch (attribute) {
551 case Attribute::Index::Position: 1013 case Attribute::Index::Position: {
552 if (stage != ShaderType::Fragment) { 1014 if (stage == ShaderType::Fragment) {
553 UNIMPLEMENTED();
554 break;
555 } else {
556 if (element == 3) { 1015 if (element == 3) {
557 return Constant(t_float, 1.0f); 1016 return {Constant(t_float, 1.0f), Type::Float};
558 } 1017 }
559 return Emit(OpLoad(t_float, AccessElement(t_in_float, frag_coord, element))); 1018 return {OpLoad(t_float, AccessElement(t_in_float, frag_coord, element)),
1019 Type::Float};
1020 }
1021 const auto elements = {in_indices.position.value(), element};
1022 return {OpLoad(t_float, ArrayPass(t_in_float, in_vertex, elements)), Type::Float};
1023 }
1024 case Attribute::Index::PointCoord: {
1025 switch (element) {
1026 case 0:
1027 case 1:
1028 return {OpCompositeExtract(t_float, OpLoad(t_float2, point_coord), element),
1029 Type::Float};
560 } 1030 }
1031 UNIMPLEMENTED_MSG("Unimplemented point coord element={}", element);
1032 return {v_float_zero, Type::Float};
1033 }
561 case Attribute::Index::TessCoordInstanceIDVertexID: 1034 case Attribute::Index::TessCoordInstanceIDVertexID:
562 // TODO(Subv): Find out what the values are for the first two elements when inside a 1035 // TODO(Subv): Find out what the values are for the first two elements when inside a
563 // vertex shader, and what's the value of the fourth element when inside a Tess Eval 1036 // vertex shader, and what's the value of the fourth element when inside a Tess Eval
564 // shader. 1037 // shader.
565 ASSERT(stage == ShaderType::Vertex);
566 switch (element) { 1038 switch (element) {
1039 case 0:
1040 case 1:
1041 return {OpLoad(t_float, AccessElement(t_in_float, tess_coord, element)),
1042 Type::Float};
567 case 2: 1043 case 2:
568 return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, instance_index))); 1044 return {OpLoad(t_uint, instance_index), Type::Uint};
569 case 3: 1045 case 3:
570 return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, vertex_index))); 1046 return {OpLoad(t_uint, vertex_index), Type::Uint};
571 } 1047 }
572 UNIMPLEMENTED_MSG("Unmanaged TessCoordInstanceIDVertexID element={}", element); 1048 UNIMPLEMENTED_MSG("Unmanaged TessCoordInstanceIDVertexID element={}", element);
573 return Constant(t_float, 0); 1049 return {Constant(t_uint, 0U), Type::Uint};
574 case Attribute::Index::FrontFacing: 1050 case Attribute::Index::FrontFacing:
575 // TODO(Subv): Find out what the values are for the other elements. 1051 // TODO(Subv): Find out what the values are for the other elements.
576 ASSERT(stage == ShaderType::Fragment); 1052 ASSERT(stage == ShaderType::Fragment);
577 if (element == 3) { 1053 if (element == 3) {
578 const Id is_front_facing = Emit(OpLoad(t_bool, front_facing)); 1054 const Id is_front_facing = OpLoad(t_bool, front_facing);
579 const Id true_value = 1055 const Id true_value = Constant(t_int, static_cast<s32>(-1));
580 BitcastTo<Type::Float>(Constant(t_int, static_cast<s32>(-1))); 1056 const Id false_value = Constant(t_int, 0);
581 const Id false_value = BitcastTo<Type::Float>(Constant(t_int, 0)); 1057 return {OpSelect(t_int, is_front_facing, true_value, false_value), Type::Int};
582 return Emit(OpSelect(t_float, is_front_facing, true_value, false_value));
583 } 1058 }
584 UNIMPLEMENTED_MSG("Unmanaged FrontFacing element={}", element); 1059 UNIMPLEMENTED_MSG("Unmanaged FrontFacing element={}", element);
585 return Constant(t_float, 0.0f); 1060 return {v_float_zero, Type::Float};
586 default: 1061 default:
587 if (IsGenericAttribute(attribute)) { 1062 if (IsGenericAttribute(attribute)) {
588 const Id pointer = 1063 const u32 location = GetGenericAttributeLocation(attribute);
589 AccessElement(t_in_float, input_attributes.at(attribute), element); 1064 const auto type_descriptor = GetAttributeType(location);
590 return Emit(OpLoad(t_float, pointer)); 1065 const Type type = type_descriptor.type;
1066 const Id attribute_id = input_attributes.at(attribute);
1067 const Id pointer = ArrayPass(type_descriptor.scalar, attribute_id, {element});
1068 return {OpLoad(GetTypeDefinition(type), pointer), type};
591 } 1069 }
592 break; 1070 break;
593 } 1071 }
594 UNIMPLEMENTED_MSG("Unhandled input attribute: {}", static_cast<u32>(attribute)); 1072 UNIMPLEMENTED_MSG("Unhandled input attribute: {}", static_cast<u32>(attribute));
1073 return {v_float_zero, Type::Float};
1074 }
595 1075
596 } else if (const auto cbuf = std::get_if<CbufNode>(&*node)) { 1076 if (const auto cbuf = std::get_if<CbufNode>(&*node)) {
597 const Node& offset = cbuf->GetOffset(); 1077 const Node& offset = cbuf->GetOffset();
598 const Id buffer_id = constant_buffers.at(cbuf->GetIndex()); 1078 const Id buffer_id = constant_buffers.at(cbuf->GetIndex());
599 1079
600 Id pointer{}; 1080 Id pointer{};
601 if (device.IsKhrUniformBufferStandardLayoutSupported()) { 1081 if (device.IsKhrUniformBufferStandardLayoutSupported()) {
602 const Id buffer_offset = Emit(OpShiftRightLogical( 1082 const Id buffer_offset =
603 t_uint, BitcastTo<Type::Uint>(Visit(offset)), Constant(t_uint, 2u))); 1083 OpShiftRightLogical(t_uint, AsUint(Visit(offset)), Constant(t_uint, 2U));
604 pointer = Emit( 1084 pointer =
605 OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0u), buffer_offset)); 1085 OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0U), buffer_offset);
606 } else { 1086 } else {
607 Id buffer_index{}; 1087 Id buffer_index{};
608 Id buffer_element{}; 1088 Id buffer_element{};
@@ -614,53 +1094,76 @@ private:
614 buffer_element = Constant(t_uint, (offset_imm / 4) % 4); 1094 buffer_element = Constant(t_uint, (offset_imm / 4) % 4);
615 } else if (std::holds_alternative<OperationNode>(*offset)) { 1095 } else if (std::holds_alternative<OperationNode>(*offset)) {
616 // Indirect access 1096 // Indirect access
617 const Id offset_id = BitcastTo<Type::Uint>(Visit(offset)); 1097 const Id offset_id = AsUint(Visit(offset));
618 const Id unsafe_offset = Emit(OpUDiv(t_uint, offset_id, Constant(t_uint, 4))); 1098 const Id unsafe_offset = OpUDiv(t_uint, offset_id, Constant(t_uint, 4));
619 const Id final_offset = Emit(OpUMod( 1099 const Id final_offset =
620 t_uint, unsafe_offset, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS - 1))); 1100 OpUMod(t_uint, unsafe_offset, Constant(t_uint, MaxConstBufferElements - 1));
621 buffer_index = Emit(OpUDiv(t_uint, final_offset, Constant(t_uint, 4))); 1101 buffer_index = OpUDiv(t_uint, final_offset, Constant(t_uint, 4));
622 buffer_element = Emit(OpUMod(t_uint, final_offset, Constant(t_uint, 4))); 1102 buffer_element = OpUMod(t_uint, final_offset, Constant(t_uint, 4));
623 } else { 1103 } else {
624 UNREACHABLE_MSG("Unmanaged offset node type"); 1104 UNREACHABLE_MSG("Unmanaged offset node type");
625 } 1105 }
626 pointer = Emit(OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0), 1106 pointer = OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0), buffer_index,
627 buffer_index, buffer_element)); 1107 buffer_element);
628 } 1108 }
629 return Emit(OpLoad(t_float, pointer)); 1109 return {OpLoad(t_float, pointer), Type::Float};
1110 }
630 1111
631 } else if (const auto gmem = std::get_if<GmemNode>(&*node)) { 1112 if (const auto gmem = std::get_if<GmemNode>(&*node)) {
632 const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor()); 1113 const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor());
633 const Id real = BitcastTo<Type::Uint>(Visit(gmem->GetRealAddress())); 1114 const Id real = AsUint(Visit(gmem->GetRealAddress()));
634 const Id base = BitcastTo<Type::Uint>(Visit(gmem->GetBaseAddress())); 1115 const Id base = AsUint(Visit(gmem->GetBaseAddress()));
1116
1117 Id offset = OpISub(t_uint, real, base);
1118 offset = OpUDiv(t_uint, offset, Constant(t_uint, 4U));
1119 return {OpLoad(t_float,
1120 OpAccessChain(t_gmem_float, gmem_buffer, Constant(t_uint, 0U), offset)),
1121 Type::Float};
1122 }
635 1123
636 Id offset = Emit(OpISub(t_uint, real, base)); 1124 if (const auto lmem = std::get_if<LmemNode>(&*node)) {
637 offset = Emit(OpUDiv(t_uint, offset, Constant(t_uint, 4u))); 1125 Id address = AsUint(Visit(lmem->GetAddress()));
638 return Emit(OpLoad(t_float, Emit(OpAccessChain(t_gmem_float, gmem_buffer, 1126 address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
639 Constant(t_uint, 0u), offset)))); 1127 const Id pointer = OpAccessChain(t_prv_float, local_memory, address);
1128 return {OpLoad(t_float, pointer), Type::Float};
1129 }
1130
1131 if (const auto smem = std::get_if<SmemNode>(&*node)) {
1132 Id address = AsUint(Visit(smem->GetAddress()));
1133 address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
1134 const Id pointer = OpAccessChain(t_smem_uint, shared_memory, address);
1135 return {OpLoad(t_uint, pointer), Type::Uint};
1136 }
640 1137
641 } else if (const auto conditional = std::get_if<ConditionalNode>(&*node)) { 1138 if (const auto internal_flag = std::get_if<InternalFlagNode>(&*node)) {
1139 const Id flag = internal_flags.at(static_cast<std::size_t>(internal_flag->GetFlag()));
1140 return {OpLoad(t_bool, flag), Type::Bool};
1141 }
1142
1143 if (const auto conditional = std::get_if<ConditionalNode>(&*node)) {
642 // It's invalid to call conditional on nested nodes, use an operation instead 1144 // It's invalid to call conditional on nested nodes, use an operation instead
643 const Id true_label = OpLabel(); 1145 const Id true_label = OpLabel();
644 const Id skip_label = OpLabel(); 1146 const Id skip_label = OpLabel();
645 const Id condition = Visit(conditional->GetCondition()); 1147 const Id condition = AsBool(Visit(conditional->GetCondition()));
646 Emit(OpSelectionMerge(skip_label, spv::SelectionControlMask::MaskNone)); 1148 OpSelectionMerge(skip_label, spv::SelectionControlMask::MaskNone);
647 Emit(OpBranchConditional(condition, true_label, skip_label)); 1149 OpBranchConditional(condition, true_label, skip_label);
648 Emit(true_label); 1150 AddLabel(true_label);
649 1151
650 ++conditional_nest_count; 1152 conditional_branch_set = true;
1153 inside_branch = false;
651 VisitBasicBlock(conditional->GetCode()); 1154 VisitBasicBlock(conditional->GetCode());
652 --conditional_nest_count; 1155 conditional_branch_set = false;
653 1156 if (!inside_branch) {
654 if (inside_branch == 0) { 1157 OpBranch(skip_label);
655 Emit(OpBranch(skip_label));
656 } else { 1158 } else {
657 inside_branch--; 1159 inside_branch = false;
658 } 1160 }
659 Emit(skip_label); 1161 AddLabel(skip_label);
660 return {}; 1162 return {};
1163 }
661 1164
662 } else if (const auto comment = std::get_if<CommentNode>(&*node)) { 1165 if (const auto comment = std::get_if<CommentNode>(&*node)) {
663 Name(Emit(OpUndef(t_void)), comment->GetText()); 1166 Name(OpUndef(t_void), comment->GetText());
664 return {}; 1167 return {};
665 } 1168 }
666 1169
@@ -669,94 +1172,126 @@ private:
669 } 1172 }
670 1173
671 template <Id (Module::*func)(Id, Id), Type result_type, Type type_a = result_type> 1174 template <Id (Module::*func)(Id, Id), Type result_type, Type type_a = result_type>
672 Id Unary(Operation operation) { 1175 Expression Unary(Operation operation) {
673 const Id type_def = GetTypeDefinition(result_type); 1176 const Id type_def = GetTypeDefinition(result_type);
674 const Id op_a = VisitOperand<type_a>(operation, 0); 1177 const Id op_a = As(Visit(operation[0]), type_a);
675 1178
676 const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a))); 1179 const Id value = (this->*func)(type_def, op_a);
677 if (IsPrecise(operation)) { 1180 if (IsPrecise(operation)) {
678 Decorate(value, spv::Decoration::NoContraction); 1181 Decorate(value, spv::Decoration::NoContraction);
679 } 1182 }
680 return value; 1183 return {value, result_type};
681 } 1184 }
682 1185
683 template <Id (Module::*func)(Id, Id, Id), Type result_type, Type type_a = result_type, 1186 template <Id (Module::*func)(Id, Id, Id), Type result_type, Type type_a = result_type,
684 Type type_b = type_a> 1187 Type type_b = type_a>
685 Id Binary(Operation operation) { 1188 Expression Binary(Operation operation) {
686 const Id type_def = GetTypeDefinition(result_type); 1189 const Id type_def = GetTypeDefinition(result_type);
687 const Id op_a = VisitOperand<type_a>(operation, 0); 1190 const Id op_a = As(Visit(operation[0]), type_a);
688 const Id op_b = VisitOperand<type_b>(operation, 1); 1191 const Id op_b = As(Visit(operation[1]), type_b);
689 1192
690 const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b))); 1193 const Id value = (this->*func)(type_def, op_a, op_b);
691 if (IsPrecise(operation)) { 1194 if (IsPrecise(operation)) {
692 Decorate(value, spv::Decoration::NoContraction); 1195 Decorate(value, spv::Decoration::NoContraction);
693 } 1196 }
694 return value; 1197 return {value, result_type};
695 } 1198 }
696 1199
697 template <Id (Module::*func)(Id, Id, Id, Id), Type result_type, Type type_a = result_type, 1200 template <Id (Module::*func)(Id, Id, Id, Id), Type result_type, Type type_a = result_type,
698 Type type_b = type_a, Type type_c = type_b> 1201 Type type_b = type_a, Type type_c = type_b>
699 Id Ternary(Operation operation) { 1202 Expression Ternary(Operation operation) {
700 const Id type_def = GetTypeDefinition(result_type); 1203 const Id type_def = GetTypeDefinition(result_type);
701 const Id op_a = VisitOperand<type_a>(operation, 0); 1204 const Id op_a = As(Visit(operation[0]), type_a);
702 const Id op_b = VisitOperand<type_b>(operation, 1); 1205 const Id op_b = As(Visit(operation[1]), type_b);
703 const Id op_c = VisitOperand<type_c>(operation, 2); 1206 const Id op_c = As(Visit(operation[2]), type_c);
704 1207
705 const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c))); 1208 const Id value = (this->*func)(type_def, op_a, op_b, op_c);
706 if (IsPrecise(operation)) { 1209 if (IsPrecise(operation)) {
707 Decorate(value, spv::Decoration::NoContraction); 1210 Decorate(value, spv::Decoration::NoContraction);
708 } 1211 }
709 return value; 1212 return {value, result_type};
710 } 1213 }
711 1214
712 template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type, Type type_a = result_type, 1215 template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type, Type type_a = result_type,
713 Type type_b = type_a, Type type_c = type_b, Type type_d = type_c> 1216 Type type_b = type_a, Type type_c = type_b, Type type_d = type_c>
714 Id Quaternary(Operation operation) { 1217 Expression Quaternary(Operation operation) {
715 const Id type_def = GetTypeDefinition(result_type); 1218 const Id type_def = GetTypeDefinition(result_type);
716 const Id op_a = VisitOperand<type_a>(operation, 0); 1219 const Id op_a = As(Visit(operation[0]), type_a);
717 const Id op_b = VisitOperand<type_b>(operation, 1); 1220 const Id op_b = As(Visit(operation[1]), type_b);
718 const Id op_c = VisitOperand<type_c>(operation, 2); 1221 const Id op_c = As(Visit(operation[2]), type_c);
719 const Id op_d = VisitOperand<type_d>(operation, 3); 1222 const Id op_d = As(Visit(operation[3]), type_d);
720 1223
721 const Id value = 1224 const Id value = (this->*func)(type_def, op_a, op_b, op_c, op_d);
722 BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c, op_d)));
723 if (IsPrecise(operation)) { 1225 if (IsPrecise(operation)) {
724 Decorate(value, spv::Decoration::NoContraction); 1226 Decorate(value, spv::Decoration::NoContraction);
725 } 1227 }
726 return value; 1228 return {value, result_type};
727 } 1229 }
728 1230
729 Id Assign(Operation operation) { 1231 Expression Assign(Operation operation) {
730 const Node& dest = operation[0]; 1232 const Node& dest = operation[0];
731 const Node& src = operation[1]; 1233 const Node& src = operation[1];
732 1234
733 Id target{}; 1235 Expression target{};
734 if (const auto gpr = std::get_if<GprNode>(&*dest)) { 1236 if (const auto gpr = std::get_if<GprNode>(&*dest)) {
735 if (gpr->GetIndex() == Register::ZeroIndex) { 1237 if (gpr->GetIndex() == Register::ZeroIndex) {
736 // Writing to Register::ZeroIndex is a no op 1238 // Writing to Register::ZeroIndex is a no op
737 return {}; 1239 return {};
738 } 1240 }
739 target = registers.at(gpr->GetIndex()); 1241 target = {registers.at(gpr->GetIndex()), Type::Float};
740 1242
741 } else if (const auto abuf = std::get_if<AbufNode>(&*dest)) { 1243 } else if (const auto abuf = std::get_if<AbufNode>(&*dest)) {
742 target = [&]() -> Id { 1244 const auto& buffer = abuf->GetBuffer();
1245 const auto ArrayPass = [&](Id pointer_type, Id composite, std::vector<u32> indices) {
1246 std::vector<Id> members;
1247 members.reserve(std::size(indices) + 1);
1248
1249 if (buffer && IsOutputAttributeArray()) {
1250 members.push_back(AsUint(Visit(buffer)));
1251 }
1252 for (const u32 index : indices) {
1253 members.push_back(Constant(t_uint, index));
1254 }
1255 return OpAccessChain(pointer_type, composite, members);
1256 };
1257
1258 target = [&]() -> Expression {
1259 const u32 element = abuf->GetElement();
743 switch (const auto attribute = abuf->GetIndex(); attribute) { 1260 switch (const auto attribute = abuf->GetIndex(); attribute) {
744 case Attribute::Index::Position: 1261 case Attribute::Index::Position: {
745 return AccessElement(t_out_float, per_vertex, position_index, 1262 const u32 index = out_indices.position.value();
746 abuf->GetElement()); 1263 return {ArrayPass(t_out_float, out_vertex, {index, element}), Type::Float};
1264 }
747 case Attribute::Index::LayerViewportPointSize: 1265 case Attribute::Index::LayerViewportPointSize:
748 UNIMPLEMENTED_IF(abuf->GetElement() != 3); 1266 switch (element) {
749 return AccessElement(t_out_float, per_vertex, point_size_index); 1267 case 2: {
750 case Attribute::Index::ClipDistances0123: 1268 if (!out_indices.viewport) {
751 return AccessElement(t_out_float, per_vertex, clip_distances_index, 1269 return {};
752 abuf->GetElement()); 1270 }
753 case Attribute::Index::ClipDistances4567: 1271 const u32 index = out_indices.viewport.value();
754 return AccessElement(t_out_float, per_vertex, clip_distances_index, 1272 return {AccessElement(t_out_int, out_vertex, index), Type::Int};
755 abuf->GetElement() + 4); 1273 }
1274 case 3: {
1275 const auto index = out_indices.point_size.value();
1276 return {AccessElement(t_out_float, out_vertex, index), Type::Float};
1277 }
1278 default:
1279 UNIMPLEMENTED_MSG("LayerViewportPoint element={}", abuf->GetElement());
1280 return {};
1281 }
1282 case Attribute::Index::ClipDistances0123: {
1283 const u32 index = out_indices.clip_distances.value();
1284 return {AccessElement(t_out_float, out_vertex, index, element), Type::Float};
1285 }
1286 case Attribute::Index::ClipDistances4567: {
1287 const u32 index = out_indices.clip_distances.value();
1288 return {AccessElement(t_out_float, out_vertex, index, element + 4),
1289 Type::Float};
1290 }
756 default: 1291 default:
757 if (IsGenericAttribute(attribute)) { 1292 if (IsGenericAttribute(attribute)) {
758 return AccessElement(t_out_float, output_attributes.at(attribute), 1293 const Id composite = output_attributes.at(attribute);
759 abuf->GetElement()); 1294 return {ArrayPass(t_out_float, composite, {element}), Type::Float};
760 } 1295 }
761 UNIMPLEMENTED_MSG("Unhandled output attribute: {}", 1296 UNIMPLEMENTED_MSG("Unhandled output attribute: {}",
762 static_cast<u32>(attribute)); 1297 static_cast<u32>(attribute));
@@ -764,72 +1299,154 @@ private:
764 } 1299 }
765 }(); 1300 }();
766 1301
1302 } else if (const auto patch = std::get_if<PatchNode>(&*dest)) {
1303 target = [&]() -> Expression {
1304 const u32 offset = patch->GetOffset();
1305 switch (offset) {
1306 case 0:
1307 case 1:
1308 case 2:
1309 case 3:
1310 return {AccessElement(t_out_float, tess_level_outer, offset % 4), Type::Float};
1311 case 4:
1312 case 5:
1313 return {AccessElement(t_out_float, tess_level_inner, offset % 4), Type::Float};
1314 }
1315 UNIMPLEMENTED_MSG("Unhandled patch output offset: {}", offset);
1316 return {};
1317 }();
1318
767 } else if (const auto lmem = std::get_if<LmemNode>(&*dest)) { 1319 } else if (const auto lmem = std::get_if<LmemNode>(&*dest)) {
768 Id address = BitcastTo<Type::Uint>(Visit(lmem->GetAddress())); 1320 Id address = AsUint(Visit(lmem->GetAddress()));
769 address = Emit(OpUDiv(t_uint, address, Constant(t_uint, 4))); 1321 address = OpUDiv(t_uint, address, Constant(t_uint, 4));
770 target = Emit(OpAccessChain(t_prv_float, local_memory, {address})); 1322 target = {OpAccessChain(t_prv_float, local_memory, address), Type::Float};
1323
1324 } else if (const auto smem = std::get_if<SmemNode>(&*dest)) {
1325 ASSERT(stage == ShaderType::Compute);
1326 Id address = AsUint(Visit(smem->GetAddress()));
1327 address = OpShiftRightLogical(t_uint, address, Constant(t_uint, 2U));
1328 target = {OpAccessChain(t_smem_uint, shared_memory, address), Type::Uint};
1329
1330 } else if (const auto gmem = std::get_if<GmemNode>(&*dest)) {
1331 const Id real = AsUint(Visit(gmem->GetRealAddress()));
1332 const Id base = AsUint(Visit(gmem->GetBaseAddress()));
1333 const Id diff = OpISub(t_uint, real, base);
1334 const Id offset = OpShiftRightLogical(t_uint, diff, Constant(t_uint, 2));
1335
1336 const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor());
1337 target = {OpAccessChain(t_gmem_float, gmem_buffer, Constant(t_uint, 0), offset),
1338 Type::Float};
1339
1340 } else {
1341 UNIMPLEMENTED();
771 } 1342 }
772 1343
773 Emit(OpStore(target, Visit(src))); 1344 OpStore(target.id, As(Visit(src), target.type));
774 return {}; 1345 return {};
775 } 1346 }
776 1347
777 Id FCastHalf0(Operation operation) { 1348 template <u32 offset>
778 UNIMPLEMENTED(); 1349 Expression FCastHalf(Operation operation) {
779 return {}; 1350 const Id value = AsHalfFloat(Visit(operation[0]));
1351 return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, offset)),
1352 Type::Float};
780 } 1353 }
781 1354
782 Id FCastHalf1(Operation operation) { 1355 Expression FSwizzleAdd(Operation operation) {
783 UNIMPLEMENTED(); 1356 const Id minus = Constant(t_float, -1.0f);
784 return {}; 1357 const Id plus = v_float_one;
785 } 1358 const Id zero = v_float_zero;
1359 const Id lut_a = ConstantComposite(t_float4, minus, plus, minus, zero);
1360 const Id lut_b = ConstantComposite(t_float4, minus, minus, plus, minus);
786 1361
787 Id FSwizzleAdd(Operation operation) { 1362 Id mask = OpLoad(t_uint, thread_id);
788 UNIMPLEMENTED(); 1363 mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3));
789 return {}; 1364 mask = OpShiftLeftLogical(t_uint, mask, Constant(t_uint, 1));
790 } 1365 mask = OpShiftRightLogical(t_uint, AsUint(Visit(operation[2])), mask);
1366 mask = OpBitwiseAnd(t_uint, mask, Constant(t_uint, 3));
791 1367
792 Id HNegate(Operation operation) { 1368 const Id modifier_a = OpVectorExtractDynamic(t_float, lut_a, mask);
793 UNIMPLEMENTED(); 1369 const Id modifier_b = OpVectorExtractDynamic(t_float, lut_b, mask);
794 return {}; 1370
1371 const Id op_a = OpFMul(t_float, AsFloat(Visit(operation[0])), modifier_a);
1372 const Id op_b = OpFMul(t_float, AsFloat(Visit(operation[1])), modifier_b);
1373 return {OpFAdd(t_float, op_a, op_b), Type::Float};
795 } 1374 }
796 1375
797 Id HClamp(Operation operation) { 1376 Expression HNegate(Operation operation) {
798 UNIMPLEMENTED(); 1377 const bool is_f16 = device.IsFloat16Supported();
799 return {}; 1378 const Id minus_one = Constant(t_scalar_half, is_f16 ? 0xbc00 : 0xbf800000);
1379 const Id one = Constant(t_scalar_half, is_f16 ? 0x3c00 : 0x3f800000);
1380 const auto GetNegate = [&](std::size_t index) {
1381 return OpSelect(t_scalar_half, AsBool(Visit(operation[index])), minus_one, one);
1382 };
1383 const Id negation = OpCompositeConstruct(t_half, GetNegate(1), GetNegate(2));
1384 return {OpFMul(t_half, AsHalfFloat(Visit(operation[0])), negation), Type::HalfFloat};
800 } 1385 }
801 1386
802 Id HCastFloat(Operation operation) { 1387 Expression HClamp(Operation operation) {
803 UNIMPLEMENTED(); 1388 const auto Pack = [&](std::size_t index) {
804 return {}; 1389 const Id scalar = GetHalfScalarFromFloat(AsFloat(Visit(operation[index])));
1390 return OpCompositeConstruct(t_half, scalar, scalar);
1391 };
1392 const Id value = AsHalfFloat(Visit(operation[0]));
1393 const Id min = Pack(1);
1394 const Id max = Pack(2);
1395
1396 const Id clamped = OpFClamp(t_half, value, min, max);
1397 if (IsPrecise(operation)) {
1398 Decorate(clamped, spv::Decoration::NoContraction);
1399 }
1400 return {clamped, Type::HalfFloat};
805 } 1401 }
806 1402
807 Id HUnpack(Operation operation) { 1403 Expression HCastFloat(Operation operation) {
808 UNIMPLEMENTED(); 1404 const Id value = GetHalfScalarFromFloat(AsFloat(Visit(operation[0])));
809 return {}; 1405 return {OpCompositeConstruct(t_half, value, Constant(t_scalar_half, 0)), Type::HalfFloat};
810 } 1406 }
811 1407
812 Id HMergeF32(Operation operation) { 1408 Expression HUnpack(Operation operation) {
813 UNIMPLEMENTED(); 1409 Expression operand = Visit(operation[0]);
814 return {}; 1410 const auto type = std::get<Tegra::Shader::HalfType>(operation.GetMeta());
1411 if (type == Tegra::Shader::HalfType::H0_H1) {
1412 return operand;
1413 }
1414 const auto value = [&] {
1415 switch (std::get<Tegra::Shader::HalfType>(operation.GetMeta())) {
1416 case Tegra::Shader::HalfType::F32:
1417 return GetHalfScalarFromFloat(AsFloat(operand));
1418 case Tegra::Shader::HalfType::H0_H0:
1419 return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 0);
1420 case Tegra::Shader::HalfType::H1_H1:
1421 return OpCompositeExtract(t_scalar_half, AsHalfFloat(operand), 1);
1422 default:
1423 UNREACHABLE();
1424 return ConstantNull(t_half);
1425 }
1426 }();
1427 return {OpCompositeConstruct(t_half, value, value), Type::HalfFloat};
815 } 1428 }
816 1429
817 Id HMergeH0(Operation operation) { 1430 Expression HMergeF32(Operation operation) {
818 UNIMPLEMENTED(); 1431 const Id value = AsHalfFloat(Visit(operation[0]));
819 return {}; 1432 return {GetFloatFromHalfScalar(OpCompositeExtract(t_scalar_half, value, 0)), Type::Float};
820 } 1433 }
821 1434
822 Id HMergeH1(Operation operation) { 1435 template <u32 offset>
823 UNIMPLEMENTED(); 1436 Expression HMergeHN(Operation operation) {
824 return {}; 1437 const Id target = AsHalfFloat(Visit(operation[0]));
1438 const Id source = AsHalfFloat(Visit(operation[1]));
1439 const Id object = OpCompositeExtract(t_scalar_half, source, offset);
1440 return {OpCompositeInsert(t_half, object, target, offset), Type::HalfFloat};
825 } 1441 }
826 1442
827 Id HPack2(Operation operation) { 1443 Expression HPack2(Operation operation) {
828 UNIMPLEMENTED(); 1444 const Id low = GetHalfScalarFromFloat(AsFloat(Visit(operation[0])));
829 return {}; 1445 const Id high = GetHalfScalarFromFloat(AsFloat(Visit(operation[1])));
1446 return {OpCompositeConstruct(t_half, low, high), Type::HalfFloat};
830 } 1447 }
831 1448
832 Id LogicalAssign(Operation operation) { 1449 Expression LogicalAssign(Operation operation) {
833 const Node& dest = operation[0]; 1450 const Node& dest = operation[0];
834 const Node& src = operation[1]; 1451 const Node& src = operation[1];
835 1452
@@ -850,106 +1467,190 @@ private:
850 target = internal_flags.at(static_cast<u32>(flag->GetFlag())); 1467 target = internal_flags.at(static_cast<u32>(flag->GetFlag()));
851 } 1468 }
852 1469
853 Emit(OpStore(target, Visit(src))); 1470 OpStore(target, AsBool(Visit(src)));
854 return {}; 1471 return {};
855 } 1472 }
856 1473
857 Id LogicalPick2(Operation operation) { 1474 Id GetTextureSampler(Operation operation) {
858 UNIMPLEMENTED(); 1475 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
859 return {}; 1476 ASSERT(!meta.sampler.IsBuffer());
1477
1478 const auto& entry = sampled_images.at(meta.sampler.GetIndex());
1479 return OpLoad(entry.sampled_image_type, entry.sampler);
860 } 1480 }
861 1481
862 Id LogicalAnd2(Operation operation) { 1482 Id GetTextureImage(Operation operation) {
863 UNIMPLEMENTED(); 1483 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
864 return {}; 1484 const u32 index = meta.sampler.GetIndex();
1485 if (meta.sampler.IsBuffer()) {
1486 const auto& entry = texel_buffers.at(index);
1487 return OpLoad(entry.image_type, entry.image);
1488 } else {
1489 const auto& entry = sampled_images.at(index);
1490 return OpImage(entry.image_type, GetTextureSampler(operation));
1491 }
865 } 1492 }
866 1493
867 Id GetTextureSampler(Operation operation) { 1494 Id GetImage(Operation operation) {
868 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1495 const auto& meta = std::get<MetaImage>(operation.GetMeta());
869 const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex())); 1496 const auto entry = images.at(meta.image.GetIndex());
870 return Emit(OpLoad(entry.sampled_image_type, entry.sampler)); 1497 return OpLoad(entry.image_type, entry.image);
871 } 1498 }
872 1499
873 Id GetTextureImage(Operation operation) { 1500 Id AssembleVector(const std::vector<Id>& coords, Type type) {
874 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1501 const Id coords_type = GetTypeVectorDefinitionLut(type).at(coords.size() - 1);
875 const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex())); 1502 return coords.size() == 1 ? coords[0] : OpCompositeConstruct(coords_type, coords);
876 return Emit(OpImage(entry.image_type, GetTextureSampler(operation)));
877 } 1503 }
878 1504
879 Id GetTextureCoordinates(Operation operation) { 1505 Id GetCoordinates(Operation operation, Type type) {
880 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
881 std::vector<Id> coords; 1506 std::vector<Id> coords;
882 for (std::size_t i = 0; i < operation.GetOperandsCount(); ++i) { 1507 for (std::size_t i = 0; i < operation.GetOperandsCount(); ++i) {
883 coords.push_back(Visit(operation[i])); 1508 coords.push_back(As(Visit(operation[i]), type));
884 } 1509 }
885 if (meta->sampler.IsArray()) { 1510 if (const auto meta = std::get_if<MetaTexture>(&operation.GetMeta())) {
886 const Id array_integer = BitcastTo<Type::Int>(Visit(meta->array)); 1511 // Add array coordinate for textures
887 coords.push_back(Emit(OpConvertSToF(t_float, array_integer))); 1512 if (meta->sampler.IsArray()) {
1513 Id array = AsInt(Visit(meta->array));
1514 if (type == Type::Float) {
1515 array = OpConvertSToF(t_float, array);
1516 }
1517 coords.push_back(array);
1518 }
888 } 1519 }
889 if (meta->sampler.IsShadow()) { 1520 return AssembleVector(coords, type);
890 coords.push_back(Visit(meta->depth_compare)); 1521 }
1522
1523 Id GetOffsetCoordinates(Operation operation) {
1524 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
1525 std::vector<Id> coords;
1526 coords.reserve(meta.aoffi.size());
1527 for (const auto& coord : meta.aoffi) {
1528 coords.push_back(AsInt(Visit(coord)));
891 } 1529 }
1530 return AssembleVector(coords, Type::Int);
1531 }
892 1532
893 const std::array<Id, 4> t_float_lut = {nullptr, t_float2, t_float3, t_float4}; 1533 std::pair<Id, Id> GetDerivatives(Operation operation) {
894 return coords.size() == 1 1534 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
895 ? coords[0] 1535 const auto& derivatives = meta.derivates;
896 : Emit(OpCompositeConstruct(t_float_lut.at(coords.size() - 1), coords)); 1536 ASSERT(derivatives.size() % 2 == 0);
1537
1538 const std::size_t components = derivatives.size() / 2;
1539 std::vector<Id> dx, dy;
1540 dx.reserve(components);
1541 dy.reserve(components);
1542 for (std::size_t index = 0; index < components; ++index) {
1543 dx.push_back(AsFloat(Visit(derivatives.at(index * 2 + 0))));
1544 dy.push_back(AsFloat(Visit(derivatives.at(index * 2 + 1))));
1545 }
1546 return {AssembleVector(dx, Type::Float), AssembleVector(dy, Type::Float)};
897 } 1547 }
898 1548
899 Id GetTextureElement(Operation operation, Id sample_value) { 1549 Expression GetTextureElement(Operation operation, Id sample_value, Type type) {
900 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1550 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
901 ASSERT(meta); 1551 const auto type_def = GetTypeDefinition(type);
902 return Emit(OpCompositeExtract(t_float, sample_value, meta->element)); 1552 return {OpCompositeExtract(type_def, sample_value, meta.element), type};
903 } 1553 }
904 1554
905 Id Texture(Operation operation) { 1555 Expression Texture(Operation operation) {
906 const Id texture = Emit(OpImageSampleImplicitLod(t_float4, GetTextureSampler(operation), 1556 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
907 GetTextureCoordinates(operation))); 1557 UNIMPLEMENTED_IF(!meta.aoffi.empty());
908 return GetTextureElement(operation, texture); 1558
1559 const bool can_implicit = stage == ShaderType::Fragment;
1560 const Id sampler = GetTextureSampler(operation);
1561 const Id coords = GetCoordinates(operation, Type::Float);
1562
1563 if (meta.depth_compare) {
1564 // Depth sampling
1565 UNIMPLEMENTED_IF(meta.bias);
1566 const Id dref = AsFloat(Visit(meta.depth_compare));
1567 if (can_implicit) {
1568 return {OpImageSampleDrefImplicitLod(t_float, sampler, coords, dref, {}),
1569 Type::Float};
1570 } else {
1571 return {OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref,
1572 spv::ImageOperandsMask::Lod, v_float_zero),
1573 Type::Float};
1574 }
1575 }
1576
1577 std::vector<Id> operands;
1578 spv::ImageOperandsMask mask{};
1579 if (meta.bias) {
1580 mask = mask | spv::ImageOperandsMask::Bias;
1581 operands.push_back(AsFloat(Visit(meta.bias)));
1582 }
1583
1584 Id texture;
1585 if (can_implicit) {
1586 texture = OpImageSampleImplicitLod(t_float4, sampler, coords, mask, operands);
1587 } else {
1588 texture = OpImageSampleExplicitLod(t_float4, sampler, coords,
1589 mask | spv::ImageOperandsMask::Lod, v_float_zero,
1590 operands);
1591 }
1592 return GetTextureElement(operation, texture, Type::Float);
909 } 1593 }
910 1594
911 Id TextureLod(Operation operation) { 1595 Expression TextureLod(Operation operation) {
912 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1596 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
913 const Id texture = Emit(OpImageSampleExplicitLod( 1597
914 t_float4, GetTextureSampler(operation), GetTextureCoordinates(operation), 1598 const Id sampler = GetTextureSampler(operation);
915 spv::ImageOperandsMask::Lod, Visit(meta->lod))); 1599 const Id coords = GetCoordinates(operation, Type::Float);
916 return GetTextureElement(operation, texture); 1600 const Id lod = AsFloat(Visit(meta.lod));
1601
1602 spv::ImageOperandsMask mask = spv::ImageOperandsMask::Lod;
1603 std::vector<Id> operands;
1604 if (!meta.aoffi.empty()) {
1605 mask = mask | spv::ImageOperandsMask::Offset;
1606 operands.push_back(GetOffsetCoordinates(operation));
1607 }
1608
1609 if (meta.sampler.IsShadow()) {
1610 const Id dref = AsFloat(Visit(meta.depth_compare));
1611 return {
1612 OpImageSampleDrefExplicitLod(t_float, sampler, coords, dref, mask, lod, operands),
1613 Type::Float};
1614 }
1615 const Id texture = OpImageSampleExplicitLod(t_float4, sampler, coords, mask, lod, operands);
1616 return GetTextureElement(operation, texture, Type::Float);
917 } 1617 }
918 1618
919 Id TextureGather(Operation operation) { 1619 Expression TextureGather(Operation operation) {
920 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1620 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
921 const auto coords = GetTextureCoordinates(operation); 1621 UNIMPLEMENTED_IF(!meta.aoffi.empty());
922 1622
923 Id texture; 1623 const Id coords = GetCoordinates(operation, Type::Float);
924 if (meta->sampler.IsShadow()) { 1624 Id texture{};
925 texture = Emit(OpImageDrefGather(t_float4, GetTextureSampler(operation), coords, 1625 if (meta.sampler.IsShadow()) {
926 Visit(meta->component))); 1626 texture = OpImageDrefGather(t_float4, GetTextureSampler(operation), coords,
1627 AsFloat(Visit(meta.depth_compare)));
927 } else { 1628 } else {
928 u32 component_value = 0; 1629 u32 component_value = 0;
929 if (meta->component) { 1630 if (meta.component) {
930 const auto component = std::get_if<ImmediateNode>(&*meta->component); 1631 const auto component = std::get_if<ImmediateNode>(&*meta.component);
931 ASSERT_MSG(component, "Component is not an immediate value"); 1632 ASSERT_MSG(component, "Component is not an immediate value");
932 component_value = component->GetValue(); 1633 component_value = component->GetValue();
933 } 1634 }
934 texture = Emit(OpImageGather(t_float4, GetTextureSampler(operation), coords, 1635 texture = OpImageGather(t_float4, GetTextureSampler(operation), coords,
935 Constant(t_uint, component_value))); 1636 Constant(t_uint, component_value));
936 } 1637 }
937 1638 return GetTextureElement(operation, texture, Type::Float);
938 return GetTextureElement(operation, texture);
939 } 1639 }
940 1640
941 Id TextureQueryDimensions(Operation operation) { 1641 Expression TextureQueryDimensions(Operation operation) {
942 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta()); 1642 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
943 const auto image_id = GetTextureImage(operation); 1643 UNIMPLEMENTED_IF(!meta.aoffi.empty());
944 AddCapability(spv::Capability::ImageQuery); 1644 UNIMPLEMENTED_IF(meta.depth_compare);
945 1645
946 if (meta->element == 3) { 1646 const auto image_id = GetTextureImage(operation);
947 return BitcastTo<Type::Float>(Emit(OpImageQueryLevels(t_int, image_id))); 1647 if (meta.element == 3) {
1648 return {OpImageQueryLevels(t_int, image_id), Type::Int};
948 } 1649 }
949 1650
950 const Id lod = VisitOperand<Type::Uint>(operation, 0); 1651 const Id lod = AsUint(Visit(operation[0]));
951 const std::size_t coords_count = [&]() { 1652 const std::size_t coords_count = [&]() {
952 switch (const auto type = meta->sampler.GetType(); type) { 1653 switch (const auto type = meta.sampler.GetType(); type) {
953 case Tegra::Shader::TextureType::Texture1D: 1654 case Tegra::Shader::TextureType::Texture1D:
954 return 1; 1655 return 1;
955 case Tegra::Shader::TextureType::Texture2D: 1656 case Tegra::Shader::TextureType::Texture2D:
@@ -963,141 +1664,190 @@ private:
963 } 1664 }
964 }(); 1665 }();
965 1666
966 if (meta->element >= coords_count) { 1667 if (meta.element >= coords_count) {
967 return Constant(t_float, 0.0f); 1668 return {v_float_zero, Type::Float};
968 } 1669 }
969 1670
970 const std::array<Id, 3> types = {t_int, t_int2, t_int3}; 1671 const std::array<Id, 3> types = {t_int, t_int2, t_int3};
971 const Id sizes = Emit(OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod)); 1672 const Id sizes = OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod);
972 const Id size = Emit(OpCompositeExtract(t_int, sizes, meta->element)); 1673 const Id size = OpCompositeExtract(t_int, sizes, meta.element);
973 return BitcastTo<Type::Float>(size); 1674 return {size, Type::Int};
974 } 1675 }
975 1676
976 Id TextureQueryLod(Operation operation) { 1677 Expression TextureQueryLod(Operation operation) {
977 UNIMPLEMENTED(); 1678 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
978 return {}; 1679 UNIMPLEMENTED_IF(!meta.aoffi.empty());
1680 UNIMPLEMENTED_IF(meta.depth_compare);
1681
1682 if (meta.element >= 2) {
1683 UNREACHABLE_MSG("Invalid element");
1684 return {v_float_zero, Type::Float};
1685 }
1686 const auto sampler_id = GetTextureSampler(operation);
1687
1688 const Id multiplier = Constant(t_float, 256.0f);
1689 const Id multipliers = ConstantComposite(t_float2, multiplier, multiplier);
1690
1691 const Id coords = GetCoordinates(operation, Type::Float);
1692 Id size = OpImageQueryLod(t_float2, sampler_id, coords);
1693 size = OpFMul(t_float2, size, multipliers);
1694 size = OpConvertFToS(t_int2, size);
1695 return GetTextureElement(operation, size, Type::Int);
979 } 1696 }
980 1697
981 Id TexelFetch(Operation operation) { 1698 Expression TexelFetch(Operation operation) {
1699 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
1700 UNIMPLEMENTED_IF(meta.depth_compare);
1701
1702 const Id image = GetTextureImage(operation);
1703 const Id coords = GetCoordinates(operation, Type::Int);
1704 Id fetch;
1705 if (meta.lod && !meta.sampler.IsBuffer()) {
1706 fetch = OpImageFetch(t_float4, image, coords, spv::ImageOperandsMask::Lod,
1707 AsInt(Visit(meta.lod)));
1708 } else {
1709 fetch = OpImageFetch(t_float4, image, coords);
1710 }
1711 return GetTextureElement(operation, fetch, Type::Float);
1712 }
1713
1714 Expression TextureGradient(Operation operation) {
1715 const auto& meta = std::get<MetaTexture>(operation.GetMeta());
1716 UNIMPLEMENTED_IF(!meta.aoffi.empty());
1717
1718 const Id sampler = GetTextureSampler(operation);
1719 const Id coords = GetCoordinates(operation, Type::Float);
1720 const auto [dx, dy] = GetDerivatives(operation);
1721 const std::vector grad = {dx, dy};
1722
1723 static constexpr auto mask = spv::ImageOperandsMask::Grad;
1724 const Id texture = OpImageSampleImplicitLod(t_float4, sampler, coords, mask, grad);
1725 return GetTextureElement(operation, texture, Type::Float);
1726 }
1727
1728 Expression ImageLoad(Operation operation) {
982 UNIMPLEMENTED(); 1729 UNIMPLEMENTED();
983 return {}; 1730 return {};
984 } 1731 }
985 1732
986 Id TextureGradient(Operation operation) { 1733 Expression ImageStore(Operation operation) {
987 UNIMPLEMENTED(); 1734 const auto meta{std::get<MetaImage>(operation.GetMeta())};
1735 std::vector<Id> colors;
1736 for (const auto& value : meta.values) {
1737 colors.push_back(AsUint(Visit(value)));
1738 }
1739
1740 const Id coords = GetCoordinates(operation, Type::Int);
1741 const Id texel = OpCompositeConstruct(t_uint4, colors);
1742
1743 OpImageWrite(GetImage(operation), coords, texel, {});
988 return {}; 1744 return {};
989 } 1745 }
990 1746
991 Id ImageLoad(Operation operation) { 1747 Expression AtomicImageAdd(Operation operation) {
992 UNIMPLEMENTED(); 1748 UNIMPLEMENTED();
993 return {}; 1749 return {};
994 } 1750 }
995 1751
996 Id ImageStore(Operation operation) { 1752 Expression AtomicImageMin(Operation operation) {
997 UNIMPLEMENTED(); 1753 UNIMPLEMENTED();
998 return {}; 1754 return {};
999 } 1755 }
1000 1756
1001 Id AtomicImageAdd(Operation operation) { 1757 Expression AtomicImageMax(Operation operation) {
1002 UNIMPLEMENTED(); 1758 UNIMPLEMENTED();
1003 return {}; 1759 return {};
1004 } 1760 }
1005 1761
1006 Id AtomicImageAnd(Operation operation) { 1762 Expression AtomicImageAnd(Operation operation) {
1007 UNIMPLEMENTED(); 1763 UNIMPLEMENTED();
1008 return {}; 1764 return {};
1009 } 1765 }
1010 1766
1011 Id AtomicImageOr(Operation operation) { 1767 Expression AtomicImageOr(Operation operation) {
1012 UNIMPLEMENTED(); 1768 UNIMPLEMENTED();
1013 return {}; 1769 return {};
1014 } 1770 }
1015 1771
1016 Id AtomicImageXor(Operation operation) { 1772 Expression AtomicImageXor(Operation operation) {
1017 UNIMPLEMENTED(); 1773 UNIMPLEMENTED();
1018 return {}; 1774 return {};
1019 } 1775 }
1020 1776
1021 Id AtomicImageExchange(Operation operation) { 1777 Expression AtomicImageExchange(Operation operation) {
1022 UNIMPLEMENTED(); 1778 UNIMPLEMENTED();
1023 return {}; 1779 return {};
1024 } 1780 }
1025 1781
1026 Id Branch(Operation operation) { 1782 Expression Branch(Operation operation) {
1027 const auto target = std::get_if<ImmediateNode>(&*operation[0]); 1783 const auto& target = std::get<ImmediateNode>(*operation[0]);
1028 UNIMPLEMENTED_IF(!target); 1784 OpStore(jmp_to, Constant(t_uint, target.GetValue()));
1029 1785 OpBranch(continue_label);
1030 Emit(OpStore(jmp_to, Constant(t_uint, target->GetValue()))); 1786 inside_branch = true;
1031 Emit(OpBranch(continue_label)); 1787 if (!conditional_branch_set) {
1032 inside_branch = conditional_nest_count; 1788 AddLabel();
1033 if (conditional_nest_count == 0) {
1034 Emit(OpLabel());
1035 } 1789 }
1036 return {}; 1790 return {};
1037 } 1791 }
1038 1792
1039 Id BranchIndirect(Operation operation) { 1793 Expression BranchIndirect(Operation operation) {
1040 const Id op_a = VisitOperand<Type::Uint>(operation, 0); 1794 const Id op_a = AsUint(Visit(operation[0]));
1041 1795
1042 Emit(OpStore(jmp_to, op_a)); 1796 OpStore(jmp_to, op_a);
1043 Emit(OpBranch(continue_label)); 1797 OpBranch(continue_label);
1044 inside_branch = conditional_nest_count; 1798 inside_branch = true;
1045 if (conditional_nest_count == 0) { 1799 if (!conditional_branch_set) {
1046 Emit(OpLabel()); 1800 AddLabel();
1047 } 1801 }
1048 return {}; 1802 return {};
1049 } 1803 }
1050 1804
1051 Id PushFlowStack(Operation operation) { 1805 Expression PushFlowStack(Operation operation) {
1052 const auto target = std::get_if<ImmediateNode>(&*operation[0]); 1806 const auto& target = std::get<ImmediateNode>(*operation[0]);
1053 ASSERT(target);
1054
1055 const auto [flow_stack, flow_stack_top] = GetFlowStack(operation); 1807 const auto [flow_stack, flow_stack_top] = GetFlowStack(operation);
1056 const Id current = Emit(OpLoad(t_uint, flow_stack_top)); 1808 const Id current = OpLoad(t_uint, flow_stack_top);
1057 const Id next = Emit(OpIAdd(t_uint, current, Constant(t_uint, 1))); 1809 const Id next = OpIAdd(t_uint, current, Constant(t_uint, 1));
1058 const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, current)); 1810 const Id access = OpAccessChain(t_func_uint, flow_stack, current);
1059 1811
1060 Emit(OpStore(access, Constant(t_uint, target->GetValue()))); 1812 OpStore(access, Constant(t_uint, target.GetValue()));
1061 Emit(OpStore(flow_stack_top, next)); 1813 OpStore(flow_stack_top, next);
1062 return {}; 1814 return {};
1063 } 1815 }
1064 1816
1065 Id PopFlowStack(Operation operation) { 1817 Expression PopFlowStack(Operation operation) {
1066 const auto [flow_stack, flow_stack_top] = GetFlowStack(operation); 1818 const auto [flow_stack, flow_stack_top] = GetFlowStack(operation);
1067 const Id current = Emit(OpLoad(t_uint, flow_stack_top)); 1819 const Id current = OpLoad(t_uint, flow_stack_top);
1068 const Id previous = Emit(OpISub(t_uint, current, Constant(t_uint, 1))); 1820 const Id previous = OpISub(t_uint, current, Constant(t_uint, 1));
1069 const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, previous)); 1821 const Id access = OpAccessChain(t_func_uint, flow_stack, previous);
1070 const Id target = Emit(OpLoad(t_uint, access)); 1822 const Id target = OpLoad(t_uint, access);
1071 1823
1072 Emit(OpStore(flow_stack_top, previous)); 1824 OpStore(flow_stack_top, previous);
1073 Emit(OpStore(jmp_to, target)); 1825 OpStore(jmp_to, target);
1074 Emit(OpBranch(continue_label)); 1826 OpBranch(continue_label);
1075 inside_branch = conditional_nest_count; 1827 inside_branch = true;
1076 if (conditional_nest_count == 0) { 1828 if (!conditional_branch_set) {
1077 Emit(OpLabel()); 1829 AddLabel();
1078 } 1830 }
1079 return {}; 1831 return {};
1080 } 1832 }
1081 1833
1082 Id PreExit() { 1834 void PreExit() {
1083 switch (stage) { 1835 if (stage == ShaderType::Vertex) {
1084 case ShaderType::Vertex: { 1836 const u32 position_index = out_indices.position.value();
1085 // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't 1837 const Id z_pointer = AccessElement(t_out_float, out_vertex, position_index, 2U);
1086 // seem to be working on Nvidia's drivers and Intel (mesa and blob) doesn't support it. 1838 const Id w_pointer = AccessElement(t_out_float, out_vertex, position_index, 3U);
1087 const Id z_pointer = AccessElement(t_out_float, per_vertex, position_index, 2u); 1839 Id depth = OpLoad(t_float, z_pointer);
1088 Id depth = Emit(OpLoad(t_float, z_pointer)); 1840 depth = OpFAdd(t_float, depth, OpLoad(t_float, w_pointer));
1089 depth = Emit(OpFAdd(t_float, depth, Constant(t_float, 1.0f))); 1841 depth = OpFMul(t_float, depth, Constant(t_float, 0.5f));
1090 depth = Emit(OpFMul(t_float, depth, Constant(t_float, 0.5f))); 1842 OpStore(z_pointer, depth);
1091 Emit(OpStore(z_pointer, depth));
1092 break;
1093 } 1843 }
1094 case ShaderType::Fragment: { 1844 if (stage == ShaderType::Fragment) {
1095 const auto SafeGetRegister = [&](u32 reg) { 1845 const auto SafeGetRegister = [&](u32 reg) {
1096 // TODO(Rodrigo): Replace with contains once C++20 releases 1846 // TODO(Rodrigo): Replace with contains once C++20 releases
1097 if (const auto it = registers.find(reg); it != registers.end()) { 1847 if (const auto it = registers.find(reg); it != registers.end()) {
1098 return Emit(OpLoad(t_float, it->second)); 1848 return OpLoad(t_float, it->second);
1099 } 1849 }
1100 return Constant(t_float, 0.0f); 1850 return v_float_zero;
1101 }; 1851 };
1102 1852
1103 UNIMPLEMENTED_IF_MSG(header.ps.omap.sample_mask != 0, 1853 UNIMPLEMENTED_IF_MSG(header.ps.omap.sample_mask != 0,
@@ -1112,8 +1862,8 @@ private:
1112 // TODO(Subv): Figure out how dual-source blending is configured in the Switch. 1862 // TODO(Subv): Figure out how dual-source blending is configured in the Switch.
1113 for (u32 component = 0; component < 4; ++component) { 1863 for (u32 component = 0; component < 4; ++component) {
1114 if (header.ps.IsColorComponentOutputEnabled(rt, component)) { 1864 if (header.ps.IsColorComponentOutputEnabled(rt, component)) {
1115 Emit(OpStore(AccessElement(t_out_float, frag_colors.at(rt), component), 1865 OpStore(AccessElement(t_out_float, frag_colors.at(rt), component),
1116 SafeGetRegister(current_reg))); 1866 SafeGetRegister(current_reg));
1117 ++current_reg; 1867 ++current_reg;
1118 } 1868 }
1119 } 1869 }
@@ -1121,110 +1871,117 @@ private:
1121 if (header.ps.omap.depth) { 1871 if (header.ps.omap.depth) {
1122 // The depth output is always 2 registers after the last color output, and 1872 // The depth output is always 2 registers after the last color output, and
1123 // current_reg already contains one past the last color register. 1873 // current_reg already contains one past the last color register.
1124 Emit(OpStore(frag_depth, SafeGetRegister(current_reg + 1))); 1874 OpStore(frag_depth, SafeGetRegister(current_reg + 1));
1125 } 1875 }
1126 break;
1127 }
1128 } 1876 }
1129
1130 return {};
1131 } 1877 }
1132 1878
1133 Id Exit(Operation operation) { 1879 Expression Exit(Operation operation) {
1134 PreExit(); 1880 PreExit();
1135 inside_branch = conditional_nest_count; 1881 inside_branch = true;
1136 if (conditional_nest_count > 0) { 1882 if (conditional_branch_set) {
1137 Emit(OpReturn()); 1883 OpReturn();
1138 } else { 1884 } else {
1139 const Id dummy = OpLabel(); 1885 const Id dummy = OpLabel();
1140 Emit(OpBranch(dummy)); 1886 OpBranch(dummy);
1141 Emit(dummy); 1887 AddLabel(dummy);
1142 Emit(OpReturn()); 1888 OpReturn();
1143 Emit(OpLabel()); 1889 AddLabel();
1144 } 1890 }
1145 return {}; 1891 return {};
1146 } 1892 }
1147 1893
1148 Id Discard(Operation operation) { 1894 Expression Discard(Operation operation) {
1149 inside_branch = conditional_nest_count; 1895 inside_branch = true;
1150 if (conditional_nest_count > 0) { 1896 if (conditional_branch_set) {
1151 Emit(OpKill()); 1897 OpKill();
1152 } else { 1898 } else {
1153 const Id dummy = OpLabel(); 1899 const Id dummy = OpLabel();
1154 Emit(OpBranch(dummy)); 1900 OpBranch(dummy);
1155 Emit(dummy); 1901 AddLabel(dummy);
1156 Emit(OpKill()); 1902 OpKill();
1157 Emit(OpLabel()); 1903 AddLabel();
1158 } 1904 }
1159 return {}; 1905 return {};
1160 } 1906 }
1161 1907
1162 Id EmitVertex(Operation operation) { 1908 Expression EmitVertex(Operation) {
1163 UNIMPLEMENTED(); 1909 OpEmitVertex();
1164 return {}; 1910 return {};
1165 } 1911 }
1166 1912
1167 Id EndPrimitive(Operation operation) { 1913 Expression EndPrimitive(Operation operation) {
1168 UNIMPLEMENTED(); 1914 OpEndPrimitive();
1169 return {}; 1915 return {};
1170 } 1916 }
1171 1917
1172 Id YNegate(Operation operation) { 1918 Expression InvocationId(Operation) {
1173 UNIMPLEMENTED(); 1919 return {OpLoad(t_int, invocation_id), Type::Int};
1174 return {};
1175 } 1920 }
1176 1921
1177 template <u32 element> 1922 Expression YNegate(Operation) {
1178 Id LocalInvocationId(Operation) {
1179 UNIMPLEMENTED(); 1923 UNIMPLEMENTED();
1180 return {}; 1924 return {Constant(t_float, 1.0f), Type::Float};
1181 } 1925 }
1182 1926
1183 template <u32 element> 1927 template <u32 element>
1184 Id WorkGroupId(Operation) { 1928 Expression LocalInvocationId(Operation) {
1185 UNIMPLEMENTED(); 1929 const Id id = OpLoad(t_uint3, local_invocation_id);
1186 return {}; 1930 return {OpCompositeExtract(t_uint, id, element), Type::Uint};
1187 } 1931 }
1188 1932
1189 Id BallotThread(Operation) { 1933 template <u32 element>
1190 UNIMPLEMENTED(); 1934 Expression WorkGroupId(Operation operation) {
1191 return {}; 1935 const Id id = OpLoad(t_uint3, workgroup_id);
1936 return {OpCompositeExtract(t_uint, id, element), Type::Uint};
1192 } 1937 }
1193 1938
1194 Id VoteAll(Operation) { 1939 Expression BallotThread(Operation operation) {
1195 UNIMPLEMENTED(); 1940 const Id predicate = AsBool(Visit(operation[0]));
1196 return {}; 1941 const Id ballot = OpSubgroupBallotKHR(t_uint4, predicate);
1197 }
1198 1942
1199 Id VoteAny(Operation) { 1943 if (!device.IsWarpSizePotentiallyBiggerThanGuest()) {
1200 UNIMPLEMENTED(); 1944 // Guest-like devices can just return the first index.
1201 return {}; 1945 return {OpCompositeExtract(t_uint, ballot, 0U), Type::Uint};
1946 }
1947
1948 // The others will have to return what is local to the current thread.
1949 // For instance a device with a warp size of 64 will return the upper uint when the current
1950 // thread is 38.
1951 const Id tid = OpLoad(t_uint, thread_id);
1952 const Id thread_index = OpShiftRightLogical(t_uint, tid, Constant(t_uint, 5));
1953 return {OpVectorExtractDynamic(t_uint, ballot, thread_index), Type::Uint};
1202 } 1954 }
1203 1955
1204 Id VoteEqual(Operation) { 1956 template <Id (Module::*func)(Id, Id)>
1205 UNIMPLEMENTED(); 1957 Expression Vote(Operation operation) {
1206 return {}; 1958 // TODO(Rodrigo): Handle devices with different warp sizes
1959 const Id predicate = AsBool(Visit(operation[0]));
1960 return {(this->*func)(t_bool, predicate), Type::Bool};
1207 } 1961 }
1208 1962
1209 Id ThreadId(Operation) { 1963 Expression ThreadId(Operation) {
1210 UNIMPLEMENTED(); 1964 return {OpLoad(t_uint, thread_id), Type::Uint};
1211 return {};
1212 } 1965 }
1213 1966
1214 Id ShuffleIndexed(Operation) { 1967 Expression ShuffleIndexed(Operation operation) {
1215 UNIMPLEMENTED(); 1968 const Id value = AsFloat(Visit(operation[0]));
1216 return {}; 1969 const Id index = AsUint(Visit(operation[1]));
1970 return {OpSubgroupReadInvocationKHR(t_float, value, index), Type::Float};
1217 } 1971 }
1218 1972
1219 Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type, 1973 Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type, std::string name) {
1220 const std::string& name) {
1221 const Id id = OpVariable(type, storage); 1974 const Id id = OpVariable(type, storage);
1222 Decorate(id, spv::Decoration::BuiltIn, static_cast<u32>(builtin)); 1975 Decorate(id, spv::Decoration::BuiltIn, static_cast<u32>(builtin));
1223 AddGlobalVariable(Name(id, name)); 1976 AddGlobalVariable(Name(id, std::move(name)));
1224 interfaces.push_back(id); 1977 interfaces.push_back(id);
1225 return id; 1978 return id;
1226 } 1979 }
1227 1980
1981 Id DeclareInputBuiltIn(spv::BuiltIn builtin, Id type, std::string name) {
1982 return DeclareBuiltIn(builtin, spv::StorageClass::Input, type, std::move(name));
1983 }
1984
1228 bool IsRenderTargetUsed(u32 rt) const { 1985 bool IsRenderTargetUsed(u32 rt) const {
1229 for (u32 component = 0; component < 4; ++component) { 1986 for (u32 component = 0; component < 4; ++component) {
1230 if (header.ps.IsColorComponentOutputEnabled(rt, component)) { 1987 if (header.ps.IsColorComponentOutputEnabled(rt, component)) {
@@ -1242,66 +1999,148 @@ private:
1242 members.push_back(Constant(t_uint, element)); 1999 members.push_back(Constant(t_uint, element));
1243 } 2000 }
1244 2001
1245 return Emit(OpAccessChain(pointer_type, composite, members)); 2002 return OpAccessChain(pointer_type, composite, members);
1246 } 2003 }
1247 2004
1248 template <Type type> 2005 Id As(Expression expr, Type wanted_type) {
1249 Id VisitOperand(Operation operation, std::size_t operand_index) { 2006 switch (wanted_type) {
1250 const Id value = Visit(operation[operand_index]);
1251
1252 switch (type) {
1253 case Type::Bool: 2007 case Type::Bool:
2008 return AsBool(expr);
1254 case Type::Bool2: 2009 case Type::Bool2:
2010 return AsBool2(expr);
1255 case Type::Float: 2011 case Type::Float:
1256 return value; 2012 return AsFloat(expr);
1257 case Type::Int: 2013 case Type::Int:
1258 return Emit(OpBitcast(t_int, value)); 2014 return AsInt(expr);
1259 case Type::Uint: 2015 case Type::Uint:
1260 return Emit(OpBitcast(t_uint, value)); 2016 return AsUint(expr);
1261 case Type::HalfFloat: 2017 case Type::HalfFloat:
1262 UNIMPLEMENTED(); 2018 return AsHalfFloat(expr);
2019 default:
2020 UNREACHABLE();
2021 return expr.id;
1263 } 2022 }
1264 UNREACHABLE();
1265 return value;
1266 } 2023 }
1267 2024
1268 template <Type type> 2025 Id AsBool(Expression expr) {
1269 Id BitcastFrom(Id value) { 2026 ASSERT(expr.type == Type::Bool);
1270 switch (type) { 2027 return expr.id;
1271 case Type::Bool: 2028 }
1272 case Type::Bool2: 2029
2030 Id AsBool2(Expression expr) {
2031 ASSERT(expr.type == Type::Bool2);
2032 return expr.id;
2033 }
2034
2035 Id AsFloat(Expression expr) {
2036 switch (expr.type) {
1273 case Type::Float: 2037 case Type::Float:
1274 return value; 2038 return expr.id;
1275 case Type::Int: 2039 case Type::Int:
1276 case Type::Uint: 2040 case Type::Uint:
1277 return Emit(OpBitcast(t_float, value)); 2041 return OpBitcast(t_float, expr.id);
1278 case Type::HalfFloat: 2042 case Type::HalfFloat:
1279 UNIMPLEMENTED(); 2043 if (device.IsFloat16Supported()) {
2044 return OpBitcast(t_float, expr.id);
2045 }
2046 return OpBitcast(t_float, OpPackHalf2x16(t_uint, expr.id));
2047 default:
2048 UNREACHABLE();
2049 return expr.id;
1280 } 2050 }
1281 UNREACHABLE();
1282 return value;
1283 } 2051 }
1284 2052
1285 template <Type type> 2053 Id AsInt(Expression expr) {
1286 Id BitcastTo(Id value) { 2054 switch (expr.type) {
1287 switch (type) { 2055 case Type::Int:
1288 case Type::Bool: 2056 return expr.id;
1289 case Type::Bool2: 2057 case Type::Float:
2058 case Type::Uint:
2059 return OpBitcast(t_int, expr.id);
2060 case Type::HalfFloat:
2061 if (device.IsFloat16Supported()) {
2062 return OpBitcast(t_int, expr.id);
2063 }
2064 return OpPackHalf2x16(t_int, expr.id);
2065 default:
1290 UNREACHABLE(); 2066 UNREACHABLE();
2067 return expr.id;
2068 }
2069 }
2070
2071 Id AsUint(Expression expr) {
2072 switch (expr.type) {
2073 case Type::Uint:
2074 return expr.id;
1291 case Type::Float: 2075 case Type::Float:
1292 return Emit(OpBitcast(t_float, value));
1293 case Type::Int: 2076 case Type::Int:
1294 return Emit(OpBitcast(t_int, value)); 2077 return OpBitcast(t_uint, expr.id);
1295 case Type::Uint:
1296 return Emit(OpBitcast(t_uint, value));
1297 case Type::HalfFloat: 2078 case Type::HalfFloat:
1298 UNIMPLEMENTED(); 2079 if (device.IsFloat16Supported()) {
2080 return OpBitcast(t_uint, expr.id);
2081 }
2082 return OpPackHalf2x16(t_uint, expr.id);
2083 default:
2084 UNREACHABLE();
2085 return expr.id;
2086 }
2087 }
2088
2089 Id AsHalfFloat(Expression expr) {
2090 switch (expr.type) {
2091 case Type::HalfFloat:
2092 return expr.id;
2093 case Type::Float:
2094 case Type::Int:
2095 case Type::Uint:
2096 if (device.IsFloat16Supported()) {
2097 return OpBitcast(t_half, expr.id);
2098 }
2099 return OpUnpackHalf2x16(t_half, AsUint(expr));
2100 default:
2101 UNREACHABLE();
2102 return expr.id;
2103 }
2104 }
2105
2106 Id GetHalfScalarFromFloat(Id value) {
2107 if (device.IsFloat16Supported()) {
2108 return OpFConvert(t_scalar_half, value);
1299 } 2109 }
1300 UNREACHABLE();
1301 return value; 2110 return value;
1302 } 2111 }
1303 2112
1304 Id GetTypeDefinition(Type type) { 2113 Id GetFloatFromHalfScalar(Id value) {
2114 if (device.IsFloat16Supported()) {
2115 return OpFConvert(t_float, value);
2116 }
2117 return value;
2118 }
2119
2120 AttributeType GetAttributeType(u32 location) const {
2121 if (stage != ShaderType::Vertex) {
2122 return {Type::Float, t_in_float, t_in_float4};
2123 }
2124 switch (specialization.attribute_types.at(location)) {
2125 case Maxwell::VertexAttribute::Type::SignedNorm:
2126 case Maxwell::VertexAttribute::Type::UnsignedNorm:
2127 case Maxwell::VertexAttribute::Type::Float:
2128 return {Type::Float, t_in_float, t_in_float4};
2129 case Maxwell::VertexAttribute::Type::SignedInt:
2130 return {Type::Int, t_in_int, t_in_int4};
2131 case Maxwell::VertexAttribute::Type::UnsignedInt:
2132 return {Type::Uint, t_in_uint, t_in_uint4};
2133 case Maxwell::VertexAttribute::Type::UnsignedScaled:
2134 case Maxwell::VertexAttribute::Type::SignedScaled:
2135 UNIMPLEMENTED();
2136 return {Type::Float, t_in_float, t_in_float4};
2137 default:
2138 UNREACHABLE();
2139 return {Type::Float, t_in_float, t_in_float4};
2140 }
2141 }
2142
2143 Id GetTypeDefinition(Type type) const {
1305 switch (type) { 2144 switch (type) {
1306 case Type::Bool: 2145 case Type::Bool:
1307 return t_bool; 2146 return t_bool;
@@ -1314,10 +2153,25 @@ private:
1314 case Type::Uint: 2153 case Type::Uint:
1315 return t_uint; 2154 return t_uint;
1316 case Type::HalfFloat: 2155 case Type::HalfFloat:
2156 return t_half;
2157 default:
2158 UNREACHABLE();
2159 return {};
2160 }
2161 }
2162
2163 std::array<Id, 4> GetTypeVectorDefinitionLut(Type type) const {
2164 switch (type) {
2165 case Type::Float:
2166 return {nullptr, t_float2, t_float3, t_float4};
2167 case Type::Int:
2168 return {nullptr, t_int2, t_int3, t_int4};
2169 case Type::Uint:
2170 return {nullptr, t_uint2, t_uint3, t_uint4};
2171 default:
1317 UNIMPLEMENTED(); 2172 UNIMPLEMENTED();
2173 return {};
1318 } 2174 }
1319 UNREACHABLE();
1320 return {};
1321 } 2175 }
1322 2176
1323 std::tuple<Id, Id> CreateFlowStack() { 2177 std::tuple<Id, Id> CreateFlowStack() {
@@ -1327,9 +2181,11 @@ private:
1327 constexpr auto storage_class = spv::StorageClass::Function; 2181 constexpr auto storage_class = spv::StorageClass::Function;
1328 2182
1329 const Id flow_stack_type = TypeArray(t_uint, Constant(t_uint, FLOW_STACK_SIZE)); 2183 const Id flow_stack_type = TypeArray(t_uint, Constant(t_uint, FLOW_STACK_SIZE));
1330 const Id stack = Emit(OpVariable(TypePointer(storage_class, flow_stack_type), storage_class, 2184 const Id stack = OpVariable(TypePointer(storage_class, flow_stack_type), storage_class,
1331 ConstantNull(flow_stack_type))); 2185 ConstantNull(flow_stack_type));
1332 const Id top = Emit(OpVariable(t_func_uint, storage_class, Constant(t_uint, 0))); 2186 const Id top = OpVariable(t_func_uint, storage_class, Constant(t_uint, 0));
2187 AddLocalVariable(stack);
2188 AddLocalVariable(top);
1333 return std::tie(stack, top); 2189 return std::tie(stack, top);
1334 } 2190 }
1335 2191
@@ -1358,8 +2214,8 @@ private:
1358 &SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>, 2214 &SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>,
1359 &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>, 2215 &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>,
1360 &SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>, 2216 &SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>,
1361 &SPIRVDecompiler::FCastHalf0, 2217 &SPIRVDecompiler::FCastHalf<0>,
1362 &SPIRVDecompiler::FCastHalf1, 2218 &SPIRVDecompiler::FCastHalf<1>,
1363 &SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>, 2219 &SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>,
1364 &SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>, 2220 &SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>,
1365 &SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>, 2221 &SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>,
@@ -1407,7 +2263,7 @@ private:
1407 &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>, 2263 &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>,
1408 &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>, 2264 &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>,
1409 &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>, 2265 &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
1410 &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Uint>, 2266 &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
1411 &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>, 2267 &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>,
1412 &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>, 2268 &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>,
1413 &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>, 2269 &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>,
@@ -1426,8 +2282,8 @@ private:
1426 &SPIRVDecompiler::HCastFloat, 2282 &SPIRVDecompiler::HCastFloat,
1427 &SPIRVDecompiler::HUnpack, 2283 &SPIRVDecompiler::HUnpack,
1428 &SPIRVDecompiler::HMergeF32, 2284 &SPIRVDecompiler::HMergeF32,
1429 &SPIRVDecompiler::HMergeH0, 2285 &SPIRVDecompiler::HMergeHN<0>,
1430 &SPIRVDecompiler::HMergeH1, 2286 &SPIRVDecompiler::HMergeHN<1>,
1431 &SPIRVDecompiler::HPack2, 2287 &SPIRVDecompiler::HPack2,
1432 2288
1433 &SPIRVDecompiler::LogicalAssign, 2289 &SPIRVDecompiler::LogicalAssign,
@@ -1435,8 +2291,9 @@ private:
1435 &SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>, 2291 &SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>,
1436 &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>, 2292 &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>,
1437 &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>, 2293 &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>,
1438 &SPIRVDecompiler::LogicalPick2, 2294 &SPIRVDecompiler::Binary<&Module::OpVectorExtractDynamic, Type::Bool, Type::Bool2,
1439 &SPIRVDecompiler::LogicalAnd2, 2295 Type::Uint>,
2296 &SPIRVDecompiler::Unary<&Module::OpAll, Type::Bool, Type::Bool2>,
1440 2297
1441 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>, 2298 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>,
1442 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>, 2299 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>,
@@ -1444,7 +2301,7 @@ private:
1444 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>, 2301 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>,
1445 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>, 2302 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>,
1446 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>, 2303 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>,
1447 &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool>, 2304 &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool, Type::Float>,
1448 2305
1449 &SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>, 2306 &SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>,
1450 &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>, 2307 &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>,
@@ -1460,19 +2317,19 @@ private:
1460 &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>, 2317 &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>,
1461 &SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>, 2318 &SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>,
1462 2319
1463 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>, 2320 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>,
1464 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>, 2321 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>,
1465 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>, 2322 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>,
1466 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>, 2323 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>,
1467 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>, 2324 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>,
1468 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>, 2325 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>,
1469 // TODO(Rodrigo): Should these use the OpFUnord* variants? 2326 // TODO(Rodrigo): Should these use the OpFUnord* variants?
1470 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>, 2327 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool2, Type::HalfFloat>,
1471 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>, 2328 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool2, Type::HalfFloat>,
1472 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>, 2329 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool2, Type::HalfFloat>,
1473 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>, 2330 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool2, Type::HalfFloat>,
1474 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>, 2331 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool2, Type::HalfFloat>,
1475 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>, 2332 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool2, Type::HalfFloat>,
1476 2333
1477 &SPIRVDecompiler::Texture, 2334 &SPIRVDecompiler::Texture,
1478 &SPIRVDecompiler::TextureLod, 2335 &SPIRVDecompiler::TextureLod,
@@ -1509,9 +2366,9 @@ private:
1509 &SPIRVDecompiler::WorkGroupId<2>, 2366 &SPIRVDecompiler::WorkGroupId<2>,
1510 2367
1511 &SPIRVDecompiler::BallotThread, 2368 &SPIRVDecompiler::BallotThread,
1512 &SPIRVDecompiler::VoteAll, 2369 &SPIRVDecompiler::Vote<&Module::OpSubgroupAllKHR>,
1513 &SPIRVDecompiler::VoteAny, 2370 &SPIRVDecompiler::Vote<&Module::OpSubgroupAnyKHR>,
1514 &SPIRVDecompiler::VoteEqual, 2371 &SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>,
1515 2372
1516 &SPIRVDecompiler::ThreadId, 2373 &SPIRVDecompiler::ThreadId,
1517 &SPIRVDecompiler::ShuffleIndexed, 2374 &SPIRVDecompiler::ShuffleIndexed,
@@ -1522,8 +2379,7 @@ private:
1522 const ShaderIR& ir; 2379 const ShaderIR& ir;
1523 const ShaderType stage; 2380 const ShaderType stage;
1524 const Tegra::Shader::Header header; 2381 const Tegra::Shader::Header header;
1525 u64 conditional_nest_count{}; 2382 const Specialization& specialization;
1526 u64 inside_branch{};
1527 2383
1528 const Id t_void = Name(TypeVoid(), "void"); 2384 const Id t_void = Name(TypeVoid(), "void");
1529 2385
@@ -1551,20 +2407,28 @@ private:
1551 const Id t_func_uint = Name(TypePointer(spv::StorageClass::Function, t_uint), "func_uint"); 2407 const Id t_func_uint = Name(TypePointer(spv::StorageClass::Function, t_uint), "func_uint");
1552 2408
1553 const Id t_in_bool = Name(TypePointer(spv::StorageClass::Input, t_bool), "in_bool"); 2409 const Id t_in_bool = Name(TypePointer(spv::StorageClass::Input, t_bool), "in_bool");
2410 const Id t_in_int = Name(TypePointer(spv::StorageClass::Input, t_int), "in_int");
2411 const Id t_in_int4 = Name(TypePointer(spv::StorageClass::Input, t_int4), "in_int4");
1554 const Id t_in_uint = Name(TypePointer(spv::StorageClass::Input, t_uint), "in_uint"); 2412 const Id t_in_uint = Name(TypePointer(spv::StorageClass::Input, t_uint), "in_uint");
2413 const Id t_in_uint3 = Name(TypePointer(spv::StorageClass::Input, t_uint3), "in_uint3");
2414 const Id t_in_uint4 = Name(TypePointer(spv::StorageClass::Input, t_uint4), "in_uint4");
1555 const Id t_in_float = Name(TypePointer(spv::StorageClass::Input, t_float), "in_float"); 2415 const Id t_in_float = Name(TypePointer(spv::StorageClass::Input, t_float), "in_float");
2416 const Id t_in_float2 = Name(TypePointer(spv::StorageClass::Input, t_float2), "in_float2");
2417 const Id t_in_float3 = Name(TypePointer(spv::StorageClass::Input, t_float3), "in_float3");
1556 const Id t_in_float4 = Name(TypePointer(spv::StorageClass::Input, t_float4), "in_float4"); 2418 const Id t_in_float4 = Name(TypePointer(spv::StorageClass::Input, t_float4), "in_float4");
1557 2419
2420 const Id t_out_int = Name(TypePointer(spv::StorageClass::Output, t_int), "out_int");
2421
1558 const Id t_out_float = Name(TypePointer(spv::StorageClass::Output, t_float), "out_float"); 2422 const Id t_out_float = Name(TypePointer(spv::StorageClass::Output, t_float), "out_float");
1559 const Id t_out_float4 = Name(TypePointer(spv::StorageClass::Output, t_float4), "out_float4"); 2423 const Id t_out_float4 = Name(TypePointer(spv::StorageClass::Output, t_float4), "out_float4");
1560 2424
1561 const Id t_cbuf_float = TypePointer(spv::StorageClass::Uniform, t_float); 2425 const Id t_cbuf_float = TypePointer(spv::StorageClass::Uniform, t_float);
1562 const Id t_cbuf_std140 = Decorate( 2426 const Id t_cbuf_std140 = Decorate(
1563 Name(TypeArray(t_float4, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS)), "CbufStd140Array"), 2427 Name(TypeArray(t_float4, Constant(t_uint, MaxConstBufferElements)), "CbufStd140Array"),
1564 spv::Decoration::ArrayStride, 16u); 2428 spv::Decoration::ArrayStride, 16U);
1565 const Id t_cbuf_scalar = Decorate( 2429 const Id t_cbuf_scalar = Decorate(
1566 Name(TypeArray(t_float, Constant(t_uint, MAX_CONSTBUFFER_FLOATS)), "CbufScalarArray"), 2430 Name(TypeArray(t_float, Constant(t_uint, MaxConstBufferFloats)), "CbufScalarArray"),
1567 spv::Decoration::ArrayStride, 4u); 2431 spv::Decoration::ArrayStride, 4U);
1568 const Id t_cbuf_std140_struct = MemberDecorate( 2432 const Id t_cbuf_std140_struct = MemberDecorate(
1569 Decorate(TypeStruct(t_cbuf_std140), spv::Decoration::Block), 0, spv::Decoration::Offset, 0); 2433 Decorate(TypeStruct(t_cbuf_std140), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
1570 const Id t_cbuf_scalar_struct = MemberDecorate( 2434 const Id t_cbuf_scalar_struct = MemberDecorate(
@@ -1572,28 +2436,43 @@ private:
1572 const Id t_cbuf_std140_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_std140_struct); 2436 const Id t_cbuf_std140_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_std140_struct);
1573 const Id t_cbuf_scalar_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_scalar_struct); 2437 const Id t_cbuf_scalar_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_scalar_struct);
1574 2438
2439 Id t_smem_uint{};
2440
1575 const Id t_gmem_float = TypePointer(spv::StorageClass::StorageBuffer, t_float); 2441 const Id t_gmem_float = TypePointer(spv::StorageClass::StorageBuffer, t_float);
1576 const Id t_gmem_array = 2442 const Id t_gmem_array =
1577 Name(Decorate(TypeRuntimeArray(t_float), spv::Decoration::ArrayStride, 4u), "GmemArray"); 2443 Name(Decorate(TypeRuntimeArray(t_float), spv::Decoration::ArrayStride, 4U), "GmemArray");
1578 const Id t_gmem_struct = MemberDecorate( 2444 const Id t_gmem_struct = MemberDecorate(
1579 Decorate(TypeStruct(t_gmem_array), spv::Decoration::Block), 0, spv::Decoration::Offset, 0); 2445 Decorate(TypeStruct(t_gmem_array), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
1580 const Id t_gmem_ssbo = TypePointer(spv::StorageClass::StorageBuffer, t_gmem_struct); 2446 const Id t_gmem_ssbo = TypePointer(spv::StorageClass::StorageBuffer, t_gmem_struct);
1581 2447
1582 const Id v_float_zero = Constant(t_float, 0.0f); 2448 const Id v_float_zero = Constant(t_float, 0.0f);
2449 const Id v_float_one = Constant(t_float, 1.0f);
2450
2451 // Nvidia uses these defaults for varyings (e.g. position and generic attributes)
2452 const Id v_varying_default =
2453 ConstantComposite(t_float4, v_float_zero, v_float_zero, v_float_zero, v_float_one);
2454
1583 const Id v_true = ConstantTrue(t_bool); 2455 const Id v_true = ConstantTrue(t_bool);
1584 const Id v_false = ConstantFalse(t_bool); 2456 const Id v_false = ConstantFalse(t_bool);
1585 2457
1586 Id per_vertex{}; 2458 Id t_scalar_half{};
2459 Id t_half{};
2460
2461 Id out_vertex{};
2462 Id in_vertex{};
1587 std::map<u32, Id> registers; 2463 std::map<u32, Id> registers;
1588 std::map<Tegra::Shader::Pred, Id> predicates; 2464 std::map<Tegra::Shader::Pred, Id> predicates;
1589 std::map<u32, Id> flow_variables; 2465 std::map<u32, Id> flow_variables;
1590 Id local_memory{}; 2466 Id local_memory{};
2467 Id shared_memory{};
1591 std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{}; 2468 std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
1592 std::map<Attribute::Index, Id> input_attributes; 2469 std::map<Attribute::Index, Id> input_attributes;
1593 std::map<Attribute::Index, Id> output_attributes; 2470 std::map<Attribute::Index, Id> output_attributes;
1594 std::map<u32, Id> constant_buffers; 2471 std::map<u32, Id> constant_buffers;
1595 std::map<GlobalMemoryBase, Id> global_buffers; 2472 std::map<GlobalMemoryBase, Id> global_buffers;
1596 std::map<u32, SamplerImage> sampler_images; 2473 std::map<u32, TexelBuffer> texel_buffers;
2474 std::map<u32, SampledImage> sampled_images;
2475 std::map<u32, StorageImage> images;
1597 2476
1598 Id instance_index{}; 2477 Id instance_index{};
1599 Id vertex_index{}; 2478 Id vertex_index{};
@@ -1601,18 +2480,20 @@ private:
1601 Id frag_depth{}; 2480 Id frag_depth{};
1602 Id frag_coord{}; 2481 Id frag_coord{};
1603 Id front_facing{}; 2482 Id front_facing{};
1604 2483 Id point_coord{};
1605 u32 position_index{}; 2484 Id tess_level_outer{};
1606 u32 point_size_index{}; 2485 Id tess_level_inner{};
1607 u32 clip_distances_index{}; 2486 Id tess_coord{};
2487 Id invocation_id{};
2488 Id workgroup_id{};
2489 Id local_invocation_id{};
2490 Id thread_id{};
2491
2492 VertexIndices in_indices;
2493 VertexIndices out_indices;
1608 2494
1609 std::vector<Id> interfaces; 2495 std::vector<Id> interfaces;
1610 2496
1611 u32 const_buffers_base_binding{};
1612 u32 global_buffers_base_binding{};
1613 u32 samplers_base_binding{};
1614
1615 Id execute_function{};
1616 Id jmp_to{}; 2497 Id jmp_to{};
1617 Id ssy_flow_stack_top{}; 2498 Id ssy_flow_stack_top{};
1618 Id pbk_flow_stack_top{}; 2499 Id pbk_flow_stack_top{};
@@ -1620,6 +2501,9 @@ private:
1620 Id pbk_flow_stack{}; 2501 Id pbk_flow_stack{};
1621 Id continue_label{}; 2502 Id continue_label{};
1622 std::map<u32, Id> labels; 2503 std::map<u32, Id> labels;
2504
2505 bool conditional_branch_set{};
2506 bool inside_branch{};
1623}; 2507};
1624 2508
1625class ExprDecompiler { 2509class ExprDecompiler {
@@ -1630,25 +2514,25 @@ public:
1630 const Id type_def = decomp.GetTypeDefinition(Type::Bool); 2514 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1631 const Id op1 = Visit(expr.operand1); 2515 const Id op1 = Visit(expr.operand1);
1632 const Id op2 = Visit(expr.operand2); 2516 const Id op2 = Visit(expr.operand2);
1633 return decomp.Emit(decomp.OpLogicalAnd(type_def, op1, op2)); 2517 return decomp.OpLogicalAnd(type_def, op1, op2);
1634 } 2518 }
1635 2519
1636 Id operator()(const ExprOr& expr) { 2520 Id operator()(const ExprOr& expr) {
1637 const Id type_def = decomp.GetTypeDefinition(Type::Bool); 2521 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1638 const Id op1 = Visit(expr.operand1); 2522 const Id op1 = Visit(expr.operand1);
1639 const Id op2 = Visit(expr.operand2); 2523 const Id op2 = Visit(expr.operand2);
1640 return decomp.Emit(decomp.OpLogicalOr(type_def, op1, op2)); 2524 return decomp.OpLogicalOr(type_def, op1, op2);
1641 } 2525 }
1642 2526
1643 Id operator()(const ExprNot& expr) { 2527 Id operator()(const ExprNot& expr) {
1644 const Id type_def = decomp.GetTypeDefinition(Type::Bool); 2528 const Id type_def = decomp.GetTypeDefinition(Type::Bool);
1645 const Id op1 = Visit(expr.operand1); 2529 const Id op1 = Visit(expr.operand1);
1646 return decomp.Emit(decomp.OpLogicalNot(type_def, op1)); 2530 return decomp.OpLogicalNot(type_def, op1);
1647 } 2531 }
1648 2532
1649 Id operator()(const ExprPredicate& expr) { 2533 Id operator()(const ExprPredicate& expr) {
1650 const auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate); 2534 const auto pred = static_cast<Tegra::Shader::Pred>(expr.predicate);
1651 return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred))); 2535 return decomp.OpLoad(decomp.t_bool, decomp.predicates.at(pred));
1652 } 2536 }
1653 2537
1654 Id operator()(const ExprCondCode& expr) { 2538 Id operator()(const ExprCondCode& expr) {
@@ -1670,12 +2554,15 @@ public:
1670 } 2554 }
1671 } else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) { 2555 } else if (const auto flag = std::get_if<InternalFlagNode>(&*cc)) {
1672 target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag())); 2556 target = decomp.internal_flags.at(static_cast<u32>(flag->GetFlag()));
2557 } else {
2558 UNREACHABLE();
1673 } 2559 }
1674 return decomp.Emit(decomp.OpLoad(decomp.t_bool, target)); 2560
2561 return decomp.OpLoad(decomp.t_bool, target);
1675 } 2562 }
1676 2563
1677 Id operator()(const ExprVar& expr) { 2564 Id operator()(const ExprVar& expr) {
1678 return decomp.Emit(decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index))); 2565 return decomp.OpLoad(decomp.t_bool, decomp.flow_variables.at(expr.var_index));
1679 } 2566 }
1680 2567
1681 Id operator()(const ExprBoolean& expr) { 2568 Id operator()(const ExprBoolean& expr) {
@@ -1684,9 +2571,9 @@ public:
1684 2571
1685 Id operator()(const ExprGprEqual& expr) { 2572 Id operator()(const ExprGprEqual& expr) {
1686 const Id target = decomp.Constant(decomp.t_uint, expr.value); 2573 const Id target = decomp.Constant(decomp.t_uint, expr.value);
1687 const Id gpr = decomp.BitcastTo<Type::Uint>( 2574 Id gpr = decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr));
1688 decomp.Emit(decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr)))); 2575 gpr = decomp.OpBitcast(decomp.t_uint, gpr);
1689 return decomp.Emit(decomp.OpLogicalEqual(decomp.t_uint, gpr, target)); 2576 return decomp.OpLogicalEqual(decomp.t_uint, gpr, target);
1690 } 2577 }
1691 2578
1692 Id Visit(const Expr& node) { 2579 Id Visit(const Expr& node) {
@@ -1714,16 +2601,16 @@ public:
1714 const Id condition = expr_parser.Visit(ast.condition); 2601 const Id condition = expr_parser.Visit(ast.condition);
1715 const Id then_label = decomp.OpLabel(); 2602 const Id then_label = decomp.OpLabel();
1716 const Id endif_label = decomp.OpLabel(); 2603 const Id endif_label = decomp.OpLabel();
1717 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); 2604 decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
1718 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); 2605 decomp.OpBranchConditional(condition, then_label, endif_label);
1719 decomp.Emit(then_label); 2606 decomp.AddLabel(then_label);
1720 ASTNode current = ast.nodes.GetFirst(); 2607 ASTNode current = ast.nodes.GetFirst();
1721 while (current) { 2608 while (current) {
1722 Visit(current); 2609 Visit(current);
1723 current = current->GetNext(); 2610 current = current->GetNext();
1724 } 2611 }
1725 decomp.Emit(decomp.OpBranch(endif_label)); 2612 decomp.OpBranch(endif_label);
1726 decomp.Emit(endif_label); 2613 decomp.AddLabel(endif_label);
1727 } 2614 }
1728 2615
1729 void operator()([[maybe_unused]] const ASTIfElse& ast) { 2616 void operator()([[maybe_unused]] const ASTIfElse& ast) {
@@ -1741,7 +2628,7 @@ public:
1741 void operator()(const ASTVarSet& ast) { 2628 void operator()(const ASTVarSet& ast) {
1742 ExprDecompiler expr_parser{decomp}; 2629 ExprDecompiler expr_parser{decomp};
1743 const Id condition = expr_parser.Visit(ast.condition); 2630 const Id condition = expr_parser.Visit(ast.condition);
1744 decomp.Emit(decomp.OpStore(decomp.flow_variables.at(ast.index), condition)); 2631 decomp.OpStore(decomp.flow_variables.at(ast.index), condition);
1745 } 2632 }
1746 2633
1747 void operator()([[maybe_unused]] const ASTLabel& ast) { 2634 void operator()([[maybe_unused]] const ASTLabel& ast) {
@@ -1758,12 +2645,11 @@ public:
1758 const Id loop_start_block = decomp.OpLabel(); 2645 const Id loop_start_block = decomp.OpLabel();
1759 const Id loop_end_block = decomp.OpLabel(); 2646 const Id loop_end_block = decomp.OpLabel();
1760 current_loop_exit = endloop_label; 2647 current_loop_exit = endloop_label;
1761 decomp.Emit(decomp.OpBranch(loop_label)); 2648 decomp.OpBranch(loop_label);
1762 decomp.Emit(loop_label); 2649 decomp.AddLabel(loop_label);
1763 decomp.Emit( 2650 decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone);
1764 decomp.OpLoopMerge(endloop_label, loop_end_block, spv::LoopControlMask::MaskNone)); 2651 decomp.OpBranch(loop_start_block);
1765 decomp.Emit(decomp.OpBranch(loop_start_block)); 2652 decomp.AddLabel(loop_start_block);
1766 decomp.Emit(loop_start_block);
1767 ASTNode current = ast.nodes.GetFirst(); 2653 ASTNode current = ast.nodes.GetFirst();
1768 while (current) { 2654 while (current) {
1769 Visit(current); 2655 Visit(current);
@@ -1771,8 +2657,8 @@ public:
1771 } 2657 }
1772 ExprDecompiler expr_parser{decomp}; 2658 ExprDecompiler expr_parser{decomp};
1773 const Id condition = expr_parser.Visit(ast.condition); 2659 const Id condition = expr_parser.Visit(ast.condition);
1774 decomp.Emit(decomp.OpBranchConditional(condition, loop_label, endloop_label)); 2660 decomp.OpBranchConditional(condition, loop_label, endloop_label);
1775 decomp.Emit(endloop_label); 2661 decomp.AddLabel(endloop_label);
1776 } 2662 }
1777 2663
1778 void operator()(const ASTReturn& ast) { 2664 void operator()(const ASTReturn& ast) {
@@ -1781,27 +2667,27 @@ public:
1781 const Id condition = expr_parser.Visit(ast.condition); 2667 const Id condition = expr_parser.Visit(ast.condition);
1782 const Id then_label = decomp.OpLabel(); 2668 const Id then_label = decomp.OpLabel();
1783 const Id endif_label = decomp.OpLabel(); 2669 const Id endif_label = decomp.OpLabel();
1784 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); 2670 decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
1785 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); 2671 decomp.OpBranchConditional(condition, then_label, endif_label);
1786 decomp.Emit(then_label); 2672 decomp.AddLabel(then_label);
1787 if (ast.kills) { 2673 if (ast.kills) {
1788 decomp.Emit(decomp.OpKill()); 2674 decomp.OpKill();
1789 } else { 2675 } else {
1790 decomp.PreExit(); 2676 decomp.PreExit();
1791 decomp.Emit(decomp.OpReturn()); 2677 decomp.OpReturn();
1792 } 2678 }
1793 decomp.Emit(endif_label); 2679 decomp.AddLabel(endif_label);
1794 } else { 2680 } else {
1795 const Id next_block = decomp.OpLabel(); 2681 const Id next_block = decomp.OpLabel();
1796 decomp.Emit(decomp.OpBranch(next_block)); 2682 decomp.OpBranch(next_block);
1797 decomp.Emit(next_block); 2683 decomp.AddLabel(next_block);
1798 if (ast.kills) { 2684 if (ast.kills) {
1799 decomp.Emit(decomp.OpKill()); 2685 decomp.OpKill();
1800 } else { 2686 } else {
1801 decomp.PreExit(); 2687 decomp.PreExit();
1802 decomp.Emit(decomp.OpReturn()); 2688 decomp.OpReturn();
1803 } 2689 }
1804 decomp.Emit(decomp.OpLabel()); 2690 decomp.AddLabel(decomp.OpLabel());
1805 } 2691 }
1806 } 2692 }
1807 2693
@@ -1811,17 +2697,17 @@ public:
1811 const Id condition = expr_parser.Visit(ast.condition); 2697 const Id condition = expr_parser.Visit(ast.condition);
1812 const Id then_label = decomp.OpLabel(); 2698 const Id then_label = decomp.OpLabel();
1813 const Id endif_label = decomp.OpLabel(); 2699 const Id endif_label = decomp.OpLabel();
1814 decomp.Emit(decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone)); 2700 decomp.OpSelectionMerge(endif_label, spv::SelectionControlMask::MaskNone);
1815 decomp.Emit(decomp.OpBranchConditional(condition, then_label, endif_label)); 2701 decomp.OpBranchConditional(condition, then_label, endif_label);
1816 decomp.Emit(then_label); 2702 decomp.AddLabel(then_label);
1817 decomp.Emit(decomp.OpBranch(current_loop_exit)); 2703 decomp.OpBranch(current_loop_exit);
1818 decomp.Emit(endif_label); 2704 decomp.AddLabel(endif_label);
1819 } else { 2705 } else {
1820 const Id next_block = decomp.OpLabel(); 2706 const Id next_block = decomp.OpLabel();
1821 decomp.Emit(decomp.OpBranch(next_block)); 2707 decomp.OpBranch(next_block);
1822 decomp.Emit(next_block); 2708 decomp.AddLabel(next_block);
1823 decomp.Emit(decomp.OpBranch(current_loop_exit)); 2709 decomp.OpBranch(current_loop_exit);
1824 decomp.Emit(decomp.OpLabel()); 2710 decomp.AddLabel(decomp.OpLabel());
1825 } 2711 }
1826 } 2712 }
1827 2713
@@ -1842,20 +2728,51 @@ void SPIRVDecompiler::DecompileAST() {
1842 flow_variables.emplace(i, AddGlobalVariable(id)); 2728 flow_variables.emplace(i, AddGlobalVariable(id));
1843 } 2729 }
1844 2730
2731 DefinePrologue();
2732
1845 const ASTNode program = ir.GetASTProgram(); 2733 const ASTNode program = ir.GetASTProgram();
1846 ASTDecompiler decompiler{*this}; 2734 ASTDecompiler decompiler{*this};
1847 decompiler.Visit(program); 2735 decompiler.Visit(program);
1848 2736
1849 const Id next_block = OpLabel(); 2737 const Id next_block = OpLabel();
1850 Emit(OpBranch(next_block)); 2738 OpBranch(next_block);
1851 Emit(next_block); 2739 AddLabel(next_block);
2740}
2741
2742} // Anonymous namespace
2743
2744ShaderEntries GenerateShaderEntries(const VideoCommon::Shader::ShaderIR& ir) {
2745 ShaderEntries entries;
2746 for (const auto& cbuf : ir.GetConstantBuffers()) {
2747 entries.const_buffers.emplace_back(cbuf.second, cbuf.first);
2748 }
2749 for (const auto& [base, usage] : ir.GetGlobalMemory()) {
2750 entries.global_buffers.emplace_back(base.cbuf_index, base.cbuf_offset, usage.is_written);
2751 }
2752 for (const auto& sampler : ir.GetSamplers()) {
2753 if (sampler.IsBuffer()) {
2754 entries.texel_buffers.emplace_back(sampler);
2755 } else {
2756 entries.samplers.emplace_back(sampler);
2757 }
2758 }
2759 for (const auto& image : ir.GetImages()) {
2760 entries.images.emplace_back(image);
2761 }
2762 for (const auto& attribute : ir.GetInputAttributes()) {
2763 if (IsGenericAttribute(attribute)) {
2764 entries.attributes.insert(GetGenericAttributeLocation(attribute));
2765 }
2766 }
2767 entries.clip_distances = ir.GetClipDistances();
2768 entries.shader_length = ir.GetLength();
2769 entries.uses_warps = ir.UsesWarps();
2770 return entries;
1852} 2771}
1853 2772
1854DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, 2773std::vector<u32> Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
1855 ShaderType stage) { 2774 ShaderType stage, const Specialization& specialization) {
1856 auto decompiler = std::make_unique<SPIRVDecompiler>(device, ir, stage); 2775 return SPIRVDecompiler(device, ir, stage, specialization).Assemble();
1857 decompiler->Decompile();
1858 return {std::move(decompiler), decompiler->GetShaderEntries()};
1859} 2776}
1860 2777
1861} // namespace Vulkan::VKShader 2778} // namespace Vulkan
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.h b/src/video_core/renderer_vulkan/vk_shader_decompiler.h
index 203fc00d0..2b01321b6 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.h
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.h
@@ -5,29 +5,28 @@
5#pragma once 5#pragma once
6 6
7#include <array> 7#include <array>
8#include <bitset>
8#include <memory> 9#include <memory>
9#include <set> 10#include <set>
11#include <type_traits>
10#include <utility> 12#include <utility>
11#include <vector> 13#include <vector>
12 14
13#include <sirit/sirit.h>
14
15#include "common/common_types.h" 15#include "common/common_types.h"
16#include "video_core/engines/maxwell_3d.h" 16#include "video_core/engines/maxwell_3d.h"
17#include "video_core/engines/shader_type.h"
17#include "video_core/shader/shader_ir.h" 18#include "video_core/shader/shader_ir.h"
18 19
19namespace VideoCommon::Shader {
20class ShaderIR;
21}
22
23namespace Vulkan { 20namespace Vulkan {
24class VKDevice; 21class VKDevice;
25} 22}
26 23
27namespace Vulkan::VKShader { 24namespace Vulkan {
28 25
29using Maxwell = Tegra::Engines::Maxwell3D::Regs; 26using Maxwell = Tegra::Engines::Maxwell3D::Regs;
27using TexelBufferEntry = VideoCommon::Shader::Sampler;
30using SamplerEntry = VideoCommon::Shader::Sampler; 28using SamplerEntry = VideoCommon::Shader::Sampler;
29using ImageEntry = VideoCommon::Shader::Image;
31 30
32constexpr u32 DESCRIPTOR_SET = 0; 31constexpr u32 DESCRIPTOR_SET = 0;
33 32
@@ -46,39 +45,74 @@ private:
46 45
47class GlobalBufferEntry { 46class GlobalBufferEntry {
48public: 47public:
49 explicit GlobalBufferEntry(u32 cbuf_index, u32 cbuf_offset) 48 constexpr explicit GlobalBufferEntry(u32 cbuf_index, u32 cbuf_offset, bool is_written)
50 : cbuf_index{cbuf_index}, cbuf_offset{cbuf_offset} {} 49 : cbuf_index{cbuf_index}, cbuf_offset{cbuf_offset}, is_written{is_written} {}
51 50
52 u32 GetCbufIndex() const { 51 constexpr u32 GetCbufIndex() const {
53 return cbuf_index; 52 return cbuf_index;
54 } 53 }
55 54
56 u32 GetCbufOffset() const { 55 constexpr u32 GetCbufOffset() const {
57 return cbuf_offset; 56 return cbuf_offset;
58 } 57 }
59 58
59 constexpr bool IsWritten() const {
60 return is_written;
61 }
62
60private: 63private:
61 u32 cbuf_index{}; 64 u32 cbuf_index{};
62 u32 cbuf_offset{}; 65 u32 cbuf_offset{};
66 bool is_written{};
63}; 67};
64 68
65struct ShaderEntries { 69struct ShaderEntries {
66 u32 const_buffers_base_binding{}; 70 u32 NumBindings() const {
67 u32 global_buffers_base_binding{}; 71 return static_cast<u32>(const_buffers.size() + global_buffers.size() +
68 u32 samplers_base_binding{}; 72 texel_buffers.size() + samplers.size() + images.size());
73 }
74
69 std::vector<ConstBufferEntry> const_buffers; 75 std::vector<ConstBufferEntry> const_buffers;
70 std::vector<GlobalBufferEntry> global_buffers; 76 std::vector<GlobalBufferEntry> global_buffers;
77 std::vector<TexelBufferEntry> texel_buffers;
71 std::vector<SamplerEntry> samplers; 78 std::vector<SamplerEntry> samplers;
79 std::vector<ImageEntry> images;
72 std::set<u32> attributes; 80 std::set<u32> attributes;
73 std::array<bool, Maxwell::NumClipDistances> clip_distances{}; 81 std::array<bool, Maxwell::NumClipDistances> clip_distances{};
74 std::size_t shader_length{}; 82 std::size_t shader_length{};
75 Sirit::Id entry_function{}; 83 bool uses_warps{};
76 std::vector<Sirit::Id> interfaces; 84};
85
86struct Specialization final {
87 u32 base_binding{};
88
89 // Compute specific
90 std::array<u32, 3> workgroup_size{};
91 u32 shared_memory_size{};
92
93 // Graphics specific
94 Maxwell::PrimitiveTopology primitive_topology{};
95 std::optional<float> point_size{};
96 std::array<Maxwell::VertexAttribute::Type, Maxwell::NumVertexAttributes> attribute_types{};
97
98 // Tessellation specific
99 struct {
100 Maxwell::TessellationPrimitive primitive{};
101 Maxwell::TessellationSpacing spacing{};
102 bool clockwise{};
103 } tessellation;
104};
105// Old gcc versions don't consider this trivially copyable.
106// static_assert(std::is_trivially_copyable_v<Specialization>);
107
108struct SPIRVShader {
109 std::vector<u32> code;
110 ShaderEntries entries;
77}; 111};
78 112
79using DecompilerResult = std::pair<std::unique_ptr<Sirit::Module>, ShaderEntries>; 113ShaderEntries GenerateShaderEntries(const VideoCommon::Shader::ShaderIR& ir);
80 114
81DecompilerResult Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir, 115std::vector<u32> Decompile(const VKDevice& device, const VideoCommon::Shader::ShaderIR& ir,
82 Tegra::Engines::ShaderType stage); 116 Tegra::Engines::ShaderType stage, const Specialization& specialization);
83 117
84} // namespace Vulkan::VKShader 118} // namespace Vulkan