summaryrefslogtreecommitdiff
path: root/src/shader_recompiler/backend
diff options
context:
space:
mode:
authorGravatar ReinUsesLisp2021-04-15 22:46:11 -0300
committerGravatar ameerj2021-07-22 21:51:27 -0400
commit183855e396cc6918d36fbf3e38ea426e934b4e3e (patch)
treea665794753520c09a1d34d8a086352894ec1cb72 /src/shader_recompiler/backend
parentshader: Mark atomic instructions as writes (diff)
downloadyuzu-183855e396cc6918d36fbf3e38ea426e934b4e3e.tar.gz
yuzu-183855e396cc6918d36fbf3e38ea426e934b4e3e.tar.xz
yuzu-183855e396cc6918d36fbf3e38ea426e934b4e3e.zip
shader: Implement tessellation shaders, polygon mode and invocation id
Diffstat (limited to 'src/shader_recompiler/backend')
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp147
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.h10
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv.cpp39
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv.h3
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp88
5 files changed, 232 insertions, 55 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 032cf5e03..067f61613 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -125,19 +125,36 @@ u32 NumVertices(InputTopology input_topology) {
125 throw InvalidArgument("Invalid input topology {}", input_topology); 125 throw InvalidArgument("Invalid input topology {}", input_topology);
126} 126}
127 127
128Id DefineInput(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin = std::nullopt) { 128Id DefineInput(EmitContext& ctx, Id type, bool per_invocation,
129 if (ctx.stage == Stage::Geometry) { 129 std::optional<spv::BuiltIn> builtin = std::nullopt) {
130 const u32 num_vertices{NumVertices(ctx.profile.input_topology)}; 130 switch (ctx.stage) {
131 type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], num_vertices)); 131 case Stage::TessellationControl:
132 case Stage::TessellationEval:
133 if (per_invocation) {
134 type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], 32u));
135 }
136 break;
137 case Stage::Geometry:
138 if (per_invocation) {
139 const u32 num_vertices{NumVertices(ctx.profile.input_topology)};
140 type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], num_vertices));
141 }
142 break;
143 default:
144 break;
132 } 145 }
133 return DefineVariable(ctx, type, builtin, spv::StorageClass::Input); 146 return DefineVariable(ctx, type, builtin, spv::StorageClass::Input);
134} 147}
135 148
136Id DefineOutput(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin = std::nullopt) { 149Id DefineOutput(EmitContext& ctx, Id type, std::optional<u32> invocations,
150 std::optional<spv::BuiltIn> builtin = std::nullopt) {
151 if (invocations && ctx.stage == Stage::TessellationControl) {
152 type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], *invocations));
153 }
137 return DefineVariable(ctx, type, builtin, spv::StorageClass::Output); 154 return DefineVariable(ctx, type, builtin, spv::StorageClass::Output);
138} 155}
139 156
140void DefineGenericOutput(EmitContext& ctx, size_t index) { 157void DefineGenericOutput(EmitContext& ctx, size_t index, std::optional<u32> invocations) {
141 static constexpr std::string_view swizzle{"xyzw"}; 158 static constexpr std::string_view swizzle{"xyzw"};
142 const size_t base_attr_index{static_cast<size_t>(IR::Attribute::Generic0X) + index * 4}; 159 const size_t base_attr_index{static_cast<size_t>(IR::Attribute::Generic0X) + index * 4};
143 u32 element{0}; 160 u32 element{0};
@@ -150,7 +167,7 @@ void DefineGenericOutput(EmitContext& ctx, size_t index) {
150 } 167 }
151 const u32 num_components{xfb_varying ? xfb_varying->components : remainder}; 168 const u32 num_components{xfb_varying ? xfb_varying->components : remainder};
152 169
153 const Id id{DefineOutput(ctx, ctx.F32[num_components])}; 170 const Id id{DefineOutput(ctx, ctx.F32[num_components], invocations)};
154 ctx.Decorate(id, spv::Decoration::Location, static_cast<u32>(index)); 171 ctx.Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
155 if (element > 0) { 172 if (element > 0) {
156 ctx.Decorate(id, spv::Decoration::Component, element); 173 ctx.Decorate(id, spv::Decoration::Component, element);
@@ -161,10 +178,10 @@ void DefineGenericOutput(EmitContext& ctx, size_t index) {
161 ctx.Decorate(id, spv::Decoration::Offset, xfb_varying->offset); 178 ctx.Decorate(id, spv::Decoration::Offset, xfb_varying->offset);
162 } 179 }
163 if (num_components < 4 || element > 0) { 180 if (num_components < 4 || element > 0) {
164 ctx.Name(id, fmt::format("out_attr{}", index));
165 } else {
166 const std::string_view subswizzle{swizzle.substr(element, num_components)}; 181 const std::string_view subswizzle{swizzle.substr(element, num_components)};
167 ctx.Name(id, fmt::format("out_attr{}_{}", index, subswizzle)); 182 ctx.Name(id, fmt::format("out_attr{}_{}", index, subswizzle));
183 } else {
184 ctx.Name(id, fmt::format("out_attr{}", index));
168 } 185 }
169 const GenericElementInfo info{ 186 const GenericElementInfo info{
170 .id = id, 187 .id = id,
@@ -383,7 +400,7 @@ EmitContext::EmitContext(const Profile& profile_, IR::Program& program, u32& bin
383 AddCapability(spv::Capability::Shader); 400 AddCapability(spv::Capability::Shader);
384 DefineCommonTypes(program.info); 401 DefineCommonTypes(program.info);
385 DefineCommonConstants(); 402 DefineCommonConstants();
386 DefineInterfaces(program.info); 403 DefineInterfaces(program);
387 DefineLocalMemory(program); 404 DefineLocalMemory(program);
388 DefineSharedMemory(program); 405 DefineSharedMemory(program);
389 DefineSharedMemoryFunctions(program); 406 DefineSharedMemoryFunctions(program);
@@ -472,9 +489,9 @@ void EmitContext::DefineCommonConstants() {
472 f32_zero_value = Constant(F32[1], 0.0f); 489 f32_zero_value = Constant(F32[1], 0.0f);
473} 490}
474 491
475void EmitContext::DefineInterfaces(const Info& info) { 492void EmitContext::DefineInterfaces(const IR::Program& program) {
476 DefineInputs(info); 493 DefineInputs(program.info);
477 DefineOutputs(info); 494 DefineOutputs(program);
478} 495}
479 496
480void EmitContext::DefineLocalMemory(const IR::Program& program) { 497void EmitContext::DefineLocalMemory(const IR::Program& program) {
@@ -972,26 +989,29 @@ void EmitContext::DefineLabels(IR::Program& program) {
972 989
973void EmitContext::DefineInputs(const Info& info) { 990void EmitContext::DefineInputs(const Info& info) {
974 if (info.uses_workgroup_id) { 991 if (info.uses_workgroup_id) {
975 workgroup_id = DefineInput(*this, U32[3], spv::BuiltIn::WorkgroupId); 992 workgroup_id = DefineInput(*this, U32[3], false, spv::BuiltIn::WorkgroupId);
976 } 993 }
977 if (info.uses_local_invocation_id) { 994 if (info.uses_local_invocation_id) {
978 local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId); 995 local_invocation_id = DefineInput(*this, U32[3], false, spv::BuiltIn::LocalInvocationId);
996 }
997 if (info.uses_invocation_id) {
998 invocation_id = DefineInput(*this, U32[1], false, spv::BuiltIn::InvocationId);
979 } 999 }
980 if (info.uses_is_helper_invocation) { 1000 if (info.uses_is_helper_invocation) {
981 is_helper_invocation = DefineInput(*this, U1, spv::BuiltIn::HelperInvocation); 1001 is_helper_invocation = DefineInput(*this, U1, false, spv::BuiltIn::HelperInvocation);
982 } 1002 }
983 if (info.uses_subgroup_mask) { 1003 if (info.uses_subgroup_mask) {
984 subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR); 1004 subgroup_mask_eq = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupEqMaskKHR);
985 subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR); 1005 subgroup_mask_lt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLtMaskKHR);
986 subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR); 1006 subgroup_mask_le = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLeMaskKHR);
987 subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR); 1007 subgroup_mask_gt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGtMaskKHR);
988 subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR); 1008 subgroup_mask_ge = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGeMaskKHR);
989 } 1009 }
990 if (info.uses_subgroup_invocation_id || 1010 if (info.uses_subgroup_invocation_id ||
991 (profile.warp_size_potentially_larger_than_guest && 1011 (profile.warp_size_potentially_larger_than_guest &&
992 (info.uses_subgroup_vote || info.uses_subgroup_mask))) { 1012 (info.uses_subgroup_vote || info.uses_subgroup_mask))) {
993 subgroup_local_invocation_id = 1013 subgroup_local_invocation_id =
994 DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId); 1014 DefineInput(*this, U32[1], false, spv::BuiltIn::SubgroupLocalInvocationId);
995 } 1015 }
996 if (info.uses_fswzadd) { 1016 if (info.uses_fswzadd) {
997 const Id f32_one{Constant(F32[1], 1.0f)}; 1017 const Id f32_one{Constant(F32[1], 1.0f)};
@@ -1004,29 +1024,32 @@ void EmitContext::DefineInputs(const Info& info) {
1004 if (info.loads_position) { 1024 if (info.loads_position) {
1005 const bool is_fragment{stage != Stage::Fragment}; 1025 const bool is_fragment{stage != Stage::Fragment};
1006 const spv::BuiltIn built_in{is_fragment ? spv::BuiltIn::Position : spv::BuiltIn::FragCoord}; 1026 const spv::BuiltIn built_in{is_fragment ? spv::BuiltIn::Position : spv::BuiltIn::FragCoord};
1007 input_position = DefineInput(*this, F32[4], built_in); 1027 input_position = DefineInput(*this, F32[4], true, built_in);
1008 } 1028 }
1009 if (info.loads_instance_id) { 1029 if (info.loads_instance_id) {
1010 if (profile.support_vertex_instance_id) { 1030 if (profile.support_vertex_instance_id) {
1011 instance_id = DefineInput(*this, U32[1], spv::BuiltIn::InstanceId); 1031 instance_id = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceId);
1012 } else { 1032 } else {
1013 instance_index = DefineInput(*this, U32[1], spv::BuiltIn::InstanceIndex); 1033 instance_index = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceIndex);
1014 base_instance = DefineInput(*this, U32[1], spv::BuiltIn::BaseInstance); 1034 base_instance = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseInstance);
1015 } 1035 }
1016 } 1036 }
1017 if (info.loads_vertex_id) { 1037 if (info.loads_vertex_id) {
1018 if (profile.support_vertex_instance_id) { 1038 if (profile.support_vertex_instance_id) {
1019 vertex_id = DefineInput(*this, U32[1], spv::BuiltIn::VertexId); 1039 vertex_id = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexId);
1020 } else { 1040 } else {
1021 vertex_index = DefineInput(*this, U32[1], spv::BuiltIn::VertexIndex); 1041 vertex_index = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexIndex);
1022 base_vertex = DefineInput(*this, U32[1], spv::BuiltIn::BaseVertex); 1042 base_vertex = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseVertex);
1023 } 1043 }
1024 } 1044 }
1025 if (info.loads_front_face) { 1045 if (info.loads_front_face) {
1026 front_face = DefineInput(*this, U1, spv::BuiltIn::FrontFacing); 1046 front_face = DefineInput(*this, U1, true, spv::BuiltIn::FrontFacing);
1027 } 1047 }
1028 if (info.loads_point_coord) { 1048 if (info.loads_point_coord) {
1029 point_coord = DefineInput(*this, F32[2], spv::BuiltIn::PointCoord); 1049 point_coord = DefineInput(*this, F32[2], true, spv::BuiltIn::PointCoord);
1050 }
1051 if (info.loads_tess_coord) {
1052 tess_coord = DefineInput(*this, F32[3], false, spv::BuiltIn::TessCoord);
1030 } 1053 }
1031 for (size_t index = 0; index < info.input_generics.size(); ++index) { 1054 for (size_t index = 0; index < info.input_generics.size(); ++index) {
1032 const InputVarying generic{info.input_generics[index]}; 1055 const InputVarying generic{info.input_generics[index]};
@@ -1038,7 +1061,7 @@ void EmitContext::DefineInputs(const Info& info) {
1038 continue; 1061 continue;
1039 } 1062 }
1040 const Id type{GetAttributeType(*this, input_type)}; 1063 const Id type{GetAttributeType(*this, input_type)};
1041 const Id id{DefineInput(*this, type)}; 1064 const Id id{DefineInput(*this, type, true)};
1042 Decorate(id, spv::Decoration::Location, static_cast<u32>(index)); 1065 Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
1043 Name(id, fmt::format("in_attr{}", index)); 1066 Name(id, fmt::format("in_attr{}", index));
1044 input_generics[index] = id; 1067 input_generics[index] = id;
@@ -1059,58 +1082,98 @@ void EmitContext::DefineInputs(const Info& info) {
1059 break; 1082 break;
1060 } 1083 }
1061 } 1084 }
1085 if (stage == Stage::TessellationEval) {
1086 for (size_t index = 0; index < info.uses_patches.size(); ++index) {
1087 if (!info.uses_patches[index]) {
1088 continue;
1089 }
1090 const Id id{DefineInput(*this, F32[4], false)};
1091 Decorate(id, spv::Decoration::Patch);
1092 Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
1093 patches[index] = id;
1094 }
1095 }
1062} 1096}
1063 1097
1064void EmitContext::DefineOutputs(const Info& info) { 1098void EmitContext::DefineOutputs(const IR::Program& program) {
1099 const Info& info{program.info};
1100 const std::optional<u32> invocations{program.invocations};
1065 if (info.stores_position || stage == Stage::VertexB) { 1101 if (info.stores_position || stage == Stage::VertexB) {
1066 output_position = DefineOutput(*this, F32[4], spv::BuiltIn::Position); 1102 output_position = DefineOutput(*this, F32[4], invocations, spv::BuiltIn::Position);
1067 } 1103 }
1068 if (info.stores_point_size || profile.fixed_state_point_size) { 1104 if (info.stores_point_size || profile.fixed_state_point_size) {
1069 if (stage == Stage::Fragment) { 1105 if (stage == Stage::Fragment) {
1070 throw NotImplementedException("Storing PointSize in fragment stage"); 1106 throw NotImplementedException("Storing PointSize in fragment stage");
1071 } 1107 }
1072 output_point_size = DefineOutput(*this, F32[1], spv::BuiltIn::PointSize); 1108 output_point_size = DefineOutput(*this, F32[1], invocations, spv::BuiltIn::PointSize);
1073 } 1109 }
1074 if (info.stores_clip_distance) { 1110 if (info.stores_clip_distance) {
1075 if (stage == Stage::Fragment) { 1111 if (stage == Stage::Fragment) {
1076 throw NotImplementedException("Storing ClipDistance in fragment stage"); 1112 throw NotImplementedException("Storing ClipDistance in fragment stage");
1077 } 1113 }
1078 const Id type{TypeArray(F32[1], Constant(U32[1], 8U))}; 1114 const Id type{TypeArray(F32[1], Constant(U32[1], 8U))};
1079 clip_distances = DefineOutput(*this, type, spv::BuiltIn::ClipDistance); 1115 clip_distances = DefineOutput(*this, type, invocations, spv::BuiltIn::ClipDistance);
1080 } 1116 }
1081 if (info.stores_layer && 1117 if (info.stores_layer &&
1082 (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) { 1118 (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
1083 if (stage == Stage::Fragment) { 1119 if (stage == Stage::Fragment) {
1084 throw NotImplementedException("Storing Layer in fragment stage"); 1120 throw NotImplementedException("Storing Layer in fragment stage");
1085 } 1121 }
1086 layer = DefineOutput(*this, U32[1], spv::BuiltIn::Layer); 1122 layer = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::Layer);
1087 } 1123 }
1088 if (info.stores_viewport_index && 1124 if (info.stores_viewport_index &&
1089 (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) { 1125 (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
1090 if (stage == Stage::Fragment) { 1126 if (stage == Stage::Fragment) {
1091 throw NotImplementedException("Storing ViewportIndex in fragment stage"); 1127 throw NotImplementedException("Storing ViewportIndex in fragment stage");
1092 } 1128 }
1093 viewport_index = DefineOutput(*this, U32[1], spv::BuiltIn::ViewportIndex); 1129 viewport_index = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::ViewportIndex);
1094 } 1130 }
1095 for (size_t index = 0; index < info.stores_generics.size(); ++index) { 1131 for (size_t index = 0; index < info.stores_generics.size(); ++index) {
1096 if (info.stores_generics[index]) { 1132 if (info.stores_generics[index]) {
1097 DefineGenericOutput(*this, index); 1133 DefineGenericOutput(*this, index, invocations);
1098 } 1134 }
1099 } 1135 }
1100 if (stage == Stage::Fragment) { 1136 switch (stage) {
1137 case Stage::TessellationControl:
1138 if (info.stores_tess_level_outer) {
1139 const Id type{TypeArray(F32[1], Constant(U32[1], 4))};
1140 output_tess_level_outer =
1141 DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelOuter);
1142 Decorate(output_tess_level_outer, spv::Decoration::Patch);
1143 }
1144 if (info.stores_tess_level_inner) {
1145 const Id type{TypeArray(F32[1], Constant(U32[1], 2))};
1146 output_tess_level_inner =
1147 DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelInner);
1148 Decorate(output_tess_level_inner, spv::Decoration::Patch);
1149 }
1150 for (size_t index = 0; index < info.uses_patches.size(); ++index) {
1151 if (!info.uses_patches[index]) {
1152 continue;
1153 }
1154 const Id id{DefineOutput(*this, F32[4], std::nullopt)};
1155 Decorate(id, spv::Decoration::Patch);
1156 Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
1157 patches[index] = id;
1158 }
1159 break;
1160 case Stage::Fragment:
1101 for (u32 index = 0; index < 8; ++index) { 1161 for (u32 index = 0; index < 8; ++index) {
1102 if (!info.stores_frag_color[index]) { 1162 if (!info.stores_frag_color[index]) {
1103 continue; 1163 continue;
1104 } 1164 }
1105 frag_color[index] = DefineOutput(*this, F32[4]); 1165 frag_color[index] = DefineOutput(*this, F32[4], std::nullopt);
1106 Decorate(frag_color[index], spv::Decoration::Location, index); 1166 Decorate(frag_color[index], spv::Decoration::Location, index);
1107 Name(frag_color[index], fmt::format("frag_color{}", index)); 1167 Name(frag_color[index], fmt::format("frag_color{}", index));
1108 } 1168 }
1109 if (info.stores_frag_depth) { 1169 if (info.stores_frag_depth) {
1110 frag_depth = DefineOutput(*this, F32[1]); 1170 frag_depth = DefineOutput(*this, F32[1], std::nullopt);
1111 Decorate(frag_depth, spv::Decoration::BuiltIn, spv::BuiltIn::FragDepth); 1171 Decorate(frag_depth, spv::Decoration::BuiltIn, spv::BuiltIn::FragDepth);
1112 Name(frag_depth, "frag_depth"); 1172 Name(frag_depth, "frag_depth");
1113 } 1173 }
1174 break;
1175 default:
1176 break;
1114 } 1177 }
1115} 1178}
1116 1179
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h
index 0da14d5f8..ba0a253b3 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.h
+++ b/src/shader_recompiler/backend/spirv/emit_context.h
@@ -147,6 +147,7 @@ public:
147 147
148 Id workgroup_id{}; 148 Id workgroup_id{};
149 Id local_invocation_id{}; 149 Id local_invocation_id{};
150 Id invocation_id{};
150 Id is_helper_invocation{}; 151 Id is_helper_invocation{};
151 Id subgroup_local_invocation_id{}; 152 Id subgroup_local_invocation_id{};
152 Id subgroup_mask_eq{}; 153 Id subgroup_mask_eq{};
@@ -162,6 +163,7 @@ public:
162 Id base_vertex{}; 163 Id base_vertex{};
163 Id front_face{}; 164 Id front_face{};
164 Id point_coord{}; 165 Id point_coord{};
166 Id tess_coord{};
165 Id clip_distances{}; 167 Id clip_distances{};
166 Id layer{}; 168 Id layer{};
167 Id viewport_index{}; 169 Id viewport_index{};
@@ -204,6 +206,10 @@ public:
204 Id output_position{}; 206 Id output_position{};
205 std::array<std::array<GenericElementInfo, 4>, 32> output_generics{}; 207 std::array<std::array<GenericElementInfo, 4>, 32> output_generics{};
206 208
209 Id output_tess_level_outer{};
210 Id output_tess_level_inner{};
211 std::array<Id, 30> patches{};
212
207 std::array<Id, 8> frag_color{}; 213 std::array<Id, 8> frag_color{};
208 Id frag_depth{}; 214 Id frag_depth{};
209 215
@@ -212,7 +218,7 @@ public:
212private: 218private:
213 void DefineCommonTypes(const Info& info); 219 void DefineCommonTypes(const Info& info);
214 void DefineCommonConstants(); 220 void DefineCommonConstants();
215 void DefineInterfaces(const Info& info); 221 void DefineInterfaces(const IR::Program& program);
216 void DefineLocalMemory(const IR::Program& program); 222 void DefineLocalMemory(const IR::Program& program);
217 void DefineSharedMemory(const IR::Program& program); 223 void DefineSharedMemory(const IR::Program& program);
218 void DefineSharedMemoryFunctions(const IR::Program& program); 224 void DefineSharedMemoryFunctions(const IR::Program& program);
@@ -226,7 +232,7 @@ private:
226 void DefineLabels(IR::Program& program); 232 void DefineLabels(IR::Program& program);
227 233
228 void DefineInputs(const Info& info); 234 void DefineInputs(const Info& info);
229 void DefineOutputs(const Info& info); 235 void DefineOutputs(const IR::Program& program);
230}; 236};
231 237
232} // namespace Shader::Backend::SPIRV 238} // namespace Shader::Backend::SPIRV
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
index 3bf4c6a9e..105602ccf 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
@@ -45,6 +45,8 @@ ArgType Arg(EmitContext& ctx, const IR::Value& arg) {
45 return arg.Label(); 45 return arg.Label();
46 } else if constexpr (std::is_same_v<ArgType, IR::Attribute>) { 46 } else if constexpr (std::is_same_v<ArgType, IR::Attribute>) {
47 return arg.Attribute(); 47 return arg.Attribute();
48 } else if constexpr (std::is_same_v<ArgType, IR::Patch>) {
49 return arg.Patch();
48 } else if constexpr (std::is_same_v<ArgType, IR::Reg>) { 50 } else if constexpr (std::is_same_v<ArgType, IR::Reg>) {
49 return arg.Reg(); 51 return arg.Reg();
50 } 52 }
@@ -120,6 +122,30 @@ Id DefineMain(EmitContext& ctx, IR::Program& program) {
120 return main; 122 return main;
121} 123}
122 124
125spv::ExecutionMode ExecutionMode(TessPrimitive primitive) {
126 switch (primitive) {
127 case TessPrimitive::Isolines:
128 return spv::ExecutionMode::Isolines;
129 case TessPrimitive::Triangles:
130 return spv::ExecutionMode::Triangles;
131 case TessPrimitive::Quads:
132 return spv::ExecutionMode::Quads;
133 }
134 throw InvalidArgument("Tessellation primitive {}", primitive);
135}
136
137spv::ExecutionMode ExecutionMode(TessSpacing spacing) {
138 switch (spacing) {
139 case TessSpacing::Equal:
140 return spv::ExecutionMode::SpacingEqual;
141 case TessSpacing::FractionalOdd:
142 return spv::ExecutionMode::SpacingFractionalOdd;
143 case TessSpacing::FractionalEven:
144 return spv::ExecutionMode::SpacingFractionalEven;
145 }
146 throw InvalidArgument("Tessellation spacing {}", spacing);
147}
148
123void DefineEntryPoint(const IR::Program& program, EmitContext& ctx, Id main) { 149void DefineEntryPoint(const IR::Program& program, EmitContext& ctx, Id main) {
124 const std::span interfaces(ctx.interfaces.data(), ctx.interfaces.size()); 150 const std::span interfaces(ctx.interfaces.data(), ctx.interfaces.size());
125 spv::ExecutionModel execution_model{}; 151 spv::ExecutionModel execution_model{};
@@ -134,6 +160,19 @@ void DefineEntryPoint(const IR::Program& program, EmitContext& ctx, Id main) {
134 case Stage::VertexB: 160 case Stage::VertexB:
135 execution_model = spv::ExecutionModel::Vertex; 161 execution_model = spv::ExecutionModel::Vertex;
136 break; 162 break;
163 case Stage::TessellationControl:
164 execution_model = spv::ExecutionModel::TessellationControl;
165 ctx.AddCapability(spv::Capability::Tessellation);
166 ctx.AddExecutionMode(main, spv::ExecutionMode::OutputVertices, program.invocations);
167 break;
168 case Stage::TessellationEval:
169 execution_model = spv::ExecutionModel::TessellationEvaluation;
170 ctx.AddCapability(spv::Capability::Tessellation);
171 ctx.AddExecutionMode(main, ExecutionMode(ctx.profile.tess_primitive));
172 ctx.AddExecutionMode(main, ExecutionMode(ctx.profile.tess_spacing));
173 ctx.AddExecutionMode(main, ctx.profile.tess_clockwise ? spv::ExecutionMode::VertexOrderCw
174 : spv::ExecutionMode::VertexOrderCcw);
175 break;
137 case Stage::Geometry: 176 case Stage::Geometry:
138 execution_model = spv::ExecutionModel::Geometry; 177 execution_model = spv::ExecutionModel::Geometry;
139 ctx.AddCapability(spv::Capability::Geometry); 178 ctx.AddCapability(spv::Capability::Geometry);
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h
index 55b2edba0..8caf30f1b 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.h
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.h
@@ -55,6 +55,8 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, Id vertex);
55void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, Id vertex); 55void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, Id vertex);
56Id EmitGetAttributeIndexed(EmitContext& ctx, Id offset, Id vertex); 56Id EmitGetAttributeIndexed(EmitContext& ctx, Id offset, Id vertex);
57void EmitSetAttributeIndexed(EmitContext& ctx, Id offset, Id value, Id vertex); 57void EmitSetAttributeIndexed(EmitContext& ctx, Id offset, Id value, Id vertex);
58Id EmitGetPatch(EmitContext& ctx, IR::Patch patch);
59void EmitSetPatch(EmitContext& ctx, IR::Patch patch, Id value);
58void EmitSetFragColor(EmitContext& ctx, u32 index, u32 component, Id value); 60void EmitSetFragColor(EmitContext& ctx, u32 index, u32 component, Id value);
59void EmitSetFragDepth(EmitContext& ctx, Id value); 61void EmitSetFragDepth(EmitContext& ctx, Id value);
60void EmitGetZFlag(EmitContext& ctx); 62void EmitGetZFlag(EmitContext& ctx);
@@ -67,6 +69,7 @@ void EmitSetCFlag(EmitContext& ctx);
67void EmitSetOFlag(EmitContext& ctx); 69void EmitSetOFlag(EmitContext& ctx);
68Id EmitWorkgroupId(EmitContext& ctx); 70Id EmitWorkgroupId(EmitContext& ctx);
69Id EmitLocalInvocationId(EmitContext& ctx); 71Id EmitLocalInvocationId(EmitContext& ctx);
72Id EmitInvocationId(EmitContext& ctx);
70Id EmitIsHelperInvocation(EmitContext& ctx); 73Id EmitIsHelperInvocation(EmitContext& ctx);
71Id EmitLoadLocal(EmitContext& ctx, Id word_offset); 74Id EmitLoadLocal(EmitContext& ctx, Id word_offset);
72void EmitWriteLocal(EmitContext& ctx, Id word_offset, Id value); 75void EmitWriteLocal(EmitContext& ctx, Id word_offset, Id value);
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
index 59c56c5ba..4a1aeece5 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
@@ -32,13 +32,26 @@ std::optional<AttrInfo> AttrTypes(EmitContext& ctx, u32 index) {
32 32
33template <typename... Args> 33template <typename... Args>
34Id AttrPointer(EmitContext& ctx, Id pointer_type, Id vertex, Id base, Args&&... args) { 34Id AttrPointer(EmitContext& ctx, Id pointer_type, Id vertex, Id base, Args&&... args) {
35 if (ctx.stage == Stage::Geometry) { 35 switch (ctx.stage) {
36 case Stage::TessellationControl:
37 case Stage::TessellationEval:
38 case Stage::Geometry:
36 return ctx.OpAccessChain(pointer_type, base, vertex, std::forward<Args>(args)...); 39 return ctx.OpAccessChain(pointer_type, base, vertex, std::forward<Args>(args)...);
37 } else { 40 default:
38 return ctx.OpAccessChain(pointer_type, base, std::forward<Args>(args)...); 41 return ctx.OpAccessChain(pointer_type, base, std::forward<Args>(args)...);
39 } 42 }
40} 43}
41 44
45template <typename... Args>
46Id OutputAccessChain(EmitContext& ctx, Id result_type, Id base, Args&&... args) {
47 if (ctx.stage == Stage::TessellationControl) {
48 const Id invocation_id{ctx.OpLoad(ctx.U32[1], ctx.invocation_id)};
49 return ctx.OpAccessChain(result_type, base, invocation_id, std::forward<Args>(args)...);
50 } else {
51 return ctx.OpAccessChain(result_type, base, std::forward<Args>(args)...);
52 }
53}
54
42std::optional<Id> OutputAttrPointer(EmitContext& ctx, IR::Attribute attr) { 55std::optional<Id> OutputAttrPointer(EmitContext& ctx, IR::Attribute attr) {
43 if (IR::IsGeneric(attr)) { 56 if (IR::IsGeneric(attr)) {
44 const u32 index{IR::GenericAttributeIndex(attr)}; 57 const u32 index{IR::GenericAttributeIndex(attr)};
@@ -49,7 +62,7 @@ std::optional<Id> OutputAttrPointer(EmitContext& ctx, IR::Attribute attr) {
49 } else { 62 } else {
50 const u32 index_element{element - info.first_element}; 63 const u32 index_element{element - info.first_element};
51 const Id index_id{ctx.Constant(ctx.U32[1], index_element)}; 64 const Id index_id{ctx.Constant(ctx.U32[1], index_element)};
52 return ctx.OpAccessChain(ctx.output_f32, info.id, index_id); 65 return OutputAccessChain(ctx, ctx.output_f32, info.id, index_id);
53 } 66 }
54 } 67 }
55 switch (attr) { 68 switch (attr) {
@@ -61,7 +74,7 @@ std::optional<Id> OutputAttrPointer(EmitContext& ctx, IR::Attribute attr) {
61 case IR::Attribute::PositionW: { 74 case IR::Attribute::PositionW: {
62 const u32 element{static_cast<u32>(attr) % 4}; 75 const u32 element{static_cast<u32>(attr) % 4};
63 const Id element_id{ctx.Constant(ctx.U32[1], element)}; 76 const Id element_id{ctx.Constant(ctx.U32[1], element)};
64 return ctx.OpAccessChain(ctx.output_f32, ctx.output_position, element_id); 77 return OutputAccessChain(ctx, ctx.output_f32, ctx.output_position, element_id);
65 } 78 }
66 case IR::Attribute::ClipDistance0: 79 case IR::Attribute::ClipDistance0:
67 case IR::Attribute::ClipDistance1: 80 case IR::Attribute::ClipDistance1:
@@ -74,7 +87,7 @@ std::optional<Id> OutputAttrPointer(EmitContext& ctx, IR::Attribute attr) {
74 const u32 base{static_cast<u32>(IR::Attribute::ClipDistance0)}; 87 const u32 base{static_cast<u32>(IR::Attribute::ClipDistance0)};
75 const u32 index{static_cast<u32>(attr) - base}; 88 const u32 index{static_cast<u32>(attr) - base};
76 const Id clip_num{ctx.Constant(ctx.U32[1], index)}; 89 const Id clip_num{ctx.Constant(ctx.U32[1], index)};
77 return ctx.OpAccessChain(ctx.output_f32, ctx.clip_distances, clip_num); 90 return OutputAccessChain(ctx, ctx.output_f32, ctx.clip_distances, clip_num);
78 } 91 }
79 case IR::Attribute::Layer: 92 case IR::Attribute::Layer:
80 return ctx.profile.support_viewport_index_layer_non_geometry || 93 return ctx.profile.support_viewport_index_layer_non_geometry ||
@@ -222,11 +235,18 @@ Id EmitGetAttribute(EmitContext& ctx, IR::Attribute attr, Id vertex) {
222 ctx.Constant(ctx.U32[1], std::numeric_limits<u32>::max()), 235 ctx.Constant(ctx.U32[1], std::numeric_limits<u32>::max()),
223 ctx.u32_zero_value); 236 ctx.u32_zero_value);
224 case IR::Attribute::PointSpriteS: 237 case IR::Attribute::PointSpriteS:
225 return ctx.OpLoad(ctx.F32[1], AttrPointer(ctx, ctx.input_f32, vertex, ctx.point_coord, 238 return ctx.OpLoad(ctx.F32[1],
226 ctx.u32_zero_value)); 239 ctx.OpAccessChain(ctx.input_f32, ctx.point_coord, ctx.u32_zero_value));
227 case IR::Attribute::PointSpriteT: 240 case IR::Attribute::PointSpriteT:
228 return ctx.OpLoad(ctx.F32[1], AttrPointer(ctx, ctx.input_f32, vertex, ctx.point_coord, 241 return ctx.OpLoad(ctx.F32[1], ctx.OpAccessChain(ctx.input_f32, ctx.point_coord,
229 ctx.Constant(ctx.U32[1], 1U))); 242 ctx.Constant(ctx.U32[1], 1U)));
243 case IR::Attribute::TessellationEvaluationPointU:
244 return ctx.OpLoad(ctx.F32[1],
245 ctx.OpAccessChain(ctx.input_f32, ctx.tess_coord, ctx.u32_zero_value));
246 case IR::Attribute::TessellationEvaluationPointV:
247 return ctx.OpLoad(ctx.F32[1], ctx.OpAccessChain(ctx.input_f32, ctx.tess_coord,
248 ctx.Constant(ctx.U32[1], 1U)));
249
230 default: 250 default:
231 throw NotImplementedException("Read attribute {}", attr); 251 throw NotImplementedException("Read attribute {}", attr);
232 } 252 }
@@ -240,9 +260,12 @@ void EmitSetAttribute(EmitContext& ctx, IR::Attribute attr, Id value, [[maybe_un
240} 260}
241 261
242Id EmitGetAttributeIndexed(EmitContext& ctx, Id offset, Id vertex) { 262Id EmitGetAttributeIndexed(EmitContext& ctx, Id offset, Id vertex) {
243 if (ctx.stage == Stage::Geometry) { 263 switch (ctx.stage) {
264 case Stage::TessellationControl:
265 case Stage::TessellationEval:
266 case Stage::Geometry:
244 return ctx.OpFunctionCall(ctx.F32[1], ctx.indexed_load_func, offset, vertex); 267 return ctx.OpFunctionCall(ctx.F32[1], ctx.indexed_load_func, offset, vertex);
245 } else { 268 default:
246 return ctx.OpFunctionCall(ctx.F32[1], ctx.indexed_load_func, offset); 269 return ctx.OpFunctionCall(ctx.F32[1], ctx.indexed_load_func, offset);
247 } 270 }
248} 271}
@@ -251,6 +274,45 @@ void EmitSetAttributeIndexed(EmitContext& ctx, Id offset, Id value, [[maybe_unus
251 ctx.OpFunctionCall(ctx.void_id, ctx.indexed_store_func, offset, value); 274 ctx.OpFunctionCall(ctx.void_id, ctx.indexed_store_func, offset, value);
252} 275}
253 276
277Id EmitGetPatch(EmitContext& ctx, IR::Patch patch) {
278 if (!IR::IsGeneric(patch)) {
279 throw NotImplementedException("Non-generic patch load");
280 }
281 const u32 index{IR::GenericPatchIndex(patch)};
282 const Id element{ctx.Constant(ctx.U32[1], IR::GenericPatchElement(patch))};
283 const Id pointer{ctx.OpAccessChain(ctx.input_f32, ctx.patches.at(index), element)};
284 return ctx.OpLoad(ctx.F32[1], pointer);
285}
286
287void EmitSetPatch(EmitContext& ctx, IR::Patch patch, Id value) {
288 const Id pointer{[&] {
289 if (IR::IsGeneric(patch)) {
290 const u32 index{IR::GenericPatchIndex(patch)};
291 const Id element{ctx.Constant(ctx.U32[1], IR::GenericPatchElement(patch))};
292 return ctx.OpAccessChain(ctx.output_f32, ctx.patches.at(index), element);
293 }
294 switch (patch) {
295 case IR::Patch::TessellationLodLeft:
296 case IR::Patch::TessellationLodRight:
297 case IR::Patch::TessellationLodTop:
298 case IR::Patch::TessellationLodBottom: {
299 const u32 index{static_cast<u32>(patch) - u32(IR::Patch::TessellationLodLeft)};
300 const Id index_id{ctx.Constant(ctx.U32[1], index)};
301 return ctx.OpAccessChain(ctx.output_f32, ctx.output_tess_level_outer, index_id);
302 }
303 case IR::Patch::TessellationLodInteriorU:
304 return ctx.OpAccessChain(ctx.output_f32, ctx.output_tess_level_inner,
305 ctx.u32_zero_value);
306 case IR::Patch::TessellationLodInteriorV:
307 return ctx.OpAccessChain(ctx.output_f32, ctx.output_tess_level_inner,
308 ctx.Constant(ctx.U32[1], 1u));
309 default:
310 throw NotImplementedException("Patch {}", patch);
311 }
312 }()};
313 ctx.OpStore(pointer, value);
314}
315
254void EmitSetFragColor(EmitContext& ctx, u32 index, u32 component, Id value) { 316void EmitSetFragColor(EmitContext& ctx, u32 index, u32 component, Id value) {
255 const Id component_id{ctx.Constant(ctx.U32[1], component)}; 317 const Id component_id{ctx.Constant(ctx.U32[1], component)};
256 const Id pointer{ctx.OpAccessChain(ctx.output_f32, ctx.frag_color.at(index), component_id)}; 318 const Id pointer{ctx.OpAccessChain(ctx.output_f32, ctx.frag_color.at(index), component_id)};
@@ -301,6 +363,10 @@ Id EmitLocalInvocationId(EmitContext& ctx) {
301 return ctx.OpLoad(ctx.U32[3], ctx.local_invocation_id); 363 return ctx.OpLoad(ctx.U32[3], ctx.local_invocation_id);
302} 364}
303 365
366Id EmitInvocationId(EmitContext& ctx) {
367 return ctx.OpLoad(ctx.U32[1], ctx.invocation_id);
368}
369
304Id EmitIsHelperInvocation(EmitContext& ctx) { 370Id EmitIsHelperInvocation(EmitContext& ctx) {
305 return ctx.OpLoad(ctx.U1, ctx.is_helper_invocation); 371 return ctx.OpLoad(ctx.U1, ctx.is_helper_invocation);
306} 372}