summaryrefslogtreecommitdiff
path: root/src/shader_recompiler/backend/spirv/emit_context.cpp
diff options
context:
space:
mode:
authorGravatar bunnei2021-07-25 11:39:04 -0700
committerGravatar GitHub2021-07-25 11:39:04 -0700
commit98b26b6e126d4775fdf3f773fe8a8ac808a8ff8f (patch)
tree816faa96c2c4d291825063433331a8ea4b3d08f1 /src/shader_recompiler/backend/spirv/emit_context.cpp
parentMerge pull request #6699 from lat9nq/common-threads (diff)
parentshader: Support out of bound local memory reads and immediate writes (diff)
downloadyuzu-98b26b6e126d4775fdf3f773fe8a8ac808a8ff8f.tar.gz
yuzu-98b26b6e126d4775fdf3f773fe8a8ac808a8ff8f.tar.xz
yuzu-98b26b6e126d4775fdf3f773fe8a8ac808a8ff8f.zip
Merge pull request #6585 from ameerj/hades
Shader Decompiler Rewrite
Diffstat (limited to 'src/shader_recompiler/backend/spirv/emit_context.cpp')
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp1368
1 files changed, 1368 insertions, 0 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
new file mode 100644
index 000000000..2d29d8c14
--- /dev/null
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -0,0 +1,1368 @@
1// Copyright 2021 yuzu Emulator Project
2// Licensed under GPLv2 or any later version
3// Refer to the license.txt file included.
4
5#include <algorithm>
6#include <array>
7#include <climits>
8#include <string_view>
9
10#include <fmt/format.h>
11
12#include "common/common_types.h"
13#include "common/div_ceil.h"
14#include "shader_recompiler/backend/spirv/emit_context.h"
15
16namespace Shader::Backend::SPIRV {
17namespace {
18enum class Operation {
19 Increment,
20 Decrement,
21 FPAdd,
22 FPMin,
23 FPMax,
24};
25
26struct AttrInfo {
27 Id pointer;
28 Id id;
29 bool needs_cast;
30};
31
32Id ImageType(EmitContext& ctx, const TextureDescriptor& desc) {
33 const spv::ImageFormat format{spv::ImageFormat::Unknown};
34 const Id type{ctx.F32[1]};
35 const bool depth{desc.is_depth};
36 switch (desc.type) {
37 case TextureType::Color1D:
38 return ctx.TypeImage(type, spv::Dim::Dim1D, depth, false, false, 1, format);
39 case TextureType::ColorArray1D:
40 return ctx.TypeImage(type, spv::Dim::Dim1D, depth, true, false, 1, format);
41 case TextureType::Color2D:
42 return ctx.TypeImage(type, spv::Dim::Dim2D, depth, false, false, 1, format);
43 case TextureType::ColorArray2D:
44 return ctx.TypeImage(type, spv::Dim::Dim2D, depth, true, false, 1, format);
45 case TextureType::Color3D:
46 return ctx.TypeImage(type, spv::Dim::Dim3D, depth, false, false, 1, format);
47 case TextureType::ColorCube:
48 return ctx.TypeImage(type, spv::Dim::Cube, depth, false, false, 1, format);
49 case TextureType::ColorArrayCube:
50 return ctx.TypeImage(type, spv::Dim::Cube, depth, true, false, 1, format);
51 case TextureType::Buffer:
52 break;
53 }
54 throw InvalidArgument("Invalid texture type {}", desc.type);
55}
56
57spv::ImageFormat GetImageFormat(ImageFormat format) {
58 switch (format) {
59 case ImageFormat::Typeless:
60 return spv::ImageFormat::Unknown;
61 case ImageFormat::R8_UINT:
62 return spv::ImageFormat::R8ui;
63 case ImageFormat::R8_SINT:
64 return spv::ImageFormat::R8i;
65 case ImageFormat::R16_UINT:
66 return spv::ImageFormat::R16ui;
67 case ImageFormat::R16_SINT:
68 return spv::ImageFormat::R16i;
69 case ImageFormat::R32_UINT:
70 return spv::ImageFormat::R32ui;
71 case ImageFormat::R32G32_UINT:
72 return spv::ImageFormat::Rg32ui;
73 case ImageFormat::R32G32B32A32_UINT:
74 return spv::ImageFormat::Rgba32ui;
75 }
76 throw InvalidArgument("Invalid image format {}", format);
77}
78
79Id ImageType(EmitContext& ctx, const ImageDescriptor& desc) {
80 const spv::ImageFormat format{GetImageFormat(desc.format)};
81 const Id type{ctx.U32[1]};
82 switch (desc.type) {
83 case TextureType::Color1D:
84 return ctx.TypeImage(type, spv::Dim::Dim1D, false, false, false, 2, format);
85 case TextureType::ColorArray1D:
86 return ctx.TypeImage(type, spv::Dim::Dim1D, false, true, false, 2, format);
87 case TextureType::Color2D:
88 return ctx.TypeImage(type, spv::Dim::Dim2D, false, false, false, 2, format);
89 case TextureType::ColorArray2D:
90 return ctx.TypeImage(type, spv::Dim::Dim2D, false, true, false, 2, format);
91 case TextureType::Color3D:
92 return ctx.TypeImage(type, spv::Dim::Dim3D, false, false, false, 2, format);
93 case TextureType::Buffer:
94 throw NotImplementedException("Image buffer");
95 default:
96 break;
97 }
98 throw InvalidArgument("Invalid texture type {}", desc.type);
99}
100
101Id DefineVariable(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin,
102 spv::StorageClass storage_class) {
103 const Id pointer_type{ctx.TypePointer(storage_class, type)};
104 const Id id{ctx.AddGlobalVariable(pointer_type, storage_class)};
105 if (builtin) {
106 ctx.Decorate(id, spv::Decoration::BuiltIn, *builtin);
107 }
108 ctx.interfaces.push_back(id);
109 return id;
110}
111
112u32 NumVertices(InputTopology input_topology) {
113 switch (input_topology) {
114 case InputTopology::Points:
115 return 1;
116 case InputTopology::Lines:
117 return 2;
118 case InputTopology::LinesAdjacency:
119 return 4;
120 case InputTopology::Triangles:
121 return 3;
122 case InputTopology::TrianglesAdjacency:
123 return 6;
124 }
125 throw InvalidArgument("Invalid input topology {}", input_topology);
126}
127
128Id DefineInput(EmitContext& ctx, Id type, bool per_invocation,
129 std::optional<spv::BuiltIn> builtin = std::nullopt) {
130 switch (ctx.stage) {
131 case Stage::TessellationControl:
132 case Stage::TessellationEval:
133 if (per_invocation) {
134 type = ctx.TypeArray(type, ctx.Const(32u));
135 }
136 break;
137 case Stage::Geometry:
138 if (per_invocation) {
139 const u32 num_vertices{NumVertices(ctx.runtime_info.input_topology)};
140 type = ctx.TypeArray(type, ctx.Const(num_vertices));
141 }
142 break;
143 default:
144 break;
145 }
146 return DefineVariable(ctx, type, builtin, spv::StorageClass::Input);
147}
148
149Id DefineOutput(EmitContext& ctx, Id type, std::optional<u32> invocations,
150 std::optional<spv::BuiltIn> builtin = std::nullopt) {
151 if (invocations && ctx.stage == Stage::TessellationControl) {
152 type = ctx.TypeArray(type, ctx.Const(*invocations));
153 }
154 return DefineVariable(ctx, type, builtin, spv::StorageClass::Output);
155}
156
157void DefineGenericOutput(EmitContext& ctx, size_t index, std::optional<u32> invocations) {
158 static constexpr std::string_view swizzle{"xyzw"};
159 const size_t base_attr_index{static_cast<size_t>(IR::Attribute::Generic0X) + index * 4};
160 u32 element{0};
161 while (element < 4) {
162 const u32 remainder{4 - element};
163 const TransformFeedbackVarying* xfb_varying{};
164 if (!ctx.runtime_info.xfb_varyings.empty()) {
165 xfb_varying = &ctx.runtime_info.xfb_varyings[base_attr_index + element];
166 xfb_varying = xfb_varying && xfb_varying->components > 0 ? xfb_varying : nullptr;
167 }
168 const u32 num_components{xfb_varying ? xfb_varying->components : remainder};
169
170 const Id id{DefineOutput(ctx, ctx.F32[num_components], invocations)};
171 ctx.Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
172 if (element > 0) {
173 ctx.Decorate(id, spv::Decoration::Component, element);
174 }
175 if (xfb_varying) {
176 ctx.Decorate(id, spv::Decoration::XfbBuffer, xfb_varying->buffer);
177 ctx.Decorate(id, spv::Decoration::XfbStride, xfb_varying->stride);
178 ctx.Decorate(id, spv::Decoration::Offset, xfb_varying->offset);
179 }
180 if (num_components < 4 || element > 0) {
181 const std::string_view subswizzle{swizzle.substr(element, num_components)};
182 ctx.Name(id, fmt::format("out_attr{}_{}", index, subswizzle));
183 } else {
184 ctx.Name(id, fmt::format("out_attr{}", index));
185 }
186 const GenericElementInfo info{
187 .id = id,
188 .first_element = element,
189 .num_components = num_components,
190 };
191 std::fill_n(ctx.output_generics[index].begin() + element, num_components, info);
192 element += num_components;
193 }
194}
195
196Id GetAttributeType(EmitContext& ctx, AttributeType type) {
197 switch (type) {
198 case AttributeType::Float:
199 return ctx.F32[4];
200 case AttributeType::SignedInt:
201 return ctx.TypeVector(ctx.TypeInt(32, true), 4);
202 case AttributeType::UnsignedInt:
203 return ctx.U32[4];
204 case AttributeType::Disabled:
205 break;
206 }
207 throw InvalidArgument("Invalid attribute type {}", type);
208}
209
210std::optional<AttrInfo> AttrTypes(EmitContext& ctx, u32 index) {
211 const AttributeType type{ctx.runtime_info.generic_input_types.at(index)};
212 switch (type) {
213 case AttributeType::Float:
214 return AttrInfo{ctx.input_f32, ctx.F32[1], false};
215 case AttributeType::UnsignedInt:
216 return AttrInfo{ctx.input_u32, ctx.U32[1], true};
217 case AttributeType::SignedInt:
218 return AttrInfo{ctx.input_s32, ctx.TypeInt(32, true), true};
219 case AttributeType::Disabled:
220 return std::nullopt;
221 }
222 throw InvalidArgument("Invalid attribute type {}", type);
223}
224
225std::string_view StageName(Stage stage) {
226 switch (stage) {
227 case Stage::VertexA:
228 return "vs_a";
229 case Stage::VertexB:
230 return "vs";
231 case Stage::TessellationControl:
232 return "tcs";
233 case Stage::TessellationEval:
234 return "tes";
235 case Stage::Geometry:
236 return "gs";
237 case Stage::Fragment:
238 return "fs";
239 case Stage::Compute:
240 return "cs";
241 }
242 throw InvalidArgument("Invalid stage {}", stage);
243}
244
245template <typename... Args>
246void Name(EmitContext& ctx, Id object, std::string_view format_str, Args&&... args) {
247 ctx.Name(object, fmt::format(fmt::runtime(format_str), StageName(ctx.stage),
248 std::forward<Args>(args)...)
249 .c_str());
250}
251
252void DefineConstBuffers(EmitContext& ctx, const Info& info, Id UniformDefinitions::*member_type,
253 u32 binding, Id type, char type_char, u32 element_size) {
254 const Id array_type{ctx.TypeArray(type, ctx.Const(65536U / element_size))};
255 ctx.Decorate(array_type, spv::Decoration::ArrayStride, element_size);
256
257 const Id struct_type{ctx.TypeStruct(array_type)};
258 Name(ctx, struct_type, "{}_cbuf_block_{}{}", ctx.stage, type_char, element_size * CHAR_BIT);
259 ctx.Decorate(struct_type, spv::Decoration::Block);
260 ctx.MemberName(struct_type, 0, "data");
261 ctx.MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U);
262
263 const Id struct_pointer_type{ctx.TypePointer(spv::StorageClass::Uniform, struct_type)};
264 const Id uniform_type{ctx.TypePointer(spv::StorageClass::Uniform, type)};
265 ctx.uniform_types.*member_type = uniform_type;
266
267 for (const ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) {
268 const Id id{ctx.AddGlobalVariable(struct_pointer_type, spv::StorageClass::Uniform)};
269 ctx.Decorate(id, spv::Decoration::Binding, binding);
270 ctx.Decorate(id, spv::Decoration::DescriptorSet, 0U);
271 ctx.Name(id, fmt::format("c{}", desc.index));
272 for (size_t i = 0; i < desc.count; ++i) {
273 ctx.cbufs[desc.index + i].*member_type = id;
274 }
275 if (ctx.profile.supported_spirv >= 0x00010400) {
276 ctx.interfaces.push_back(id);
277 }
278 binding += desc.count;
279 }
280}
281
282void DefineSsbos(EmitContext& ctx, StorageTypeDefinition& type_def,
283 Id StorageDefinitions::*member_type, const Info& info, u32 binding, Id type,
284 u32 stride) {
285 const Id array_type{ctx.TypeRuntimeArray(type)};
286 ctx.Decorate(array_type, spv::Decoration::ArrayStride, stride);
287
288 const Id struct_type{ctx.TypeStruct(array_type)};
289 ctx.Decorate(struct_type, spv::Decoration::Block);
290 ctx.MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U);
291
292 const Id struct_pointer{ctx.TypePointer(spv::StorageClass::StorageBuffer, struct_type)};
293 type_def.array = struct_pointer;
294 type_def.element = ctx.TypePointer(spv::StorageClass::StorageBuffer, type);
295
296 u32 index{};
297 for (const StorageBufferDescriptor& desc : info.storage_buffers_descriptors) {
298 const Id id{ctx.AddGlobalVariable(struct_pointer, spv::StorageClass::StorageBuffer)};
299 ctx.Decorate(id, spv::Decoration::Binding, binding);
300 ctx.Decorate(id, spv::Decoration::DescriptorSet, 0U);
301 ctx.Name(id, fmt::format("ssbo{}", index));
302 if (ctx.profile.supported_spirv >= 0x00010400) {
303 ctx.interfaces.push_back(id);
304 }
305 for (size_t i = 0; i < desc.count; ++i) {
306 ctx.ssbos[index + i].*member_type = id;
307 }
308 index += desc.count;
309 binding += desc.count;
310 }
311}
312
313Id CasFunction(EmitContext& ctx, Operation operation, Id value_type) {
314 const Id func_type{ctx.TypeFunction(value_type, value_type, value_type)};
315 const Id func{ctx.OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)};
316 const Id op_a{ctx.OpFunctionParameter(value_type)};
317 const Id op_b{ctx.OpFunctionParameter(value_type)};
318 ctx.AddLabel();
319 Id result{};
320 switch (operation) {
321 case Operation::Increment: {
322 const Id pred{ctx.OpUGreaterThanEqual(ctx.U1, op_a, op_b)};
323 const Id incr{ctx.OpIAdd(value_type, op_a, ctx.Constant(value_type, 1))};
324 result = ctx.OpSelect(value_type, pred, ctx.u32_zero_value, incr);
325 break;
326 }
327 case Operation::Decrement: {
328 const Id lhs{ctx.OpIEqual(ctx.U1, op_a, ctx.Constant(value_type, 0u))};
329 const Id rhs{ctx.OpUGreaterThan(ctx.U1, op_a, op_b)};
330 const Id pred{ctx.OpLogicalOr(ctx.U1, lhs, rhs)};
331 const Id decr{ctx.OpISub(value_type, op_a, ctx.Constant(value_type, 1))};
332 result = ctx.OpSelect(value_type, pred, op_b, decr);
333 break;
334 }
335 case Operation::FPAdd:
336 result = ctx.OpFAdd(value_type, op_a, op_b);
337 break;
338 case Operation::FPMin:
339 result = ctx.OpFMin(value_type, op_a, op_b);
340 break;
341 case Operation::FPMax:
342 result = ctx.OpFMax(value_type, op_a, op_b);
343 break;
344 default:
345 break;
346 }
347 ctx.OpReturnValue(result);
348 ctx.OpFunctionEnd();
349 return func;
350}
351
352Id CasLoop(EmitContext& ctx, Operation operation, Id array_pointer, Id element_pointer,
353 Id value_type, Id memory_type, spv::Scope scope) {
354 const bool is_shared{scope == spv::Scope::Workgroup};
355 const bool is_struct{!is_shared || ctx.profile.support_explicit_workgroup_layout};
356 const Id cas_func{CasFunction(ctx, operation, value_type)};
357 const Id zero{ctx.u32_zero_value};
358 const Id scope_id{ctx.Const(static_cast<u32>(scope))};
359
360 const Id loop_header{ctx.OpLabel()};
361 const Id continue_block{ctx.OpLabel()};
362 const Id merge_block{ctx.OpLabel()};
363 const Id func_type{is_shared
364 ? ctx.TypeFunction(value_type, ctx.U32[1], value_type)
365 : ctx.TypeFunction(value_type, ctx.U32[1], value_type, array_pointer)};
366
367 const Id func{ctx.OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)};
368 const Id index{ctx.OpFunctionParameter(ctx.U32[1])};
369 const Id op_b{ctx.OpFunctionParameter(value_type)};
370 const Id base{is_shared ? ctx.shared_memory_u32 : ctx.OpFunctionParameter(array_pointer)};
371 ctx.AddLabel();
372 ctx.OpBranch(loop_header);
373 ctx.AddLabel(loop_header);
374
375 ctx.OpLoopMerge(merge_block, continue_block, spv::LoopControlMask::MaskNone);
376 ctx.OpBranch(continue_block);
377
378 ctx.AddLabel(continue_block);
379 const Id word_pointer{is_struct ? ctx.OpAccessChain(element_pointer, base, zero, index)
380 : ctx.OpAccessChain(element_pointer, base, index)};
381 if (value_type.value == ctx.F32[2].value) {
382 const Id u32_value{ctx.OpLoad(ctx.U32[1], word_pointer)};
383 const Id value{ctx.OpUnpackHalf2x16(ctx.F32[2], u32_value)};
384 const Id new_value{ctx.OpFunctionCall(value_type, cas_func, value, op_b)};
385 const Id u32_new_value{ctx.OpPackHalf2x16(ctx.U32[1], new_value)};
386 const Id atomic_res{ctx.OpAtomicCompareExchange(ctx.U32[1], word_pointer, scope_id, zero,
387 zero, u32_new_value, u32_value)};
388 const Id success{ctx.OpIEqual(ctx.U1, atomic_res, u32_value)};
389 ctx.OpBranchConditional(success, merge_block, loop_header);
390
391 ctx.AddLabel(merge_block);
392 ctx.OpReturnValue(ctx.OpUnpackHalf2x16(ctx.F32[2], atomic_res));
393 } else {
394 const Id value{ctx.OpLoad(memory_type, word_pointer)};
395 const bool matching_type{value_type.value == memory_type.value};
396 const Id bitcast_value{matching_type ? value : ctx.OpBitcast(value_type, value)};
397 const Id cal_res{ctx.OpFunctionCall(value_type, cas_func, bitcast_value, op_b)};
398 const Id new_value{matching_type ? cal_res : ctx.OpBitcast(memory_type, cal_res)};
399 const Id atomic_res{ctx.OpAtomicCompareExchange(ctx.U32[1], word_pointer, scope_id, zero,
400 zero, new_value, value)};
401 const Id success{ctx.OpIEqual(ctx.U1, atomic_res, value)};
402 ctx.OpBranchConditional(success, merge_block, loop_header);
403
404 ctx.AddLabel(merge_block);
405 ctx.OpReturnValue(ctx.OpBitcast(value_type, atomic_res));
406 }
407 ctx.OpFunctionEnd();
408 return func;
409}
410
411template <typename Desc>
412std::string NameOf(Stage stage, const Desc& desc, std::string_view prefix) {
413 if (desc.count > 1) {
414 return fmt::format("{}_{}{}_{:02x}x{}", StageName(stage), prefix, desc.cbuf_index,
415 desc.cbuf_offset, desc.count);
416 } else {
417 return fmt::format("{}_{}{}_{:02x}", StageName(stage), prefix, desc.cbuf_index,
418 desc.cbuf_offset);
419 }
420}
421
422Id DescType(EmitContext& ctx, Id sampled_type, Id pointer_type, u32 count) {
423 if (count > 1) {
424 const Id array_type{ctx.TypeArray(sampled_type, ctx.Const(count))};
425 return ctx.TypePointer(spv::StorageClass::UniformConstant, array_type);
426 } else {
427 return pointer_type;
428 }
429}
430} // Anonymous namespace
431
432void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) {
433 defs[0] = sirit_ctx.Name(base_type, name);
434
435 std::array<char, 6> def_name;
436 for (int i = 1; i < 4; ++i) {
437 const std::string_view def_name_view(
438 def_name.data(),
439 fmt::format_to_n(def_name.data(), def_name.size(), "{}x{}", name, i + 1).size);
440 defs[static_cast<size_t>(i)] =
441 sirit_ctx.Name(sirit_ctx.TypeVector(base_type, i + 1), def_name_view);
442 }
443}
444
445EmitContext::EmitContext(const Profile& profile_, const RuntimeInfo& runtime_info_,
446 IR::Program& program, Bindings& bindings)
447 : Sirit::Module(profile_.supported_spirv), profile{profile_},
448 runtime_info{runtime_info_}, stage{program.stage} {
449 const bool is_unified{profile.unified_descriptor_binding};
450 u32& uniform_binding{is_unified ? bindings.unified : bindings.uniform_buffer};
451 u32& storage_binding{is_unified ? bindings.unified : bindings.storage_buffer};
452 u32& texture_binding{is_unified ? bindings.unified : bindings.texture};
453 u32& image_binding{is_unified ? bindings.unified : bindings.image};
454 AddCapability(spv::Capability::Shader);
455 DefineCommonTypes(program.info);
456 DefineCommonConstants();
457 DefineInterfaces(program);
458 DefineLocalMemory(program);
459 DefineSharedMemory(program);
460 DefineSharedMemoryFunctions(program);
461 DefineConstantBuffers(program.info, uniform_binding);
462 DefineStorageBuffers(program.info, storage_binding);
463 DefineTextureBuffers(program.info, texture_binding);
464 DefineImageBuffers(program.info, image_binding);
465 DefineTextures(program.info, texture_binding);
466 DefineImages(program.info, image_binding);
467 DefineAttributeMemAccess(program.info);
468 DefineGlobalMemoryFunctions(program.info);
469}
470
471EmitContext::~EmitContext() = default;
472
473Id EmitContext::Def(const IR::Value& value) {
474 if (!value.IsImmediate()) {
475 return value.InstRecursive()->Definition<Id>();
476 }
477 switch (value.Type()) {
478 case IR::Type::Void:
479 // Void instructions are used for optional arguments (e.g. texture offsets)
480 // They are not meant to be used in the SPIR-V module
481 return Id{};
482 case IR::Type::U1:
483 return value.U1() ? true_value : false_value;
484 case IR::Type::U32:
485 return Const(value.U32());
486 case IR::Type::U64:
487 return Constant(U64, value.U64());
488 case IR::Type::F32:
489 return Const(value.F32());
490 case IR::Type::F64:
491 return Constant(F64[1], value.F64());
492 default:
493 throw NotImplementedException("Immediate type {}", value.Type());
494 }
495}
496
497Id EmitContext::BitOffset8(const IR::Value& offset) {
498 if (offset.IsImmediate()) {
499 return Const((offset.U32() % 4) * 8);
500 }
501 return OpBitwiseAnd(U32[1], OpShiftLeftLogical(U32[1], Def(offset), Const(3u)), Const(24u));
502}
503
504Id EmitContext::BitOffset16(const IR::Value& offset) {
505 if (offset.IsImmediate()) {
506 return Const(((offset.U32() / 2) % 2) * 16);
507 }
508 return OpBitwiseAnd(U32[1], OpShiftLeftLogical(U32[1], Def(offset), Const(3u)), Const(16u));
509}
510
511void EmitContext::DefineCommonTypes(const Info& info) {
512 void_id = TypeVoid();
513
514 U1 = Name(TypeBool(), "u1");
515
516 F32.Define(*this, TypeFloat(32), "f32");
517 U32.Define(*this, TypeInt(32, false), "u32");
518 S32.Define(*this, TypeInt(32, true), "s32");
519
520 private_u32 = Name(TypePointer(spv::StorageClass::Private, U32[1]), "private_u32");
521
522 input_f32 = Name(TypePointer(spv::StorageClass::Input, F32[1]), "input_f32");
523 input_u32 = Name(TypePointer(spv::StorageClass::Input, U32[1]), "input_u32");
524 input_s32 = Name(TypePointer(spv::StorageClass::Input, TypeInt(32, true)), "input_s32");
525
526 output_f32 = Name(TypePointer(spv::StorageClass::Output, F32[1]), "output_f32");
527 output_u32 = Name(TypePointer(spv::StorageClass::Output, U32[1]), "output_u32");
528
529 if (info.uses_int8 && profile.support_int8) {
530 AddCapability(spv::Capability::Int8);
531 U8 = Name(TypeInt(8, false), "u8");
532 S8 = Name(TypeInt(8, true), "s8");
533 }
534 if (info.uses_int16 && profile.support_int16) {
535 AddCapability(spv::Capability::Int16);
536 U16 = Name(TypeInt(16, false), "u16");
537 S16 = Name(TypeInt(16, true), "s16");
538 }
539 if (info.uses_int64) {
540 AddCapability(spv::Capability::Int64);
541 U64 = Name(TypeInt(64, false), "u64");
542 }
543 if (info.uses_fp16) {
544 AddCapability(spv::Capability::Float16);
545 F16.Define(*this, TypeFloat(16), "f16");
546 }
547 if (info.uses_fp64) {
548 AddCapability(spv::Capability::Float64);
549 F64.Define(*this, TypeFloat(64), "f64");
550 }
551}
552
553void EmitContext::DefineCommonConstants() {
554 true_value = ConstantTrue(U1);
555 false_value = ConstantFalse(U1);
556 u32_zero_value = Const(0U);
557 f32_zero_value = Const(0.0f);
558}
559
560void EmitContext::DefineInterfaces(const IR::Program& program) {
561 DefineInputs(program);
562 DefineOutputs(program);
563}
564
565void EmitContext::DefineLocalMemory(const IR::Program& program) {
566 if (program.local_memory_size == 0) {
567 return;
568 }
569 const u32 num_elements{Common::DivCeil(program.local_memory_size, 4U)};
570 const Id type{TypeArray(U32[1], Const(num_elements))};
571 const Id pointer{TypePointer(spv::StorageClass::Private, type)};
572 local_memory = AddGlobalVariable(pointer, spv::StorageClass::Private);
573 if (profile.supported_spirv >= 0x00010400) {
574 interfaces.push_back(local_memory);
575 }
576}
577
578void EmitContext::DefineSharedMemory(const IR::Program& program) {
579 if (program.shared_memory_size == 0) {
580 return;
581 }
582 const auto make{[&](Id element_type, u32 element_size) {
583 const u32 num_elements{Common::DivCeil(program.shared_memory_size, element_size)};
584 const Id array_type{TypeArray(element_type, Const(num_elements))};
585 Decorate(array_type, spv::Decoration::ArrayStride, element_size);
586
587 const Id struct_type{TypeStruct(array_type)};
588 MemberDecorate(struct_type, 0U, spv::Decoration::Offset, 0U);
589 Decorate(struct_type, spv::Decoration::Block);
590
591 const Id pointer{TypePointer(spv::StorageClass::Workgroup, struct_type)};
592 const Id element_pointer{TypePointer(spv::StorageClass::Workgroup, element_type)};
593 const Id variable{AddGlobalVariable(pointer, spv::StorageClass::Workgroup)};
594 Decorate(variable, spv::Decoration::Aliased);
595 interfaces.push_back(variable);
596
597 return std::make_tuple(variable, element_pointer, pointer);
598 }};
599 if (profile.support_explicit_workgroup_layout) {
600 AddExtension("SPV_KHR_workgroup_memory_explicit_layout");
601 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayoutKHR);
602 if (program.info.uses_int8) {
603 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
604 std::tie(shared_memory_u8, shared_u8, std::ignore) = make(U8, 1);
605 }
606 if (program.info.uses_int16) {
607 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
608 std::tie(shared_memory_u16, shared_u16, std::ignore) = make(U16, 2);
609 }
610 if (program.info.uses_int64) {
611 std::tie(shared_memory_u64, shared_u64, std::ignore) = make(U64, 8);
612 }
613 std::tie(shared_memory_u32, shared_u32, shared_memory_u32_type) = make(U32[1], 4);
614 std::tie(shared_memory_u32x2, shared_u32x2, std::ignore) = make(U32[2], 8);
615 std::tie(shared_memory_u32x4, shared_u32x4, std::ignore) = make(U32[4], 16);
616 return;
617 }
618 const u32 num_elements{Common::DivCeil(program.shared_memory_size, 4U)};
619 const Id type{TypeArray(U32[1], Const(num_elements))};
620 shared_memory_u32_type = TypePointer(spv::StorageClass::Workgroup, type);
621
622 shared_u32 = TypePointer(spv::StorageClass::Workgroup, U32[1]);
623 shared_memory_u32 = AddGlobalVariable(shared_memory_u32_type, spv::StorageClass::Workgroup);
624 interfaces.push_back(shared_memory_u32);
625
626 const Id func_type{TypeFunction(void_id, U32[1], U32[1])};
627 const auto make_function{[&](u32 mask, u32 size) {
628 const Id loop_header{OpLabel()};
629 const Id continue_block{OpLabel()};
630 const Id merge_block{OpLabel()};
631
632 const Id func{OpFunction(void_id, spv::FunctionControlMask::MaskNone, func_type)};
633 const Id offset{OpFunctionParameter(U32[1])};
634 const Id insert_value{OpFunctionParameter(U32[1])};
635 AddLabel();
636 OpBranch(loop_header);
637
638 AddLabel(loop_header);
639 const Id word_offset{OpShiftRightArithmetic(U32[1], offset, Const(2U))};
640 const Id shift_offset{OpShiftLeftLogical(U32[1], offset, Const(3U))};
641 const Id bit_offset{OpBitwiseAnd(U32[1], shift_offset, Const(mask))};
642 const Id count{Const(size)};
643 OpLoopMerge(merge_block, continue_block, spv::LoopControlMask::MaskNone);
644 OpBranch(continue_block);
645
646 AddLabel(continue_block);
647 const Id word_pointer{OpAccessChain(shared_u32, shared_memory_u32, word_offset)};
648 const Id old_value{OpLoad(U32[1], word_pointer)};
649 const Id new_value{OpBitFieldInsert(U32[1], old_value, insert_value, bit_offset, count)};
650 const Id atomic_res{OpAtomicCompareExchange(U32[1], word_pointer, Const(1U), u32_zero_value,
651 u32_zero_value, new_value, old_value)};
652 const Id success{OpIEqual(U1, atomic_res, old_value)};
653 OpBranchConditional(success, merge_block, loop_header);
654
655 AddLabel(merge_block);
656 OpReturn();
657 OpFunctionEnd();
658 return func;
659 }};
660 if (program.info.uses_int8) {
661 shared_store_u8_func = make_function(24, 8);
662 }
663 if (program.info.uses_int16) {
664 shared_store_u16_func = make_function(16, 16);
665 }
666}
667
668void EmitContext::DefineSharedMemoryFunctions(const IR::Program& program) {
669 if (program.info.uses_shared_increment) {
670 increment_cas_shared = CasLoop(*this, Operation::Increment, shared_memory_u32_type,
671 shared_u32, U32[1], U32[1], spv::Scope::Workgroup);
672 }
673 if (program.info.uses_shared_decrement) {
674 decrement_cas_shared = CasLoop(*this, Operation::Decrement, shared_memory_u32_type,
675 shared_u32, U32[1], U32[1], spv::Scope::Workgroup);
676 }
677}
678
679void EmitContext::DefineAttributeMemAccess(const Info& info) {
680 const auto make_load{[&] {
681 const bool is_array{stage == Stage::Geometry};
682 const Id end_block{OpLabel()};
683 const Id default_label{OpLabel()};
684
685 const Id func_type_load{is_array ? TypeFunction(F32[1], U32[1], U32[1])
686 : TypeFunction(F32[1], U32[1])};
687 const Id func{OpFunction(F32[1], spv::FunctionControlMask::MaskNone, func_type_load)};
688 const Id offset{OpFunctionParameter(U32[1])};
689 const Id vertex{is_array ? OpFunctionParameter(U32[1]) : Id{}};
690
691 AddLabel();
692 const Id base_index{OpShiftRightArithmetic(U32[1], offset, Const(2U))};
693 const Id masked_index{OpBitwiseAnd(U32[1], base_index, Const(3U))};
694 const Id compare_index{OpShiftRightArithmetic(U32[1], base_index, Const(2U))};
695 std::vector<Sirit::Literal> literals;
696 std::vector<Id> labels;
697 if (info.loads.AnyComponent(IR::Attribute::PositionX)) {
698 literals.push_back(static_cast<u32>(IR::Attribute::PositionX) >> 2);
699 labels.push_back(OpLabel());
700 }
701 const u32 base_attribute_value = static_cast<u32>(IR::Attribute::Generic0X) >> 2;
702 for (u32 index = 0; index < static_cast<u32>(IR::NUM_GENERICS); ++index) {
703 if (!info.loads.Generic(index)) {
704 continue;
705 }
706 literals.push_back(base_attribute_value + index);
707 labels.push_back(OpLabel());
708 }
709 OpSelectionMerge(end_block, spv::SelectionControlMask::MaskNone);
710 OpSwitch(compare_index, default_label, literals, labels);
711 AddLabel(default_label);
712 OpReturnValue(Const(0.0f));
713 size_t label_index{0};
714 if (info.loads.AnyComponent(IR::Attribute::PositionX)) {
715 AddLabel(labels[label_index]);
716 const Id pointer{is_array
717 ? OpAccessChain(input_f32, input_position, vertex, masked_index)
718 : OpAccessChain(input_f32, input_position, masked_index)};
719 const Id result{OpLoad(F32[1], pointer)};
720 OpReturnValue(result);
721 ++label_index;
722 }
723 for (size_t index = 0; index < IR::NUM_GENERICS; ++index) {
724 if (!info.loads.Generic(index)) {
725 continue;
726 }
727 AddLabel(labels[label_index]);
728 const auto type{AttrTypes(*this, static_cast<u32>(index))};
729 if (!type) {
730 OpReturnValue(Const(0.0f));
731 ++label_index;
732 continue;
733 }
734 const Id generic_id{input_generics.at(index)};
735 const Id pointer{is_array
736 ? OpAccessChain(type->pointer, generic_id, vertex, masked_index)
737 : OpAccessChain(type->pointer, generic_id, masked_index)};
738 const Id value{OpLoad(type->id, pointer)};
739 const Id result{type->needs_cast ? OpBitcast(F32[1], value) : value};
740 OpReturnValue(result);
741 ++label_index;
742 }
743 AddLabel(end_block);
744 OpUnreachable();
745 OpFunctionEnd();
746 return func;
747 }};
748 const auto make_store{[&] {
749 const Id end_block{OpLabel()};
750 const Id default_label{OpLabel()};
751
752 const Id func_type_store{TypeFunction(void_id, U32[1], F32[1])};
753 const Id func{OpFunction(void_id, spv::FunctionControlMask::MaskNone, func_type_store)};
754 const Id offset{OpFunctionParameter(U32[1])};
755 const Id store_value{OpFunctionParameter(F32[1])};
756 AddLabel();
757 const Id base_index{OpShiftRightArithmetic(U32[1], offset, Const(2U))};
758 const Id masked_index{OpBitwiseAnd(U32[1], base_index, Const(3U))};
759 const Id compare_index{OpShiftRightArithmetic(U32[1], base_index, Const(2U))};
760 std::vector<Sirit::Literal> literals;
761 std::vector<Id> labels;
762 if (info.stores.AnyComponent(IR::Attribute::PositionX)) {
763 literals.push_back(static_cast<u32>(IR::Attribute::PositionX) >> 2);
764 labels.push_back(OpLabel());
765 }
766 const u32 base_attribute_value = static_cast<u32>(IR::Attribute::Generic0X) >> 2;
767 for (size_t index = 0; index < IR::NUM_GENERICS; ++index) {
768 if (!info.stores.Generic(index)) {
769 continue;
770 }
771 literals.push_back(base_attribute_value + static_cast<u32>(index));
772 labels.push_back(OpLabel());
773 }
774 if (info.stores.ClipDistances()) {
775 literals.push_back(static_cast<u32>(IR::Attribute::ClipDistance0) >> 2);
776 labels.push_back(OpLabel());
777 literals.push_back(static_cast<u32>(IR::Attribute::ClipDistance4) >> 2);
778 labels.push_back(OpLabel());
779 }
780 OpSelectionMerge(end_block, spv::SelectionControlMask::MaskNone);
781 OpSwitch(compare_index, default_label, literals, labels);
782 AddLabel(default_label);
783 OpReturn();
784 size_t label_index{0};
785 if (info.stores.AnyComponent(IR::Attribute::PositionX)) {
786 AddLabel(labels[label_index]);
787 const Id pointer{OpAccessChain(output_f32, output_position, masked_index)};
788 OpStore(pointer, store_value);
789 OpReturn();
790 ++label_index;
791 }
792 for (size_t index = 0; index < IR::NUM_GENERICS; ++index) {
793 if (!info.stores.Generic(index)) {
794 continue;
795 }
796 if (output_generics[index][0].num_components != 4) {
797 throw NotImplementedException("Physical stores and transform feedbacks");
798 }
799 AddLabel(labels[label_index]);
800 const Id generic_id{output_generics[index][0].id};
801 const Id pointer{OpAccessChain(output_f32, generic_id, masked_index)};
802 OpStore(pointer, store_value);
803 OpReturn();
804 ++label_index;
805 }
806 if (info.stores.ClipDistances()) {
807 AddLabel(labels[label_index]);
808 const Id pointer{OpAccessChain(output_f32, clip_distances, masked_index)};
809 OpStore(pointer, store_value);
810 OpReturn();
811 ++label_index;
812 AddLabel(labels[label_index]);
813 const Id fixed_index{OpIAdd(U32[1], masked_index, Const(4U))};
814 const Id pointer2{OpAccessChain(output_f32, clip_distances, fixed_index)};
815 OpStore(pointer2, store_value);
816 OpReturn();
817 ++label_index;
818 }
819 AddLabel(end_block);
820 OpUnreachable();
821 OpFunctionEnd();
822 return func;
823 }};
824 if (info.loads_indexed_attributes) {
825 indexed_load_func = make_load();
826 }
827 if (info.stores_indexed_attributes) {
828 indexed_store_func = make_store();
829 }
830}
831
832void EmitContext::DefineGlobalMemoryFunctions(const Info& info) {
833 if (!info.uses_global_memory || !profile.support_int64) {
834 return;
835 }
836 using DefPtr = Id StorageDefinitions::*;
837 const Id zero{u32_zero_value};
838 const auto define_body{[&](DefPtr ssbo_member, Id addr, Id element_pointer, u32 shift,
839 auto&& callback) {
840 AddLabel();
841 const size_t num_buffers{info.storage_buffers_descriptors.size()};
842 for (size_t index = 0; index < num_buffers; ++index) {
843 if (!info.nvn_buffer_used[index]) {
844 continue;
845 }
846 const auto& ssbo{info.storage_buffers_descriptors[index]};
847 const Id ssbo_addr_cbuf_offset{Const(ssbo.cbuf_offset / 8)};
848 const Id ssbo_size_cbuf_offset{Const(ssbo.cbuf_offset / 4 + 2)};
849 const Id ssbo_addr_pointer{OpAccessChain(
850 uniform_types.U32x2, cbufs[ssbo.cbuf_index].U32x2, zero, ssbo_addr_cbuf_offset)};
851 const Id ssbo_size_pointer{OpAccessChain(uniform_types.U32, cbufs[ssbo.cbuf_index].U32,
852 zero, ssbo_size_cbuf_offset)};
853
854 const Id ssbo_addr{OpBitcast(U64, OpLoad(U32[2], ssbo_addr_pointer))};
855 const Id ssbo_size{OpUConvert(U64, OpLoad(U32[1], ssbo_size_pointer))};
856 const Id ssbo_end{OpIAdd(U64, ssbo_addr, ssbo_size)};
857 const Id cond{OpLogicalAnd(U1, OpUGreaterThanEqual(U1, addr, ssbo_addr),
858 OpULessThan(U1, addr, ssbo_end))};
859 const Id then_label{OpLabel()};
860 const Id else_label{OpLabel()};
861 OpSelectionMerge(else_label, spv::SelectionControlMask::MaskNone);
862 OpBranchConditional(cond, then_label, else_label);
863 AddLabel(then_label);
864 const Id ssbo_id{ssbos[index].*ssbo_member};
865 const Id ssbo_offset{OpUConvert(U32[1], OpISub(U64, addr, ssbo_addr))};
866 const Id ssbo_index{OpShiftRightLogical(U32[1], ssbo_offset, Const(shift))};
867 const Id ssbo_pointer{OpAccessChain(element_pointer, ssbo_id, zero, ssbo_index)};
868 callback(ssbo_pointer);
869 AddLabel(else_label);
870 }
871 }};
872 const auto define_load{[&](DefPtr ssbo_member, Id element_pointer, Id type, u32 shift) {
873 const Id function_type{TypeFunction(type, U64)};
874 const Id func_id{OpFunction(type, spv::FunctionControlMask::MaskNone, function_type)};
875 const Id addr{OpFunctionParameter(U64)};
876 define_body(ssbo_member, addr, element_pointer, shift,
877 [&](Id ssbo_pointer) { OpReturnValue(OpLoad(type, ssbo_pointer)); });
878 OpReturnValue(ConstantNull(type));
879 OpFunctionEnd();
880 return func_id;
881 }};
882 const auto define_write{[&](DefPtr ssbo_member, Id element_pointer, Id type, u32 shift) {
883 const Id function_type{TypeFunction(void_id, U64, type)};
884 const Id func_id{OpFunction(void_id, spv::FunctionControlMask::MaskNone, function_type)};
885 const Id addr{OpFunctionParameter(U64)};
886 const Id data{OpFunctionParameter(type)};
887 define_body(ssbo_member, addr, element_pointer, shift, [&](Id ssbo_pointer) {
888 OpStore(ssbo_pointer, data);
889 OpReturn();
890 });
891 OpReturn();
892 OpFunctionEnd();
893 return func_id;
894 }};
895 const auto define{
896 [&](DefPtr ssbo_member, const StorageTypeDefinition& type_def, Id type, size_t size) {
897 const Id element_type{type_def.element};
898 const u32 shift{static_cast<u32>(std::countr_zero(size))};
899 const Id load_func{define_load(ssbo_member, element_type, type, shift)};
900 const Id write_func{define_write(ssbo_member, element_type, type, shift)};
901 return std::make_pair(load_func, write_func);
902 }};
903 std::tie(load_global_func_u32, write_global_func_u32) =
904 define(&StorageDefinitions::U32, storage_types.U32, U32[1], sizeof(u32));
905 std::tie(load_global_func_u32x2, write_global_func_u32x2) =
906 define(&StorageDefinitions::U32x2, storage_types.U32x2, U32[2], sizeof(u32[2]));
907 std::tie(load_global_func_u32x4, write_global_func_u32x4) =
908 define(&StorageDefinitions::U32x4, storage_types.U32x4, U32[4], sizeof(u32[4]));
909}
910
911void EmitContext::DefineConstantBuffers(const Info& info, u32& binding) {
912 if (info.constant_buffer_descriptors.empty()) {
913 return;
914 }
915 if (!profile.support_descriptor_aliasing) {
916 DefineConstBuffers(*this, info, &UniformDefinitions::U32x4, binding, U32[4], 'u',
917 sizeof(u32[4]));
918 for (const ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) {
919 binding += desc.count;
920 }
921 return;
922 }
923 IR::Type types{info.used_constant_buffer_types};
924 if (True(types & IR::Type::U8)) {
925 if (profile.support_int8) {
926 DefineConstBuffers(*this, info, &UniformDefinitions::U8, binding, U8, 'u', sizeof(u8));
927 DefineConstBuffers(*this, info, &UniformDefinitions::S8, binding, S8, 's', sizeof(s8));
928 } else {
929 types |= IR::Type::U32;
930 }
931 }
932 if (True(types & IR::Type::U16)) {
933 if (profile.support_int16) {
934 DefineConstBuffers(*this, info, &UniformDefinitions::U16, binding, U16, 'u',
935 sizeof(u16));
936 DefineConstBuffers(*this, info, &UniformDefinitions::S16, binding, S16, 's',
937 sizeof(s16));
938 } else {
939 types |= IR::Type::U32;
940 }
941 }
942 if (True(types & IR::Type::U32)) {
943 DefineConstBuffers(*this, info, &UniformDefinitions::U32, binding, U32[1], 'u',
944 sizeof(u32));
945 }
946 if (True(types & IR::Type::F32)) {
947 DefineConstBuffers(*this, info, &UniformDefinitions::F32, binding, F32[1], 'f',
948 sizeof(f32));
949 }
950 if (True(types & IR::Type::U32x2)) {
951 DefineConstBuffers(*this, info, &UniformDefinitions::U32x2, binding, U32[2], 'u',
952 sizeof(u32[2]));
953 }
954 binding += static_cast<u32>(info.constant_buffer_descriptors.size());
955}
956
957void EmitContext::DefineStorageBuffers(const Info& info, u32& binding) {
958 if (info.storage_buffers_descriptors.empty()) {
959 return;
960 }
961 AddExtension("SPV_KHR_storage_buffer_storage_class");
962
963 const IR::Type used_types{profile.support_descriptor_aliasing ? info.used_storage_buffer_types
964 : IR::Type::U32};
965 if (profile.support_int8 && True(used_types & IR::Type::U8)) {
966 DefineSsbos(*this, storage_types.U8, &StorageDefinitions::U8, info, binding, U8,
967 sizeof(u8));
968 DefineSsbos(*this, storage_types.S8, &StorageDefinitions::S8, info, binding, S8,
969 sizeof(u8));
970 }
971 if (profile.support_int16 && True(used_types & IR::Type::U16)) {
972 DefineSsbos(*this, storage_types.U16, &StorageDefinitions::U16, info, binding, U16,
973 sizeof(u16));
974 DefineSsbos(*this, storage_types.S16, &StorageDefinitions::S16, info, binding, S16,
975 sizeof(u16));
976 }
977 if (True(used_types & IR::Type::U32)) {
978 DefineSsbos(*this, storage_types.U32, &StorageDefinitions::U32, info, binding, U32[1],
979 sizeof(u32));
980 }
981 if (True(used_types & IR::Type::F32)) {
982 DefineSsbos(*this, storage_types.F32, &StorageDefinitions::F32, info, binding, F32[1],
983 sizeof(f32));
984 }
985 if (True(used_types & IR::Type::U64)) {
986 DefineSsbos(*this, storage_types.U64, &StorageDefinitions::U64, info, binding, U64,
987 sizeof(u64));
988 }
989 if (True(used_types & IR::Type::U32x2)) {
990 DefineSsbos(*this, storage_types.U32x2, &StorageDefinitions::U32x2, info, binding, U32[2],
991 sizeof(u32[2]));
992 }
993 if (True(used_types & IR::Type::U32x4)) {
994 DefineSsbos(*this, storage_types.U32x4, &StorageDefinitions::U32x4, info, binding, U32[4],
995 sizeof(u32[4]));
996 }
997 for (const StorageBufferDescriptor& desc : info.storage_buffers_descriptors) {
998 binding += desc.count;
999 }
1000 const bool needs_function{
1001 info.uses_global_increment || info.uses_global_decrement || info.uses_atomic_f32_add ||
1002 info.uses_atomic_f16x2_add || info.uses_atomic_f16x2_min || info.uses_atomic_f16x2_max ||
1003 info.uses_atomic_f32x2_add || info.uses_atomic_f32x2_min || info.uses_atomic_f32x2_max};
1004 if (needs_function) {
1005 AddCapability(spv::Capability::VariablePointersStorageBuffer);
1006 }
1007 if (info.uses_global_increment) {
1008 increment_cas_ssbo = CasLoop(*this, Operation::Increment, storage_types.U32.array,
1009 storage_types.U32.element, U32[1], U32[1], spv::Scope::Device);
1010 }
1011 if (info.uses_global_decrement) {
1012 decrement_cas_ssbo = CasLoop(*this, Operation::Decrement, storage_types.U32.array,
1013 storage_types.U32.element, U32[1], U32[1], spv::Scope::Device);
1014 }
1015 if (info.uses_atomic_f32_add) {
1016 f32_add_cas = CasLoop(*this, Operation::FPAdd, storage_types.U32.array,
1017 storage_types.U32.element, F32[1], U32[1], spv::Scope::Device);
1018 }
1019 if (info.uses_atomic_f16x2_add) {
1020 f16x2_add_cas = CasLoop(*this, Operation::FPAdd, storage_types.U32.array,
1021 storage_types.U32.element, F16[2], F16[2], spv::Scope::Device);
1022 }
1023 if (info.uses_atomic_f16x2_min) {
1024 f16x2_min_cas = CasLoop(*this, Operation::FPMin, storage_types.U32.array,
1025 storage_types.U32.element, F16[2], F16[2], spv::Scope::Device);
1026 }
1027 if (info.uses_atomic_f16x2_max) {
1028 f16x2_max_cas = CasLoop(*this, Operation::FPMax, storage_types.U32.array,
1029 storage_types.U32.element, F16[2], F16[2], spv::Scope::Device);
1030 }
1031 if (info.uses_atomic_f32x2_add) {
1032 f32x2_add_cas = CasLoop(*this, Operation::FPAdd, storage_types.U32.array,
1033 storage_types.U32.element, F32[2], F32[2], spv::Scope::Device);
1034 }
1035 if (info.uses_atomic_f32x2_min) {
1036 f32x2_min_cas = CasLoop(*this, Operation::FPMin, storage_types.U32.array,
1037 storage_types.U32.element, F32[2], F32[2], spv::Scope::Device);
1038 }
1039 if (info.uses_atomic_f32x2_max) {
1040 f32x2_max_cas = CasLoop(*this, Operation::FPMax, storage_types.U32.array,
1041 storage_types.U32.element, F32[2], F32[2], spv::Scope::Device);
1042 }
1043}
1044
1045void EmitContext::DefineTextureBuffers(const Info& info, u32& binding) {
1046 if (info.texture_buffer_descriptors.empty()) {
1047 return;
1048 }
1049 const spv::ImageFormat format{spv::ImageFormat::Unknown};
1050 image_buffer_type = TypeImage(F32[1], spv::Dim::Buffer, 0U, false, false, 1, format);
1051 sampled_texture_buffer_type = TypeSampledImage(image_buffer_type);
1052
1053 const Id type{TypePointer(spv::StorageClass::UniformConstant, sampled_texture_buffer_type)};
1054 texture_buffers.reserve(info.texture_buffer_descriptors.size());
1055 for (const TextureBufferDescriptor& desc : info.texture_buffer_descriptors) {
1056 if (desc.count != 1) {
1057 throw NotImplementedException("Array of texture buffers");
1058 }
1059 const Id id{AddGlobalVariable(type, spv::StorageClass::UniformConstant)};
1060 Decorate(id, spv::Decoration::Binding, binding);
1061 Decorate(id, spv::Decoration::DescriptorSet, 0U);
1062 Name(id, NameOf(stage, desc, "texbuf"));
1063 texture_buffers.push_back({
1064 .id = id,
1065 .count = desc.count,
1066 });
1067 if (profile.supported_spirv >= 0x00010400) {
1068 interfaces.push_back(id);
1069 }
1070 ++binding;
1071 }
1072}
1073
1074void EmitContext::DefineImageBuffers(const Info& info, u32& binding) {
1075 image_buffers.reserve(info.image_buffer_descriptors.size());
1076 for (const ImageBufferDescriptor& desc : info.image_buffer_descriptors) {
1077 if (desc.count != 1) {
1078 throw NotImplementedException("Array of image buffers");
1079 }
1080 const spv::ImageFormat format{GetImageFormat(desc.format)};
1081 const Id image_type{TypeImage(U32[1], spv::Dim::Buffer, false, false, false, 2, format)};
1082 const Id pointer_type{TypePointer(spv::StorageClass::UniformConstant, image_type)};
1083 const Id id{AddGlobalVariable(pointer_type, spv::StorageClass::UniformConstant)};
1084 Decorate(id, spv::Decoration::Binding, binding);
1085 Decorate(id, spv::Decoration::DescriptorSet, 0U);
1086 Name(id, NameOf(stage, desc, "imgbuf"));
1087 image_buffers.push_back({
1088 .id = id,
1089 .image_type = image_type,
1090 .count = desc.count,
1091 });
1092 if (profile.supported_spirv >= 0x00010400) {
1093 interfaces.push_back(id);
1094 }
1095 ++binding;
1096 }
1097}
1098
1099void EmitContext::DefineTextures(const Info& info, u32& binding) {
1100 textures.reserve(info.texture_descriptors.size());
1101 for (const TextureDescriptor& desc : info.texture_descriptors) {
1102 const Id image_type{ImageType(*this, desc)};
1103 const Id sampled_type{TypeSampledImage(image_type)};
1104 const Id pointer_type{TypePointer(spv::StorageClass::UniformConstant, sampled_type)};
1105 const Id desc_type{DescType(*this, sampled_type, pointer_type, desc.count)};
1106 const Id id{AddGlobalVariable(desc_type, spv::StorageClass::UniformConstant)};
1107 Decorate(id, spv::Decoration::Binding, binding);
1108 Decorate(id, spv::Decoration::DescriptorSet, 0U);
1109 Name(id, NameOf(stage, desc, "tex"));
1110 textures.push_back({
1111 .id = id,
1112 .sampled_type = sampled_type,
1113 .pointer_type = pointer_type,
1114 .image_type = image_type,
1115 .count = desc.count,
1116 });
1117 if (profile.supported_spirv >= 0x00010400) {
1118 interfaces.push_back(id);
1119 }
1120 ++binding;
1121 }
1122 if (info.uses_atomic_image_u32) {
1123 image_u32 = TypePointer(spv::StorageClass::Image, U32[1]);
1124 }
1125}
1126
1127void EmitContext::DefineImages(const Info& info, u32& binding) {
1128 images.reserve(info.image_descriptors.size());
1129 for (const ImageDescriptor& desc : info.image_descriptors) {
1130 if (desc.count != 1) {
1131 throw NotImplementedException("Array of images");
1132 }
1133 const Id image_type{ImageType(*this, desc)};
1134 const Id pointer_type{TypePointer(spv::StorageClass::UniformConstant, image_type)};
1135 const Id id{AddGlobalVariable(pointer_type, spv::StorageClass::UniformConstant)};
1136 Decorate(id, spv::Decoration::Binding, binding);
1137 Decorate(id, spv::Decoration::DescriptorSet, 0U);
1138 Name(id, NameOf(stage, desc, "img"));
1139 images.push_back({
1140 .id = id,
1141 .image_type = image_type,
1142 .count = desc.count,
1143 });
1144 if (profile.supported_spirv >= 0x00010400) {
1145 interfaces.push_back(id);
1146 }
1147 ++binding;
1148 }
1149}
1150
1151void EmitContext::DefineInputs(const IR::Program& program) {
1152 const Info& info{program.info};
1153 const VaryingState loads{info.loads.mask | info.passthrough.mask};
1154
1155 if (info.uses_workgroup_id) {
1156 workgroup_id = DefineInput(*this, U32[3], false, spv::BuiltIn::WorkgroupId);
1157 }
1158 if (info.uses_local_invocation_id) {
1159 local_invocation_id = DefineInput(*this, U32[3], false, spv::BuiltIn::LocalInvocationId);
1160 }
1161 if (info.uses_invocation_id) {
1162 invocation_id = DefineInput(*this, U32[1], false, spv::BuiltIn::InvocationId);
1163 }
1164 if (info.uses_sample_id) {
1165 sample_id = DefineInput(*this, U32[1], false, spv::BuiltIn::SampleId);
1166 }
1167 if (info.uses_is_helper_invocation) {
1168 is_helper_invocation = DefineInput(*this, U1, false, spv::BuiltIn::HelperInvocation);
1169 }
1170 if (info.uses_subgroup_mask) {
1171 subgroup_mask_eq = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupEqMaskKHR);
1172 subgroup_mask_lt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLtMaskKHR);
1173 subgroup_mask_le = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLeMaskKHR);
1174 subgroup_mask_gt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGtMaskKHR);
1175 subgroup_mask_ge = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGeMaskKHR);
1176 }
1177 if (info.uses_subgroup_invocation_id || info.uses_subgroup_shuffles ||
1178 (profile.warp_size_potentially_larger_than_guest &&
1179 (info.uses_subgroup_vote || info.uses_subgroup_mask))) {
1180 subgroup_local_invocation_id =
1181 DefineInput(*this, U32[1], false, spv::BuiltIn::SubgroupLocalInvocationId);
1182 }
1183 if (info.uses_fswzadd) {
1184 const Id f32_one{Const(1.0f)};
1185 const Id f32_minus_one{Const(-1.0f)};
1186 const Id f32_zero{Const(0.0f)};
1187 fswzadd_lut_a = ConstantComposite(F32[4], f32_minus_one, f32_one, f32_minus_one, f32_zero);
1188 fswzadd_lut_b =
1189 ConstantComposite(F32[4], f32_minus_one, f32_minus_one, f32_one, f32_minus_one);
1190 }
1191 if (loads[IR::Attribute::PrimitiveId]) {
1192 primitive_id = DefineInput(*this, U32[1], false, spv::BuiltIn::PrimitiveId);
1193 }
1194 if (loads.AnyComponent(IR::Attribute::PositionX)) {
1195 const bool is_fragment{stage != Stage::Fragment};
1196 const spv::BuiltIn built_in{is_fragment ? spv::BuiltIn::Position : spv::BuiltIn::FragCoord};
1197 input_position = DefineInput(*this, F32[4], true, built_in);
1198 if (profile.support_geometry_shader_passthrough) {
1199 if (info.passthrough.AnyComponent(IR::Attribute::PositionX)) {
1200 Decorate(input_position, spv::Decoration::PassthroughNV);
1201 }
1202 }
1203 }
1204 if (loads[IR::Attribute::InstanceId]) {
1205 if (profile.support_vertex_instance_id) {
1206 instance_id = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceId);
1207 } else {
1208 instance_index = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceIndex);
1209 base_instance = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseInstance);
1210 }
1211 }
1212 if (loads[IR::Attribute::VertexId]) {
1213 if (profile.support_vertex_instance_id) {
1214 vertex_id = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexId);
1215 } else {
1216 vertex_index = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexIndex);
1217 base_vertex = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseVertex);
1218 }
1219 }
1220 if (loads[IR::Attribute::FrontFace]) {
1221 front_face = DefineInput(*this, U1, true, spv::BuiltIn::FrontFacing);
1222 }
1223 if (loads[IR::Attribute::PointSpriteS] || loads[IR::Attribute::PointSpriteT]) {
1224 point_coord = DefineInput(*this, F32[2], true, spv::BuiltIn::PointCoord);
1225 }
1226 if (loads[IR::Attribute::TessellationEvaluationPointU] ||
1227 loads[IR::Attribute::TessellationEvaluationPointV]) {
1228 tess_coord = DefineInput(*this, F32[3], false, spv::BuiltIn::TessCoord);
1229 }
1230 for (size_t index = 0; index < IR::NUM_GENERICS; ++index) {
1231 const AttributeType input_type{runtime_info.generic_input_types[index]};
1232 if (!runtime_info.previous_stage_stores.Generic(index)) {
1233 continue;
1234 }
1235 if (!loads.Generic(index)) {
1236 continue;
1237 }
1238 if (input_type == AttributeType::Disabled) {
1239 continue;
1240 }
1241 const Id type{GetAttributeType(*this, input_type)};
1242 const Id id{DefineInput(*this, type, true)};
1243 Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
1244 Name(id, fmt::format("in_attr{}", index));
1245 input_generics[index] = id;
1246
1247 if (info.passthrough.Generic(index) && profile.support_geometry_shader_passthrough) {
1248 Decorate(id, spv::Decoration::PassthroughNV);
1249 }
1250 if (stage != Stage::Fragment) {
1251 continue;
1252 }
1253 switch (info.interpolation[index]) {
1254 case Interpolation::Smooth:
1255 // Default
1256 // Decorate(id, spv::Decoration::Smooth);
1257 break;
1258 case Interpolation::NoPerspective:
1259 Decorate(id, spv::Decoration::NoPerspective);
1260 break;
1261 case Interpolation::Flat:
1262 Decorate(id, spv::Decoration::Flat);
1263 break;
1264 }
1265 }
1266 if (stage == Stage::TessellationEval) {
1267 for (size_t index = 0; index < info.uses_patches.size(); ++index) {
1268 if (!info.uses_patches[index]) {
1269 continue;
1270 }
1271 const Id id{DefineInput(*this, F32[4], false)};
1272 Decorate(id, spv::Decoration::Patch);
1273 Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
1274 patches[index] = id;
1275 }
1276 }
1277}
1278
1279void EmitContext::DefineOutputs(const IR::Program& program) {
1280 const Info& info{program.info};
1281 const std::optional<u32> invocations{program.invocations};
1282 if (info.stores.AnyComponent(IR::Attribute::PositionX) || stage == Stage::VertexB) {
1283 output_position = DefineOutput(*this, F32[4], invocations, spv::BuiltIn::Position);
1284 }
1285 if (info.stores[IR::Attribute::PointSize] || runtime_info.fixed_state_point_size) {
1286 if (stage == Stage::Fragment) {
1287 throw NotImplementedException("Storing PointSize in fragment stage");
1288 }
1289 output_point_size = DefineOutput(*this, F32[1], invocations, spv::BuiltIn::PointSize);
1290 }
1291 if (info.stores.ClipDistances()) {
1292 if (stage == Stage::Fragment) {
1293 throw NotImplementedException("Storing ClipDistance in fragment stage");
1294 }
1295 const Id type{TypeArray(F32[1], Const(8U))};
1296 clip_distances = DefineOutput(*this, type, invocations, spv::BuiltIn::ClipDistance);
1297 }
1298 if (info.stores[IR::Attribute::Layer] &&
1299 (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
1300 if (stage == Stage::Fragment) {
1301 throw NotImplementedException("Storing Layer in fragment stage");
1302 }
1303 layer = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::Layer);
1304 }
1305 if (info.stores[IR::Attribute::ViewportIndex] &&
1306 (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
1307 if (stage == Stage::Fragment) {
1308 throw NotImplementedException("Storing ViewportIndex in fragment stage");
1309 }
1310 viewport_index = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::ViewportIndex);
1311 }
1312 if (info.stores[IR::Attribute::ViewportMask] && profile.support_viewport_mask) {
1313 viewport_mask = DefineOutput(*this, TypeArray(U32[1], Const(1u)), std::nullopt,
1314 spv::BuiltIn::ViewportMaskNV);
1315 }
1316 for (size_t index = 0; index < IR::NUM_GENERICS; ++index) {
1317 if (info.stores.Generic(index)) {
1318 DefineGenericOutput(*this, index, invocations);
1319 }
1320 }
1321 switch (stage) {
1322 case Stage::TessellationControl:
1323 if (info.stores_tess_level_outer) {
1324 const Id type{TypeArray(F32[1], Const(4U))};
1325 output_tess_level_outer =
1326 DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelOuter);
1327 Decorate(output_tess_level_outer, spv::Decoration::Patch);
1328 }
1329 if (info.stores_tess_level_inner) {
1330 const Id type{TypeArray(F32[1], Const(2U))};
1331 output_tess_level_inner =
1332 DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelInner);
1333 Decorate(output_tess_level_inner, spv::Decoration::Patch);
1334 }
1335 for (size_t index = 0; index < info.uses_patches.size(); ++index) {
1336 if (!info.uses_patches[index]) {
1337 continue;
1338 }
1339 const Id id{DefineOutput(*this, F32[4], std::nullopt)};
1340 Decorate(id, spv::Decoration::Patch);
1341 Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
1342 patches[index] = id;
1343 }
1344 break;
1345 case Stage::Fragment:
1346 for (u32 index = 0; index < 8; ++index) {
1347 if (!info.stores_frag_color[index] && !profile.need_declared_frag_colors) {
1348 continue;
1349 }
1350 frag_color[index] = DefineOutput(*this, F32[4], std::nullopt);
1351 Decorate(frag_color[index], spv::Decoration::Location, index);
1352 Name(frag_color[index], fmt::format("frag_color{}", index));
1353 }
1354 if (info.stores_frag_depth) {
1355 frag_depth = DefineOutput(*this, F32[1], std::nullopt);
1356 Decorate(frag_depth, spv::Decoration::BuiltIn, spv::BuiltIn::FragDepth);
1357 }
1358 if (info.stores_sample_mask) {
1359 sample_mask = DefineOutput(*this, U32[1], std::nullopt);
1360 Decorate(sample_mask, spv::Decoration::BuiltIn, spv::BuiltIn::SampleMask);
1361 }
1362 break;
1363 default:
1364 break;
1365 }
1366}
1367
1368} // namespace Shader::Backend::SPIRV