diff options
| author | 2021-04-04 05:17:17 -0300 | |
|---|---|---|
| committer | 2021-07-22 21:51:26 -0400 | |
| commit | da6cf2632cd4dc0d2b0278353fcaee0789b418c0 (patch) | |
| tree | 90c2d6f6fa724365a4a23c888389e525e316a4fd /src/shader_recompiler/backend | |
| parent | shader: Implement BAR and fix memory barriers (diff) | |
| download | yuzu-da6cf2632cd4dc0d2b0278353fcaee0789b418c0.tar.gz yuzu-da6cf2632cd4dc0d2b0278353fcaee0789b418c0.tar.xz yuzu-da6cf2632cd4dc0d2b0278353fcaee0789b418c0.zip | |
shader: Add subgroup masks
Diffstat (limited to 'src/shader_recompiler/backend')
4 files changed, 56 insertions, 10 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp index e70b78a28..5ef637fe7 100644 --- a/src/shader_recompiler/backend/spirv/emit_context.cpp +++ b/src/shader_recompiler/backend/spirv/emit_context.cpp | |||
| @@ -390,8 +390,16 @@ void EmitContext::DefineInputs(const Info& info) { | |||
| 390 | if (info.uses_local_invocation_id) { | 390 | if (info.uses_local_invocation_id) { |
| 391 | local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId); | 391 | local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId); |
| 392 | } | 392 | } |
| 393 | if (info.uses_subgroup_mask) { | ||
| 394 | subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR); | ||
| 395 | subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR); | ||
| 396 | subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR); | ||
| 397 | subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR); | ||
| 398 | subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR); | ||
| 399 | } | ||
| 393 | if (info.uses_subgroup_invocation_id || | 400 | if (info.uses_subgroup_invocation_id || |
| 394 | (profile.warp_size_potentially_larger_than_guest && info.uses_subgroup_vote)) { | 401 | (profile.warp_size_potentially_larger_than_guest && |
| 402 | (info.uses_subgroup_vote || info.uses_subgroup_mask))) { | ||
| 395 | subgroup_local_invocation_id = | 403 | subgroup_local_invocation_id = |
| 396 | DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId); | 404 | DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId); |
| 397 | } | 405 | } |
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h index 3a686a78c..03c5a6aba 100644 --- a/src/shader_recompiler/backend/spirv/emit_context.h +++ b/src/shader_recompiler/backend/spirv/emit_context.h | |||
| @@ -97,6 +97,11 @@ public: | |||
| 97 | Id workgroup_id{}; | 97 | Id workgroup_id{}; |
| 98 | Id local_invocation_id{}; | 98 | Id local_invocation_id{}; |
| 99 | Id subgroup_local_invocation_id{}; | 99 | Id subgroup_local_invocation_id{}; |
| 100 | Id subgroup_mask_eq{}; | ||
| 101 | Id subgroup_mask_lt{}; | ||
| 102 | Id subgroup_mask_le{}; | ||
| 103 | Id subgroup_mask_gt{}; | ||
| 104 | Id subgroup_mask_ge{}; | ||
| 100 | Id instance_id{}; | 105 | Id instance_id{}; |
| 101 | Id instance_index{}; | 106 | Id instance_index{}; |
| 102 | Id base_instance{}; | 107 | Id base_instance{}; |
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h index 032b0b2f9..712c5e61f 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv.h +++ b/src/shader_recompiler/backend/spirv/emit_spirv.h | |||
| @@ -401,6 +401,11 @@ Id EmitVoteAll(EmitContext& ctx, Id pred); | |||
| 401 | Id EmitVoteAny(EmitContext& ctx, Id pred); | 401 | Id EmitVoteAny(EmitContext& ctx, Id pred); |
| 402 | Id EmitVoteEqual(EmitContext& ctx, Id pred); | 402 | Id EmitVoteEqual(EmitContext& ctx, Id pred); |
| 403 | Id EmitSubgroupBallot(EmitContext& ctx, Id pred); | 403 | Id EmitSubgroupBallot(EmitContext& ctx, Id pred); |
| 404 | Id EmitSubgroupEqMask(EmitContext& ctx); | ||
| 405 | Id EmitSubgroupLtMask(EmitContext& ctx); | ||
| 406 | Id EmitSubgroupLeMask(EmitContext& ctx); | ||
| 407 | Id EmitSubgroupGtMask(EmitContext& ctx); | ||
| 408 | Id EmitSubgroupGeMask(EmitContext& ctx); | ||
| 404 | Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | 409 | Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, |
| 405 | Id segmentation_mask); | 410 | Id segmentation_mask); |
| 406 | Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | 411 | Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, |
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp index cbc5b1c96..c57bd291d 100644 --- a/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp +++ b/src/shader_recompiler/backend/spirv/emit_spirv_warp.cpp | |||
| @@ -6,10 +6,18 @@ | |||
| 6 | 6 | ||
| 7 | namespace Shader::Backend::SPIRV { | 7 | namespace Shader::Backend::SPIRV { |
| 8 | namespace { | 8 | namespace { |
| 9 | Id LargeWarpBallot(EmitContext& ctx, Id ballot) { | 9 | Id WarpExtract(EmitContext& ctx, Id value) { |
| 10 | const Id shift{ctx.Constant(ctx.U32[1], 5)}; | 10 | const Id shift{ctx.Constant(ctx.U32[1], 5)}; |
| 11 | const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; | 11 | const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)}; |
| 12 | return ctx.OpVectorExtractDynamic(ctx.U32[1], ballot, local_index); | 12 | return ctx.OpVectorExtractDynamic(ctx.U32[1], value, local_index); |
| 13 | } | ||
| 14 | |||
| 15 | Id LoadMask(EmitContext& ctx, Id mask) { | ||
| 16 | const Id value{ctx.OpLoad(ctx.U32[4], mask)}; | ||
| 17 | if (!ctx.profile.warp_size_potentially_larger_than_guest) { | ||
| 18 | return ctx.OpCompositeExtract(ctx.U32[1], value, 0U); | ||
| 19 | } | ||
| 20 | return WarpExtract(ctx, value); | ||
| 13 | } | 21 | } |
| 14 | 22 | ||
| 15 | void SetInBoundsFlag(IR::Inst* inst, Id result) { | 23 | void SetInBoundsFlag(IR::Inst* inst, Id result) { |
| @@ -47,8 +55,8 @@ Id EmitVoteAll(EmitContext& ctx, Id pred) { | |||
| 47 | return ctx.OpSubgroupAllKHR(ctx.U1, pred); | 55 | return ctx.OpSubgroupAllKHR(ctx.U1, pred); |
| 48 | } | 56 | } |
| 49 | const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; | 57 | const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; |
| 50 | const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; | 58 | const Id active_mask{WarpExtract(ctx, mask_ballot)}; |
| 51 | const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | 59 | const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; |
| 52 | const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; | 60 | const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; |
| 53 | return ctx.OpIEqual(ctx.U1, lhs, active_mask); | 61 | return ctx.OpIEqual(ctx.U1, lhs, active_mask); |
| 54 | } | 62 | } |
| @@ -58,8 +66,8 @@ Id EmitVoteAny(EmitContext& ctx, Id pred) { | |||
| 58 | return ctx.OpSubgroupAnyKHR(ctx.U1, pred); | 66 | return ctx.OpSubgroupAnyKHR(ctx.U1, pred); |
| 59 | } | 67 | } |
| 60 | const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; | 68 | const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; |
| 61 | const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; | 69 | const Id active_mask{WarpExtract(ctx, mask_ballot)}; |
| 62 | const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | 70 | const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; |
| 63 | const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; | 71 | const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)}; |
| 64 | return ctx.OpINotEqual(ctx.U1, lhs, ctx.u32_zero_value); | 72 | return ctx.OpINotEqual(ctx.U1, lhs, ctx.u32_zero_value); |
| 65 | } | 73 | } |
| @@ -69,8 +77,8 @@ Id EmitVoteEqual(EmitContext& ctx, Id pred) { | |||
| 69 | return ctx.OpSubgroupAllEqualKHR(ctx.U1, pred); | 77 | return ctx.OpSubgroupAllEqualKHR(ctx.U1, pred); |
| 70 | } | 78 | } |
| 71 | const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; | 79 | const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)}; |
| 72 | const Id active_mask{LargeWarpBallot(ctx, mask_ballot)}; | 80 | const Id active_mask{WarpExtract(ctx, mask_ballot)}; |
| 73 | const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; | 81 | const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))}; |
| 74 | const Id lhs{ctx.OpBitwiseXor(ctx.U32[1], ballot, active_mask)}; | 82 | const Id lhs{ctx.OpBitwiseXor(ctx.U32[1], ballot, active_mask)}; |
| 75 | return ctx.OpLogicalOr(ctx.U1, ctx.OpIEqual(ctx.U1, lhs, ctx.u32_zero_value), | 83 | return ctx.OpLogicalOr(ctx.U1, ctx.OpIEqual(ctx.U1, lhs, ctx.u32_zero_value), |
| 76 | ctx.OpIEqual(ctx.U1, lhs, active_mask)); | 84 | ctx.OpIEqual(ctx.U1, lhs, active_mask)); |
| @@ -81,7 +89,27 @@ Id EmitSubgroupBallot(EmitContext& ctx, Id pred) { | |||
| 81 | if (!ctx.profile.warp_size_potentially_larger_than_guest) { | 89 | if (!ctx.profile.warp_size_potentially_larger_than_guest) { |
| 82 | return ctx.OpCompositeExtract(ctx.U32[1], ballot, 0U); | 90 | return ctx.OpCompositeExtract(ctx.U32[1], ballot, 0U); |
| 83 | } | 91 | } |
| 84 | return LargeWarpBallot(ctx, ballot); | 92 | return WarpExtract(ctx, ballot); |
| 93 | } | ||
| 94 | |||
| 95 | Id EmitSubgroupEqMask(EmitContext& ctx) { | ||
| 96 | return LoadMask(ctx, ctx.subgroup_mask_eq); | ||
| 97 | } | ||
| 98 | |||
| 99 | Id EmitSubgroupLtMask(EmitContext& ctx) { | ||
| 100 | return LoadMask(ctx, ctx.subgroup_mask_lt); | ||
| 101 | } | ||
| 102 | |||
| 103 | Id EmitSubgroupLeMask(EmitContext& ctx) { | ||
| 104 | return LoadMask(ctx, ctx.subgroup_mask_le); | ||
| 105 | } | ||
| 106 | |||
| 107 | Id EmitSubgroupGtMask(EmitContext& ctx) { | ||
| 108 | return LoadMask(ctx, ctx.subgroup_mask_gt); | ||
| 109 | } | ||
| 110 | |||
| 111 | Id EmitSubgroupGeMask(EmitContext& ctx) { | ||
| 112 | return LoadMask(ctx, ctx.subgroup_mask_ge); | ||
| 85 | } | 113 | } |
| 86 | 114 | ||
| 87 | Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, | 115 | Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp, |