summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/video_core/CMakeLists.txt7
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.cpp1379
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.h80
3 files changed, 1465 insertions, 1 deletions
diff --git a/src/video_core/CMakeLists.txt b/src/video_core/CMakeLists.txt
index 5c8ca429e..114bed20d 100644
--- a/src/video_core/CMakeLists.txt
+++ b/src/video_core/CMakeLists.txt
@@ -129,12 +129,14 @@ if (ENABLE_VULKAN)
129 renderer_vulkan/vk_sampler_cache.h 129 renderer_vulkan/vk_sampler_cache.h
130 renderer_vulkan/vk_scheduler.cpp 130 renderer_vulkan/vk_scheduler.cpp
131 renderer_vulkan/vk_scheduler.h 131 renderer_vulkan/vk_scheduler.h
132 renderer_vulkan/vk_shader_decompiler.cpp
133 renderer_vulkan/vk_shader_decompiler.h
132 renderer_vulkan/vk_stream_buffer.cpp 134 renderer_vulkan/vk_stream_buffer.cpp
133 renderer_vulkan/vk_stream_buffer.h 135 renderer_vulkan/vk_stream_buffer.h
134 renderer_vulkan/vk_swapchain.cpp 136 renderer_vulkan/vk_swapchain.cpp
135 renderer_vulkan/vk_swapchain.h) 137 renderer_vulkan/vk_swapchain.h)
136 138
137 target_include_directories(video_core PRIVATE ../../externals/Vulkan-Headers/include) 139 target_include_directories(video_core PRIVATE sirit ../../externals/Vulkan-Headers/include)
138 target_compile_definitions(video_core PRIVATE HAS_VULKAN) 140 target_compile_definitions(video_core PRIVATE HAS_VULKAN)
139endif() 141endif()
140 142
@@ -142,3 +144,6 @@ create_target_directory_groups(video_core)
142 144
143target_link_libraries(video_core PUBLIC common core) 145target_link_libraries(video_core PUBLIC common core)
144target_link_libraries(video_core PRIVATE glad) 146target_link_libraries(video_core PRIVATE glad)
147if (ENABLE_VULKAN)
148 target_link_libraries(video_core PRIVATE sirit)
149endif()
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
new file mode 100644
index 000000000..e0a6f5e87
--- /dev/null
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -0,0 +1,1379 @@
1// Copyright 2019 yuzu Emulator Project
2// Licensed under GPLv2 or any later version
3// Refer to the license.txt file included.
4
5#include <functional>
6#include <map>
7#include <set>
8
9#include <fmt/format.h>
10
11#include <sirit/sirit.h>
12
13#include "common/alignment.h"
14#include "common/assert.h"
15#include "common/common_types.h"
16#include "common/logging/log.h"
17#include "video_core/engines/maxwell_3d.h"
18#include "video_core/engines/shader_bytecode.h"
19#include "video_core/engines/shader_header.h"
20#include "video_core/renderer_vulkan/vk_shader_decompiler.h"
21#include "video_core/shader/shader_ir.h"
22
23namespace Vulkan::VKShader {
24
25using Sirit::Id;
26using Tegra::Shader::Attribute;
27using Tegra::Shader::AttributeUse;
28using Tegra::Shader::Register;
29using namespace VideoCommon::Shader;
30
31using Maxwell = Tegra::Engines::Maxwell3D::Regs;
32using ShaderStage = Tegra::Engines::Maxwell3D::Regs::ShaderStage;
33using Operation = const OperationNode&;
34
35// TODO(Rodrigo): Use rasterizer's value
36constexpr u32 MAX_CONSTBUFFER_ELEMENTS = 0x1000;
37constexpr u32 STAGE_BINDING_STRIDE = 0x100;
38
39enum class Type { Bool, Bool2, Float, Int, Uint, HalfFloat };
40
41struct SamplerImage {
42 Id image_type;
43 Id sampled_image_type;
44 Id sampler;
45};
46
47namespace {
48
49spv::Dim GetSamplerDim(const Sampler& sampler) {
50 switch (sampler.GetType()) {
51 case Tegra::Shader::TextureType::Texture1D:
52 return spv::Dim::Dim1D;
53 case Tegra::Shader::TextureType::Texture2D:
54 return spv::Dim::Dim2D;
55 case Tegra::Shader::TextureType::Texture3D:
56 return spv::Dim::Dim3D;
57 case Tegra::Shader::TextureType::TextureCube:
58 return spv::Dim::Cube;
59 default:
60 UNIMPLEMENTED_MSG("Unimplemented sampler type={}", static_cast<u32>(sampler.GetType()));
61 return spv::Dim::Dim2D;
62 }
63}
64
65/// Returns true if an attribute index is one of the 32 generic attributes
66constexpr bool IsGenericAttribute(Attribute::Index attribute) {
67 return attribute >= Attribute::Index::Attribute_0 &&
68 attribute <= Attribute::Index::Attribute_31;
69}
70
71/// Returns the location of a generic attribute
72constexpr u32 GetGenericAttributeLocation(Attribute::Index attribute) {
73 ASSERT(IsGenericAttribute(attribute));
74 return static_cast<u32>(attribute) - static_cast<u32>(Attribute::Index::Attribute_0);
75}
76
77/// Returns true if an object has to be treated as precise
78bool IsPrecise(Operation operand) {
79 const auto& meta = operand.GetMeta();
80
81 if (std::holds_alternative<MetaArithmetic>(meta)) {
82 return std::get<MetaArithmetic>(meta).precise;
83 }
84 if (std::holds_alternative<MetaHalfArithmetic>(meta)) {
85 return std::get<MetaHalfArithmetic>(meta).precise;
86 }
87 return false;
88}
89
90} // namespace
91
92class SPIRVDecompiler : public Sirit::Module {
93public:
94 explicit SPIRVDecompiler(const ShaderIR& ir, ShaderStage stage)
95 : Module(0x00010300), ir{ir}, stage{stage}, header{ir.GetHeader()} {
96 AddCapability(spv::Capability::Shader);
97 AddExtension("SPV_KHR_storage_buffer_storage_class");
98 AddExtension("SPV_KHR_variable_pointers");
99 }
100
101 void Decompile() {
102 AllocateBindings();
103 AllocateLabels();
104
105 DeclareVertex();
106 DeclareGeometry();
107 DeclareFragment();
108 DeclareRegisters();
109 DeclarePredicates();
110 DeclareLocalMemory();
111 DeclareInternalFlags();
112 DeclareInputAttributes();
113 DeclareOutputAttributes();
114 DeclareConstantBuffers();
115 DeclareGlobalBuffers();
116 DeclareSamplers();
117
118 execute_function =
119 Emit(OpFunction(t_void, spv::FunctionControlMask::Inline, TypeFunction(t_void)));
120 Emit(OpLabel());
121
122 const u32 first_address = ir.GetBasicBlocks().begin()->first;
123 const Id loop_label = OpLabel("loop");
124 const Id merge_label = OpLabel("merge");
125 const Id dummy_label = OpLabel();
126 const Id jump_label = OpLabel();
127 continue_label = OpLabel("continue");
128
129 std::vector<Sirit::Literal> literals;
130 std::vector<Id> branch_labels;
131 for (const auto& pair : labels) {
132 const auto [literal, label] = pair;
133 literals.push_back(literal);
134 branch_labels.push_back(label);
135 }
136
137 // TODO(Rodrigo): Figure out the actual depth of the flow stack, for now it seems unlikely
138 // that shaders will use 20 nested SSYs and PBKs.
139 constexpr u32 FLOW_STACK_SIZE = 20;
140 const Id flow_stack_type = TypeArray(t_uint, Constant(t_uint, FLOW_STACK_SIZE));
141 jmp_to = Emit(OpVariable(TypePointer(spv::StorageClass::Function, t_uint),
142 spv::StorageClass::Function, Constant(t_uint, first_address)));
143 flow_stack = Emit(OpVariable(TypePointer(spv::StorageClass::Function, flow_stack_type),
144 spv::StorageClass::Function, ConstantNull(flow_stack_type)));
145 flow_stack_top =
146 Emit(OpVariable(t_func_uint, spv::StorageClass::Function, Constant(t_uint, 0)));
147
148 Name(jmp_to, "jmp_to");
149 Name(flow_stack, "flow_stack");
150 Name(flow_stack_top, "flow_stack_top");
151
152 Emit(OpBranch(loop_label));
153 Emit(loop_label);
154 Emit(OpLoopMerge(merge_label, continue_label, spv::LoopControlMask::Unroll));
155 Emit(OpBranch(dummy_label));
156
157 Emit(dummy_label);
158 const Id default_branch = OpLabel();
159 const Id jmp_to_load = Emit(OpLoad(t_uint, jmp_to));
160 Emit(OpSelectionMerge(jump_label, spv::SelectionControlMask::MaskNone));
161 Emit(OpSwitch(jmp_to_load, default_branch, literals, branch_labels));
162
163 Emit(default_branch);
164 Emit(OpReturn());
165
166 for (const auto& pair : ir.GetBasicBlocks()) {
167 const auto& [address, bb] = pair;
168 Emit(labels.at(address));
169
170 VisitBasicBlock(bb);
171
172 const auto next_it = labels.lower_bound(address + 1);
173 const Id next_label = next_it != labels.end() ? next_it->second : default_branch;
174 Emit(OpBranch(next_label));
175 }
176
177 Emit(jump_label);
178 Emit(OpBranch(continue_label));
179 Emit(continue_label);
180 Emit(OpBranch(loop_label));
181 Emit(merge_label);
182 Emit(OpReturn());
183 Emit(OpFunctionEnd());
184 }
185
186 ShaderEntries GetShaderEntries() const {
187 ShaderEntries entries;
188 entries.const_buffers_base_binding = const_buffers_base_binding;
189 entries.global_buffers_base_binding = global_buffers_base_binding;
190 entries.samplers_base_binding = samplers_base_binding;
191 for (const auto& cbuf : ir.GetConstantBuffers()) {
192 entries.const_buffers.emplace_back(cbuf.second, cbuf.first);
193 }
194 for (const auto& gmem : ir.GetGlobalMemoryBases()) {
195 entries.global_buffers.emplace_back(gmem.cbuf_index, gmem.cbuf_offset);
196 }
197 for (const auto& sampler : ir.GetSamplers()) {
198 entries.samplers.emplace_back(sampler);
199 }
200 for (const auto& attr : ir.GetInputAttributes()) {
201 entries.attributes.insert(GetGenericAttributeLocation(attr.first));
202 }
203 entries.clip_distances = ir.GetClipDistances();
204 entries.shader_length = ir.GetLength();
205 entries.entry_function = execute_function;
206 entries.interfaces = interfaces;
207 return entries;
208 }
209
210private:
211 using OperationDecompilerFn = Id (SPIRVDecompiler::*)(Operation);
212 using OperationDecompilersArray =
213 std::array<OperationDecompilerFn, static_cast<std::size_t>(OperationCode::Amount)>;
214
215 static constexpr auto INTERNAL_FLAGS_COUNT = static_cast<std::size_t>(InternalFlag::Amount);
216 static constexpr u32 CBUF_STRIDE = 16;
217
218 void AllocateBindings() {
219 const u32 binding_base = static_cast<u32>(stage) * STAGE_BINDING_STRIDE;
220 u32 binding_iterator = binding_base;
221
222 const auto Allocate = [&binding_iterator](std::size_t count) {
223 const u32 current_binding = binding_iterator;
224 binding_iterator += static_cast<u32>(count);
225 return current_binding;
226 };
227 const_buffers_base_binding = Allocate(ir.GetConstantBuffers().size());
228 global_buffers_base_binding = Allocate(ir.GetGlobalMemoryBases().size());
229 samplers_base_binding = Allocate(ir.GetSamplers().size());
230
231 ASSERT_MSG(binding_iterator - binding_base < STAGE_BINDING_STRIDE,
232 "Stage binding stride is too small");
233 }
234
235 void AllocateLabels() {
236 for (const auto& pair : ir.GetBasicBlocks()) {
237 const u32 address = pair.first;
238 labels.emplace(address, OpLabel(fmt::format("label_0x{:x}", address)));
239 }
240 }
241
242 void DeclareVertex() {
243 if (stage != ShaderStage::Vertex)
244 return;
245
246 DeclareVertexRedeclarations();
247 }
248
249 void DeclareGeometry() {
250 if (stage != ShaderStage::Geometry)
251 return;
252
253 UNIMPLEMENTED();
254 }
255
256 void DeclareFragment() {
257 if (stage != ShaderStage::Fragment)
258 return;
259
260 for (u32 rt = 0; rt < static_cast<u32>(frag_colors.size()); ++rt) {
261 if (!IsRenderTargetUsed(rt)) {
262 continue;
263 }
264
265 const Id id = AddGlobalVariable(OpVariable(t_out_float4, spv::StorageClass::Output));
266 Name(id, fmt::format("frag_color{}", rt));
267 Decorate(id, spv::Decoration::Location, rt);
268
269 frag_colors[rt] = id;
270 interfaces.push_back(id);
271 }
272
273 if (header.ps.omap.depth) {
274 frag_depth = AddGlobalVariable(OpVariable(t_out_float, spv::StorageClass::Output));
275 Name(frag_depth, "frag_depth");
276 Decorate(frag_depth, spv::Decoration::BuiltIn,
277 static_cast<u32>(spv::BuiltIn::FragDepth));
278
279 interfaces.push_back(frag_depth);
280 }
281
282 frag_coord = DeclareBuiltIn(spv::BuiltIn::FragCoord, spv::StorageClass::Input, t_in_float4,
283 "frag_coord");
284 front_facing = DeclareBuiltIn(spv::BuiltIn::FrontFacing, spv::StorageClass::Input,
285 t_in_bool, "front_facing");
286 }
287
288 void DeclareRegisters() {
289 for (const u32 gpr : ir.GetRegisters()) {
290 const Id id = OpVariable(t_prv_float, spv::StorageClass::Private, v_float_zero);
291 Name(id, fmt::format("gpr_{}", gpr));
292 registers.emplace(gpr, AddGlobalVariable(id));
293 }
294 }
295
296 void DeclarePredicates() {
297 for (const auto pred : ir.GetPredicates()) {
298 const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
299 Name(id, fmt::format("pred_{}", static_cast<u32>(pred)));
300 predicates.emplace(pred, AddGlobalVariable(id));
301 }
302 }
303
304 void DeclareLocalMemory() {
305 if (const u64 local_memory_size = header.GetLocalMemorySize(); local_memory_size > 0) {
306 const auto element_count = static_cast<u32>(Common::AlignUp(local_memory_size, 4) / 4);
307 const Id type_array = TypeArray(t_float, Constant(t_uint, element_count));
308 const Id type_pointer = TypePointer(spv::StorageClass::Private, type_array);
309 Name(type_pointer, "LocalMemory");
310
311 local_memory =
312 OpVariable(type_pointer, spv::StorageClass::Private, ConstantNull(type_array));
313 AddGlobalVariable(Name(local_memory, "local_memory"));
314 }
315 }
316
317 void DeclareInternalFlags() {
318 constexpr std::array<const char*, INTERNAL_FLAGS_COUNT> names = {"zero", "sign", "carry",
319 "overflow"};
320 for (std::size_t flag = 0; flag < INTERNAL_FLAGS_COUNT; ++flag) {
321 const auto flag_code = static_cast<InternalFlag>(flag);
322 const Id id = OpVariable(t_prv_bool, spv::StorageClass::Private, v_false);
323 internal_flags[flag] = AddGlobalVariable(Name(id, names[flag]));
324 }
325 }
326
327 void DeclareInputAttributes() {
328 for (const auto element : ir.GetInputAttributes()) {
329 const Attribute::Index index = element.first;
330 if (!IsGenericAttribute(index)) {
331 continue;
332 }
333
334 UNIMPLEMENTED_IF(stage == ShaderStage::Geometry);
335
336 const u32 location = GetGenericAttributeLocation(index);
337 const Id id = OpVariable(t_in_float4, spv::StorageClass::Input);
338 Name(AddGlobalVariable(id), fmt::format("in_attr{}", location));
339 input_attributes.emplace(index, id);
340 interfaces.push_back(id);
341
342 Decorate(id, spv::Decoration::Location, location);
343
344 if (stage != ShaderStage::Fragment) {
345 continue;
346 }
347 switch (header.ps.GetAttributeUse(location)) {
348 case AttributeUse::Constant:
349 Decorate(id, spv::Decoration::Flat);
350 break;
351 case AttributeUse::ScreenLinear:
352 Decorate(id, spv::Decoration::NoPerspective);
353 break;
354 case AttributeUse::Perspective:
355 // Default
356 break;
357 default:
358 UNREACHABLE_MSG("Unused attribute being fetched");
359 }
360 }
361 }
362
363 void DeclareOutputAttributes() {
364 for (const auto index : ir.GetOutputAttributes()) {
365 if (!IsGenericAttribute(index)) {
366 continue;
367 }
368 const auto location = GetGenericAttributeLocation(index);
369 const Id id = OpVariable(t_out_float4, spv::StorageClass::Output);
370 Name(AddGlobalVariable(id), fmt::format("out_attr{}", location));
371 output_attributes.emplace(index, id);
372 interfaces.push_back(id);
373
374 Decorate(id, spv::Decoration::Location, location);
375 }
376 }
377
378 void DeclareConstantBuffers() {
379 u32 binding = const_buffers_base_binding;
380 for (const auto& entry : ir.GetConstantBuffers()) {
381 const auto [index, size] = entry;
382 const Id id = OpVariable(t_cbuf_ubo, spv::StorageClass::Uniform);
383 AddGlobalVariable(Name(id, fmt::format("cbuf_{}", index)));
384
385 Decorate(id, spv::Decoration::Binding, binding++);
386 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
387 constant_buffers.emplace(index, id);
388 }
389 }
390
391 void DeclareGlobalBuffers() {
392 u32 binding = global_buffers_base_binding;
393 for (const auto& entry : ir.GetGlobalMemoryBases()) {
394 const Id id = OpVariable(t_gmem_ssbo, spv::StorageClass::StorageBuffer);
395 AddGlobalVariable(
396 Name(id, fmt::format("gmem_{}_{}", entry.cbuf_index, entry.cbuf_offset)));
397
398 Decorate(id, spv::Decoration::Binding, binding++);
399 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
400 global_buffers.emplace(entry, id);
401 }
402 }
403
404 void DeclareSamplers() {
405 u32 binding = samplers_base_binding;
406 for (const auto& sampler : ir.GetSamplers()) {
407 const auto dim = GetSamplerDim(sampler);
408 const int depth = sampler.IsShadow() ? 1 : 0;
409 const int arrayed = sampler.IsArray() ? 1 : 0;
410 // TODO(Rodrigo): Sampled 1 indicates that the image will be used with a sampler. When
411 // SULD and SUST instructions are implemented, replace this value.
412 const int sampled = 1;
413 const Id image_type =
414 TypeImage(t_float, dim, depth, arrayed, false, sampled, spv::ImageFormat::Unknown);
415 const Id sampled_image_type = TypeSampledImage(image_type);
416 const Id pointer_type =
417 TypePointer(spv::StorageClass::UniformConstant, sampled_image_type);
418 const Id id = OpVariable(pointer_type, spv::StorageClass::UniformConstant);
419 AddGlobalVariable(Name(id, fmt::format("sampler_{}", sampler.GetIndex())));
420
421 sampler_images.insert(
422 {static_cast<u32>(sampler.GetIndex()), {image_type, sampled_image_type, id}});
423
424 Decorate(id, spv::Decoration::Binding, binding++);
425 Decorate(id, spv::Decoration::DescriptorSet, DESCRIPTOR_SET);
426 }
427 }
428
429 void DeclareVertexRedeclarations() {
430 vertex_index = DeclareBuiltIn(spv::BuiltIn::VertexIndex, spv::StorageClass::Input,
431 t_in_uint, "vertex_index");
432 instance_index = DeclareBuiltIn(spv::BuiltIn::InstanceIndex, spv::StorageClass::Input,
433 t_in_uint, "instance_index");
434
435 bool is_point_size_declared = false;
436 bool is_clip_distances_declared = false;
437 for (const auto index : ir.GetOutputAttributes()) {
438 if (index == Attribute::Index::PointSize) {
439 is_point_size_declared = true;
440 } else if (index == Attribute::Index::ClipDistances0123 ||
441 index == Attribute::Index::ClipDistances4567) {
442 is_clip_distances_declared = true;
443 }
444 }
445
446 std::vector<Id> members;
447 members.push_back(t_float4);
448 if (is_point_size_declared) {
449 members.push_back(t_float);
450 }
451 if (is_clip_distances_declared) {
452 members.push_back(TypeArray(t_float, Constant(t_uint, 8)));
453 }
454
455 const Id gl_per_vertex_struct = Name(TypeStruct(members), "PerVertex");
456 Decorate(gl_per_vertex_struct, spv::Decoration::Block);
457
458 u32 declaration_index = 0;
459 const auto MemberDecorateBuiltIn = [&](spv::BuiltIn builtin, std::string name,
460 bool condition) {
461 if (!condition)
462 return u32{};
463 MemberName(gl_per_vertex_struct, declaration_index, name);
464 MemberDecorate(gl_per_vertex_struct, declaration_index, spv::Decoration::BuiltIn,
465 static_cast<u32>(builtin));
466 return declaration_index++;
467 };
468
469 position_index = MemberDecorateBuiltIn(spv::BuiltIn::Position, "position", true);
470 point_size_index =
471 MemberDecorateBuiltIn(spv::BuiltIn::PointSize, "point_size", is_point_size_declared);
472 clip_distances_index = MemberDecorateBuiltIn(spv::BuiltIn::ClipDistance, "clip_distances",
473 is_clip_distances_declared);
474
475 const Id type_pointer = TypePointer(spv::StorageClass::Output, gl_per_vertex_struct);
476 per_vertex = OpVariable(type_pointer, spv::StorageClass::Output);
477 AddGlobalVariable(Name(per_vertex, "per_vertex"));
478 interfaces.push_back(per_vertex);
479 }
480
481 void VisitBasicBlock(const NodeBlock& bb) {
482 for (const Node node : bb) {
483 static_cast<void>(Visit(node));
484 }
485 }
486
487 Id Visit(Node node) {
488 if (const auto operation = std::get_if<OperationNode>(node)) {
489 const auto operation_index = static_cast<std::size_t>(operation->GetCode());
490 const auto decompiler = operation_decompilers[operation_index];
491 if (decompiler == nullptr) {
492 UNREACHABLE_MSG("Operation decompiler {} not defined", operation_index);
493 }
494 return (this->*decompiler)(*operation);
495
496 } else if (const auto gpr = std::get_if<GprNode>(node)) {
497 const u32 index = gpr->GetIndex();
498 if (index == Register::ZeroIndex) {
499 return Constant(t_float, 0.0f);
500 }
501 return Emit(OpLoad(t_float, registers.at(index)));
502
503 } else if (const auto immediate = std::get_if<ImmediateNode>(node)) {
504 return BitcastTo<Type::Float>(Constant(t_uint, immediate->GetValue()));
505
506 } else if (const auto predicate = std::get_if<PredicateNode>(node)) {
507 const auto value = [&]() -> Id {
508 switch (const auto index = predicate->GetIndex(); index) {
509 case Tegra::Shader::Pred::UnusedIndex:
510 return v_true;
511 case Tegra::Shader::Pred::NeverExecute:
512 return v_false;
513 default:
514 return Emit(OpLoad(t_bool, predicates.at(index)));
515 }
516 }();
517 if (predicate->IsNegated()) {
518 return Emit(OpLogicalNot(t_bool, value));
519 }
520 return value;
521
522 } else if (const auto abuf = std::get_if<AbufNode>(node)) {
523 const auto attribute = abuf->GetIndex();
524 const auto element = abuf->GetElement();
525
526 switch (attribute) {
527 case Attribute::Index::Position:
528 if (stage != ShaderStage::Fragment) {
529 UNIMPLEMENTED();
530 break;
531 } else {
532 if (element == 3) {
533 return Constant(t_float, 1.0f);
534 }
535 return Emit(OpLoad(t_float, AccessElement(t_in_float, frag_coord, element)));
536 }
537 case Attribute::Index::TessCoordInstanceIDVertexID:
538 // TODO(Subv): Find out what the values are for the first two elements when inside a
539 // vertex shader, and what's the value of the fourth element when inside a Tess Eval
540 // shader.
541 ASSERT(stage == ShaderStage::Vertex);
542 switch (element) {
543 case 2:
544 return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, instance_index)));
545 case 3:
546 return BitcastFrom<Type::Uint>(Emit(OpLoad(t_uint, vertex_index)));
547 }
548 UNIMPLEMENTED_MSG("Unmanaged TessCoordInstanceIDVertexID element={}", element);
549 return Constant(t_float, 0);
550 case Attribute::Index::FrontFacing:
551 // TODO(Subv): Find out what the values are for the other elements.
552 ASSERT(stage == ShaderStage::Fragment);
553 if (element == 3) {
554 const Id is_front_facing = Emit(OpLoad(t_bool, front_facing));
555 const Id true_value =
556 BitcastTo<Type::Float>(Constant(t_int, static_cast<s32>(-1)));
557 const Id false_value = BitcastTo<Type::Float>(Constant(t_int, 0));
558 return Emit(OpSelect(t_float, is_front_facing, true_value, false_value));
559 }
560 UNIMPLEMENTED_MSG("Unmanaged FrontFacing element={}", element);
561 return Constant(t_float, 0.0f);
562 default:
563 if (IsGenericAttribute(attribute)) {
564 const Id pointer =
565 AccessElement(t_in_float, input_attributes.at(attribute), element);
566 return Emit(OpLoad(t_float, pointer));
567 }
568 break;
569 }
570 UNIMPLEMENTED_MSG("Unhandled input attribute: {}", static_cast<u32>(attribute));
571
572 } else if (const auto cbuf = std::get_if<CbufNode>(node)) {
573 const Node offset = cbuf->GetOffset();
574 const Id buffer_id = constant_buffers.at(cbuf->GetIndex());
575
576 Id buffer_index{};
577 Id buffer_element{};
578
579 if (const auto immediate = std::get_if<ImmediateNode>(offset)) {
580 // Direct access
581 const u32 offset_imm = immediate->GetValue();
582 ASSERT(offset_imm % 4 == 0);
583 buffer_index = Constant(t_uint, offset_imm / 16);
584 buffer_element = Constant(t_uint, (offset_imm / 4) % 4);
585
586 } else if (std::holds_alternative<OperationNode>(*offset)) {
587 // Indirect access
588 // TODO(Rodrigo): Use a uniform buffer stride of 4 and drop this slow math (which
589 // emits sub-optimal code on GLSL from my testing).
590 const Id offset_id = BitcastTo<Type::Uint>(Visit(offset));
591 const Id unsafe_offset = Emit(OpUDiv(t_uint, offset_id, Constant(t_uint, 4)));
592 const Id final_offset = Emit(
593 OpUMod(t_uint, unsafe_offset, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS - 1)));
594 buffer_index = Emit(OpUDiv(t_uint, final_offset, Constant(t_uint, 4)));
595 buffer_element = Emit(OpUMod(t_uint, final_offset, Constant(t_uint, 4)));
596
597 } else {
598 UNREACHABLE_MSG("Unmanaged offset node type");
599 }
600
601 const Id pointer = Emit(OpAccessChain(t_cbuf_float, buffer_id, Constant(t_uint, 0),
602 buffer_index, buffer_element));
603 return Emit(OpLoad(t_float, pointer));
604
605 } else if (const auto gmem = std::get_if<GmemNode>(node)) {
606 const Id gmem_buffer = global_buffers.at(gmem->GetDescriptor());
607 const Id real = BitcastTo<Type::Uint>(Visit(gmem->GetRealAddress()));
608 const Id base = BitcastTo<Type::Uint>(Visit(gmem->GetBaseAddress()));
609
610 Id offset = Emit(OpISub(t_uint, real, base));
611 offset = Emit(OpUDiv(t_uint, offset, Constant(t_uint, 4u)));
612 return Emit(OpLoad(t_float, Emit(OpAccessChain(t_gmem_float, gmem_buffer,
613 Constant(t_uint, 0u), offset))));
614
615 } else if (const auto conditional = std::get_if<ConditionalNode>(node)) {
616 // It's invalid to call conditional on nested nodes, use an operation instead
617 const Id true_label = OpLabel();
618 const Id skip_label = OpLabel();
619 Emit(OpBranchConditional(Visit(conditional->GetCondition()), true_label, skip_label));
620 Emit(true_label);
621
622 VisitBasicBlock(conditional->GetCode());
623
624 Emit(OpBranch(skip_label));
625 Emit(skip_label);
626 return {};
627
628 } else if (const auto comment = std::get_if<CommentNode>(node)) {
629 Name(Emit(OpUndef(t_void)), comment->GetText());
630 return {};
631 }
632
633 UNREACHABLE();
634 return {};
635 }
636
637 template <Id (Module::*func)(Id, Id), Type result_type, Type type_a = result_type>
638 Id Unary(Operation operation) {
639 const Id type_def = GetTypeDefinition(result_type);
640 const Id op_a = VisitOperand<type_a>(operation, 0);
641
642 const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a)));
643 if (IsPrecise(operation)) {
644 Decorate(value, spv::Decoration::NoContraction);
645 }
646 return value;
647 }
648
649 template <Id (Module::*func)(Id, Id, Id), Type result_type, Type type_a = result_type,
650 Type type_b = type_a>
651 Id Binary(Operation operation) {
652 const Id type_def = GetTypeDefinition(result_type);
653 const Id op_a = VisitOperand<type_a>(operation, 0);
654 const Id op_b = VisitOperand<type_b>(operation, 1);
655
656 const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b)));
657 if (IsPrecise(operation)) {
658 Decorate(value, spv::Decoration::NoContraction);
659 }
660 return value;
661 }
662
663 template <Id (Module::*func)(Id, Id, Id, Id), Type result_type, Type type_a = result_type,
664 Type type_b = type_a, Type type_c = type_b>
665 Id Ternary(Operation operation) {
666 const Id type_def = GetTypeDefinition(result_type);
667 const Id op_a = VisitOperand<type_a>(operation, 0);
668 const Id op_b = VisitOperand<type_b>(operation, 1);
669 const Id op_c = VisitOperand<type_c>(operation, 2);
670
671 const Id value = BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c)));
672 if (IsPrecise(operation)) {
673 Decorate(value, spv::Decoration::NoContraction);
674 }
675 return value;
676 }
677
678 template <Id (Module::*func)(Id, Id, Id, Id, Id), Type result_type, Type type_a = result_type,
679 Type type_b = type_a, Type type_c = type_b, Type type_d = type_c>
680 Id Quaternary(Operation operation) {
681 const Id type_def = GetTypeDefinition(result_type);
682 const Id op_a = VisitOperand<type_a>(operation, 0);
683 const Id op_b = VisitOperand<type_b>(operation, 1);
684 const Id op_c = VisitOperand<type_c>(operation, 2);
685 const Id op_d = VisitOperand<type_d>(operation, 3);
686
687 const Id value =
688 BitcastFrom<result_type>(Emit((this->*func)(type_def, op_a, op_b, op_c, op_d)));
689 if (IsPrecise(operation)) {
690 Decorate(value, spv::Decoration::NoContraction);
691 }
692 return value;
693 }
694
695 Id Assign(Operation operation) {
696 const Node dest = operation[0];
697 const Node src = operation[1];
698
699 Id target{};
700 if (const auto gpr = std::get_if<GprNode>(dest)) {
701 if (gpr->GetIndex() == Register::ZeroIndex) {
702 // Writing to Register::ZeroIndex is a no op
703 return {};
704 }
705 target = registers.at(gpr->GetIndex());
706
707 } else if (const auto abuf = std::get_if<AbufNode>(dest)) {
708 target = [&]() -> Id {
709 switch (const auto attribute = abuf->GetIndex(); attribute) {
710 case Attribute::Index::Position:
711 return AccessElement(t_out_float, per_vertex, position_index,
712 abuf->GetElement());
713 case Attribute::Index::PointSize:
714 return AccessElement(t_out_float, per_vertex, point_size_index);
715 case Attribute::Index::ClipDistances0123:
716 return AccessElement(t_out_float, per_vertex, clip_distances_index,
717 abuf->GetElement());
718 case Attribute::Index::ClipDistances4567:
719 return AccessElement(t_out_float, per_vertex, clip_distances_index,
720 abuf->GetElement() + 4);
721 default:
722 if (IsGenericAttribute(attribute)) {
723 return AccessElement(t_out_float, output_attributes.at(attribute),
724 abuf->GetElement());
725 }
726 UNIMPLEMENTED_MSG("Unhandled output attribute: {}",
727 static_cast<u32>(attribute));
728 return {};
729 }
730 }();
731
732 } else if (const auto lmem = std::get_if<LmemNode>(dest)) {
733 Id address = BitcastTo<Type::Uint>(Visit(lmem->GetAddress()));
734 address = Emit(OpUDiv(t_uint, address, Constant(t_uint, 4)));
735 target = Emit(OpAccessChain(t_prv_float, local_memory, {address}));
736 }
737
738 Emit(OpStore(target, Visit(src)));
739 return {};
740 }
741
742 Id HNegate(Operation operation) {
743 UNIMPLEMENTED();
744 return {};
745 }
746
747 Id HMergeF32(Operation operation) {
748 UNIMPLEMENTED();
749 return {};
750 }
751
752 Id HMergeH0(Operation operation) {
753 UNIMPLEMENTED();
754 return {};
755 }
756
757 Id HMergeH1(Operation operation) {
758 UNIMPLEMENTED();
759 return {};
760 }
761
762 Id HPack2(Operation operation) {
763 UNIMPLEMENTED();
764 return {};
765 }
766
767 Id LogicalAssign(Operation operation) {
768 const Node dest = operation[0];
769 const Node src = operation[1];
770
771 Id target{};
772 if (const auto pred = std::get_if<PredicateNode>(dest)) {
773 ASSERT_MSG(!pred->IsNegated(), "Negating logical assignment");
774
775 const auto index = pred->GetIndex();
776 switch (index) {
777 case Tegra::Shader::Pred::NeverExecute:
778 case Tegra::Shader::Pred::UnusedIndex:
779 // Writing to these predicates is a no-op
780 return {};
781 }
782 target = predicates.at(index);
783
784 } else if (const auto flag = std::get_if<InternalFlagNode>(dest)) {
785 target = internal_flags.at(static_cast<u32>(flag->GetFlag()));
786 }
787
788 Emit(OpStore(target, Visit(src)));
789 return {};
790 }
791
792 Id LogicalPick2(Operation operation) {
793 UNIMPLEMENTED();
794 return {};
795 }
796
797 Id LogicalAll2(Operation operation) {
798 UNIMPLEMENTED();
799 return {};
800 }
801
802 Id LogicalAny2(Operation operation) {
803 UNIMPLEMENTED();
804 return {};
805 }
806
807 Id GetTextureSampler(Operation operation) {
808 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
809 const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex()));
810 return Emit(OpLoad(entry.sampled_image_type, entry.sampler));
811 }
812
813 Id GetTextureImage(Operation operation) {
814 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
815 const auto entry = sampler_images.at(static_cast<u32>(meta->sampler.GetIndex()));
816 return Emit(OpImage(entry.image_type, GetTextureSampler(operation)));
817 }
818
819 Id GetTextureCoordinates(Operation operation) {
820 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
821 std::vector<Id> coords;
822 for (std::size_t i = 0; i < operation.GetOperandsCount(); ++i) {
823 coords.push_back(Visit(operation[i]));
824 }
825 if (meta->sampler.IsArray()) {
826 const Id array_integer = BitcastTo<Type::Int>(Visit(meta->array));
827 coords.push_back(Emit(OpConvertSToF(t_float, array_integer)));
828 }
829 if (meta->sampler.IsShadow()) {
830 coords.push_back(Visit(meta->depth_compare));
831 }
832
833 const std::array<Id, 4> t_float_lut = {nullptr, t_float2, t_float3, t_float4};
834 return coords.size() == 1
835 ? coords[0]
836 : Emit(OpCompositeConstruct(t_float_lut.at(coords.size() - 1), coords));
837 }
838
839 Id GetTextureElement(Operation operation, Id sample_value) {
840 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
841 ASSERT(meta);
842 return Emit(OpCompositeExtract(t_float, sample_value, meta->element));
843 }
844
845 Id Texture(Operation operation) {
846 const Id texture = Emit(OpImageSampleImplicitLod(t_float4, GetTextureSampler(operation),
847 GetTextureCoordinates(operation)));
848 return GetTextureElement(operation, texture);
849 }
850
851 Id TextureLod(Operation operation) {
852 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
853 const Id texture = Emit(OpImageSampleExplicitLod(
854 t_float4, GetTextureSampler(operation), GetTextureCoordinates(operation),
855 spv::ImageOperandsMask::Lod, Visit(meta->lod)));
856 return GetTextureElement(operation, texture);
857 }
858
859 Id TextureGather(Operation operation) {
860 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
861 const auto coords = GetTextureCoordinates(operation);
862
863 Id texture;
864 if (meta->sampler.IsShadow()) {
865 texture = Emit(OpImageDrefGather(t_float4, GetTextureSampler(operation), coords,
866 Visit(meta->component)));
867 } else {
868 u32 component_value = 0;
869 if (meta->component) {
870 const auto component = std::get_if<ImmediateNode>(meta->component);
871 ASSERT_MSG(component, "Component is not an immediate value");
872 component_value = component->GetValue();
873 }
874 texture = Emit(OpImageGather(t_float4, GetTextureSampler(operation), coords,
875 Constant(t_uint, component_value)));
876 }
877
878 return GetTextureElement(operation, texture);
879 }
880
881 Id TextureQueryDimensions(Operation operation) {
882 const auto meta = std::get_if<MetaTexture>(&operation.GetMeta());
883 const auto image_id = GetTextureImage(operation);
884 AddCapability(spv::Capability::ImageQuery);
885
886 if (meta->element == 3) {
887 return BitcastTo<Type::Float>(Emit(OpImageQueryLevels(t_int, image_id)));
888 }
889
890 const Id lod = VisitOperand<Type::Uint>(operation, 0);
891 const std::size_t coords_count = [&]() {
892 switch (const auto type = meta->sampler.GetType(); type) {
893 case Tegra::Shader::TextureType::Texture1D:
894 return 1;
895 case Tegra::Shader::TextureType::Texture2D:
896 case Tegra::Shader::TextureType::TextureCube:
897 return 2;
898 case Tegra::Shader::TextureType::Texture3D:
899 return 3;
900 default:
901 UNREACHABLE_MSG("Invalid texture type={}", static_cast<u32>(type));
902 return 2;
903 }
904 }();
905
906 if (meta->element >= coords_count) {
907 return Constant(t_float, 0.0f);
908 }
909
910 const std::array<Id, 3> types = {t_int, t_int2, t_int3};
911 const Id sizes = Emit(OpImageQuerySizeLod(types.at(coords_count - 1), image_id, lod));
912 const Id size = Emit(OpCompositeExtract(t_int, sizes, meta->element));
913 return BitcastTo<Type::Float>(size);
914 }
915
916 Id TextureQueryLod(Operation operation) {
917 UNIMPLEMENTED();
918 return {};
919 }
920
921 Id TexelFetch(Operation operation) {
922 UNIMPLEMENTED();
923 return {};
924 }
925
926 Id Branch(Operation operation) {
927 const auto target = std::get_if<ImmediateNode>(operation[0]);
928 UNIMPLEMENTED_IF(!target);
929
930 Emit(OpStore(jmp_to, Constant(t_uint, target->GetValue())));
931 BranchingOp([&]() { Emit(OpBranch(continue_label)); });
932 return {};
933 }
934
935 Id PushFlowStack(Operation operation) {
936 const auto target = std::get_if<ImmediateNode>(operation[0]);
937 ASSERT(target);
938
939 const Id current = Emit(OpLoad(t_uint, flow_stack_top));
940 const Id next = Emit(OpIAdd(t_uint, current, Constant(t_uint, 1)));
941 const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, current));
942
943 Emit(OpStore(access, Constant(t_uint, target->GetValue())));
944 Emit(OpStore(flow_stack_top, next));
945 return {};
946 }
947
948 Id PopFlowStack(Operation operation) {
949 const Id current = Emit(OpLoad(t_uint, flow_stack_top));
950 const Id previous = Emit(OpISub(t_uint, current, Constant(t_uint, 1)));
951 const Id access = Emit(OpAccessChain(t_func_uint, flow_stack, previous));
952 const Id target = Emit(OpLoad(t_uint, access));
953
954 Emit(OpStore(flow_stack_top, previous));
955 Emit(OpStore(jmp_to, target));
956 BranchingOp([&]() { Emit(OpBranch(continue_label)); });
957 return {};
958 }
959
960 Id Exit(Operation operation) {
961 switch (stage) {
962 case ShaderStage::Vertex: {
963 // TODO(Rodrigo): We should use VK_EXT_depth_range_unrestricted instead, but it doesn't
964 // seem to be working on Nvidia's drivers and Intel (mesa and blob) doesn't support it.
965 const Id position = AccessElement(t_float4, per_vertex, position_index);
966 Id depth = Emit(OpLoad(t_float, AccessElement(t_out_float, position, 2)));
967 depth = Emit(OpFAdd(t_float, depth, Constant(t_float, 1.0f)));
968 depth = Emit(OpFMul(t_float, depth, Constant(t_float, 0.5f)));
969 Emit(OpStore(AccessElement(t_out_float, position, 2), depth));
970 break;
971 }
972 case ShaderStage::Fragment: {
973 const auto SafeGetRegister = [&](u32 reg) {
974 // TODO(Rodrigo): Replace with contains once C++20 releases
975 if (const auto it = registers.find(reg); it != registers.end()) {
976 return Emit(OpLoad(t_float, it->second));
977 }
978 return Constant(t_float, 0.0f);
979 };
980
981 UNIMPLEMENTED_IF_MSG(header.ps.omap.sample_mask != 0,
982 "Sample mask write is unimplemented");
983
984 // TODO(Rodrigo): Alpha testing
985
986 // Write the color outputs using the data in the shader registers, disabled
987 // rendertargets/components are skipped in the register assignment.
988 u32 current_reg = 0;
989 for (u32 rt = 0; rt < Maxwell::NumRenderTargets; ++rt) {
990 // TODO(Subv): Figure out how dual-source blending is configured in the Switch.
991 for (u32 component = 0; component < 4; ++component) {
992 if (header.ps.IsColorComponentOutputEnabled(rt, component)) {
993 Emit(OpStore(AccessElement(t_out_float, frag_colors.at(rt), component),
994 SafeGetRegister(current_reg)));
995 ++current_reg;
996 }
997 }
998 }
999 if (header.ps.omap.depth) {
1000 // The depth output is always 2 registers after the last color output, and
1001 // current_reg already contains one past the last color register.
1002 Emit(OpStore(frag_depth, SafeGetRegister(current_reg + 1)));
1003 }
1004 break;
1005 }
1006 }
1007
1008 BranchingOp([&]() { Emit(OpReturn()); });
1009 return {};
1010 }
1011
1012 Id Discard(Operation operation) {
1013 BranchingOp([&]() { Emit(OpKill()); });
1014 return {};
1015 }
1016
1017 Id EmitVertex(Operation operation) {
1018 UNIMPLEMENTED();
1019 return {};
1020 }
1021
1022 Id EndPrimitive(Operation operation) {
1023 UNIMPLEMENTED();
1024 return {};
1025 }
1026
1027 Id YNegate(Operation operation) {
1028 UNIMPLEMENTED();
1029 return {};
1030 }
1031
1032 Id DeclareBuiltIn(spv::BuiltIn builtin, spv::StorageClass storage, Id type,
1033 const std::string& name) {
1034 const Id id = OpVariable(type, storage);
1035 Decorate(id, spv::Decoration::BuiltIn, static_cast<u32>(builtin));
1036 AddGlobalVariable(Name(id, name));
1037 interfaces.push_back(id);
1038 return id;
1039 }
1040
1041 bool IsRenderTargetUsed(u32 rt) const {
1042 for (u32 component = 0; component < 4; ++component) {
1043 if (header.ps.IsColorComponentOutputEnabled(rt, component)) {
1044 return true;
1045 }
1046 }
1047 return false;
1048 }
1049
1050 template <typename... Args>
1051 Id AccessElement(Id pointer_type, Id composite, Args... elements_) {
1052 std::vector<Id> members;
1053 auto elements = {elements_...};
1054 for (const auto element : elements) {
1055 members.push_back(Constant(t_uint, element));
1056 }
1057
1058 return Emit(OpAccessChain(pointer_type, composite, members));
1059 }
1060
1061 template <Type type>
1062 Id VisitOperand(Operation operation, std::size_t operand_index) {
1063 const Id value = Visit(operation[operand_index]);
1064
1065 switch (type) {
1066 case Type::Bool:
1067 case Type::Bool2:
1068 case Type::Float:
1069 return value;
1070 case Type::Int:
1071 return Emit(OpBitcast(t_int, value));
1072 case Type::Uint:
1073 return Emit(OpBitcast(t_uint, value));
1074 case Type::HalfFloat:
1075 UNIMPLEMENTED();
1076 }
1077 UNREACHABLE();
1078 return value;
1079 }
1080
1081 template <Type type>
1082 Id BitcastFrom(Id value) {
1083 switch (type) {
1084 case Type::Bool:
1085 case Type::Bool2:
1086 case Type::Float:
1087 return value;
1088 case Type::Int:
1089 case Type::Uint:
1090 return Emit(OpBitcast(t_float, value));
1091 case Type::HalfFloat:
1092 UNIMPLEMENTED();
1093 }
1094 UNREACHABLE();
1095 return value;
1096 }
1097
1098 template <Type type>
1099 Id BitcastTo(Id value) {
1100 switch (type) {
1101 case Type::Bool:
1102 case Type::Bool2:
1103 UNREACHABLE();
1104 case Type::Float:
1105 return Emit(OpBitcast(t_float, value));
1106 case Type::Int:
1107 return Emit(OpBitcast(t_int, value));
1108 case Type::Uint:
1109 return Emit(OpBitcast(t_uint, value));
1110 case Type::HalfFloat:
1111 UNIMPLEMENTED();
1112 }
1113 UNREACHABLE();
1114 return value;
1115 }
1116
1117 Id GetTypeDefinition(Type type) {
1118 switch (type) {
1119 case Type::Bool:
1120 return t_bool;
1121 case Type::Bool2:
1122 return t_bool2;
1123 case Type::Float:
1124 return t_float;
1125 case Type::Int:
1126 return t_int;
1127 case Type::Uint:
1128 return t_uint;
1129 case Type::HalfFloat:
1130 UNIMPLEMENTED();
1131 }
1132 UNREACHABLE();
1133 return {};
1134 }
1135
1136 void BranchingOp(std::function<void()> call) {
1137 const Id true_label = OpLabel();
1138 const Id skip_label = OpLabel();
1139 Emit(OpSelectionMerge(skip_label, spv::SelectionControlMask::Flatten));
1140 Emit(OpBranchConditional(v_true, true_label, skip_label, 1, 0));
1141 Emit(true_label);
1142 call();
1143
1144 Emit(skip_label);
1145 }
1146
1147 static constexpr OperationDecompilersArray operation_decompilers = {
1148 &SPIRVDecompiler::Assign,
1149
1150 &SPIRVDecompiler::Ternary<&Module::OpSelect, Type::Float, Type::Bool, Type::Float,
1151 Type::Float>,
1152
1153 &SPIRVDecompiler::Binary<&Module::OpFAdd, Type::Float>,
1154 &SPIRVDecompiler::Binary<&Module::OpFMul, Type::Float>,
1155 &SPIRVDecompiler::Binary<&Module::OpFDiv, Type::Float>,
1156 &SPIRVDecompiler::Ternary<&Module::OpFma, Type::Float>,
1157 &SPIRVDecompiler::Unary<&Module::OpFNegate, Type::Float>,
1158 &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::Float>,
1159 &SPIRVDecompiler::Ternary<&Module::OpFClamp, Type::Float>,
1160 &SPIRVDecompiler::Binary<&Module::OpFMin, Type::Float>,
1161 &SPIRVDecompiler::Binary<&Module::OpFMax, Type::Float>,
1162 &SPIRVDecompiler::Unary<&Module::OpCos, Type::Float>,
1163 &SPIRVDecompiler::Unary<&Module::OpSin, Type::Float>,
1164 &SPIRVDecompiler::Unary<&Module::OpExp2, Type::Float>,
1165 &SPIRVDecompiler::Unary<&Module::OpLog2, Type::Float>,
1166 &SPIRVDecompiler::Unary<&Module::OpInverseSqrt, Type::Float>,
1167 &SPIRVDecompiler::Unary<&Module::OpSqrt, Type::Float>,
1168 &SPIRVDecompiler::Unary<&Module::OpRoundEven, Type::Float>,
1169 &SPIRVDecompiler::Unary<&Module::OpFloor, Type::Float>,
1170 &SPIRVDecompiler::Unary<&Module::OpCeil, Type::Float>,
1171 &SPIRVDecompiler::Unary<&Module::OpTrunc, Type::Float>,
1172 &SPIRVDecompiler::Unary<&Module::OpConvertSToF, Type::Float, Type::Int>,
1173 &SPIRVDecompiler::Unary<&Module::OpConvertUToF, Type::Float, Type::Uint>,
1174
1175 &SPIRVDecompiler::Binary<&Module::OpIAdd, Type::Int>,
1176 &SPIRVDecompiler::Binary<&Module::OpIMul, Type::Int>,
1177 &SPIRVDecompiler::Binary<&Module::OpSDiv, Type::Int>,
1178 &SPIRVDecompiler::Unary<&Module::OpSNegate, Type::Int>,
1179 &SPIRVDecompiler::Unary<&Module::OpSAbs, Type::Int>,
1180 &SPIRVDecompiler::Binary<&Module::OpSMin, Type::Int>,
1181 &SPIRVDecompiler::Binary<&Module::OpSMax, Type::Int>,
1182
1183 &SPIRVDecompiler::Unary<&Module::OpConvertFToS, Type::Int, Type::Float>,
1184 &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Int, Type::Uint>,
1185 &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Int, Type::Int, Type::Uint>,
1186 &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Int, Type::Int, Type::Uint>,
1187 &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Int, Type::Int, Type::Uint>,
1188 &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Int>,
1189 &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Int>,
1190 &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Int>,
1191 &SPIRVDecompiler::Unary<&Module::OpNot, Type::Int>,
1192 &SPIRVDecompiler::Quaternary<&Module::OpBitFieldInsert, Type::Int>,
1193 &SPIRVDecompiler::Ternary<&Module::OpBitFieldSExtract, Type::Int>,
1194 &SPIRVDecompiler::Unary<&Module::OpBitCount, Type::Int>,
1195
1196 &SPIRVDecompiler::Binary<&Module::OpIAdd, Type::Uint>,
1197 &SPIRVDecompiler::Binary<&Module::OpIMul, Type::Uint>,
1198 &SPIRVDecompiler::Binary<&Module::OpUDiv, Type::Uint>,
1199 &SPIRVDecompiler::Binary<&Module::OpUMin, Type::Uint>,
1200 &SPIRVDecompiler::Binary<&Module::OpUMax, Type::Uint>,
1201 &SPIRVDecompiler::Unary<&Module::OpConvertFToU, Type::Uint, Type::Float>,
1202 &SPIRVDecompiler::Unary<&Module::OpBitcast, Type::Uint, Type::Int>,
1203 &SPIRVDecompiler::Binary<&Module::OpShiftLeftLogical, Type::Uint>,
1204 &SPIRVDecompiler::Binary<&Module::OpShiftRightLogical, Type::Uint>,
1205 &SPIRVDecompiler::Binary<&Module::OpShiftRightArithmetic, Type::Uint>,
1206 &SPIRVDecompiler::Binary<&Module::OpBitwiseAnd, Type::Uint>,
1207 &SPIRVDecompiler::Binary<&Module::OpBitwiseOr, Type::Uint>,
1208 &SPIRVDecompiler::Binary<&Module::OpBitwiseXor, Type::Uint>,
1209 &SPIRVDecompiler::Unary<&Module::OpNot, Type::Uint>,
1210 &SPIRVDecompiler::Quaternary<&Module::OpBitFieldInsert, Type::Uint>,
1211 &SPIRVDecompiler::Ternary<&Module::OpBitFieldUExtract, Type::Uint>,
1212 &SPIRVDecompiler::Unary<&Module::OpBitCount, Type::Uint>,
1213
1214 &SPIRVDecompiler::Binary<&Module::OpFAdd, Type::HalfFloat>,
1215 &SPIRVDecompiler::Binary<&Module::OpFMul, Type::HalfFloat>,
1216 &SPIRVDecompiler::Ternary<&Module::OpFma, Type::HalfFloat>,
1217 &SPIRVDecompiler::Unary<&Module::OpFAbs, Type::HalfFloat>,
1218 &SPIRVDecompiler::HNegate,
1219 &SPIRVDecompiler::HMergeF32,
1220 &SPIRVDecompiler::HMergeH0,
1221 &SPIRVDecompiler::HMergeH1,
1222 &SPIRVDecompiler::HPack2,
1223
1224 &SPIRVDecompiler::LogicalAssign,
1225 &SPIRVDecompiler::Binary<&Module::OpLogicalAnd, Type::Bool>,
1226 &SPIRVDecompiler::Binary<&Module::OpLogicalOr, Type::Bool>,
1227 &SPIRVDecompiler::Binary<&Module::OpLogicalNotEqual, Type::Bool>,
1228 &SPIRVDecompiler::Unary<&Module::OpLogicalNot, Type::Bool>,
1229 &SPIRVDecompiler::LogicalPick2,
1230 &SPIRVDecompiler::LogicalAll2,
1231 &SPIRVDecompiler::LogicalAny2,
1232
1233 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::Float>,
1234 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::Float>,
1235 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::Float>,
1236 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::Float>,
1237 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::Float>,
1238 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::Float>,
1239 &SPIRVDecompiler::Unary<&Module::OpIsNan, Type::Bool>,
1240
1241 &SPIRVDecompiler::Binary<&Module::OpSLessThan, Type::Bool, Type::Int>,
1242 &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Int>,
1243 &SPIRVDecompiler::Binary<&Module::OpSLessThanEqual, Type::Bool, Type::Int>,
1244 &SPIRVDecompiler::Binary<&Module::OpSGreaterThan, Type::Bool, Type::Int>,
1245 &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Int>,
1246 &SPIRVDecompiler::Binary<&Module::OpSGreaterThanEqual, Type::Bool, Type::Int>,
1247
1248 &SPIRVDecompiler::Binary<&Module::OpULessThan, Type::Bool, Type::Uint>,
1249 &SPIRVDecompiler::Binary<&Module::OpIEqual, Type::Bool, Type::Uint>,
1250 &SPIRVDecompiler::Binary<&Module::OpULessThanEqual, Type::Bool, Type::Uint>,
1251 &SPIRVDecompiler::Binary<&Module::OpUGreaterThan, Type::Bool, Type::Uint>,
1252 &SPIRVDecompiler::Binary<&Module::OpINotEqual, Type::Bool, Type::Uint>,
1253 &SPIRVDecompiler::Binary<&Module::OpUGreaterThanEqual, Type::Bool, Type::Uint>,
1254
1255 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThan, Type::Bool, Type::HalfFloat>,
1256 &SPIRVDecompiler::Binary<&Module::OpFOrdEqual, Type::Bool, Type::HalfFloat>,
1257 &SPIRVDecompiler::Binary<&Module::OpFOrdLessThanEqual, Type::Bool, Type::HalfFloat>,
1258 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThan, Type::Bool, Type::HalfFloat>,
1259 &SPIRVDecompiler::Binary<&Module::OpFOrdNotEqual, Type::Bool, Type::HalfFloat>,
1260 &SPIRVDecompiler::Binary<&Module::OpFOrdGreaterThanEqual, Type::Bool, Type::HalfFloat>,
1261
1262 &SPIRVDecompiler::Texture,
1263 &SPIRVDecompiler::TextureLod,
1264 &SPIRVDecompiler::TextureGather,
1265 &SPIRVDecompiler::TextureQueryDimensions,
1266 &SPIRVDecompiler::TextureQueryLod,
1267 &SPIRVDecompiler::TexelFetch,
1268
1269 &SPIRVDecompiler::Branch,
1270 &SPIRVDecompiler::PushFlowStack,
1271 &SPIRVDecompiler::PopFlowStack,
1272 &SPIRVDecompiler::Exit,
1273 &SPIRVDecompiler::Discard,
1274
1275 &SPIRVDecompiler::EmitVertex,
1276 &SPIRVDecompiler::EndPrimitive,
1277
1278 &SPIRVDecompiler::YNegate,
1279 };
1280
1281 const ShaderIR& ir;
1282 const ShaderStage stage;
1283 const Tegra::Shader::Header header;
1284
1285 const Id t_void = Name(TypeVoid(), "void");
1286
1287 const Id t_bool = Name(TypeBool(), "bool");
1288 const Id t_bool2 = Name(TypeVector(t_bool, 2), "bool2");
1289
1290 const Id t_int = Name(TypeInt(32, true), "int");
1291 const Id t_int2 = Name(TypeVector(t_int, 2), "int2");
1292 const Id t_int3 = Name(TypeVector(t_int, 3), "int3");
1293 const Id t_int4 = Name(TypeVector(t_int, 4), "int4");
1294
1295 const Id t_uint = Name(TypeInt(32, false), "uint");
1296 const Id t_uint2 = Name(TypeVector(t_uint, 2), "uint2");
1297 const Id t_uint3 = Name(TypeVector(t_uint, 3), "uint3");
1298 const Id t_uint4 = Name(TypeVector(t_uint, 4), "uint4");
1299
1300 const Id t_float = Name(TypeFloat(32), "float");
1301 const Id t_float2 = Name(TypeVector(t_float, 2), "float2");
1302 const Id t_float3 = Name(TypeVector(t_float, 3), "float3");
1303 const Id t_float4 = Name(TypeVector(t_float, 4), "float4");
1304
1305 const Id t_prv_bool = Name(TypePointer(spv::StorageClass::Private, t_bool), "prv_bool");
1306 const Id t_prv_float = Name(TypePointer(spv::StorageClass::Private, t_float), "prv_float");
1307
1308 const Id t_func_uint = Name(TypePointer(spv::StorageClass::Function, t_uint), "func_uint");
1309
1310 const Id t_in_bool = Name(TypePointer(spv::StorageClass::Input, t_bool), "in_bool");
1311 const Id t_in_uint = Name(TypePointer(spv::StorageClass::Input, t_uint), "in_uint");
1312 const Id t_in_float = Name(TypePointer(spv::StorageClass::Input, t_float), "in_float");
1313 const Id t_in_float4 = Name(TypePointer(spv::StorageClass::Input, t_float4), "in_float4");
1314
1315 const Id t_out_float = Name(TypePointer(spv::StorageClass::Output, t_float), "out_float");
1316 const Id t_out_float4 = Name(TypePointer(spv::StorageClass::Output, t_float4), "out_float4");
1317
1318 const Id t_cbuf_float = TypePointer(spv::StorageClass::Uniform, t_float);
1319 const Id t_cbuf_array =
1320 Decorate(Name(TypeArray(t_float4, Constant(t_uint, MAX_CONSTBUFFER_ELEMENTS)), "CbufArray"),
1321 spv::Decoration::ArrayStride, CBUF_STRIDE);
1322 const Id t_cbuf_struct = MemberDecorate(
1323 Decorate(TypeStruct(t_cbuf_array), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
1324 const Id t_cbuf_ubo = TypePointer(spv::StorageClass::Uniform, t_cbuf_struct);
1325
1326 const Id t_gmem_float = TypePointer(spv::StorageClass::StorageBuffer, t_float);
1327 const Id t_gmem_array =
1328 Name(Decorate(TypeRuntimeArray(t_float), spv::Decoration::ArrayStride, 4u), "GmemArray");
1329 const Id t_gmem_struct = MemberDecorate(
1330 Decorate(TypeStruct(t_gmem_array), spv::Decoration::Block), 0, spv::Decoration::Offset, 0);
1331 const Id t_gmem_ssbo = TypePointer(spv::StorageClass::StorageBuffer, t_gmem_struct);
1332
1333 const Id v_float_zero = Constant(t_float, 0.0f);
1334 const Id v_true = ConstantTrue(t_bool);
1335 const Id v_false = ConstantFalse(t_bool);
1336
1337 Id per_vertex{};
1338 std::map<u32, Id> registers;
1339 std::map<Tegra::Shader::Pred, Id> predicates;
1340 Id local_memory{};
1341 std::array<Id, INTERNAL_FLAGS_COUNT> internal_flags{};
1342 std::map<Attribute::Index, Id> input_attributes;
1343 std::map<Attribute::Index, Id> output_attributes;
1344 std::map<u32, Id> constant_buffers;
1345 std::map<GlobalMemoryBase, Id> global_buffers;
1346 std::map<u32, SamplerImage> sampler_images;
1347
1348 Id instance_index{};
1349 Id vertex_index{};
1350 std::array<Id, Maxwell::NumRenderTargets> frag_colors{};
1351 Id frag_depth{};
1352 Id frag_coord{};
1353 Id front_facing{};
1354
1355 u32 position_index{};
1356 u32 point_size_index{};
1357 u32 clip_distances_index{};
1358
1359 std::vector<Id> interfaces;
1360
1361 u32 const_buffers_base_binding{};
1362 u32 global_buffers_base_binding{};
1363 u32 samplers_base_binding{};
1364
1365 Id execute_function{};
1366 Id jmp_to{};
1367 Id flow_stack_top{};
1368 Id flow_stack{};
1369 Id continue_label{};
1370 std::map<u32, Id> labels;
1371};
1372
1373DecompilerResult Decompile(const VideoCommon::Shader::ShaderIR& ir, Maxwell::ShaderStage stage) {
1374 auto decompiler = std::make_unique<SPIRVDecompiler>(ir, stage);
1375 decompiler->Decompile();
1376 return {std::move(decompiler), decompiler->GetShaderEntries()};
1377}
1378
1379} // namespace Vulkan::VKShader
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.h b/src/video_core/renderer_vulkan/vk_shader_decompiler.h
new file mode 100644
index 000000000..329d8fa38
--- /dev/null
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.h
@@ -0,0 +1,80 @@
1// Copyright 2019 yuzu Emulator Project
2// Licensed under GPLv2 or any later version
3// Refer to the license.txt file included.
4
5#pragma once
6
7#include <array>
8#include <memory>
9#include <set>
10#include <utility>
11#include <vector>
12
13#include <sirit/sirit.h>
14
15#include "common/common_types.h"
16#include "video_core/engines/maxwell_3d.h"
17#include "video_core/shader/shader_ir.h"
18
19namespace VideoCommon::Shader {
20class ShaderIR;
21}
22
23namespace Vulkan::VKShader {
24
25using Maxwell = Tegra::Engines::Maxwell3D::Regs;
26
27using SamplerEntry = VideoCommon::Shader::Sampler;
28
29constexpr u32 DESCRIPTOR_SET = 0;
30
31class ConstBufferEntry : public VideoCommon::Shader::ConstBuffer {
32public:
33 explicit constexpr ConstBufferEntry(const VideoCommon::Shader::ConstBuffer& entry, u32 index)
34 : VideoCommon::Shader::ConstBuffer{entry}, index{index} {}
35
36 constexpr u32 GetIndex() const {
37 return index;
38 }
39
40private:
41 u32 index{};
42};
43
44class GlobalBufferEntry {
45public:
46 explicit GlobalBufferEntry(u32 cbuf_index, u32 cbuf_offset)
47 : cbuf_index{cbuf_index}, cbuf_offset{cbuf_offset} {}
48
49 u32 GetCbufIndex() const {
50 return cbuf_index;
51 }
52
53 u32 GetCbufOffset() const {
54 return cbuf_offset;
55 }
56
57private:
58 u32 cbuf_index{};
59 u32 cbuf_offset{};
60};
61
62struct ShaderEntries {
63 u32 const_buffers_base_binding{};
64 u32 global_buffers_base_binding{};
65 u32 samplers_base_binding{};
66 std::vector<ConstBufferEntry> const_buffers;
67 std::vector<GlobalBufferEntry> global_buffers;
68 std::vector<SamplerEntry> samplers;
69 std::set<u32> attributes;
70 std::array<bool, Maxwell::NumClipDistances> clip_distances{};
71 std::size_t shader_length{};
72 Sirit::Id entry_function{};
73 std::vector<Sirit::Id> interfaces;
74};
75
76using DecompilerResult = std::pair<std::unique_ptr<Sirit::Module>, ShaderEntries>;
77
78DecompilerResult Decompile(const VideoCommon::Shader::ShaderIR& ir, Maxwell::ShaderStage stage);
79
80} // namespace Vulkan::VKShader