summaryrefslogtreecommitdiff
path: root/src/shader_recompiler/backend/spirv/emit_context.cpp
diff options
context:
space:
mode:
authorGravatar ReinUsesLisp2021-04-13 05:32:21 -0300
committerGravatar ameerj2021-07-22 21:51:27 -0400
commitfa75b9b0626c8e118e27207dd1e82e2f415fc0bc (patch)
tree29738f645876c19fd561a39b8f9d62799bf92ef9 /src/shader_recompiler/backend/spirv/emit_context.cpp
parentshader: Fix fixed pipeline point size on geometry shaders (diff)
downloadyuzu-fa75b9b0626c8e118e27207dd1e82e2f415fc0bc.tar.gz
yuzu-fa75b9b0626c8e118e27207dd1e82e2f415fc0bc.tar.xz
yuzu-fa75b9b0626c8e118e27207dd1e82e2f415fc0bc.zip
spirv: Rework storage buffers and shader memory
Diffstat (limited to 'src/shader_recompiler/backend/spirv/emit_context.cpp')
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp440
1 files changed, 249 insertions, 191 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 01b77a7d1..df53e58a8 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -15,7 +15,7 @@
15 15
16namespace Shader::Backend::SPIRV { 16namespace Shader::Backend::SPIRV {
17namespace { 17namespace {
18enum class CasFunctionType { 18enum class Operation {
19 Increment, 19 Increment,
20 Decrement, 20 Decrement,
21 FPAdd, 21 FPAdd,
@@ -23,44 +23,11 @@ enum class CasFunctionType {
23 FPMax, 23 FPMax,
24}; 24};
25 25
26Id CasFunction(EmitContext& ctx, CasFunctionType function_type, Id value_type) { 26struct AttrInfo {
27 const Id func_type{ctx.TypeFunction(value_type, value_type, value_type)}; 27 Id pointer;
28 const Id func{ctx.OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)}; 28 Id id;
29 const Id op_a{ctx.OpFunctionParameter(value_type)}; 29 bool needs_cast;
30 const Id op_b{ctx.OpFunctionParameter(value_type)}; 30};
31 ctx.AddLabel();
32 Id result{};
33 switch (function_type) {
34 case CasFunctionType::Increment: {
35 const Id pred{ctx.OpUGreaterThanEqual(ctx.U1, op_a, op_b)};
36 const Id incr{ctx.OpIAdd(value_type, op_a, ctx.Constant(value_type, 1))};
37 result = ctx.OpSelect(value_type, pred, ctx.u32_zero_value, incr);
38 break;
39 }
40 case CasFunctionType::Decrement: {
41 const Id lhs{ctx.OpIEqual(ctx.U1, op_a, ctx.Constant(value_type, 0u))};
42 const Id rhs{ctx.OpUGreaterThan(ctx.U1, op_a, op_b)};
43 const Id pred{ctx.OpLogicalOr(ctx.U1, lhs, rhs)};
44 const Id decr{ctx.OpISub(value_type, op_a, ctx.Constant(value_type, 1))};
45 result = ctx.OpSelect(value_type, pred, op_b, decr);
46 break;
47 }
48 case CasFunctionType::FPAdd:
49 result = ctx.OpFAdd(value_type, op_a, op_b);
50 break;
51 case CasFunctionType::FPMin:
52 result = ctx.OpFMin(value_type, op_a, op_b);
53 break;
54 case CasFunctionType::FPMax:
55 result = ctx.OpFMax(value_type, op_a, op_b);
56 break;
57 default:
58 break;
59 }
60 ctx.OpReturnValue(result);
61 ctx.OpFunctionEnd();
62 return func;
63}
64 31
65Id ImageType(EmitContext& ctx, const TextureDescriptor& desc) { 32Id ImageType(EmitContext& ctx, const TextureDescriptor& desc) {
66 const spv::ImageFormat format{spv::ImageFormat::Unknown}; 33 const spv::ImageFormat format{spv::ImageFormat::Unknown};
@@ -182,12 +149,6 @@ Id GetAttributeType(EmitContext& ctx, AttributeType type) {
182 throw InvalidArgument("Invalid attribute type {}", type); 149 throw InvalidArgument("Invalid attribute type {}", type);
183} 150}
184 151
185struct AttrInfo {
186 Id pointer;
187 Id id;
188 bool needs_cast;
189};
190
191std::optional<AttrInfo> AttrTypes(EmitContext& ctx, u32 index) { 152std::optional<AttrInfo> AttrTypes(EmitContext& ctx, u32 index) {
192 const AttributeType type{ctx.profile.generic_input_types.at(index)}; 153 const AttributeType type{ctx.profile.generic_input_types.at(index)};
193 switch (type) { 154 switch (type) {
@@ -203,6 +164,164 @@ std::optional<AttrInfo> AttrTypes(EmitContext& ctx, u32 index) {
203 throw InvalidArgument("Invalid attribute type {}", type); 164 throw InvalidArgument("Invalid attribute type {}", type);
204} 165}
205 166
167void DefineConstBuffers(EmitContext& ctx, const Info& info, Id UniformDefinitions::*member_type,
168 u32 binding, Id type, char type_char, u32 element_size) {
169 const Id array_type{ctx.TypeArray(type, ctx.Constant(ctx.U32[1], 65536U / element_size))};
170 ctx.Decorate(array_type, spv::Decoration::ArrayStride, element_size);
171
172 const Id struct_type{ctx.TypeStruct(array_type)};
173 ctx.Name(struct_type, fmt::format("cbuf_block_{}{}", type_char, element_size * CHAR_BIT));
174 ctx.Decorate(struct_type, spv::Decoration::Block);
175 ctx.MemberName(struct_type, 0, "data");
176 ctx.MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U);
177
178 const Id struct_pointer_type{ctx.TypePointer(spv::StorageClass::Uniform, struct_type)};
179 const Id uniform_type{ctx.TypePointer(spv::StorageClass::Uniform, type)};
180 ctx.uniform_types.*member_type = uniform_type;
181
182 for (const ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) {
183 const Id id{ctx.AddGlobalVariable(struct_pointer_type, spv::StorageClass::Uniform)};
184 ctx.Decorate(id, spv::Decoration::Binding, binding);
185 ctx.Decorate(id, spv::Decoration::DescriptorSet, 0U);
186 ctx.Name(id, fmt::format("c{}", desc.index));
187 for (size_t i = 0; i < desc.count; ++i) {
188 ctx.cbufs[desc.index + i].*member_type = id;
189 }
190 if (ctx.profile.supported_spirv >= 0x00010400) {
191 ctx.interfaces.push_back(id);
192 }
193 binding += desc.count;
194 }
195}
196
197void DefineSsbos(EmitContext& ctx, StorageTypeDefinition& type_def,
198 Id StorageDefinitions::*member_type, const Info& info, u32 binding, Id type,
199 u32 stride) {
200 const Id array_type{ctx.TypeRuntimeArray(type)};
201 ctx.Decorate(array_type, spv::Decoration::ArrayStride, stride);
202
203 const Id struct_type{ctx.TypeStruct(array_type)};
204 ctx.Decorate(struct_type, spv::Decoration::Block);
205 ctx.MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U);
206
207 const Id struct_pointer{ctx.TypePointer(spv::StorageClass::StorageBuffer, struct_type)};
208 type_def.array = struct_pointer;
209 type_def.element = ctx.TypePointer(spv::StorageClass::StorageBuffer, type);
210
211 u32 index{};
212 for (const StorageBufferDescriptor& desc : info.storage_buffers_descriptors) {
213 const Id id{ctx.AddGlobalVariable(struct_pointer, spv::StorageClass::StorageBuffer)};
214 ctx.Decorate(id, spv::Decoration::Binding, binding);
215 ctx.Decorate(id, spv::Decoration::DescriptorSet, 0U);
216 ctx.Name(id, fmt::format("ssbo{}", index));
217 if (ctx.profile.supported_spirv >= 0x00010400) {
218 ctx.interfaces.push_back(id);
219 }
220 for (size_t i = 0; i < desc.count; ++i) {
221 ctx.ssbos[index + i].*member_type = id;
222 }
223 index += desc.count;
224 binding += desc.count;
225 }
226}
227
228Id CasFunction(EmitContext& ctx, Operation operation, Id value_type) {
229 const Id func_type{ctx.TypeFunction(value_type, value_type, value_type)};
230 const Id func{ctx.OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)};
231 const Id op_a{ctx.OpFunctionParameter(value_type)};
232 const Id op_b{ctx.OpFunctionParameter(value_type)};
233 ctx.AddLabel();
234 Id result{};
235 switch (operation) {
236 case Operation::Increment: {
237 const Id pred{ctx.OpUGreaterThanEqual(ctx.U1, op_a, op_b)};
238 const Id incr{ctx.OpIAdd(value_type, op_a, ctx.Constant(value_type, 1))};
239 result = ctx.OpSelect(value_type, pred, ctx.u32_zero_value, incr);
240 break;
241 }
242 case Operation::Decrement: {
243 const Id lhs{ctx.OpIEqual(ctx.U1, op_a, ctx.Constant(value_type, 0u))};
244 const Id rhs{ctx.OpUGreaterThan(ctx.U1, op_a, op_b)};
245 const Id pred{ctx.OpLogicalOr(ctx.U1, lhs, rhs)};
246 const Id decr{ctx.OpISub(value_type, op_a, ctx.Constant(value_type, 1))};
247 result = ctx.OpSelect(value_type, pred, op_b, decr);
248 break;
249 }
250 case Operation::FPAdd:
251 result = ctx.OpFAdd(value_type, op_a, op_b);
252 break;
253 case Operation::FPMin:
254 result = ctx.OpFMin(value_type, op_a, op_b);
255 break;
256 case Operation::FPMax:
257 result = ctx.OpFMax(value_type, op_a, op_b);
258 break;
259 default:
260 break;
261 }
262 ctx.OpReturnValue(result);
263 ctx.OpFunctionEnd();
264 return func;
265}
266
267Id CasLoop(EmitContext& ctx, Operation operation, Id array_pointer, Id element_pointer,
268 Id value_type, Id memory_type, spv::Scope scope) {
269 const bool is_shared{scope == spv::Scope::Workgroup};
270 const bool is_struct{!is_shared || ctx.profile.support_explicit_workgroup_layout};
271 const Id cas_func{CasFunction(ctx, operation, value_type)};
272 const Id zero{ctx.u32_zero_value};
273 const Id scope_id{ctx.Constant(ctx.U32[1], static_cast<u32>(scope))};
274
275 const Id loop_header{ctx.OpLabel()};
276 const Id continue_block{ctx.OpLabel()};
277 const Id merge_block{ctx.OpLabel()};
278 const Id func_type{is_shared
279 ? ctx.TypeFunction(value_type, ctx.U32[1], value_type)
280 : ctx.TypeFunction(value_type, ctx.U32[1], value_type, array_pointer)};
281
282 const Id func{ctx.OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)};
283 const Id index{ctx.OpFunctionParameter(ctx.U32[1])};
284 const Id op_b{ctx.OpFunctionParameter(value_type)};
285 const Id base{is_shared ? ctx.shared_memory_u32 : ctx.OpFunctionParameter(array_pointer)};
286 ctx.AddLabel();
287 ctx.OpBranch(loop_header);
288 ctx.AddLabel(loop_header);
289
290 ctx.OpLoopMerge(merge_block, continue_block, spv::LoopControlMask::MaskNone);
291 ctx.OpBranch(continue_block);
292
293 ctx.AddLabel(continue_block);
294 const Id word_pointer{is_struct ? ctx.OpAccessChain(element_pointer, base, zero, index)
295 : ctx.OpAccessChain(element_pointer, base, index)};
296 if (value_type.value == ctx.F32[2].value) {
297 const Id u32_value{ctx.OpLoad(ctx.U32[1], word_pointer)};
298 const Id value{ctx.OpUnpackHalf2x16(ctx.F32[2], u32_value)};
299 const Id new_value{ctx.OpFunctionCall(value_type, cas_func, value, op_b)};
300 const Id u32_new_value{ctx.OpPackHalf2x16(ctx.U32[1], new_value)};
301 const Id atomic_res{ctx.OpAtomicCompareExchange(ctx.U32[1], word_pointer, scope_id, zero,
302 zero, u32_new_value, u32_value)};
303 const Id success{ctx.OpIEqual(ctx.U1, atomic_res, u32_value)};
304 ctx.OpBranchConditional(success, merge_block, loop_header);
305
306 ctx.AddLabel(merge_block);
307 ctx.OpReturnValue(ctx.OpUnpackHalf2x16(ctx.F32[2], atomic_res));
308 } else {
309 const Id value{ctx.OpLoad(memory_type, word_pointer)};
310 const bool matching_type{value_type.value == memory_type.value};
311 const Id bitcast_value{matching_type ? value : ctx.OpBitcast(value_type, value)};
312 const Id cal_res{ctx.OpFunctionCall(value_type, cas_func, bitcast_value, op_b)};
313 const Id new_value{matching_type ? cal_res : ctx.OpBitcast(memory_type, cal_res)};
314 const Id atomic_res{ctx.OpAtomicCompareExchange(ctx.U32[1], word_pointer, scope_id, zero,
315 zero, new_value, value)};
316 const Id success{ctx.OpIEqual(ctx.U1, atomic_res, value)};
317 ctx.OpBranchConditional(success, merge_block, loop_header);
318
319 ctx.AddLabel(merge_block);
320 ctx.OpReturnValue(ctx.OpBitcast(value_type, atomic_res));
321 }
322 ctx.OpFunctionEnd();
323 return func;
324}
206} // Anonymous namespace 325} // Anonymous namespace
207 326
208void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) { 327void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) {
@@ -226,6 +345,7 @@ EmitContext::EmitContext(const Profile& profile_, IR::Program& program, u32& bin
226 DefineInterfaces(program.info); 345 DefineInterfaces(program.info);
227 DefineLocalMemory(program); 346 DefineLocalMemory(program);
228 DefineSharedMemory(program); 347 DefineSharedMemory(program);
348 DefineSharedMemoryFunctions(program);
229 DefineConstantBuffers(program.info, binding); 349 DefineConstantBuffers(program.info, binding);
230 DefineStorageBuffers(program.info, binding); 350 DefineStorageBuffers(program.info, binding);
231 DefineTextureBuffers(program.info, binding); 351 DefineTextureBuffers(program.info, binding);
@@ -263,56 +383,6 @@ Id EmitContext::Def(const IR::Value& value) {
263 } 383 }
264} 384}
265 385
266Id EmitContext::CasLoop(Id function, CasPointerType pointer_type, Id value_type) {
267 const Id loop_header{OpLabel()};
268 const Id continue_block{OpLabel()};
269 const Id merge_block{OpLabel()};
270 const Id storage_type{pointer_type == CasPointerType::Shared ? shared_memory_u32_type
271 : storage_memory_u32};
272 const Id func_type{TypeFunction(value_type, U32[1], value_type, storage_type)};
273 const Id func{OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)};
274 const Id index{OpFunctionParameter(U32[1])};
275 const Id op_b{OpFunctionParameter(value_type)};
276 const Id base{OpFunctionParameter(storage_type)};
277 AddLabel();
278 const Id one{Constant(U32[1], 1)};
279 OpBranch(loop_header);
280 AddLabel(loop_header);
281 OpLoopMerge(merge_block, continue_block, spv::LoopControlMask::MaskNone);
282 OpBranch(continue_block);
283
284 AddLabel(continue_block);
285 const Id word_pointer{pointer_type == CasPointerType::Shared
286 ? OpAccessChain(shared_u32, base, index)
287 : OpAccessChain(storage_u32, base, u32_zero_value, index)};
288 if (value_type.value == F32[2].value) {
289 const Id u32_value{OpLoad(U32[1], word_pointer)};
290 const Id value{OpUnpackHalf2x16(F32[2], u32_value)};
291 const Id new_value{OpFunctionCall(value_type, function, value, op_b)};
292 const Id u32_new_value{OpPackHalf2x16(U32[1], new_value)};
293 const Id atomic_res{OpAtomicCompareExchange(U32[1], word_pointer, one, u32_zero_value,
294 u32_zero_value, u32_new_value, u32_value)};
295 const Id success{OpIEqual(U1, atomic_res, u32_value)};
296 OpBranchConditional(success, merge_block, loop_header);
297
298 AddLabel(merge_block);
299 OpReturnValue(OpUnpackHalf2x16(F32[2], atomic_res));
300 } else {
301 const Id value{OpLoad(U32[1], word_pointer)};
302 const Id new_value{OpBitcast(
303 U32[1], OpFunctionCall(value_type, function, OpBitcast(value_type, value), op_b))};
304 const Id atomic_res{OpAtomicCompareExchange(U32[1], word_pointer, one, u32_zero_value,
305 u32_zero_value, new_value, value)};
306 const Id success{OpIEqual(U1, atomic_res, value)};
307 OpBranchConditional(success, merge_block, loop_header);
308
309 AddLabel(merge_block);
310 OpReturnValue(OpBitcast(value_type, atomic_res));
311 }
312 OpFunctionEnd();
313 return func;
314}
315
316void EmitContext::DefineCommonTypes(const Info& info) { 386void EmitContext::DefineCommonTypes(const Info& info) {
317 void_id = TypeVoid(); 387 void_id = TypeVoid();
318 388
@@ -397,27 +467,31 @@ void EmitContext::DefineSharedMemory(const IR::Program& program) {
397 Decorate(variable, spv::Decoration::Aliased); 467 Decorate(variable, spv::Decoration::Aliased);
398 interfaces.push_back(variable); 468 interfaces.push_back(variable);
399 469
400 return std::make_pair(variable, element_pointer); 470 return std::make_tuple(variable, element_pointer, pointer);
401 }}; 471 }};
402 if (profile.support_explicit_workgroup_layout) { 472 if (profile.support_explicit_workgroup_layout) {
403 AddExtension("SPV_KHR_workgroup_memory_explicit_layout"); 473 AddExtension("SPV_KHR_workgroup_memory_explicit_layout");
404 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayoutKHR); 474 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayoutKHR);
405 if (program.info.uses_int8) { 475 if (program.info.uses_int8) {
406 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR); 476 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
407 std::tie(shared_memory_u8, shared_u8) = make(U8, 1); 477 std::tie(shared_memory_u8, shared_u8, std::ignore) = make(U8, 1);
408 } 478 }
409 if (program.info.uses_int16) { 479 if (program.info.uses_int16) {
410 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR); 480 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
411 std::tie(shared_memory_u16, shared_u16) = make(U16, 2); 481 std::tie(shared_memory_u16, shared_u16, std::ignore) = make(U16, 2);
482 }
483 if (program.info.uses_int64) {
484 std::tie(shared_memory_u64, shared_u64, std::ignore) = make(U64, 8);
412 } 485 }
413 std::tie(shared_memory_u32, shared_u32) = make(U32[1], 4); 486 std::tie(shared_memory_u32, shared_u32, shared_memory_u32_type) = make(U32[1], 4);
414 std::tie(shared_memory_u32x2, shared_u32x2) = make(U32[2], 8); 487 std::tie(shared_memory_u32x2, shared_u32x2, std::ignore) = make(U32[2], 8);
415 std::tie(shared_memory_u32x4, shared_u32x4) = make(U32[4], 16); 488 std::tie(shared_memory_u32x4, shared_u32x4, std::ignore) = make(U32[4], 16);
416 return; 489 return;
417 } 490 }
418 const u32 num_elements{Common::DivCeil(program.shared_memory_size, 4U)}; 491 const u32 num_elements{Common::DivCeil(program.shared_memory_size, 4U)};
419 const Id type{TypeArray(U32[1], Constant(U32[1], num_elements))}; 492 const Id type{TypeArray(U32[1], Constant(U32[1], num_elements))};
420 shared_memory_u32_type = TypePointer(spv::StorageClass::Workgroup, type); 493 shared_memory_u32_type = TypePointer(spv::StorageClass::Workgroup, type);
494
421 shared_u32 = TypePointer(spv::StorageClass::Workgroup, U32[1]); 495 shared_u32 = TypePointer(spv::StorageClass::Workgroup, U32[1]);
422 shared_memory_u32 = AddGlobalVariable(shared_memory_u32_type, spv::StorageClass::Workgroup); 496 shared_memory_u32 = AddGlobalVariable(shared_memory_u32_type, spv::StorageClass::Workgroup);
423 interfaces.push_back(shared_memory_u32); 497 interfaces.push_back(shared_memory_u32);
@@ -463,13 +537,16 @@ void EmitContext::DefineSharedMemory(const IR::Program& program) {
463 if (program.info.uses_int16) { 537 if (program.info.uses_int16) {
464 shared_store_u16_func = make_function(16, 16); 538 shared_store_u16_func = make_function(16, 16);
465 } 539 }
540}
541
542void EmitContext::DefineSharedMemoryFunctions(const IR::Program& program) {
466 if (program.info.uses_shared_increment) { 543 if (program.info.uses_shared_increment) {
467 const Id inc_func{CasFunction(*this, CasFunctionType::Increment, U32[1])}; 544 increment_cas_shared = CasLoop(*this, Operation::Increment, shared_memory_u32_type,
468 increment_cas_shared = CasLoop(inc_func, CasPointerType::Shared, U32[1]); 545 shared_u32, U32[1], U32[1], spv::Scope::Workgroup);
469 } 546 }
470 if (program.info.uses_shared_decrement) { 547 if (program.info.uses_shared_decrement) {
471 const Id dec_func{CasFunction(*this, CasFunctionType::Decrement, U32[1])}; 548 decrement_cas_shared = CasLoop(*this, Operation::Decrement, shared_memory_u32_type,
472 decrement_cas_shared = CasLoop(dec_func, CasPointerType::Shared, U32[1]); 549 shared_u32, U32[1], U32[1], spv::Scope::Workgroup);
473 } 550 }
474} 551}
475 552
@@ -628,21 +705,24 @@ void EmitContext::DefineConstantBuffers(const Info& info, u32& binding) {
628 return; 705 return;
629 } 706 }
630 if (True(info.used_constant_buffer_types & IR::Type::U8)) { 707 if (True(info.used_constant_buffer_types & IR::Type::U8)) {
631 DefineConstantBuffers(info, &UniformDefinitions::U8, binding, U8, 'u', sizeof(u8)); 708 DefineConstBuffers(*this, info, &UniformDefinitions::U8, binding, U8, 'u', sizeof(u8));
632 DefineConstantBuffers(info, &UniformDefinitions::S8, binding, S8, 's', sizeof(s8)); 709 DefineConstBuffers(*this, info, &UniformDefinitions::S8, binding, S8, 's', sizeof(s8));
633 } 710 }
634 if (True(info.used_constant_buffer_types & IR::Type::U16)) { 711 if (True(info.used_constant_buffer_types & IR::Type::U16)) {
635 DefineConstantBuffers(info, &UniformDefinitions::U16, binding, U16, 'u', sizeof(u16)); 712 DefineConstBuffers(*this, info, &UniformDefinitions::U16, binding, U16, 'u', sizeof(u16));
636 DefineConstantBuffers(info, &UniformDefinitions::S16, binding, S16, 's', sizeof(s16)); 713 DefineConstBuffers(*this, info, &UniformDefinitions::S16, binding, S16, 's', sizeof(s16));
637 } 714 }
638 if (True(info.used_constant_buffer_types & IR::Type::U32)) { 715 if (True(info.used_constant_buffer_types & IR::Type::U32)) {
639 DefineConstantBuffers(info, &UniformDefinitions::U32, binding, U32[1], 'u', sizeof(u32)); 716 DefineConstBuffers(*this, info, &UniformDefinitions::U32, binding, U32[1], 'u',
717 sizeof(u32));
640 } 718 }
641 if (True(info.used_constant_buffer_types & IR::Type::F32)) { 719 if (True(info.used_constant_buffer_types & IR::Type::F32)) {
642 DefineConstantBuffers(info, &UniformDefinitions::F32, binding, F32[1], 'f', sizeof(f32)); 720 DefineConstBuffers(*this, info, &UniformDefinitions::F32, binding, F32[1], 'f',
721 sizeof(f32));
643 } 722 }
644 if (True(info.used_constant_buffer_types & IR::Type::U32x2)) { 723 if (True(info.used_constant_buffer_types & IR::Type::U32x2)) {
645 DefineConstantBuffers(info, &UniformDefinitions::U32x2, binding, U32[2], 'u', sizeof(u64)); 724 DefineConstBuffers(*this, info, &UniformDefinitions::U32x2, binding, U32[2], 'u',
725 sizeof(u32[2]));
646 } 726 }
647 for (const ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) { 727 for (const ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) {
648 binding += desc.count; 728 binding += desc.count;
@@ -655,75 +735,83 @@ void EmitContext::DefineStorageBuffers(const Info& info, u32& binding) {
655 } 735 }
656 AddExtension("SPV_KHR_storage_buffer_storage_class"); 736 AddExtension("SPV_KHR_storage_buffer_storage_class");
657 737
658 const Id array_type{TypeRuntimeArray(U32[1])}; 738 if (True(info.used_storage_buffer_types & IR::Type::U8)) {
659 Decorate(array_type, spv::Decoration::ArrayStride, 4U); 739 DefineSsbos(*this, storage_types.U8, &StorageDefinitions::U8, info, binding, U8,
660 740 sizeof(u8));
661 const Id struct_type{TypeStruct(array_type)}; 741 DefineSsbos(*this, storage_types.S8, &StorageDefinitions::S8, info, binding, S8,
662 Name(struct_type, "ssbo_block"); 742 sizeof(u8));
663 Decorate(struct_type, spv::Decoration::Block); 743 }
664 MemberName(struct_type, 0, "data"); 744 if (True(info.used_storage_buffer_types & IR::Type::U16)) {
665 MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U); 745 DefineSsbos(*this, storage_types.U16, &StorageDefinitions::U16, info, binding, U16,
666 746 sizeof(u16));
667 storage_memory_u32 = TypePointer(spv::StorageClass::StorageBuffer, struct_type); 747 DefineSsbos(*this, storage_types.S16, &StorageDefinitions::S16, info, binding, S16,
668 storage_u32 = TypePointer(spv::StorageClass::StorageBuffer, U32[1]); 748 sizeof(u16));
669 749 }
670 u32 index{}; 750 if (True(info.used_storage_buffer_types & IR::Type::U32)) {
751 DefineSsbos(*this, storage_types.U32, &StorageDefinitions::U32, info, binding, U32[1],
752 sizeof(u32));
753 }
754 if (True(info.used_storage_buffer_types & IR::Type::F32)) {
755 DefineSsbos(*this, storage_types.F32, &StorageDefinitions::F32, info, binding, F32[1],
756 sizeof(f32));
757 }
758 if (True(info.used_storage_buffer_types & IR::Type::U64)) {
759 DefineSsbos(*this, storage_types.U64, &StorageDefinitions::U64, info, binding, U64,
760 sizeof(u64));
761 }
762 if (True(info.used_storage_buffer_types & IR::Type::U32x2)) {
763 DefineSsbos(*this, storage_types.U32x2, &StorageDefinitions::U32x2, info, binding, U32[2],
764 sizeof(u32[2]));
765 }
766 if (True(info.used_storage_buffer_types & IR::Type::U32x4)) {
767 DefineSsbos(*this, storage_types.U32x4, &StorageDefinitions::U32x4, info, binding, U32[4],
768 sizeof(u32[4]));
769 }
671 for (const StorageBufferDescriptor& desc : info.storage_buffers_descriptors) { 770 for (const StorageBufferDescriptor& desc : info.storage_buffers_descriptors) {
672 const Id id{AddGlobalVariable(storage_memory_u32, spv::StorageClass::StorageBuffer)};
673 Decorate(id, spv::Decoration::Binding, binding);
674 Decorate(id, spv::Decoration::DescriptorSet, 0U);
675 Name(id, fmt::format("ssbo{}", index));
676 if (profile.supported_spirv >= 0x00010400) {
677 interfaces.push_back(id);
678 }
679 std::fill_n(ssbos.data() + index, desc.count, id);
680 index += desc.count;
681 binding += desc.count; 771 binding += desc.count;
682 } 772 }
683 if (info.uses_global_increment) { 773 const bool needs_function{
774 info.uses_global_increment || info.uses_global_decrement || info.uses_atomic_f32_add ||
775 info.uses_atomic_f16x2_add || info.uses_atomic_f16x2_min || info.uses_atomic_f16x2_max ||
776 info.uses_atomic_f32x2_add || info.uses_atomic_f32x2_min || info.uses_atomic_f32x2_max};
777 if (needs_function) {
684 AddCapability(spv::Capability::VariablePointersStorageBuffer); 778 AddCapability(spv::Capability::VariablePointersStorageBuffer);
685 const Id inc_func{CasFunction(*this, CasFunctionType::Increment, U32[1])}; 779 }
686 increment_cas_ssbo = CasLoop(inc_func, CasPointerType::Ssbo, U32[1]); 780 if (info.uses_global_increment) {
781 increment_cas_ssbo = CasLoop(*this, Operation::Increment, storage_types.U32.array,
782 storage_types.U32.element, U32[1], U32[1], spv::Scope::Device);
687 } 783 }
688 if (info.uses_global_decrement) { 784 if (info.uses_global_decrement) {
689 AddCapability(spv::Capability::VariablePointersStorageBuffer); 785 decrement_cas_ssbo = CasLoop(*this, Operation::Decrement, storage_types.U32.array,
690 const Id dec_func{CasFunction(*this, CasFunctionType::Decrement, U32[1])}; 786 storage_types.U32.element, U32[1], U32[1], spv::Scope::Device);
691 decrement_cas_ssbo = CasLoop(dec_func, CasPointerType::Ssbo, U32[1]);
692 } 787 }
693 if (info.uses_atomic_f32_add) { 788 if (info.uses_atomic_f32_add) {
694 AddCapability(spv::Capability::VariablePointersStorageBuffer); 789 f32_add_cas = CasLoop(*this, Operation::FPAdd, storage_types.U32.array,
695 const Id add_func{CasFunction(*this, CasFunctionType::FPAdd, F32[1])}; 790 storage_types.U32.element, F32[1], U32[1], spv::Scope::Device);
696 f32_add_cas = CasLoop(add_func, CasPointerType::Ssbo, F32[1]);
697 } 791 }
698 if (info.uses_atomic_f16x2_add) { 792 if (info.uses_atomic_f16x2_add) {
699 AddCapability(spv::Capability::VariablePointersStorageBuffer); 793 f16x2_add_cas = CasLoop(*this, Operation::FPAdd, storage_types.U32.array,
700 const Id add_func{CasFunction(*this, CasFunctionType::FPAdd, F16[2])}; 794 storage_types.U32.element, F16[2], F16[2], spv::Scope::Device);
701 f16x2_add_cas = CasLoop(add_func, CasPointerType::Ssbo, F16[2]);
702 } 795 }
703 if (info.uses_atomic_f16x2_min) { 796 if (info.uses_atomic_f16x2_min) {
704 AddCapability(spv::Capability::VariablePointersStorageBuffer); 797 f16x2_min_cas = CasLoop(*this, Operation::FPMin, storage_types.U32.array,
705 const Id func{CasFunction(*this, CasFunctionType::FPMin, F16[2])}; 798 storage_types.U32.element, F16[2], F16[2], spv::Scope::Device);
706 f16x2_min_cas = CasLoop(func, CasPointerType::Ssbo, F16[2]);
707 } 799 }
708 if (info.uses_atomic_f16x2_max) { 800 if (info.uses_atomic_f16x2_max) {
709 AddCapability(spv::Capability::VariablePointersStorageBuffer); 801 f16x2_max_cas = CasLoop(*this, Operation::FPMax, storage_types.U32.array,
710 const Id func{CasFunction(*this, CasFunctionType::FPMax, F16[2])}; 802 storage_types.U32.element, F16[2], F16[2], spv::Scope::Device);
711 f16x2_max_cas = CasLoop(func, CasPointerType::Ssbo, F16[2]);
712 } 803 }
713 if (info.uses_atomic_f32x2_add) { 804 if (info.uses_atomic_f32x2_add) {
714 AddCapability(spv::Capability::VariablePointersStorageBuffer); 805 f32x2_add_cas = CasLoop(*this, Operation::FPAdd, storage_types.U32.array,
715 const Id add_func{CasFunction(*this, CasFunctionType::FPAdd, F32[2])}; 806 storage_types.U32.element, F32[2], F32[2], spv::Scope::Device);
716 f32x2_add_cas = CasLoop(add_func, CasPointerType::Ssbo, F32[2]);
717 } 807 }
718 if (info.uses_atomic_f32x2_min) { 808 if (info.uses_atomic_f32x2_min) {
719 AddCapability(spv::Capability::VariablePointersStorageBuffer); 809 f32x2_min_cas = CasLoop(*this, Operation::FPMin, storage_types.U32.array,
720 const Id func{CasFunction(*this, CasFunctionType::FPMin, F32[2])}; 810 storage_types.U32.element, F32[2], F32[2], spv::Scope::Device);
721 f32x2_min_cas = CasLoop(func, CasPointerType::Ssbo, F32[2]);
722 } 811 }
723 if (info.uses_atomic_f32x2_max) { 812 if (info.uses_atomic_f32x2_max) {
724 AddCapability(spv::Capability::VariablePointersStorageBuffer); 813 f32x2_max_cas = CasLoop(*this, Operation::FPMax, storage_types.U32.array,
725 const Id func{CasFunction(*this, CasFunctionType::FPMax, F32[2])}; 814 storage_types.U32.element, F32[2], F32[2], spv::Scope::Device);
726 f32x2_max_cas = CasLoop(func, CasPointerType::Ssbo, F32[2]);
727 } 815 }
728} 816}
729 817
@@ -903,36 +991,6 @@ void EmitContext::DefineInputs(const Info& info) {
903 } 991 }
904} 992}
905 993
906void EmitContext::DefineConstantBuffers(const Info& info, Id UniformDefinitions::*member_type,
907 u32 binding, Id type, char type_char, u32 element_size) {
908 const Id array_type{TypeArray(type, Constant(U32[1], 65536U / element_size))};
909 Decorate(array_type, spv::Decoration::ArrayStride, element_size);
910
911 const Id struct_type{TypeStruct(array_type)};
912 Name(struct_type, fmt::format("cbuf_block_{}{}", type_char, element_size * CHAR_BIT));
913 Decorate(struct_type, spv::Decoration::Block);
914 MemberName(struct_type, 0, "data");
915 MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U);
916
917 const Id struct_pointer_type{TypePointer(spv::StorageClass::Uniform, struct_type)};
918 const Id uniform_type{TypePointer(spv::StorageClass::Uniform, type)};
919 uniform_types.*member_type = uniform_type;
920
921 for (const ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) {
922 const Id id{AddGlobalVariable(struct_pointer_type, spv::StorageClass::Uniform)};
923 Decorate(id, spv::Decoration::Binding, binding);
924 Decorate(id, spv::Decoration::DescriptorSet, 0U);
925 Name(id, fmt::format("c{}", desc.index));
926 for (size_t i = 0; i < desc.count; ++i) {
927 cbufs[desc.index + i].*member_type = id;
928 }
929 if (profile.supported_spirv >= 0x00010400) {
930 interfaces.push_back(id);
931 }
932 binding += desc.count;
933 }
934}
935
936void EmitContext::DefineOutputs(const Info& info) { 994void EmitContext::DefineOutputs(const Info& info) {
937 if (info.stores_position || stage == Stage::VertexB) { 995 if (info.stores_position || stage == Stage::VertexB) {
938 output_position = DefineOutput(*this, F32[4], spv::BuiltIn::Position); 996 output_position = DefineOutput(*this, F32[4], spv::BuiltIn::Position);