summaryrefslogtreecommitdiff
path: root/src/shader_recompiler/backend
diff options
context:
space:
mode:
authorGravatar ReinUsesLisp2021-04-04 05:17:17 -0300
committerGravatar ameerj2021-07-22 21:51:26 -0400
commitda6cf2632cd4dc0d2b0278353fcaee0789b418c0 (patch)
tree90c2d6f6fa724365a4a23c888389e525e316a4fd /src/shader_recompiler/backend
parentshader: Implement BAR and fix memory barriers (diff)
downloadyuzu-da6cf2632cd4dc0d2b0278353fcaee0789b418c0.tar.gz
yuzu-da6cf2632cd4dc0d2b0278353fcaee0789b418c0.tar.xz
yuzu-da6cf2632cd4dc0d2b0278353fcaee0789b418c0.zip
shader: Add subgroup masks
Diffstat (limited to 'src/shader_recompiler/backend')
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp10
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.h5
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv.h5
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp46
4 files changed, 56 insertions, 10 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index e70b78a28..5ef637fe7 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -390,8 +390,16 @@ void EmitContext::DefineInputs(const Info& info) {
390 if (info.uses_local_invocation_id) { 390 if (info.uses_local_invocation_id) {
391 local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId); 391 local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId);
392 } 392 }
393 if (info.uses_subgroup_mask) {
394 subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR);
395 subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR);
396 subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR);
397 subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR);
398 subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR);
399 }
393 if (info.uses_subgroup_invocation_id || 400 if (info.uses_subgroup_invocation_id ||
394 (profile.warp_size_potentially_larger_than_guest && info.uses_subgroup_vote)) { 401 (profile.warp_size_potentially_larger_than_guest &&
402 (info.uses_subgroup_vote || info.uses_subgroup_mask))) {
395 subgroup_local_invocation_id = 403 subgroup_local_invocation_id =
396 DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId); 404 DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId);
397 } 405 }
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h
index 3a686a78c..03c5a6aba 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.h
+++ b/src/shader_recompiler/backend/spirv/emit_context.h
@@ -97,6 +97,11 @@ public:
97 Id workgroup_id{}; 97 Id workgroup_id{};
98 Id local_invocation_id{}; 98 Id local_invocation_id{};
99 Id subgroup_local_invocation_id{}; 99 Id subgroup_local_invocation_id{};
100 Id subgroup_mask_eq{};
101 Id subgroup_mask_lt{};
102 Id subgroup_mask_le{};
103 Id subgroup_mask_gt{};
104 Id subgroup_mask_ge{};
100 Id instance_id{}; 105 Id instance_id{};
101 Id instance_index{}; 106 Id instance_index{};
102 Id base_instance{}; 107 Id base_instance{};
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h
index 032b0b2f9..712c5e61f 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.h
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.h
@@ -401,6 +401,11 @@ Id EmitVoteAll(EmitContext& ctx, Id pred);
401Id EmitVoteAny(EmitContext& ctx, Id pred); 401Id EmitVoteAny(EmitContext& ctx, Id pred);
402Id EmitVoteEqual(EmitContext& ctx, Id pred); 402Id EmitVoteEqual(EmitContext& ctx, Id pred);
403Id EmitSubgroupBallot(EmitContext& ctx, Id pred); 403Id EmitSubgroupBallot(EmitContext& ctx, Id pred);
404Id EmitSubgroupEqMask(EmitContext& ctx);
405Id EmitSubgroupLtMask(EmitContext& ctx);
406Id EmitSubgroupLeMask(EmitContext& ctx);
407Id EmitSubgroupGtMask(EmitContext& ctx);
408Id EmitSubgroupGeMask(EmitContext& ctx);
404Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, 409Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp,
405 Id segmentation_mask); 410 Id segmentation_mask);
406Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, 411Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp,
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp
index cbc5b1c96..c57bd291d 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp
@@ -6,10 +6,18 @@
6 6
7namespace Shader::Backend::SPIRV { 7namespace Shader::Backend::SPIRV {
8namespace { 8namespace {
9Id LargeWarpBallot(EmitContext& ctx, Id ballot) { 9Id WarpExtract(EmitContext& ctx, Id value) {
10 const Id shift{ctx.Constant(ctx.U32[1], 5)}; 10 const Id shift{ctx.Constant(ctx.U32[1], 5)};
11 const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; 11 const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)};
12 return ctx.OpVectorExtractDynamic(ctx.U32[1], ballot, local_index); 12 return ctx.OpVectorExtractDynamic(ctx.U32[1], value, local_index);
13}
14
15Id LoadMask(EmitContext& ctx, Id mask) {
16 const Id value{ctx.OpLoad(ctx.U32[4], mask)};
17 if (!ctx.profile.warp_size_potentially_larger_than_guest) {
18 return ctx.OpCompositeExtract(ctx.U32[1], value, 0U);
19 }
20 return WarpExtract(ctx, value);
13} 21}
14 22
15void SetInBoundsFlag(IR::Inst* inst, Id result) { 23void SetInBoundsFlag(IR::Inst* inst, Id result) {
@@ -47,8 +55,8 @@ Id EmitVoteAll(EmitContext& ctx, Id pred) {
47 return ctx.OpSubgroupAllKHR(ctx.U1, pred); 55 return ctx.OpSubgroupAllKHR(ctx.U1, pred);
48 } 56 }
49 const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; 57 const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
50 const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; 58 const Id active_mask{WarpExtract(ctx, mask_ballot)};
51 const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; 59 const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
52 const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; 60 const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)};
53 return ctx.OpIEqual(ctx.U1, lhs, active_mask); 61 return ctx.OpIEqual(ctx.U1, lhs, active_mask);
54} 62}
@@ -58,8 +66,8 @@ Id EmitVoteAny(EmitContext& ctx, Id pred) {
58 return ctx.OpSubgroupAnyKHR(ctx.U1, pred); 66 return ctx.OpSubgroupAnyKHR(ctx.U1, pred);
59 } 67 }
60 const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; 68 const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
61 const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; 69 const Id active_mask{WarpExtract(ctx, mask_ballot)};
62 const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; 70 const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
63 const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; 71 const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)};
64 return ctx.OpINotEqual(ctx.U1, lhs, ctx.u32_zero_value); 72 return ctx.OpINotEqual(ctx.U1, lhs, ctx.u32_zero_value);
65} 73}
@@ -69,8 +77,8 @@ Id EmitVoteEqual(EmitContext& ctx, Id pred) {
69 return ctx.OpSubgroupAllEqualKHR(ctx.U1, pred); 77 return ctx.OpSubgroupAllEqualKHR(ctx.U1, pred);
70 } 78 }
71 const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; 79 const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
72 const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; 80 const Id active_mask{WarpExtract(ctx, mask_ballot)};
73 const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; 81 const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
74 const Id lhs{ctx.OpBitwiseXor(ctx.U32[1], ballot, active_mask)}; 82 const Id lhs{ctx.OpBitwiseXor(ctx.U32[1], ballot, active_mask)};
75 return ctx.OpLogicalOr(ctx.U1, ctx.OpIEqual(ctx.U1, lhs, ctx.u32_zero_value), 83 return ctx.OpLogicalOr(ctx.U1, ctx.OpIEqual(ctx.U1, lhs, ctx.u32_zero_value),
76 ctx.OpIEqual(ctx.U1, lhs, active_mask)); 84 ctx.OpIEqual(ctx.U1, lhs, active_mask));
@@ -81,7 +89,27 @@ Id EmitSubgroupBallot(EmitContext& ctx, Id pred) {
81 if (!ctx.profile.warp_size_potentially_larger_than_guest) { 89 if (!ctx.profile.warp_size_potentially_larger_than_guest) {
82 return ctx.OpCompositeExtract(ctx.U32[1], ballot, 0U); 90 return ctx.OpCompositeExtract(ctx.U32[1], ballot, 0U);
83 } 91 }
84 return LargeWarpBallot(ctx, ballot); 92 return WarpExtract(ctx, ballot);
93}
94
95Id EmitSubgroupEqMask(EmitContext& ctx) {
96 return LoadMask(ctx, ctx.subgroup_mask_eq);
97}
98
99Id EmitSubgroupLtMask(EmitContext& ctx) {
100 return LoadMask(ctx, ctx.subgroup_mask_lt);
101}
102
103Id EmitSubgroupLeMask(EmitContext& ctx) {
104 return LoadMask(ctx, ctx.subgroup_mask_le);
105}
106
107Id EmitSubgroupGtMask(EmitContext& ctx) {
108 return LoadMask(ctx, ctx.subgroup_mask_gt);
109}
110
111Id EmitSubgroupGeMask(EmitContext& ctx) {
112 return LoadMask(ctx, ctx.subgroup_mask_ge);
85} 113}
86 114
87Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, 115Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp,