summaryrefslogtreecommitdiff
path: root/src/shader_recompiler/backend/spirv/emit_context.cpp
diff options
context:
space:
mode:
authorGravatar ameerj2021-04-11 02:07:02 -0400
committerGravatar ameerj2021-07-22 21:51:27 -0400
commit3db2b3effa953ae66457b7a19b419fc4db2c4801 (patch)
tree04c73897a74be053a610edf60703c72e985ee590 /src/shader_recompiler/backend/spirv/emit_context.cpp
parentnsight_aftermath_tracker: Report used shaders to Nsight Aftermath (diff)
downloadyuzu-3db2b3effa953ae66457b7a19b419fc4db2c4801.tar.gz
yuzu-3db2b3effa953ae66457b7a19b419fc4db2c4801.tar.xz
yuzu-3db2b3effa953ae66457b7a19b419fc4db2c4801.zip
shader: Implement ATOM/S and RED
Diffstat (limited to 'src/shader_recompiler/backend/spirv/emit_context.cpp')
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp158
1 files changed, 154 insertions, 4 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 32f8c4508..e5d83e9b4 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -15,6 +15,53 @@
15 15
16namespace Shader::Backend::SPIRV { 16namespace Shader::Backend::SPIRV {
17namespace { 17namespace {
18enum class CasFunctionType {
19 Increment,
20 Decrement,
21 FPAdd,
22 FPMin,
23 FPMax,
24};
25
26Id CasFunction(EmitContext& ctx, CasFunctionType function_type, Id value_type) {
27 const Id func_type{ctx.TypeFunction(value_type, value_type, value_type)};
28 const Id func{ctx.OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)};
29 const Id op_a{ctx.OpFunctionParameter(value_type)};
30 const Id op_b{ctx.OpFunctionParameter(value_type)};
31 ctx.AddLabel();
32 Id result{};
33 switch (function_type) {
34 case CasFunctionType::Increment: {
35 const Id pred{ctx.OpUGreaterThanEqual(ctx.U1, op_a, op_b)};
36 const Id incr{ctx.OpIAdd(value_type, op_a, ctx.Constant(value_type, 1))};
37 result = ctx.OpSelect(value_type, pred, ctx.u32_zero_value, incr);
38 break;
39 }
40 case CasFunctionType::Decrement: {
41 const Id lhs{ctx.OpIEqual(ctx.U1, op_a, ctx.Constant(value_type, 0u))};
42 const Id rhs{ctx.OpUGreaterThan(ctx.U1, op_a, op_b)};
43 const Id pred{ctx.OpLogicalOr(ctx.U1, lhs, rhs)};
44 const Id decr{ctx.OpISub(value_type, op_a, ctx.Constant(value_type, 1))};
45 result = ctx.OpSelect(value_type, pred, op_b, decr);
46 break;
47 }
48 case CasFunctionType::FPAdd:
49 result = ctx.OpFAdd(value_type, op_a, op_b);
50 break;
51 case CasFunctionType::FPMin:
52 result = ctx.OpFMin(value_type, op_a, op_b);
53 break;
54 case CasFunctionType::FPMax:
55 result = ctx.OpFMax(value_type, op_a, op_b);
56 break;
57 default:
58 break;
59 }
60 ctx.OpReturnValue(result);
61 ctx.OpFunctionEnd();
62 return func;
63}
64
18Id ImageType(EmitContext& ctx, const TextureDescriptor& desc) { 65Id ImageType(EmitContext& ctx, const TextureDescriptor& desc) {
19 const spv::ImageFormat format{spv::ImageFormat::Unknown}; 66 const spv::ImageFormat format{spv::ImageFormat::Unknown};
20 const Id type{ctx.F32[1]}; 67 const Id type{ctx.F32[1]};
@@ -196,6 +243,56 @@ Id EmitContext::Def(const IR::Value& value) {
196 } 243 }
197} 244}
198 245
246Id EmitContext::CasLoop(Id function, CasPointerType pointer_type, Id value_type) {
247 const Id loop_header{OpLabel()};
248 const Id continue_block{OpLabel()};
249 const Id merge_block{OpLabel()};
250 const Id storage_type{pointer_type == CasPointerType::Shared ? shared_memory_u32_type
251 : storage_memory_u32};
252 const Id func_type{TypeFunction(value_type, U32[1], value_type, storage_type)};
253 const Id func{OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)};
254 const Id index{OpFunctionParameter(U32[1])};
255 const Id op_b{OpFunctionParameter(value_type)};
256 const Id base{OpFunctionParameter(storage_type)};
257 AddLabel();
258 const Id one{Constant(U32[1], 1)};
259 OpBranch(loop_header);
260 AddLabel(loop_header);
261 OpLoopMerge(merge_block, continue_block, spv::LoopControlMask::MaskNone);
262 OpBranch(continue_block);
263
264 AddLabel(continue_block);
265 const Id word_pointer{pointer_type == CasPointerType::Shared
266 ? OpAccessChain(shared_u32, base, index)
267 : OpAccessChain(storage_u32, base, u32_zero_value, index)};
268 if (value_type.value == F32[2].value) {
269 const Id u32_value{OpLoad(U32[1], word_pointer)};
270 const Id value{OpUnpackHalf2x16(F32[2], u32_value)};
271 const Id new_value{OpFunctionCall(value_type, function, value, op_b)};
272 const Id u32_new_value{OpPackHalf2x16(U32[1], new_value)};
273 const Id atomic_res{OpAtomicCompareExchange(U32[1], word_pointer, one, u32_zero_value,
274 u32_zero_value, u32_new_value, u32_value)};
275 const Id success{OpIEqual(U1, atomic_res, u32_value)};
276 OpBranchConditional(success, merge_block, loop_header);
277
278 AddLabel(merge_block);
279 OpReturnValue(OpUnpackHalf2x16(F32[2], atomic_res));
280 } else {
281 const Id value{OpLoad(U32[1], word_pointer)};
282 const Id new_value{OpBitcast(
283 U32[1], OpFunctionCall(value_type, function, OpBitcast(value_type, value), op_b))};
284 const Id atomic_res{OpAtomicCompareExchange(U32[1], word_pointer, one, u32_zero_value,
285 u32_zero_value, new_value, value)};
286 const Id success{OpIEqual(U1, atomic_res, value)};
287 OpBranchConditional(success, merge_block, loop_header);
288
289 AddLabel(merge_block);
290 OpReturnValue(OpBitcast(value_type, atomic_res));
291 }
292 OpFunctionEnd();
293 return func;
294}
295
199void EmitContext::DefineCommonTypes(const Info& info) { 296void EmitContext::DefineCommonTypes(const Info& info) {
200 void_id = TypeVoid(); 297 void_id = TypeVoid();
201 298
@@ -300,9 +397,9 @@ void EmitContext::DefineSharedMemory(const IR::Program& program) {
300 } 397 }
301 const u32 num_elements{Common::DivCeil(program.shared_memory_size, 4U)}; 398 const u32 num_elements{Common::DivCeil(program.shared_memory_size, 4U)};
302 const Id type{TypeArray(U32[1], Constant(U32[1], num_elements))}; 399 const Id type{TypeArray(U32[1], Constant(U32[1], num_elements))};
303 const Id pointer_type{TypePointer(spv::StorageClass::Workgroup, type)}; 400 shared_memory_u32_type = TypePointer(spv::StorageClass::Workgroup, type);
304 shared_u32 = TypePointer(spv::StorageClass::Workgroup, U32[1]); 401 shared_u32 = TypePointer(spv::StorageClass::Workgroup, U32[1]);
305 shared_memory_u32 = AddGlobalVariable(pointer_type, spv::StorageClass::Workgroup); 402 shared_memory_u32 = AddGlobalVariable(shared_memory_u32_type, spv::StorageClass::Workgroup);
306 interfaces.push_back(shared_memory_u32); 403 interfaces.push_back(shared_memory_u32);
307 404
308 const Id func_type{TypeFunction(void_id, U32[1], U32[1])}; 405 const Id func_type{TypeFunction(void_id, U32[1], U32[1])};
@@ -346,6 +443,14 @@ void EmitContext::DefineSharedMemory(const IR::Program& program) {
346 if (program.info.uses_int16) { 443 if (program.info.uses_int16) {
347 shared_store_u16_func = make_function(16, 16); 444 shared_store_u16_func = make_function(16, 16);
348 } 445 }
446 if (program.info.uses_shared_increment) {
447 const Id inc_func{CasFunction(*this, CasFunctionType::Increment, U32[1])};
448 increment_cas_shared = CasLoop(inc_func, CasPointerType::Shared, U32[1]);
449 }
450 if (program.info.uses_shared_decrement) {
451 const Id dec_func{CasFunction(*this, CasFunctionType::Decrement, U32[1])};
452 decrement_cas_shared = CasLoop(dec_func, CasPointerType::Shared, U32[1]);
453 }
349} 454}
350 455
351void EmitContext::DefineAttributeMemAccess(const Info& info) { 456void EmitContext::DefineAttributeMemAccess(const Info& info) {
@@ -530,12 +635,12 @@ void EmitContext::DefineStorageBuffers(const Info& info, u32& binding) {
530 MemberName(struct_type, 0, "data"); 635 MemberName(struct_type, 0, "data");
531 MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U); 636 MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U);
532 637
533 const Id storage_type{TypePointer(spv::StorageClass::StorageBuffer, struct_type)}; 638 storage_memory_u32 = TypePointer(spv::StorageClass::StorageBuffer, struct_type);
534 storage_u32 = TypePointer(spv::StorageClass::StorageBuffer, U32[1]); 639 storage_u32 = TypePointer(spv::StorageClass::StorageBuffer, U32[1]);
535 640
536 u32 index{}; 641 u32 index{};
537 for (const StorageBufferDescriptor& desc : info.storage_buffers_descriptors) { 642 for (const StorageBufferDescriptor& desc : info.storage_buffers_descriptors) {
538 const Id id{AddGlobalVariable(storage_type, spv::StorageClass::StorageBuffer)}; 643 const Id id{AddGlobalVariable(storage_memory_u32, spv::StorageClass::StorageBuffer)};
539 Decorate(id, spv::Decoration::Binding, binding); 644 Decorate(id, spv::Decoration::Binding, binding);
540 Decorate(id, spv::Decoration::DescriptorSet, 0U); 645 Decorate(id, spv::Decoration::DescriptorSet, 0U);
541 Name(id, fmt::format("ssbo{}", index)); 646 Name(id, fmt::format("ssbo{}", index));
@@ -546,6 +651,51 @@ void EmitContext::DefineStorageBuffers(const Info& info, u32& binding) {
546 index += desc.count; 651 index += desc.count;
547 binding += desc.count; 652 binding += desc.count;
548 } 653 }
654 if (info.uses_global_increment) {
655 AddCapability(spv::Capability::VariablePointersStorageBuffer);
656 const Id inc_func{CasFunction(*this, CasFunctionType::Increment, U32[1])};
657 increment_cas_ssbo = CasLoop(inc_func, CasPointerType::Ssbo, U32[1]);
658 }
659 if (info.uses_global_decrement) {
660 AddCapability(spv::Capability::VariablePointersStorageBuffer);
661 const Id dec_func{CasFunction(*this, CasFunctionType::Decrement, U32[1])};
662 decrement_cas_ssbo = CasLoop(dec_func, CasPointerType::Ssbo, U32[1]);
663 }
664 if (info.uses_atomic_f32_add) {
665 AddCapability(spv::Capability::VariablePointersStorageBuffer);
666 const Id add_func{CasFunction(*this, CasFunctionType::FPAdd, F32[1])};
667 f32_add_cas = CasLoop(add_func, CasPointerType::Ssbo, F32[1]);
668 }
669 if (info.uses_atomic_f16x2_add) {
670 AddCapability(spv::Capability::VariablePointersStorageBuffer);
671 const Id add_func{CasFunction(*this, CasFunctionType::FPAdd, F16[2])};
672 f16x2_add_cas = CasLoop(add_func, CasPointerType::Ssbo, F16[2]);
673 }
674 if (info.uses_atomic_f16x2_min) {
675 AddCapability(spv::Capability::VariablePointersStorageBuffer);
676 const Id func{CasFunction(*this, CasFunctionType::FPMin, F16[2])};
677 f16x2_min_cas = CasLoop(func, CasPointerType::Ssbo, F16[2]);
678 }
679 if (info.uses_atomic_f16x2_max) {
680 AddCapability(spv::Capability::VariablePointersStorageBuffer);
681 const Id func{CasFunction(*this, CasFunctionType::FPMax, F16[2])};
682 f16x2_max_cas = CasLoop(func, CasPointerType::Ssbo, F16[2]);
683 }
684 if (info.uses_atomic_f32x2_add) {
685 AddCapability(spv::Capability::VariablePointersStorageBuffer);
686 const Id add_func{CasFunction(*this, CasFunctionType::FPAdd, F32[2])};
687 f32x2_add_cas = CasLoop(add_func, CasPointerType::Ssbo, F32[2]);
688 }
689 if (info.uses_atomic_f32x2_min) {
690 AddCapability(spv::Capability::VariablePointersStorageBuffer);
691 const Id func{CasFunction(*this, CasFunctionType::FPMin, F32[2])};
692 f32x2_min_cas = CasLoop(func, CasPointerType::Ssbo, F32[2]);
693 }
694 if (info.uses_atomic_f32x2_max) {
695 AddCapability(spv::Capability::VariablePointersStorageBuffer);
696 const Id func{CasFunction(*this, CasFunctionType::FPMax, F32[2])};
697 f32x2_max_cas = CasLoop(func, CasPointerType::Ssbo, F32[2]);
698 }
549} 699}
550 700
551void EmitContext::DefineTextureBuffers(const Info& info, u32& binding) { 701void EmitContext::DefineTextureBuffers(const Info& info, u32& binding) {