summaryrefslogtreecommitdiff
path: root/src/shader_recompiler/backend/spirv/emit_context.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/shader_recompiler/backend/spirv/emit_context.cpp')
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp147
1 files changed, 105 insertions, 42 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 032cf5e03..067f61613 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -125,19 +125,36 @@ u32 NumVertices(InputTopology input_topology) {
125 throw InvalidArgument("Invalid input topology {}", input_topology); 125 throw InvalidArgument("Invalid input topology {}", input_topology);
126} 126}
127 127
128Id DefineInput(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin = std::nullopt) { 128Id DefineInput(EmitContext& ctx, Id type, bool per_invocation,
129 if (ctx.stage == Stage::Geometry) { 129 std::optional<spv::BuiltIn> builtin = std::nullopt) {
130 const u32 num_vertices{NumVertices(ctx.profile.input_topology)}; 130 switch (ctx.stage) {
131 type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], num_vertices)); 131 case Stage::TessellationControl:
132 case Stage::TessellationEval:
133 if (per_invocation) {
134 type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], 32u));
135 }
136 break;
137 case Stage::Geometry:
138 if (per_invocation) {
139 const u32 num_vertices{NumVertices(ctx.profile.input_topology)};
140 type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], num_vertices));
141 }
142 break;
143 default:
144 break;
132 } 145 }
133 return DefineVariable(ctx, type, builtin, spv::StorageClass::Input); 146 return DefineVariable(ctx, type, builtin, spv::StorageClass::Input);
134} 147}
135 148
136Id DefineOutput(EmitContext& ctx, Id type, std::optional<spv::BuiltIn> builtin = std::nullopt) { 149Id DefineOutput(EmitContext& ctx, Id type, std::optional<u32> invocations,
150 std::optional<spv::BuiltIn> builtin = std::nullopt) {
151 if (invocations && ctx.stage == Stage::TessellationControl) {
152 type = ctx.TypeArray(type, ctx.Constant(ctx.U32[1], *invocations));
153 }
137 return DefineVariable(ctx, type, builtin, spv::StorageClass::Output); 154 return DefineVariable(ctx, type, builtin, spv::StorageClass::Output);
138} 155}
139 156
140void DefineGenericOutput(EmitContext& ctx, size_t index) { 157void DefineGenericOutput(EmitContext& ctx, size_t index, std::optional<u32> invocations) {
141 static constexpr std::string_view swizzle{"xyzw"}; 158 static constexpr std::string_view swizzle{"xyzw"};
142 const size_t base_attr_index{static_cast<size_t>(IR::Attribute::Generic0X) + index * 4}; 159 const size_t base_attr_index{static_cast<size_t>(IR::Attribute::Generic0X) + index * 4};
143 u32 element{0}; 160 u32 element{0};
@@ -150,7 +167,7 @@ void DefineGenericOutput(EmitContext& ctx, size_t index) {
150 } 167 }
151 const u32 num_components{xfb_varying ? xfb_varying->components : remainder}; 168 const u32 num_components{xfb_varying ? xfb_varying->components : remainder};
152 169
153 const Id id{DefineOutput(ctx, ctx.F32[num_components])}; 170 const Id id{DefineOutput(ctx, ctx.F32[num_components], invocations)};
154 ctx.Decorate(id, spv::Decoration::Location, static_cast<u32>(index)); 171 ctx.Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
155 if (element > 0) { 172 if (element > 0) {
156 ctx.Decorate(id, spv::Decoration::Component, element); 173 ctx.Decorate(id, spv::Decoration::Component, element);
@@ -161,10 +178,10 @@ void DefineGenericOutput(EmitContext& ctx, size_t index) {
161 ctx.Decorate(id, spv::Decoration::Offset, xfb_varying->offset); 178 ctx.Decorate(id, spv::Decoration::Offset, xfb_varying->offset);
162 } 179 }
163 if (num_components < 4 || element > 0) { 180 if (num_components < 4 || element > 0) {
164 ctx.Name(id, fmt::format("out_attr{}", index));
165 } else {
166 const std::string_view subswizzle{swizzle.substr(element, num_components)}; 181 const std::string_view subswizzle{swizzle.substr(element, num_components)};
167 ctx.Name(id, fmt::format("out_attr{}_{}", index, subswizzle)); 182 ctx.Name(id, fmt::format("out_attr{}_{}", index, subswizzle));
183 } else {
184 ctx.Name(id, fmt::format("out_attr{}", index));
168 } 185 }
169 const GenericElementInfo info{ 186 const GenericElementInfo info{
170 .id = id, 187 .id = id,
@@ -383,7 +400,7 @@ EmitContext::EmitContext(const Profile& profile_, IR::Program& program, u32& bin
383 AddCapability(spv::Capability::Shader); 400 AddCapability(spv::Capability::Shader);
384 DefineCommonTypes(program.info); 401 DefineCommonTypes(program.info);
385 DefineCommonConstants(); 402 DefineCommonConstants();
386 DefineInterfaces(program.info); 403 DefineInterfaces(program);
387 DefineLocalMemory(program); 404 DefineLocalMemory(program);
388 DefineSharedMemory(program); 405 DefineSharedMemory(program);
389 DefineSharedMemoryFunctions(program); 406 DefineSharedMemoryFunctions(program);
@@ -472,9 +489,9 @@ void EmitContext::DefineCommonConstants() {
472 f32_zero_value = Constant(F32[1], 0.0f); 489 f32_zero_value = Constant(F32[1], 0.0f);
473} 490}
474 491
475void EmitContext::DefineInterfaces(const Info& info) { 492void EmitContext::DefineInterfaces(const IR::Program& program) {
476 DefineInputs(info); 493 DefineInputs(program.info);
477 DefineOutputs(info); 494 DefineOutputs(program);
478} 495}
479 496
480void EmitContext::DefineLocalMemory(const IR::Program& program) { 497void EmitContext::DefineLocalMemory(const IR::Program& program) {
@@ -972,26 +989,29 @@ void EmitContext::DefineLabels(IR::Program& program) {
972 989
973void EmitContext::DefineInputs(const Info& info) { 990void EmitContext::DefineInputs(const Info& info) {
974 if (info.uses_workgroup_id) { 991 if (info.uses_workgroup_id) {
975 workgroup_id = DefineInput(*this, U32[3], spv::BuiltIn::WorkgroupId); 992 workgroup_id = DefineInput(*this, U32[3], false, spv::BuiltIn::WorkgroupId);
976 } 993 }
977 if (info.uses_local_invocation_id) { 994 if (info.uses_local_invocation_id) {
978 local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId); 995 local_invocation_id = DefineInput(*this, U32[3], false, spv::BuiltIn::LocalInvocationId);
996 }
997 if (info.uses_invocation_id) {
998 invocation_id = DefineInput(*this, U32[1], false, spv::BuiltIn::InvocationId);
979 } 999 }
980 if (info.uses_is_helper_invocation) { 1000 if (info.uses_is_helper_invocation) {
981 is_helper_invocation = DefineInput(*this, U1, spv::BuiltIn::HelperInvocation); 1001 is_helper_invocation = DefineInput(*this, U1, false, spv::BuiltIn::HelperInvocation);
982 } 1002 }
983 if (info.uses_subgroup_mask) { 1003 if (info.uses_subgroup_mask) {
984 subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR); 1004 subgroup_mask_eq = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupEqMaskKHR);
985 subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR); 1005 subgroup_mask_lt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLtMaskKHR);
986 subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR); 1006 subgroup_mask_le = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupLeMaskKHR);
987 subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR); 1007 subgroup_mask_gt = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGtMaskKHR);
988 subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR); 1008 subgroup_mask_ge = DefineInput(*this, U32[4], false, spv::BuiltIn::SubgroupGeMaskKHR);
989 } 1009 }
990 if (info.uses_subgroup_invocation_id || 1010 if (info.uses_subgroup_invocation_id ||
991 (profile.warp_size_potentially_larger_than_guest && 1011 (profile.warp_size_potentially_larger_than_guest &&
992 (info.uses_subgroup_vote || info.uses_subgroup_mask))) { 1012 (info.uses_subgroup_vote || info.uses_subgroup_mask))) {
993 subgroup_local_invocation_id = 1013 subgroup_local_invocation_id =
994 DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId); 1014 DefineInput(*this, U32[1], false, spv::BuiltIn::SubgroupLocalInvocationId);
995 } 1015 }
996 if (info.uses_fswzadd) { 1016 if (info.uses_fswzadd) {
997 const Id f32_one{Constant(F32[1], 1.0f)}; 1017 const Id f32_one{Constant(F32[1], 1.0f)};
@@ -1004,29 +1024,32 @@ void EmitContext::DefineInputs(const Info& info) {
1004 if (info.loads_position) { 1024 if (info.loads_position) {
1005 const bool is_fragment{stage != Stage::Fragment}; 1025 const bool is_fragment{stage != Stage::Fragment};
1006 const spv::BuiltIn built_in{is_fragment ? spv::BuiltIn::Position : spv::BuiltIn::FragCoord}; 1026 const spv::BuiltIn built_in{is_fragment ? spv::BuiltIn::Position : spv::BuiltIn::FragCoord};
1007 input_position = DefineInput(*this, F32[4], built_in); 1027 input_position = DefineInput(*this, F32[4], true, built_in);
1008 } 1028 }
1009 if (info.loads_instance_id) { 1029 if (info.loads_instance_id) {
1010 if (profile.support_vertex_instance_id) { 1030 if (profile.support_vertex_instance_id) {
1011 instance_id = DefineInput(*this, U32[1], spv::BuiltIn::InstanceId); 1031 instance_id = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceId);
1012 } else { 1032 } else {
1013 instance_index = DefineInput(*this, U32[1], spv::BuiltIn::InstanceIndex); 1033 instance_index = DefineInput(*this, U32[1], true, spv::BuiltIn::InstanceIndex);
1014 base_instance = DefineInput(*this, U32[1], spv::BuiltIn::BaseInstance); 1034 base_instance = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseInstance);
1015 } 1035 }
1016 } 1036 }
1017 if (info.loads_vertex_id) { 1037 if (info.loads_vertex_id) {
1018 if (profile.support_vertex_instance_id) { 1038 if (profile.support_vertex_instance_id) {
1019 vertex_id = DefineInput(*this, U32[1], spv::BuiltIn::VertexId); 1039 vertex_id = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexId);
1020 } else { 1040 } else {
1021 vertex_index = DefineInput(*this, U32[1], spv::BuiltIn::VertexIndex); 1041 vertex_index = DefineInput(*this, U32[1], true, spv::BuiltIn::VertexIndex);
1022 base_vertex = DefineInput(*this, U32[1], spv::BuiltIn::BaseVertex); 1042 base_vertex = DefineInput(*this, U32[1], true, spv::BuiltIn::BaseVertex);
1023 } 1043 }
1024 } 1044 }
1025 if (info.loads_front_face) { 1045 if (info.loads_front_face) {
1026 front_face = DefineInput(*this, U1, spv::BuiltIn::FrontFacing); 1046 front_face = DefineInput(*this, U1, true, spv::BuiltIn::FrontFacing);
1027 } 1047 }
1028 if (info.loads_point_coord) { 1048 if (info.loads_point_coord) {
1029 point_coord = DefineInput(*this, F32[2], spv::BuiltIn::PointCoord); 1049 point_coord = DefineInput(*this, F32[2], true, spv::BuiltIn::PointCoord);
1050 }
1051 if (info.loads_tess_coord) {
1052 tess_coord = DefineInput(*this, F32[3], false, spv::BuiltIn::TessCoord);
1030 } 1053 }
1031 for (size_t index = 0; index < info.input_generics.size(); ++index) { 1054 for (size_t index = 0; index < info.input_generics.size(); ++index) {
1032 const InputVarying generic{info.input_generics[index]}; 1055 const InputVarying generic{info.input_generics[index]};
@@ -1038,7 +1061,7 @@ void EmitContext::DefineInputs(const Info& info) {
1038 continue; 1061 continue;
1039 } 1062 }
1040 const Id type{GetAttributeType(*this, input_type)}; 1063 const Id type{GetAttributeType(*this, input_type)};
1041 const Id id{DefineInput(*this, type)}; 1064 const Id id{DefineInput(*this, type, true)};
1042 Decorate(id, spv::Decoration::Location, static_cast<u32>(index)); 1065 Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
1043 Name(id, fmt::format("in_attr{}", index)); 1066 Name(id, fmt::format("in_attr{}", index));
1044 input_generics[index] = id; 1067 input_generics[index] = id;
@@ -1059,58 +1082,98 @@ void EmitContext::DefineInputs(const Info& info) {
1059 break; 1082 break;
1060 } 1083 }
1061 } 1084 }
1085 if (stage == Stage::TessellationEval) {
1086 for (size_t index = 0; index < info.uses_patches.size(); ++index) {
1087 if (!info.uses_patches[index]) {
1088 continue;
1089 }
1090 const Id id{DefineInput(*this, F32[4], false)};
1091 Decorate(id, spv::Decoration::Patch);
1092 Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
1093 patches[index] = id;
1094 }
1095 }
1062} 1096}
1063 1097
1064void EmitContext::DefineOutputs(const Info& info) { 1098void EmitContext::DefineOutputs(const IR::Program& program) {
1099 const Info& info{program.info};
1100 const std::optional<u32> invocations{program.invocations};
1065 if (info.stores_position || stage == Stage::VertexB) { 1101 if (info.stores_position || stage == Stage::VertexB) {
1066 output_position = DefineOutput(*this, F32[4], spv::BuiltIn::Position); 1102 output_position = DefineOutput(*this, F32[4], invocations, spv::BuiltIn::Position);
1067 } 1103 }
1068 if (info.stores_point_size || profile.fixed_state_point_size) { 1104 if (info.stores_point_size || profile.fixed_state_point_size) {
1069 if (stage == Stage::Fragment) { 1105 if (stage == Stage::Fragment) {
1070 throw NotImplementedException("Storing PointSize in fragment stage"); 1106 throw NotImplementedException("Storing PointSize in fragment stage");
1071 } 1107 }
1072 output_point_size = DefineOutput(*this, F32[1], spv::BuiltIn::PointSize); 1108 output_point_size = DefineOutput(*this, F32[1], invocations, spv::BuiltIn::PointSize);
1073 } 1109 }
1074 if (info.stores_clip_distance) { 1110 if (info.stores_clip_distance) {
1075 if (stage == Stage::Fragment) { 1111 if (stage == Stage::Fragment) {
1076 throw NotImplementedException("Storing ClipDistance in fragment stage"); 1112 throw NotImplementedException("Storing ClipDistance in fragment stage");
1077 } 1113 }
1078 const Id type{TypeArray(F32[1], Constant(U32[1], 8U))}; 1114 const Id type{TypeArray(F32[1], Constant(U32[1], 8U))};
1079 clip_distances = DefineOutput(*this, type, spv::BuiltIn::ClipDistance); 1115 clip_distances = DefineOutput(*this, type, invocations, spv::BuiltIn::ClipDistance);
1080 } 1116 }
1081 if (info.stores_layer && 1117 if (info.stores_layer &&
1082 (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) { 1118 (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
1083 if (stage == Stage::Fragment) { 1119 if (stage == Stage::Fragment) {
1084 throw NotImplementedException("Storing Layer in fragment stage"); 1120 throw NotImplementedException("Storing Layer in fragment stage");
1085 } 1121 }
1086 layer = DefineOutput(*this, U32[1], spv::BuiltIn::Layer); 1122 layer = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::Layer);
1087 } 1123 }
1088 if (info.stores_viewport_index && 1124 if (info.stores_viewport_index &&
1089 (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) { 1125 (profile.support_viewport_index_layer_non_geometry || stage == Stage::Geometry)) {
1090 if (stage == Stage::Fragment) { 1126 if (stage == Stage::Fragment) {
1091 throw NotImplementedException("Storing ViewportIndex in fragment stage"); 1127 throw NotImplementedException("Storing ViewportIndex in fragment stage");
1092 } 1128 }
1093 viewport_index = DefineOutput(*this, U32[1], spv::BuiltIn::ViewportIndex); 1129 viewport_index = DefineOutput(*this, U32[1], invocations, spv::BuiltIn::ViewportIndex);
1094 } 1130 }
1095 for (size_t index = 0; index < info.stores_generics.size(); ++index) { 1131 for (size_t index = 0; index < info.stores_generics.size(); ++index) {
1096 if (info.stores_generics[index]) { 1132 if (info.stores_generics[index]) {
1097 DefineGenericOutput(*this, index); 1133 DefineGenericOutput(*this, index, invocations);
1098 } 1134 }
1099 } 1135 }
1100 if (stage == Stage::Fragment) { 1136 switch (stage) {
1137 case Stage::TessellationControl:
1138 if (info.stores_tess_level_outer) {
1139 const Id type{TypeArray(F32[1], Constant(U32[1], 4))};
1140 output_tess_level_outer =
1141 DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelOuter);
1142 Decorate(output_tess_level_outer, spv::Decoration::Patch);
1143 }
1144 if (info.stores_tess_level_inner) {
1145 const Id type{TypeArray(F32[1], Constant(U32[1], 2))};
1146 output_tess_level_inner =
1147 DefineOutput(*this, type, std::nullopt, spv::BuiltIn::TessLevelInner);
1148 Decorate(output_tess_level_inner, spv::Decoration::Patch);
1149 }
1150 for (size_t index = 0; index < info.uses_patches.size(); ++index) {
1151 if (!info.uses_patches[index]) {
1152 continue;
1153 }
1154 const Id id{DefineOutput(*this, F32[4], std::nullopt)};
1155 Decorate(id, spv::Decoration::Patch);
1156 Decorate(id, spv::Decoration::Location, static_cast<u32>(index));
1157 patches[index] = id;
1158 }
1159 break;
1160 case Stage::Fragment:
1101 for (u32 index = 0; index < 8; ++index) { 1161 for (u32 index = 0; index < 8; ++index) {
1102 if (!info.stores_frag_color[index]) { 1162 if (!info.stores_frag_color[index]) {
1103 continue; 1163 continue;
1104 } 1164 }
1105 frag_color[index] = DefineOutput(*this, F32[4]); 1165 frag_color[index] = DefineOutput(*this, F32[4], std::nullopt);
1106 Decorate(frag_color[index], spv::Decoration::Location, index); 1166 Decorate(frag_color[index], spv::Decoration::Location, index);
1107 Name(frag_color[index], fmt::format("frag_color{}", index)); 1167 Name(frag_color[index], fmt::format("frag_color{}", index));
1108 } 1168 }
1109 if (info.stores_frag_depth) { 1169 if (info.stores_frag_depth) {
1110 frag_depth = DefineOutput(*this, F32[1]); 1170 frag_depth = DefineOutput(*this, F32[1], std::nullopt);
1111 Decorate(frag_depth, spv::Decoration::BuiltIn, spv::BuiltIn::FragDepth); 1171 Decorate(frag_depth, spv::Decoration::BuiltIn, spv::BuiltIn::FragDepth);
1112 Name(frag_depth, "frag_depth"); 1172 Name(frag_depth, "frag_depth");
1113 } 1173 }
1174 break;
1175 default:
1176 break;
1114 } 1177 }
1115} 1178}
1116 1179