summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp440
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.h49
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv.cpp2
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv.h20
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_atomic.cpp333
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp135
-rw-r--r--src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp69
-rw-r--r--src/shader_recompiler/shader_info.h4
-rw-r--r--src/video_core/vulkan_common/vulkan_device.cpp29
9 files changed, 581 insertions, 500 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
index 01b77a7d1..df53e58a8 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -15,7 +15,7 @@
15 15
16namespace Shader::Backend::SPIRV { 16namespace Shader::Backend::SPIRV {
17namespace { 17namespace {
18enum class CasFunctionType { 18enum class Operation {
19 Increment, 19 Increment,
20 Decrement, 20 Decrement,
21 FPAdd, 21 FPAdd,
@@ -23,44 +23,11 @@ enum class CasFunctionType {
23 FPMax, 23 FPMax,
24}; 24};
25 25
26Id CasFunction(EmitContext& ctx, CasFunctionType function_type, Id value_type) { 26struct AttrInfo {
27 const Id func_type{ctx.TypeFunction(value_type, value_type, value_type)}; 27 Id pointer;
28 const Id func{ctx.OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)}; 28 Id id;
29 const Id op_a{ctx.OpFunctionParameter(value_type)}; 29 bool needs_cast;
30 const Id op_b{ctx.OpFunctionParameter(value_type)}; 30};
31 ctx.AddLabel();
32 Id result{};
33 switch (function_type) {
34 case CasFunctionType::Increment: {
35 const Id pred{ctx.OpUGreaterThanEqual(ctx.U1, op_a, op_b)};
36 const Id incr{ctx.OpIAdd(value_type, op_a, ctx.Constant(value_type, 1))};
37 result = ctx.OpSelect(value_type, pred, ctx.u32_zero_value, incr);
38 break;
39 }
40 case CasFunctionType::Decrement: {
41 const Id lhs{ctx.OpIEqual(ctx.U1, op_a, ctx.Constant(value_type, 0u))};
42 const Id rhs{ctx.OpUGreaterThan(ctx.U1, op_a, op_b)};
43 const Id pred{ctx.OpLogicalOr(ctx.U1, lhs, rhs)};
44 const Id decr{ctx.OpISub(value_type, op_a, ctx.Constant(value_type, 1))};
45 result = ctx.OpSelect(value_type, pred, op_b, decr);
46 break;
47 }
48 case CasFunctionType::FPAdd:
49 result = ctx.OpFAdd(value_type, op_a, op_b);
50 break;
51 case CasFunctionType::FPMin:
52 result = ctx.OpFMin(value_type, op_a, op_b);
53 break;
54 case CasFunctionType::FPMax:
55 result = ctx.OpFMax(value_type, op_a, op_b);
56 break;
57 default:
58 break;
59 }
60 ctx.OpReturnValue(result);
61 ctx.OpFunctionEnd();
62 return func;
63}
64 31
65Id ImageType(EmitContext& ctx, const TextureDescriptor& desc) { 32Id ImageType(EmitContext& ctx, const TextureDescriptor& desc) {
66 const spv::ImageFormat format{spv::ImageFormat::Unknown}; 33 const spv::ImageFormat format{spv::ImageFormat::Unknown};
@@ -182,12 +149,6 @@ Id GetAttributeType(EmitContext& ctx, AttributeType type) {
182 throw InvalidArgument("Invalid attribute type {}", type); 149 throw InvalidArgument("Invalid attribute type {}", type);
183} 150}
184 151
185struct AttrInfo {
186 Id pointer;
187 Id id;
188 bool needs_cast;
189};
190
191std::optional<AttrInfo> AttrTypes(EmitContext& ctx, u32 index) { 152std::optional<AttrInfo> AttrTypes(EmitContext& ctx, u32 index) {
192 const AttributeType type{ctx.profile.generic_input_types.at(index)}; 153 const AttributeType type{ctx.profile.generic_input_types.at(index)};
193 switch (type) { 154 switch (type) {
@@ -203,6 +164,164 @@ std::optional<AttrInfo> AttrTypes(EmitContext& ctx, u32 index) {
203 throw InvalidArgument("Invalid attribute type {}", type); 164 throw InvalidArgument("Invalid attribute type {}", type);
204} 165}
205 166
167void DefineConstBuffers(EmitContext& ctx, const Info& info, Id UniformDefinitions::*member_type,
168 u32 binding, Id type, char type_char, u32 element_size) {
169 const Id array_type{ctx.TypeArray(type, ctx.Constant(ctx.U32[1], 65536U / element_size))};
170 ctx.Decorate(array_type, spv::Decoration::ArrayStride, element_size);
171
172 const Id struct_type{ctx.TypeStruct(array_type)};
173 ctx.Name(struct_type, fmt::format("cbuf_block_{}{}", type_char, element_size * CHAR_BIT));
174 ctx.Decorate(struct_type, spv::Decoration::Block);
175 ctx.MemberName(struct_type, 0, "data");
176 ctx.MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U);
177
178 const Id struct_pointer_type{ctx.TypePointer(spv::StorageClass::Uniform, struct_type)};
179 const Id uniform_type{ctx.TypePointer(spv::StorageClass::Uniform, type)};
180 ctx.uniform_types.*member_type = uniform_type;
181
182 for (const ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) {
183 const Id id{ctx.AddGlobalVariable(struct_pointer_type, spv::StorageClass::Uniform)};
184 ctx.Decorate(id, spv::Decoration::Binding, binding);
185 ctx.Decorate(id, spv::Decoration::DescriptorSet, 0U);
186 ctx.Name(id, fmt::format("c{}", desc.index));
187 for (size_t i = 0; i < desc.count; ++i) {
188 ctx.cbufs[desc.index + i].*member_type = id;
189 }
190 if (ctx.profile.supported_spirv >= 0x00010400) {
191 ctx.interfaces.push_back(id);
192 }
193 binding += desc.count;
194 }
195}
196
197void DefineSsbos(EmitContext& ctx, StorageTypeDefinition& type_def,
198 Id StorageDefinitions::*member_type, const Info& info, u32 binding, Id type,
199 u32 stride) {
200 const Id array_type{ctx.TypeRuntimeArray(type)};
201 ctx.Decorate(array_type, spv::Decoration::ArrayStride, stride);
202
203 const Id struct_type{ctx.TypeStruct(array_type)};
204 ctx.Decorate(struct_type, spv::Decoration::Block);
205 ctx.MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U);
206
207 const Id struct_pointer{ctx.TypePointer(spv::StorageClass::StorageBuffer, struct_type)};
208 type_def.array = struct_pointer;
209 type_def.element = ctx.TypePointer(spv::StorageClass::StorageBuffer, type);
210
211 u32 index{};
212 for (const StorageBufferDescriptor& desc : info.storage_buffers_descriptors) {
213 const Id id{ctx.AddGlobalVariable(struct_pointer, spv::StorageClass::StorageBuffer)};
214 ctx.Decorate(id, spv::Decoration::Binding, binding);
215 ctx.Decorate(id, spv::Decoration::DescriptorSet, 0U);
216 ctx.Name(id, fmt::format("ssbo{}", index));
217 if (ctx.profile.supported_spirv >= 0x00010400) {
218 ctx.interfaces.push_back(id);
219 }
220 for (size_t i = 0; i < desc.count; ++i) {
221 ctx.ssbos[index + i].*member_type = id;
222 }
223 index += desc.count;
224 binding += desc.count;
225 }
226}
227
228Id CasFunction(EmitContext& ctx, Operation operation, Id value_type) {
229 const Id func_type{ctx.TypeFunction(value_type, value_type, value_type)};
230 const Id func{ctx.OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)};
231 const Id op_a{ctx.OpFunctionParameter(value_type)};
232 const Id op_b{ctx.OpFunctionParameter(value_type)};
233 ctx.AddLabel();
234 Id result{};
235 switch (operation) {
236 case Operation::Increment: {
237 const Id pred{ctx.OpUGreaterThanEqual(ctx.U1, op_a, op_b)};
238 const Id incr{ctx.OpIAdd(value_type, op_a, ctx.Constant(value_type, 1))};
239 result = ctx.OpSelect(value_type, pred, ctx.u32_zero_value, incr);
240 break;
241 }
242 case Operation::Decrement: {
243 const Id lhs{ctx.OpIEqual(ctx.U1, op_a, ctx.Constant(value_type, 0u))};
244 const Id rhs{ctx.OpUGreaterThan(ctx.U1, op_a, op_b)};
245 const Id pred{ctx.OpLogicalOr(ctx.U1, lhs, rhs)};
246 const Id decr{ctx.OpISub(value_type, op_a, ctx.Constant(value_type, 1))};
247 result = ctx.OpSelect(value_type, pred, op_b, decr);
248 break;
249 }
250 case Operation::FPAdd:
251 result = ctx.OpFAdd(value_type, op_a, op_b);
252 break;
253 case Operation::FPMin:
254 result = ctx.OpFMin(value_type, op_a, op_b);
255 break;
256 case Operation::FPMax:
257 result = ctx.OpFMax(value_type, op_a, op_b);
258 break;
259 default:
260 break;
261 }
262 ctx.OpReturnValue(result);
263 ctx.OpFunctionEnd();
264 return func;
265}
266
267Id CasLoop(EmitContext& ctx, Operation operation, Id array_pointer, Id element_pointer,
268 Id value_type, Id memory_type, spv::Scope scope) {
269 const bool is_shared{scope == spv::Scope::Workgroup};
270 const bool is_struct{!is_shared || ctx.profile.support_explicit_workgroup_layout};
271 const Id cas_func{CasFunction(ctx, operation, value_type)};
272 const Id zero{ctx.u32_zero_value};
273 const Id scope_id{ctx.Constant(ctx.U32[1], static_cast<u32>(scope))};
274
275 const Id loop_header{ctx.OpLabel()};
276 const Id continue_block{ctx.OpLabel()};
277 const Id merge_block{ctx.OpLabel()};
278 const Id func_type{is_shared
279 ? ctx.TypeFunction(value_type, ctx.U32[1], value_type)
280 : ctx.TypeFunction(value_type, ctx.U32[1], value_type, array_pointer)};
281
282 const Id func{ctx.OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)};
283 const Id index{ctx.OpFunctionParameter(ctx.U32[1])};
284 const Id op_b{ctx.OpFunctionParameter(value_type)};
285 const Id base{is_shared ? ctx.shared_memory_u32 : ctx.OpFunctionParameter(array_pointer)};
286 ctx.AddLabel();
287 ctx.OpBranch(loop_header);
288 ctx.AddLabel(loop_header);
289
290 ctx.OpLoopMerge(merge_block, continue_block, spv::LoopControlMask::MaskNone);
291 ctx.OpBranch(continue_block);
292
293 ctx.AddLabel(continue_block);
294 const Id word_pointer{is_struct ? ctx.OpAccessChain(element_pointer, base, zero, index)
295 : ctx.OpAccessChain(element_pointer, base, index)};
296 if (value_type.value == ctx.F32[2].value) {
297 const Id u32_value{ctx.OpLoad(ctx.U32[1], word_pointer)};
298 const Id value{ctx.OpUnpackHalf2x16(ctx.F32[2], u32_value)};
299 const Id new_value{ctx.OpFunctionCall(value_type, cas_func, value, op_b)};
300 const Id u32_new_value{ctx.OpPackHalf2x16(ctx.U32[1], new_value)};
301 const Id atomic_res{ctx.OpAtomicCompareExchange(ctx.U32[1], word_pointer, scope_id, zero,
302 zero, u32_new_value, u32_value)};
303 const Id success{ctx.OpIEqual(ctx.U1, atomic_res, u32_value)};
304 ctx.OpBranchConditional(success, merge_block, loop_header);
305
306 ctx.AddLabel(merge_block);
307 ctx.OpReturnValue(ctx.OpUnpackHalf2x16(ctx.F32[2], atomic_res));
308 } else {
309 const Id value{ctx.OpLoad(memory_type, word_pointer)};
310 const bool matching_type{value_type.value == memory_type.value};
311 const Id bitcast_value{matching_type ? value : ctx.OpBitcast(value_type, value)};
312 const Id cal_res{ctx.OpFunctionCall(value_type, cas_func, bitcast_value, op_b)};
313 const Id new_value{matching_type ? cal_res : ctx.OpBitcast(memory_type, cal_res)};
314 const Id atomic_res{ctx.OpAtomicCompareExchange(ctx.U32[1], word_pointer, scope_id, zero,
315 zero, new_value, value)};
316 const Id success{ctx.OpIEqual(ctx.U1, atomic_res, value)};
317 ctx.OpBranchConditional(success, merge_block, loop_header);
318
319 ctx.AddLabel(merge_block);
320 ctx.OpReturnValue(ctx.OpBitcast(value_type, atomic_res));
321 }
322 ctx.OpFunctionEnd();
323 return func;
324}
206} // Anonymous namespace 325} // Anonymous namespace
207 326
208void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) { 327void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) {
@@ -226,6 +345,7 @@ EmitContext::EmitContext(const Profile& profile_, IR::Program& program, u32& bin
226 DefineInterfaces(program.info); 345 DefineInterfaces(program.info);
227 DefineLocalMemory(program); 346 DefineLocalMemory(program);
228 DefineSharedMemory(program); 347 DefineSharedMemory(program);
348 DefineSharedMemoryFunctions(program);
229 DefineConstantBuffers(program.info, binding); 349 DefineConstantBuffers(program.info, binding);
230 DefineStorageBuffers(program.info, binding); 350 DefineStorageBuffers(program.info, binding);
231 DefineTextureBuffers(program.info, binding); 351 DefineTextureBuffers(program.info, binding);
@@ -263,56 +383,6 @@ Id EmitContext::Def(const IR::Value& value) {
263 } 383 }
264} 384}
265 385
266Id EmitContext::CasLoop(Id function, CasPointerType pointer_type, Id value_type) {
267 const Id loop_header{OpLabel()};
268 const Id continue_block{OpLabel()};
269 const Id merge_block{OpLabel()};
270 const Id storage_type{pointer_type == CasPointerType::Shared ? shared_memory_u32_type
271 : storage_memory_u32};
272 const Id func_type{TypeFunction(value_type, U32[1], value_type, storage_type)};
273 const Id func{OpFunction(value_type, spv::FunctionControlMask::MaskNone, func_type)};
274 const Id index{OpFunctionParameter(U32[1])};
275 const Id op_b{OpFunctionParameter(value_type)};
276 const Id base{OpFunctionParameter(storage_type)};
277 AddLabel();
278 const Id one{Constant(U32[1], 1)};
279 OpBranch(loop_header);
280 AddLabel(loop_header);
281 OpLoopMerge(merge_block, continue_block, spv::LoopControlMask::MaskNone);
282 OpBranch(continue_block);
283
284 AddLabel(continue_block);
285 const Id word_pointer{pointer_type == CasPointerType::Shared
286 ? OpAccessChain(shared_u32, base, index)
287 : OpAccessChain(storage_u32, base, u32_zero_value, index)};
288 if (value_type.value == F32[2].value) {
289 const Id u32_value{OpLoad(U32[1], word_pointer)};
290 const Id value{OpUnpackHalf2x16(F32[2], u32_value)};
291 const Id new_value{OpFunctionCall(value_type, function, value, op_b)};
292 const Id u32_new_value{OpPackHalf2x16(U32[1], new_value)};
293 const Id atomic_res{OpAtomicCompareExchange(U32[1], word_pointer, one, u32_zero_value,
294 u32_zero_value, u32_new_value, u32_value)};
295 const Id success{OpIEqual(U1, atomic_res, u32_value)};
296 OpBranchConditional(success, merge_block, loop_header);
297
298 AddLabel(merge_block);
299 OpReturnValue(OpUnpackHalf2x16(F32[2], atomic_res));
300 } else {
301 const Id value{OpLoad(U32[1], word_pointer)};
302 const Id new_value{OpBitcast(
303 U32[1], OpFunctionCall(value_type, function, OpBitcast(value_type, value), op_b))};
304 const Id atomic_res{OpAtomicCompareExchange(U32[1], word_pointer, one, u32_zero_value,
305 u32_zero_value, new_value, value)};
306 const Id success{OpIEqual(U1, atomic_res, value)};
307 OpBranchConditional(success, merge_block, loop_header);
308
309 AddLabel(merge_block);
310 OpReturnValue(OpBitcast(value_type, atomic_res));
311 }
312 OpFunctionEnd();
313 return func;
314}
315
316void EmitContext::DefineCommonTypes(const Info& info) { 386void EmitContext::DefineCommonTypes(const Info& info) {
317 void_id = TypeVoid(); 387 void_id = TypeVoid();
318 388
@@ -397,27 +467,31 @@ void EmitContext::DefineSharedMemory(const IR::Program& program) {
397 Decorate(variable, spv::Decoration::Aliased); 467 Decorate(variable, spv::Decoration::Aliased);
398 interfaces.push_back(variable); 468 interfaces.push_back(variable);
399 469
400 return std::make_pair(variable, element_pointer); 470 return std::make_tuple(variable, element_pointer, pointer);
401 }}; 471 }};
402 if (profile.support_explicit_workgroup_layout) { 472 if (profile.support_explicit_workgroup_layout) {
403 AddExtension("SPV_KHR_workgroup_memory_explicit_layout"); 473 AddExtension("SPV_KHR_workgroup_memory_explicit_layout");
404 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayoutKHR); 474 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayoutKHR);
405 if (program.info.uses_int8) { 475 if (program.info.uses_int8) {
406 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR); 476 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
407 std::tie(shared_memory_u8, shared_u8) = make(U8, 1); 477 std::tie(shared_memory_u8, shared_u8, std::ignore) = make(U8, 1);
408 } 478 }
409 if (program.info.uses_int16) { 479 if (program.info.uses_int16) {
410 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR); 480 AddCapability(spv::Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
411 std::tie(shared_memory_u16, shared_u16) = make(U16, 2); 481 std::tie(shared_memory_u16, shared_u16, std::ignore) = make(U16, 2);
482 }
483 if (program.info.uses_int64) {
484 std::tie(shared_memory_u64, shared_u64, std::ignore) = make(U64, 8);
412 } 485 }
413 std::tie(shared_memory_u32, shared_u32) = make(U32[1], 4); 486 std::tie(shared_memory_u32, shared_u32, shared_memory_u32_type) = make(U32[1], 4);
414 std::tie(shared_memory_u32x2, shared_u32x2) = make(U32[2], 8); 487 std::tie(shared_memory_u32x2, shared_u32x2, std::ignore) = make(U32[2], 8);
415 std::tie(shared_memory_u32x4, shared_u32x4) = make(U32[4], 16); 488 std::tie(shared_memory_u32x4, shared_u32x4, std::ignore) = make(U32[4], 16);
416 return; 489 return;
417 } 490 }
418 const u32 num_elements{Common::DivCeil(program.shared_memory_size, 4U)}; 491 const u32 num_elements{Common::DivCeil(program.shared_memory_size, 4U)};
419 const Id type{TypeArray(U32[1], Constant(U32[1], num_elements))}; 492 const Id type{TypeArray(U32[1], Constant(U32[1], num_elements))};
420 shared_memory_u32_type = TypePointer(spv::StorageClass::Workgroup, type); 493 shared_memory_u32_type = TypePointer(spv::StorageClass::Workgroup, type);
494
421 shared_u32 = TypePointer(spv::StorageClass::Workgroup, U32[1]); 495 shared_u32 = TypePointer(spv::StorageClass::Workgroup, U32[1]);
422 shared_memory_u32 = AddGlobalVariable(shared_memory_u32_type, spv::StorageClass::Workgroup); 496 shared_memory_u32 = AddGlobalVariable(shared_memory_u32_type, spv::StorageClass::Workgroup);
423 interfaces.push_back(shared_memory_u32); 497 interfaces.push_back(shared_memory_u32);
@@ -463,13 +537,16 @@ void EmitContext::DefineSharedMemory(const IR::Program& program) {
463 if (program.info.uses_int16) { 537 if (program.info.uses_int16) {
464 shared_store_u16_func = make_function(16, 16); 538 shared_store_u16_func = make_function(16, 16);
465 } 539 }
540}
541
542void EmitContext::DefineSharedMemoryFunctions(const IR::Program& program) {
466 if (program.info.uses_shared_increment) { 543 if (program.info.uses_shared_increment) {
467 const Id inc_func{CasFunction(*this, CasFunctionType::Increment, U32[1])}; 544 increment_cas_shared = CasLoop(*this, Operation::Increment, shared_memory_u32_type,
468 increment_cas_shared = CasLoop(inc_func, CasPointerType::Shared, U32[1]); 545 shared_u32, U32[1], U32[1], spv::Scope::Workgroup);
469 } 546 }
470 if (program.info.uses_shared_decrement) { 547 if (program.info.uses_shared_decrement) {
471 const Id dec_func{CasFunction(*this, CasFunctionType::Decrement, U32[1])}; 548 decrement_cas_shared = CasLoop(*this, Operation::Decrement, shared_memory_u32_type,
472 decrement_cas_shared = CasLoop(dec_func, CasPointerType::Shared, U32[1]); 549 shared_u32, U32[1], U32[1], spv::Scope::Workgroup);
473 } 550 }
474} 551}
475 552
@@ -628,21 +705,24 @@ void EmitContext::DefineConstantBuffers(const Info& info, u32& binding) {
628 return; 705 return;
629 } 706 }
630 if (True(info.used_constant_buffer_types & IR::Type::U8)) { 707 if (True(info.used_constant_buffer_types & IR::Type::U8)) {
631 DefineConstantBuffers(info, &UniformDefinitions::U8, binding, U8, 'u', sizeof(u8)); 708 DefineConstBuffers(*this, info, &UniformDefinitions::U8, binding, U8, 'u', sizeof(u8));
632 DefineConstantBuffers(info, &UniformDefinitions::S8, binding, S8, 's', sizeof(s8)); 709 DefineConstBuffers(*this, info, &UniformDefinitions::S8, binding, S8, 's', sizeof(s8));
633 } 710 }
634 if (True(info.used_constant_buffer_types & IR::Type::U16)) { 711 if (True(info.used_constant_buffer_types & IR::Type::U16)) {
635 DefineConstantBuffers(info, &UniformDefinitions::U16, binding, U16, 'u', sizeof(u16)); 712 DefineConstBuffers(*this, info, &UniformDefinitions::U16, binding, U16, 'u', sizeof(u16));
636 DefineConstantBuffers(info, &UniformDefinitions::S16, binding, S16, 's', sizeof(s16)); 713 DefineConstBuffers(*this, info, &UniformDefinitions::S16, binding, S16, 's', sizeof(s16));
637 } 714 }
638 if (True(info.used_constant_buffer_types & IR::Type::U32)) { 715 if (True(info.used_constant_buffer_types & IR::Type::U32)) {
639 DefineConstantBuffers(info, &UniformDefinitions::U32, binding, U32[1], 'u', sizeof(u32)); 716 DefineConstBuffers(*this, info, &UniformDefinitions::U32, binding, U32[1], 'u',
717 sizeof(u32));
640 } 718 }
641 if (True(info.used_constant_buffer_types & IR::Type::F32)) { 719 if (True(info.used_constant_buffer_types & IR::Type::F32)) {
642 DefineConstantBuffers(info, &UniformDefinitions::F32, binding, F32[1], 'f', sizeof(f32)); 720 DefineConstBuffers(*this, info, &UniformDefinitions::F32, binding, F32[1], 'f',
721 sizeof(f32));
643 } 722 }
644 if (True(info.used_constant_buffer_types & IR::Type::U32x2)) { 723 if (True(info.used_constant_buffer_types & IR::Type::U32x2)) {
645 DefineConstantBuffers(info, &UniformDefinitions::U32x2, binding, U32[2], 'u', sizeof(u64)); 724 DefineConstBuffers(*this, info, &UniformDefinitions::U32x2, binding, U32[2], 'u',
725 sizeof(u32[2]));
646 } 726 }
647 for (const ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) { 727 for (const ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) {
648 binding += desc.count; 728 binding += desc.count;
@@ -655,75 +735,83 @@ void EmitContext::DefineStorageBuffers(const Info& info, u32& binding) {
655 } 735 }
656 AddExtension("SPV_KHR_storage_buffer_storage_class"); 736 AddExtension("SPV_KHR_storage_buffer_storage_class");
657 737
658 const Id array_type{TypeRuntimeArray(U32[1])}; 738 if (True(info.used_storage_buffer_types & IR::Type::U8)) {
659 Decorate(array_type, spv::Decoration::ArrayStride, 4U); 739 DefineSsbos(*this, storage_types.U8, &StorageDefinitions::U8, info, binding, U8,
660 740 sizeof(u8));
661 const Id struct_type{TypeStruct(array_type)}; 741 DefineSsbos(*this, storage_types.S8, &StorageDefinitions::S8, info, binding, S8,
662 Name(struct_type, "ssbo_block"); 742 sizeof(u8));
663 Decorate(struct_type, spv::Decoration::Block); 743 }
664 MemberName(struct_type, 0, "data"); 744 if (True(info.used_storage_buffer_types & IR::Type::U16)) {
665 MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U); 745 DefineSsbos(*this, storage_types.U16, &StorageDefinitions::U16, info, binding, U16,
666 746 sizeof(u16));
667 storage_memory_u32 = TypePointer(spv::StorageClass::StorageBuffer, struct_type); 747 DefineSsbos(*this, storage_types.S16, &StorageDefinitions::S16, info, binding, S16,
668 storage_u32 = TypePointer(spv::StorageClass::StorageBuffer, U32[1]); 748 sizeof(u16));
669 749 }
670 u32 index{}; 750 if (True(info.used_storage_buffer_types & IR::Type::U32)) {
751 DefineSsbos(*this, storage_types.U32, &StorageDefinitions::U32, info, binding, U32[1],
752 sizeof(u32));
753 }
754 if (True(info.used_storage_buffer_types & IR::Type::F32)) {
755 DefineSsbos(*this, storage_types.F32, &StorageDefinitions::F32, info, binding, F32[1],
756 sizeof(f32));
757 }
758 if (True(info.used_storage_buffer_types & IR::Type::U64)) {
759 DefineSsbos(*this, storage_types.U64, &StorageDefinitions::U64, info, binding, U64,
760 sizeof(u64));
761 }
762 if (True(info.used_storage_buffer_types & IR::Type::U32x2)) {
763 DefineSsbos(*this, storage_types.U32x2, &StorageDefinitions::U32x2, info, binding, U32[2],
764 sizeof(u32[2]));
765 }
766 if (True(info.used_storage_buffer_types & IR::Type::U32x4)) {
767 DefineSsbos(*this, storage_types.U32x4, &StorageDefinitions::U32x4, info, binding, U32[4],
768 sizeof(u32[4]));
769 }
671 for (const StorageBufferDescriptor& desc : info.storage_buffers_descriptors) { 770 for (const StorageBufferDescriptor& desc : info.storage_buffers_descriptors) {
672 const Id id{AddGlobalVariable(storage_memory_u32, spv::StorageClass::StorageBuffer)};
673 Decorate(id, spv::Decoration::Binding, binding);
674 Decorate(id, spv::Decoration::DescriptorSet, 0U);
675 Name(id, fmt::format("ssbo{}", index));
676 if (profile.supported_spirv >= 0x00010400) {
677 interfaces.push_back(id);
678 }
679 std::fill_n(ssbos.data() + index, desc.count, id);
680 index += desc.count;
681 binding += desc.count; 771 binding += desc.count;
682 } 772 }
683 if (info.uses_global_increment) { 773 const bool needs_function{
774 info.uses_global_increment || info.uses_global_decrement || info.uses_atomic_f32_add ||
775 info.uses_atomic_f16x2_add || info.uses_atomic_f16x2_min || info.uses_atomic_f16x2_max ||
776 info.uses_atomic_f32x2_add || info.uses_atomic_f32x2_min || info.uses_atomic_f32x2_max};
777 if (needs_function) {
684 AddCapability(spv::Capability::VariablePointersStorageBuffer); 778 AddCapability(spv::Capability::VariablePointersStorageBuffer);
685 const Id inc_func{CasFunction(*this, CasFunctionType::Increment, U32[1])}; 779 }
686 increment_cas_ssbo = CasLoop(inc_func, CasPointerType::Ssbo, U32[1]); 780 if (info.uses_global_increment) {
781 increment_cas_ssbo = CasLoop(*this, Operation::Increment, storage_types.U32.array,
782 storage_types.U32.element, U32[1], U32[1], spv::Scope::Device);
687 } 783 }
688 if (info.uses_global_decrement) { 784 if (info.uses_global_decrement) {
689 AddCapability(spv::Capability::VariablePointersStorageBuffer); 785 decrement_cas_ssbo = CasLoop(*this, Operation::Decrement, storage_types.U32.array,
690 const Id dec_func{CasFunction(*this, CasFunctionType::Decrement, U32[1])}; 786 storage_types.U32.element, U32[1], U32[1], spv::Scope::Device);
691 decrement_cas_ssbo = CasLoop(dec_func, CasPointerType::Ssbo, U32[1]);
692 } 787 }
693 if (info.uses_atomic_f32_add) { 788 if (info.uses_atomic_f32_add) {
694 AddCapability(spv::Capability::VariablePointersStorageBuffer); 789 f32_add_cas = CasLoop(*this, Operation::FPAdd, storage_types.U32.array,
695 const Id add_func{CasFunction(*this, CasFunctionType::FPAdd, F32[1])}; 790 storage_types.U32.element, F32[1], U32[1], spv::Scope::Device);
696 f32_add_cas = CasLoop(add_func, CasPointerType::Ssbo, F32[1]);
697 } 791 }
698 if (info.uses_atomic_f16x2_add) { 792 if (info.uses_atomic_f16x2_add) {
699 AddCapability(spv::Capability::VariablePointersStorageBuffer); 793 f16x2_add_cas = CasLoop(*this, Operation::FPAdd, storage_types.U32.array,
700 const Id add_func{CasFunction(*this, CasFunctionType::FPAdd, F16[2])}; 794 storage_types.U32.element, F16[2], F16[2], spv::Scope::Device);
701 f16x2_add_cas = CasLoop(add_func, CasPointerType::Ssbo, F16[2]);
702 } 795 }
703 if (info.uses_atomic_f16x2_min) { 796 if (info.uses_atomic_f16x2_min) {
704 AddCapability(spv::Capability::VariablePointersStorageBuffer); 797 f16x2_min_cas = CasLoop(*this, Operation::FPMin, storage_types.U32.array,
705 const Id func{CasFunction(*this, CasFunctionType::FPMin, F16[2])}; 798 storage_types.U32.element, F16[2], F16[2], spv::Scope::Device);
706 f16x2_min_cas = CasLoop(func, CasPointerType::Ssbo, F16[2]);
707 } 799 }
708 if (info.uses_atomic_f16x2_max) { 800 if (info.uses_atomic_f16x2_max) {
709 AddCapability(spv::Capability::VariablePointersStorageBuffer); 801 f16x2_max_cas = CasLoop(*this, Operation::FPMax, storage_types.U32.array,
710 const Id func{CasFunction(*this, CasFunctionType::FPMax, F16[2])}; 802 storage_types.U32.element, F16[2], F16[2], spv::Scope::Device);
711 f16x2_max_cas = CasLoop(func, CasPointerType::Ssbo, F16[2]);
712 } 803 }
713 if (info.uses_atomic_f32x2_add) { 804 if (info.uses_atomic_f32x2_add) {
714 AddCapability(spv::Capability::VariablePointersStorageBuffer); 805 f32x2_add_cas = CasLoop(*this, Operation::FPAdd, storage_types.U32.array,
715 const Id add_func{CasFunction(*this, CasFunctionType::FPAdd, F32[2])}; 806 storage_types.U32.element, F32[2], F32[2], spv::Scope::Device);
716 f32x2_add_cas = CasLoop(add_func, CasPointerType::Ssbo, F32[2]);
717 } 807 }
718 if (info.uses_atomic_f32x2_min) { 808 if (info.uses_atomic_f32x2_min) {
719 AddCapability(spv::Capability::VariablePointersStorageBuffer); 809 f32x2_min_cas = CasLoop(*this, Operation::FPMin, storage_types.U32.array,
720 const Id func{CasFunction(*this, CasFunctionType::FPMin, F32[2])}; 810 storage_types.U32.element, F32[2], F32[2], spv::Scope::Device);
721 f32x2_min_cas = CasLoop(func, CasPointerType::Ssbo, F32[2]);
722 } 811 }
723 if (info.uses_atomic_f32x2_max) { 812 if (info.uses_atomic_f32x2_max) {
724 AddCapability(spv::Capability::VariablePointersStorageBuffer); 813 f32x2_max_cas = CasLoop(*this, Operation::FPMax, storage_types.U32.array,
725 const Id func{CasFunction(*this, CasFunctionType::FPMax, F32[2])}; 814 storage_types.U32.element, F32[2], F32[2], spv::Scope::Device);
726 f32x2_max_cas = CasLoop(func, CasPointerType::Ssbo, F32[2]);
727 } 815 }
728} 816}
729 817
@@ -903,36 +991,6 @@ void EmitContext::DefineInputs(const Info& info) {
903 } 991 }
904} 992}
905 993
906void EmitContext::DefineConstantBuffers(const Info& info, Id UniformDefinitions::*member_type,
907 u32 binding, Id type, char type_char, u32 element_size) {
908 const Id array_type{TypeArray(type, Constant(U32[1], 65536U / element_size))};
909 Decorate(array_type, spv::Decoration::ArrayStride, element_size);
910
911 const Id struct_type{TypeStruct(array_type)};
912 Name(struct_type, fmt::format("cbuf_block_{}{}", type_char, element_size * CHAR_BIT));
913 Decorate(struct_type, spv::Decoration::Block);
914 MemberName(struct_type, 0, "data");
915 MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U);
916
917 const Id struct_pointer_type{TypePointer(spv::StorageClass::Uniform, struct_type)};
918 const Id uniform_type{TypePointer(spv::StorageClass::Uniform, type)};
919 uniform_types.*member_type = uniform_type;
920
921 for (const ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) {
922 const Id id{AddGlobalVariable(struct_pointer_type, spv::StorageClass::Uniform)};
923 Decorate(id, spv::Decoration::Binding, binding);
924 Decorate(id, spv::Decoration::DescriptorSet, 0U);
925 Name(id, fmt::format("c{}", desc.index));
926 for (size_t i = 0; i < desc.count; ++i) {
927 cbufs[desc.index + i].*member_type = id;
928 }
929 if (profile.supported_spirv >= 0x00010400) {
930 interfaces.push_back(id);
931 }
932 binding += desc.count;
933 }
934}
935
936void EmitContext::DefineOutputs(const Info& info) { 994void EmitContext::DefineOutputs(const Info& info) {
937 if (info.stores_position || stage == Stage::VertexB) { 995 if (info.stores_position || stage == Stage::VertexB) {
938 output_position = DefineOutput(*this, F32[4], spv::BuiltIn::Position); 996 output_position = DefineOutput(*this, F32[4], spv::BuiltIn::Position);
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h
index 98a9140bf..cade1fa0d 100644
--- a/src/shader_recompiler/backend/spirv/emit_context.h
+++ b/src/shader_recompiler/backend/spirv/emit_context.h
@@ -50,6 +50,35 @@ struct UniformDefinitions {
50 Id U32x2{}; 50 Id U32x2{};
51}; 51};
52 52
53struct StorageTypeDefinition {
54 Id array{};
55 Id element{};
56};
57
58struct StorageTypeDefinitions {
59 StorageTypeDefinition U8{};
60 StorageTypeDefinition S8{};
61 StorageTypeDefinition U16{};
62 StorageTypeDefinition S16{};
63 StorageTypeDefinition U32{};
64 StorageTypeDefinition U64{};
65 StorageTypeDefinition F32{};
66 StorageTypeDefinition U32x2{};
67 StorageTypeDefinition U32x4{};
68};
69
70struct StorageDefinitions {
71 Id U8{};
72 Id S8{};
73 Id U16{};
74 Id S16{};
75 Id U32{};
76 Id F32{};
77 Id U64{};
78 Id U32x2{};
79 Id U32x4{};
80};
81
53class EmitContext final : public Sirit::Module { 82class EmitContext final : public Sirit::Module {
54public: 83public:
55 explicit EmitContext(const Profile& profile, IR::Program& program, u32& binding); 84 explicit EmitContext(const Profile& profile, IR::Program& program, u32& binding);
@@ -78,12 +107,14 @@ public:
78 Id f32_zero_value{}; 107 Id f32_zero_value{};
79 108
80 UniformDefinitions uniform_types; 109 UniformDefinitions uniform_types;
110 StorageTypeDefinitions storage_types;
81 111
82 Id private_u32{}; 112 Id private_u32{};
83 113
84 Id shared_u8{}; 114 Id shared_u8{};
85 Id shared_u16{}; 115 Id shared_u16{};
86 Id shared_u32{}; 116 Id shared_u32{};
117 Id shared_u64{};
87 Id shared_u32x2{}; 118 Id shared_u32x2{};
88 Id shared_u32x4{}; 119 Id shared_u32x4{};
89 120
@@ -93,14 +124,11 @@ public:
93 124
94 Id output_f32{}; 125 Id output_f32{};
95 126
96 Id storage_u32{};
97 Id storage_memory_u32{};
98
99 Id image_buffer_type{}; 127 Id image_buffer_type{};
100 Id sampled_texture_buffer_type{}; 128 Id sampled_texture_buffer_type{};
101 129
102 std::array<UniformDefinitions, Info::MAX_CBUFS> cbufs{}; 130 std::array<UniformDefinitions, Info::MAX_CBUFS> cbufs{};
103 std::array<Id, Info::MAX_SSBOS> ssbos{}; 131 std::array<StorageDefinitions, Info::MAX_SSBOS> ssbos{};
104 std::vector<Id> texture_buffers; 132 std::vector<Id> texture_buffers;
105 std::vector<TextureDefinition> textures; 133 std::vector<TextureDefinition> textures;
106 std::vector<ImageDefinition> images; 134 std::vector<ImageDefinition> images;
@@ -136,8 +164,10 @@ public:
136 Id shared_memory_u8{}; 164 Id shared_memory_u8{};
137 Id shared_memory_u16{}; 165 Id shared_memory_u16{};
138 Id shared_memory_u32{}; 166 Id shared_memory_u32{};
167 Id shared_memory_u64{};
139 Id shared_memory_u32x2{}; 168 Id shared_memory_u32x2{};
140 Id shared_memory_u32x4{}; 169 Id shared_memory_u32x4{};
170
141 Id shared_memory_u32_type{}; 171 Id shared_memory_u32_type{};
142 172
143 Id shared_store_u8_func{}; 173 Id shared_store_u8_func{};
@@ -167,16 +197,12 @@ public:
167 std::vector<Id> interfaces; 197 std::vector<Id> interfaces;
168 198
169private: 199private:
170 enum class CasPointerType {
171 Shared,
172 Ssbo,
173 };
174
175 void DefineCommonTypes(const Info& info); 200 void DefineCommonTypes(const Info& info);
176 void DefineCommonConstants(); 201 void DefineCommonConstants();
177 void DefineInterfaces(const Info& info); 202 void DefineInterfaces(const Info& info);
178 void DefineLocalMemory(const IR::Program& program); 203 void DefineLocalMemory(const IR::Program& program);
179 void DefineSharedMemory(const IR::Program& program); 204 void DefineSharedMemory(const IR::Program& program);
205 void DefineSharedMemoryFunctions(const IR::Program& program);
180 void DefineConstantBuffers(const Info& info, u32& binding); 206 void DefineConstantBuffers(const Info& info, u32& binding);
181 void DefineStorageBuffers(const Info& info, u32& binding); 207 void DefineStorageBuffers(const Info& info, u32& binding);
182 void DefineTextureBuffers(const Info& info, u32& binding); 208 void DefineTextureBuffers(const Info& info, u32& binding);
@@ -185,13 +211,8 @@ private:
185 void DefineAttributeMemAccess(const Info& info); 211 void DefineAttributeMemAccess(const Info& info);
186 void DefineLabels(IR::Program& program); 212 void DefineLabels(IR::Program& program);
187 213
188 void DefineConstantBuffers(const Info& info, Id UniformDefinitions::*member_type, u32 binding,
189 Id type, char type_char, u32 element_size);
190
191 void DefineInputs(const Info& info); 214 void DefineInputs(const Info& info);
192 void DefineOutputs(const Info& info); 215 void DefineOutputs(const Info& info);
193
194 [[nodiscard]] Id CasLoop(Id function, CasPointerType pointer_type, Id value_type);
195}; 216};
196 217
197} // namespace Shader::Backend::SPIRV 218} // namespace Shader::Backend::SPIRV
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
index d7c5890ab..61a2018d7 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
@@ -276,7 +276,7 @@ void SetupCapabilities(const Profile& profile, const Info& info, EmitContext& ct
276 ctx.AddCapability(spv::Capability::SubgroupVoteKHR); 276 ctx.AddCapability(spv::Capability::SubgroupVoteKHR);
277 } 277 }
278 } 278 }
279 if (info.uses_64_bit_atomics && profile.support_int64_atomics) { 279 if (info.uses_int64_bit_atomics && profile.support_int64_atomics) {
280 ctx.AddCapability(spv::Capability::Int64Atomics); 280 ctx.AddCapability(spv::Capability::Int64Atomics);
281 } 281 }
282 if (info.uses_typeless_image_reads && profile.support_typeless_image_loads) { 282 if (info.uses_typeless_image_reads && profile.support_typeless_image_loads) {
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h
index c0e1b8833..55b2edba0 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.h
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.h
@@ -89,17 +89,21 @@ void EmitWriteGlobalS16(EmitContext& ctx);
89void EmitWriteGlobal32(EmitContext& ctx); 89void EmitWriteGlobal32(EmitContext& ctx);
90void EmitWriteGlobal64(EmitContext& ctx); 90void EmitWriteGlobal64(EmitContext& ctx);
91void EmitWriteGlobal128(EmitContext& ctx); 91void EmitWriteGlobal128(EmitContext& ctx);
92void EmitLoadStorageU8(EmitContext& ctx); 92Id EmitLoadStorageU8(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset);
93void EmitLoadStorageS8(EmitContext& ctx); 93Id EmitLoadStorageS8(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset);
94void EmitLoadStorageU16(EmitContext& ctx); 94Id EmitLoadStorageU16(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset);
95void EmitLoadStorageS16(EmitContext& ctx); 95Id EmitLoadStorageS16(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset);
96Id EmitLoadStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset); 96Id EmitLoadStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset);
97Id EmitLoadStorage64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset); 97Id EmitLoadStorage64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset);
98Id EmitLoadStorage128(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset); 98Id EmitLoadStorage128(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset);
99void EmitWriteStorageU8(EmitContext& ctx); 99void EmitWriteStorageU8(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
100void EmitWriteStorageS8(EmitContext& ctx); 100 Id value);
101void EmitWriteStorageU16(EmitContext& ctx); 101void EmitWriteStorageS8(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
102void EmitWriteStorageS16(EmitContext& ctx); 102 Id value);
103void EmitWriteStorageU16(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
104 Id value);
105void EmitWriteStorageS16(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
106 Id value);
103void EmitWriteStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 107void EmitWriteStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
104 Id value); 108 Id value);
105void EmitWriteStorage64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 109void EmitWriteStorage64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_atomic.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_atomic.cpp
index 03d891419..aab32dc52 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_atomic.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_atomic.cpp
@@ -6,11 +6,12 @@
6 6
7namespace Shader::Backend::SPIRV { 7namespace Shader::Backend::SPIRV {
8namespace { 8namespace {
9 9Id SharedPointer(EmitContext& ctx, Id offset, u32 index_offset = 0) {
10Id GetSharedPointer(EmitContext& ctx, Id offset, u32 index_offset = 0) {
11 const Id shift_id{ctx.Constant(ctx.U32[1], 2U)}; 10 const Id shift_id{ctx.Constant(ctx.U32[1], 2U)};
12 const Id shifted_value{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)}; 11 Id index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)};
13 const Id index{ctx.OpIAdd(ctx.U32[1], shifted_value, ctx.Constant(ctx.U32[1], index_offset))}; 12 if (index_offset > 0) {
13 index = ctx.OpIAdd(ctx.U32[1], index, ctx.Constant(ctx.U32[1], index_offset));
14 }
14 return ctx.profile.support_explicit_workgroup_layout 15 return ctx.profile.support_explicit_workgroup_layout
15 ? ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, ctx.u32_zero_value, index) 16 ? ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, ctx.u32_zero_value, index)
16 : ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, index); 17 : ctx.OpAccessChain(ctx.shared_u32, ctx.shared_memory_u32, index);
@@ -30,340 +31,258 @@ Id StorageIndex(EmitContext& ctx, const IR::Value& offset, size_t element_size)
30 return ctx.OpShiftRightLogical(ctx.U32[1], index, shift_id); 31 return ctx.OpShiftRightLogical(ctx.U32[1], index, shift_id);
31} 32}
32 33
33Id GetStoragePointer(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 34Id StoragePointer(EmitContext& ctx, const StorageTypeDefinition& type_def,
34 u32 index_offset = 0) { 35 Id StorageDefinitions::*member_ptr, const IR::Value& binding,
35 // TODO: Support reinterpreting bindings, guaranteed to be aligned 36 const IR::Value& offset, size_t element_size) {
36 if (!binding.IsImmediate()) { 37 if (!binding.IsImmediate()) {
37 throw NotImplementedException("Dynamic storage buffer indexing"); 38 throw NotImplementedException("Dynamic storage buffer indexing");
38 } 39 }
39 const Id ssbo{ctx.ssbos[binding.U32()]}; 40 const Id ssbo{ctx.ssbos[binding.U32()].*member_ptr};
40 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))}; 41 const Id index{StorageIndex(ctx, offset, element_size)};
41 const Id index{ctx.OpIAdd(ctx.U32[1], base_index, ctx.Constant(ctx.U32[1], index_offset))}; 42 return ctx.OpAccessChain(type_def.element, ssbo, ctx.u32_zero_value, index);
42 return ctx.OpAccessChain(ctx.storage_u32, ssbo, ctx.u32_zero_value, index);
43} 43}
44 44
45std::pair<Id, Id> GetAtomicArgs(EmitContext& ctx) { 45std::pair<Id, Id> AtomicArgs(EmitContext& ctx) {
46 const Id scope{ctx.Constant(ctx.U32[1], static_cast<u32>(spv::Scope::Device))}; 46 const Id scope{ctx.Constant(ctx.U32[1], static_cast<u32>(spv::Scope::Device))};
47 const Id semantics{ctx.u32_zero_value}; 47 const Id semantics{ctx.u32_zero_value};
48 return {scope, semantics}; 48 return {scope, semantics};
49} 49}
50 50
51Id LoadU64(EmitContext& ctx, Id pointer_1, Id pointer_2) { 51Id SharedAtomicU32(EmitContext& ctx, Id offset, Id value,
52 const Id value_1{ctx.OpLoad(ctx.U32[1], pointer_1)}; 52 Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id, Id)) {
53 const Id value_2{ctx.OpLoad(ctx.U32[1], pointer_2)}; 53 const Id pointer{SharedPointer(ctx, offset)};
54 const Id original_composite{ctx.OpCompositeConstruct(ctx.U32[2], value_1, value_2)}; 54 const auto [scope, semantics]{AtomicArgs(ctx)};
55 return ctx.OpBitcast(ctx.U64, original_composite); 55 return (ctx.*atomic_func)(ctx.U32[1], pointer, scope, semantics, value);
56} 56}
57 57
58void StoreResult(EmitContext& ctx, Id pointer_1, Id pointer_2, Id result) { 58Id StorageAtomicU32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, Id value,
59 const Id composite{ctx.OpBitcast(ctx.U32[2], result)}; 59 Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id, Id)) {
60 ctx.OpStore(pointer_1, ctx.OpCompositeExtract(ctx.U32[1], composite, 0)); 60 const Id pointer{StoragePointer(ctx, ctx.storage_types.U32, &StorageDefinitions::U32, binding,
61 ctx.OpStore(pointer_2, ctx.OpCompositeExtract(ctx.U32[1], composite, 1)); 61 offset, sizeof(u32))};
62 const auto [scope, semantics]{AtomicArgs(ctx)};
63 return (ctx.*atomic_func)(ctx.U32[1], pointer, scope, semantics, value);
64}
65
66Id StorageAtomicU64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, Id value,
67 Id (Sirit::Module::*atomic_func)(Id, Id, Id, Id, Id),
68 Id (Sirit::Module::*non_atomic_func)(Id, Id, Id)) {
69 if (ctx.profile.support_int64_atomics) {
70 const Id pointer{StoragePointer(ctx, ctx.storage_types.U64, &StorageDefinitions::U64,
71 binding, offset, sizeof(u64))};
72 const auto [scope, semantics]{AtomicArgs(ctx)};
73 return (ctx.*atomic_func)(ctx.U64, pointer, scope, semantics, value);
74 }
75 // LOG_WARNING(..., "Int64 Atomics not supported, fallback to non-atomic");
76 const Id pointer{StoragePointer(ctx, ctx.storage_types.U32x2, &StorageDefinitions::U32x2,
77 binding, offset, sizeof(u32[2]))};
78 const Id original_value{ctx.OpBitcast(ctx.U64, ctx.OpLoad(ctx.U32[2], pointer))};
79 const Id result{(ctx.*non_atomic_func)(ctx.U64, value, original_value)};
80 ctx.OpStore(pointer, result);
81 return original_value;
62} 82}
63} // Anonymous namespace 83} // Anonymous namespace
64 84
65Id EmitSharedAtomicIAdd32(EmitContext& ctx, Id pointer_offset, Id value) { 85Id EmitSharedAtomicIAdd32(EmitContext& ctx, Id offset, Id value) {
66 const Id pointer{GetSharedPointer(ctx, pointer_offset)}; 86 return SharedAtomicU32(ctx, offset, value, &Sirit::Module::OpAtomicIAdd);
67 const auto [scope, semantics]{GetAtomicArgs(ctx)};
68 return ctx.OpAtomicIAdd(ctx.U32[1], pointer, scope, semantics, value);
69} 87}
70 88
71Id EmitSharedAtomicSMin32(EmitContext& ctx, Id pointer_offset, Id value) { 89Id EmitSharedAtomicSMin32(EmitContext& ctx, Id offset, Id value) {
72 const Id pointer{GetSharedPointer(ctx, pointer_offset)}; 90 return SharedAtomicU32(ctx, offset, value, &Sirit::Module::OpAtomicSMin);
73 const auto [scope, semantics]{GetAtomicArgs(ctx)};
74 return ctx.OpAtomicSMin(ctx.U32[1], pointer, scope, semantics, value);
75} 91}
76 92
77Id EmitSharedAtomicUMin32(EmitContext& ctx, Id pointer_offset, Id value) { 93Id EmitSharedAtomicUMin32(EmitContext& ctx, Id offset, Id value) {
78 const Id pointer{GetSharedPointer(ctx, pointer_offset)}; 94 return SharedAtomicU32(ctx, offset, value, &Sirit::Module::OpAtomicUMin);
79 const auto [scope, semantics]{GetAtomicArgs(ctx)};
80 return ctx.OpAtomicUMin(ctx.U32[1], pointer, scope, semantics, value);
81} 95}
82 96
83Id EmitSharedAtomicSMax32(EmitContext& ctx, Id pointer_offset, Id value) { 97Id EmitSharedAtomicSMax32(EmitContext& ctx, Id offset, Id value) {
84 const Id pointer{GetSharedPointer(ctx, pointer_offset)}; 98 return SharedAtomicU32(ctx, offset, value, &Sirit::Module::OpAtomicSMax);
85 const auto [scope, semantics]{GetAtomicArgs(ctx)};
86 return ctx.OpAtomicSMax(ctx.U32[1], pointer, scope, semantics, value);
87} 99}
88 100
89Id EmitSharedAtomicUMax32(EmitContext& ctx, Id pointer_offset, Id value) { 101Id EmitSharedAtomicUMax32(EmitContext& ctx, Id offset, Id value) {
90 const Id pointer{GetSharedPointer(ctx, pointer_offset)}; 102 return SharedAtomicU32(ctx, offset, value, &Sirit::Module::OpAtomicUMax);
91 const auto [scope, semantics]{GetAtomicArgs(ctx)};
92 return ctx.OpAtomicUMax(ctx.U32[1], pointer, scope, semantics, value);
93} 103}
94 104
95Id EmitSharedAtomicInc32(EmitContext& ctx, Id pointer_offset, Id value) { 105Id EmitSharedAtomicInc32(EmitContext& ctx, Id offset, Id value) {
96 const Id shift_id{ctx.Constant(ctx.U32[1], 2U)}; 106 const Id shift_id{ctx.Constant(ctx.U32[1], 2U)};
97 const Id index{ctx.OpShiftRightArithmetic(ctx.U32[1], pointer_offset, shift_id)}; 107 const Id index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)};
98 return ctx.OpFunctionCall(ctx.U32[1], ctx.increment_cas_shared, index, value, 108 return ctx.OpFunctionCall(ctx.U32[1], ctx.increment_cas_shared, index, value);
99 ctx.shared_memory_u32);
100} 109}
101 110
102Id EmitSharedAtomicDec32(EmitContext& ctx, Id pointer_offset, Id value) { 111Id EmitSharedAtomicDec32(EmitContext& ctx, Id offset, Id value) {
103 const Id shift_id{ctx.Constant(ctx.U32[1], 2U)}; 112 const Id shift_id{ctx.Constant(ctx.U32[1], 2U)};
104 const Id index{ctx.OpShiftRightArithmetic(ctx.U32[1], pointer_offset, shift_id)}; 113 const Id index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)};
105 return ctx.OpFunctionCall(ctx.U32[1], ctx.decrement_cas_shared, index, value, 114 return ctx.OpFunctionCall(ctx.U32[1], ctx.decrement_cas_shared, index, value);
106 ctx.shared_memory_u32);
107} 115}
108 116
109Id EmitSharedAtomicAnd32(EmitContext& ctx, Id pointer_offset, Id value) { 117Id EmitSharedAtomicAnd32(EmitContext& ctx, Id offset, Id value) {
110 const Id pointer{GetSharedPointer(ctx, pointer_offset)}; 118 return SharedAtomicU32(ctx, offset, value, &Sirit::Module::OpAtomicAnd);
111 const auto [scope, semantics]{GetAtomicArgs(ctx)};
112 return ctx.OpAtomicAnd(ctx.U32[1], pointer, scope, semantics, value);
113} 119}
114 120
115Id EmitSharedAtomicOr32(EmitContext& ctx, Id pointer_offset, Id value) { 121Id EmitSharedAtomicOr32(EmitContext& ctx, Id offset, Id value) {
116 const Id pointer{GetSharedPointer(ctx, pointer_offset)}; 122 return SharedAtomicU32(ctx, offset, value, &Sirit::Module::OpAtomicOr);
117 const auto [scope, semantics]{GetAtomicArgs(ctx)};
118 return ctx.OpAtomicOr(ctx.U32[1], pointer, scope, semantics, value);
119} 123}
120 124
121Id EmitSharedAtomicXor32(EmitContext& ctx, Id pointer_offset, Id value) { 125Id EmitSharedAtomicXor32(EmitContext& ctx, Id offset, Id value) {
122 const Id pointer{GetSharedPointer(ctx, pointer_offset)}; 126 return SharedAtomicU32(ctx, offset, value, &Sirit::Module::OpAtomicXor);
123 const auto [scope, semantics]{GetAtomicArgs(ctx)};
124 return ctx.OpAtomicXor(ctx.U32[1], pointer, scope, semantics, value);
125} 127}
126 128
127Id EmitSharedAtomicExchange32(EmitContext& ctx, Id pointer_offset, Id value) { 129Id EmitSharedAtomicExchange32(EmitContext& ctx, Id offset, Id value) {
128 const Id pointer{GetSharedPointer(ctx, pointer_offset)}; 130 return SharedAtomicU32(ctx, offset, value, &Sirit::Module::OpAtomicExchange);
129 const auto [scope, semantics]{GetAtomicArgs(ctx)};
130 return ctx.OpAtomicExchange(ctx.U32[1], pointer, scope, semantics, value);
131} 131}
132 132
133Id EmitSharedAtomicExchange64(EmitContext& ctx, Id pointer_offset, Id value) { 133Id EmitSharedAtomicExchange64(EmitContext& ctx, Id offset, Id value) {
134 const Id pointer_1{GetSharedPointer(ctx, pointer_offset)}; 134 if (ctx.profile.support_int64_atomics && ctx.profile.support_explicit_workgroup_layout) {
135 if (ctx.profile.support_int64_atomics) { 135 const Id shift_id{ctx.Constant(ctx.U32[1], 3U)};
136 const auto [scope, semantics]{GetAtomicArgs(ctx)}; 136 const Id index{ctx.OpShiftRightArithmetic(ctx.U32[1], offset, shift_id)};
137 return ctx.OpAtomicExchange(ctx.U64, pointer_1, scope, semantics, value); 137 const Id pointer{
138 ctx.OpAccessChain(ctx.shared_u64, ctx.shared_memory_u64, ctx.u32_zero_value, index)};
139 const auto [scope, semantics]{AtomicArgs(ctx)};
140 return ctx.OpAtomicExchange(ctx.U64, pointer, scope, semantics, value);
138 } 141 }
139 // LOG_WARNING(Render_Vulkan, "Int64 Atomics not supported, fallback to non-atomic"); 142 // LOG_WARNING("Int64 Atomics not supported, fallback to non-atomic");
140 const Id pointer_2{GetSharedPointer(ctx, pointer_offset, 1)}; 143 const Id pointer_1{SharedPointer(ctx, offset, 0)};
141 const Id original_value{LoadU64(ctx, pointer_1, pointer_2)}; 144 const Id pointer_2{SharedPointer(ctx, offset, 1)};
142 StoreResult(ctx, pointer_1, pointer_2, value); 145 const Id value_1{ctx.OpLoad(ctx.U32[1], pointer_1)};
143 return original_value; 146 const Id value_2{ctx.OpLoad(ctx.U32[1], pointer_2)};
147 const Id new_vector{ctx.OpBitcast(ctx.U32[2], value)};
148 ctx.OpStore(pointer_1, ctx.OpCompositeExtract(ctx.U32[1], new_vector, 0U));
149 ctx.OpStore(pointer_2, ctx.OpCompositeExtract(ctx.U32[1], new_vector, 1U));
150 return ctx.OpBitcast(ctx.U64, ctx.OpCompositeConstruct(ctx.U32[2], value_1, value_2));
144} 151}
145 152
146Id EmitStorageAtomicIAdd32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 153Id EmitStorageAtomicIAdd32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
147 Id value) { 154 Id value) {
148 const Id pointer{GetStoragePointer(ctx, binding, offset)}; 155 return StorageAtomicU32(ctx, binding, offset, value, &Sirit::Module::OpAtomicIAdd);
149 const auto [scope, semantics]{GetAtomicArgs(ctx)};
150 return ctx.OpAtomicIAdd(ctx.U32[1], pointer, scope, semantics, value);
151} 156}
152 157
153Id EmitStorageAtomicSMin32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 158Id EmitStorageAtomicSMin32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
154 Id value) { 159 Id value) {
155 const Id pointer{GetStoragePointer(ctx, binding, offset)}; 160 return StorageAtomicU32(ctx, binding, offset, value, &Sirit::Module::OpAtomicSMin);
156 const auto [scope, semantics]{GetAtomicArgs(ctx)};
157 return ctx.OpAtomicSMin(ctx.U32[1], pointer, scope, semantics, value);
158} 161}
159 162
160Id EmitStorageAtomicUMin32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 163Id EmitStorageAtomicUMin32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
161 Id value) { 164 Id value) {
162 const Id pointer{GetStoragePointer(ctx, binding, offset)}; 165 return StorageAtomicU32(ctx, binding, offset, value, &Sirit::Module::OpAtomicUMin);
163 const auto [scope, semantics]{GetAtomicArgs(ctx)};
164 return ctx.OpAtomicUMin(ctx.U32[1], pointer, scope, semantics, value);
165} 166}
166 167
167Id EmitStorageAtomicSMax32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 168Id EmitStorageAtomicSMax32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
168 Id value) { 169 Id value) {
169 const Id pointer{GetStoragePointer(ctx, binding, offset)}; 170 return StorageAtomicU32(ctx, binding, offset, value, &Sirit::Module::OpAtomicSMax);
170 const auto [scope, semantics]{GetAtomicArgs(ctx)};
171 return ctx.OpAtomicSMax(ctx.U32[1], pointer, scope, semantics, value);
172} 171}
173 172
174Id EmitStorageAtomicUMax32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 173Id EmitStorageAtomicUMax32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
175 Id value) { 174 Id value) {
176 const Id pointer{GetStoragePointer(ctx, binding, offset)}; 175 return StorageAtomicU32(ctx, binding, offset, value, &Sirit::Module::OpAtomicUMax);
177 const auto [scope, semantics]{GetAtomicArgs(ctx)};
178 return ctx.OpAtomicUMax(ctx.U32[1], pointer, scope, semantics, value);
179} 176}
180 177
181Id EmitStorageAtomicInc32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 178Id EmitStorageAtomicInc32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
182 Id value) { 179 Id value) {
183 const Id ssbo{ctx.ssbos[binding.U32()]}; 180 const Id ssbo{ctx.ssbos[binding.U32()].U32};
184 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))}; 181 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))};
185 return ctx.OpFunctionCall(ctx.U32[1], ctx.increment_cas_ssbo, base_index, value, ssbo); 182 return ctx.OpFunctionCall(ctx.U32[1], ctx.increment_cas_ssbo, base_index, value, ssbo);
186} 183}
187 184
188Id EmitStorageAtomicDec32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 185Id EmitStorageAtomicDec32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
189 Id value) { 186 Id value) {
190 const Id ssbo{ctx.ssbos[binding.U32()]}; 187 const Id ssbo{ctx.ssbos[binding.U32()].U32};
191 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))}; 188 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))};
192 return ctx.OpFunctionCall(ctx.U32[1], ctx.decrement_cas_ssbo, base_index, value, ssbo); 189 return ctx.OpFunctionCall(ctx.U32[1], ctx.decrement_cas_ssbo, base_index, value, ssbo);
193} 190}
194 191
195Id EmitStorageAtomicAnd32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 192Id EmitStorageAtomicAnd32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
196 Id value) { 193 Id value) {
197 const Id pointer{GetStoragePointer(ctx, binding, offset)}; 194 return StorageAtomicU32(ctx, binding, offset, value, &Sirit::Module::OpAtomicAnd);
198 const auto [scope, semantics]{GetAtomicArgs(ctx)};
199 return ctx.OpAtomicAnd(ctx.U32[1], pointer, scope, semantics, value);
200} 195}
201 196
202Id EmitStorageAtomicOr32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 197Id EmitStorageAtomicOr32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
203 Id value) { 198 Id value) {
204 const Id pointer{GetStoragePointer(ctx, binding, offset)}; 199 return StorageAtomicU32(ctx, binding, offset, value, &Sirit::Module::OpAtomicOr);
205 const auto [scope, semantics]{GetAtomicArgs(ctx)};
206 return ctx.OpAtomicOr(ctx.U32[1], pointer, scope, semantics, value);
207} 200}
208 201
209Id EmitStorageAtomicXor32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 202Id EmitStorageAtomicXor32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
210 Id value) { 203 Id value) {
211 const Id pointer{GetStoragePointer(ctx, binding, offset)}; 204 return StorageAtomicU32(ctx, binding, offset, value, &Sirit::Module::OpAtomicXor);
212 const auto [scope, semantics]{GetAtomicArgs(ctx)};
213 return ctx.OpAtomicXor(ctx.U32[1], pointer, scope, semantics, value);
214} 205}
215 206
216Id EmitStorageAtomicExchange32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 207Id EmitStorageAtomicExchange32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
217 Id value) { 208 Id value) {
218 const Id pointer{GetStoragePointer(ctx, binding, offset)}; 209 return StorageAtomicU32(ctx, binding, offset, value, &Sirit::Module::OpAtomicExchange);
219 const auto [scope, semantics]{GetAtomicArgs(ctx)};
220 return ctx.OpAtomicExchange(ctx.U32[1], pointer, scope, semantics, value);
221} 210}
222 211
223Id EmitStorageAtomicIAdd64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 212Id EmitStorageAtomicIAdd64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
224 Id value) { 213 Id value) {
225 const Id pointer_1{GetStoragePointer(ctx, binding, offset)}; 214 return StorageAtomicU64(ctx, binding, offset, value, &Sirit::Module::OpAtomicIAdd,
226 if (ctx.profile.support_int64_atomics) { 215 &Sirit::Module::OpIAdd);
227 const auto [scope, semantics]{GetAtomicArgs(ctx)};
228 return ctx.OpAtomicIAdd(ctx.U64, pointer_1, scope, semantics, value);
229 }
230 // LOG_WARNING(Render_Vulkan, "Int64 Atomics not supported, fallback to non-atomic");
231 const Id pointer_2{GetStoragePointer(ctx, binding, offset, 1)};
232 const Id original_value{LoadU64(ctx, pointer_1, pointer_2)};
233 const Id result{ctx.OpIAdd(ctx.U64, value, original_value)};
234 StoreResult(ctx, pointer_1, pointer_2, result);
235 return original_value;
236} 216}
237 217
238Id EmitStorageAtomicSMin64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 218Id EmitStorageAtomicSMin64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
239 Id value) { 219 Id value) {
240 const Id pointer_1{GetStoragePointer(ctx, binding, offset)}; 220 return StorageAtomicU64(ctx, binding, offset, value, &Sirit::Module::OpAtomicSMin,
241 if (ctx.profile.support_int64_atomics) { 221 &Sirit::Module::OpSMin);
242 const auto [scope, semantics]{GetAtomicArgs(ctx)};
243 return ctx.OpAtomicSMin(ctx.U64, pointer_1, scope, semantics, value);
244 }
245 // LOG_WARNING(Render_Vulkan, "Int64 Atomics not supported, fallback to non-atomic");
246 const Id pointer_2{GetStoragePointer(ctx, binding, offset, 1)};
247 const Id original_value{LoadU64(ctx, pointer_1, pointer_2)};
248 const Id result{ctx.OpSMin(ctx.U64, value, original_value)};
249 StoreResult(ctx, pointer_1, pointer_2, result);
250 return original_value;
251} 222}
252 223
253Id EmitStorageAtomicUMin64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 224Id EmitStorageAtomicUMin64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
254 Id value) { 225 Id value) {
255 const Id pointer_1{GetStoragePointer(ctx, binding, offset)}; 226 return StorageAtomicU64(ctx, binding, offset, value, &Sirit::Module::OpAtomicUMin,
256 if (ctx.profile.support_int64_atomics) { 227 &Sirit::Module::OpUMin);
257 const auto [scope, semantics]{GetAtomicArgs(ctx)};
258 return ctx.OpAtomicUMin(ctx.U64, pointer_1, scope, semantics, value);
259 }
260 // LOG_WARNING(Render_Vulkan, "Int64 Atomics not supported, fallback to non-atomic");
261 const Id pointer_2{GetStoragePointer(ctx, binding, offset, 1)};
262 const Id original_value{LoadU64(ctx, pointer_1, pointer_2)};
263 const Id result{ctx.OpUMin(ctx.U64, value, original_value)};
264 StoreResult(ctx, pointer_1, pointer_2, result);
265 return original_value;
266} 228}
267 229
268Id EmitStorageAtomicSMax64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 230Id EmitStorageAtomicSMax64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
269 Id value) { 231 Id value) {
270 const Id pointer_1{GetStoragePointer(ctx, binding, offset)}; 232 return StorageAtomicU64(ctx, binding, offset, value, &Sirit::Module::OpAtomicSMax,
271 if (ctx.profile.support_int64_atomics) { 233 &Sirit::Module::OpSMax);
272 const auto [scope, semantics]{GetAtomicArgs(ctx)};
273 return ctx.OpAtomicSMax(ctx.U64, pointer_1, scope, semantics, value);
274 }
275 // LOG_WARNING(Render_Vulkan, "Int64 Atomics not supported, fallback to non-atomic");
276 const Id pointer_2{GetStoragePointer(ctx, binding, offset, 1)};
277 const Id original_value{LoadU64(ctx, pointer_1, pointer_2)};
278 const Id result{ctx.OpSMax(ctx.U64, value, original_value)};
279 StoreResult(ctx, pointer_1, pointer_2, result);
280 return original_value;
281} 234}
282 235
283Id EmitStorageAtomicUMax64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 236Id EmitStorageAtomicUMax64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
284 Id value) { 237 Id value) {
285 const Id pointer_1{GetStoragePointer(ctx, binding, offset)}; 238 return StorageAtomicU64(ctx, binding, offset, value, &Sirit::Module::OpAtomicUMax,
286 if (ctx.profile.support_int64_atomics) { 239 &Sirit::Module::OpUMax);
287 const auto [scope, semantics]{GetAtomicArgs(ctx)};
288 return ctx.OpAtomicUMax(ctx.U64, pointer_1, scope, semantics, value);
289 }
290 // LOG_WARNING(Render_Vulkan, "Int64 Atomics not supported, fallback to non-atomic");
291 const Id pointer_2{GetStoragePointer(ctx, binding, offset, 1)};
292 const Id original_value{LoadU64(ctx, pointer_1, pointer_2)};
293 const Id result{ctx.OpUMax(ctx.U64, value, original_value)};
294 StoreResult(ctx, pointer_1, pointer_2, result);
295 return original_value;
296} 240}
297 241
298Id EmitStorageAtomicAnd64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 242Id EmitStorageAtomicAnd64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
299 Id value) { 243 Id value) {
300 const Id pointer_1{GetStoragePointer(ctx, binding, offset)}; 244 return StorageAtomicU64(ctx, binding, offset, value, &Sirit::Module::OpAtomicAnd,
301 if (ctx.profile.support_int64_atomics) { 245 &Sirit::Module::OpBitwiseAnd);
302 const auto [scope, semantics]{GetAtomicArgs(ctx)};
303 return ctx.OpAtomicAnd(ctx.U64, pointer_1, scope, semantics, value);
304 }
305 // LOG_WARNING(Render_Vulkan, "Int64 Atomics not supported, fallback to non-atomic");
306 const Id pointer_2{GetStoragePointer(ctx, binding, offset, 1)};
307 const Id original_value{LoadU64(ctx, pointer_1, pointer_2)};
308 const Id result{ctx.OpBitwiseAnd(ctx.U64, value, original_value)};
309 StoreResult(ctx, pointer_1, pointer_2, result);
310 return original_value;
311} 246}
312 247
313Id EmitStorageAtomicOr64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 248Id EmitStorageAtomicOr64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
314 Id value) { 249 Id value) {
315 const Id pointer_1{GetStoragePointer(ctx, binding, offset)}; 250 return StorageAtomicU64(ctx, binding, offset, value, &Sirit::Module::OpAtomicOr,
316 if (ctx.profile.support_int64_atomics) { 251 &Sirit::Module::OpBitwiseOr);
317 const auto [scope, semantics]{GetAtomicArgs(ctx)};
318 return ctx.OpAtomicOr(ctx.U64, pointer_1, scope, semantics, value);
319 }
320 // LOG_WARNING(Render_Vulkan, "Int64 Atomics not supported, fallback to non-atomic");
321 const Id pointer_2{GetStoragePointer(ctx, binding, offset, 1)};
322 const Id original_value{LoadU64(ctx, pointer_1, pointer_2)};
323 const Id result{ctx.OpBitwiseOr(ctx.U64, value, original_value)};
324 StoreResult(ctx, pointer_1, pointer_2, result);
325 return original_value;
326} 252}
327 253
328Id EmitStorageAtomicXor64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 254Id EmitStorageAtomicXor64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
329 Id value) { 255 Id value) {
330 const Id pointer_1{GetStoragePointer(ctx, binding, offset)}; 256 return StorageAtomicU64(ctx, binding, offset, value, &Sirit::Module::OpAtomicXor,
331 if (ctx.profile.support_int64_atomics) { 257 &Sirit::Module::OpBitwiseXor);
332 const auto [scope, semantics]{GetAtomicArgs(ctx)};
333 return ctx.OpAtomicXor(ctx.U64, pointer_1, scope, semantics, value);
334 }
335 // LOG_WARNING(Render_Vulkan, "Int64 Atomics not supported, fallback to non-atomic");
336 const Id pointer_2{GetStoragePointer(ctx, binding, offset, 1)};
337 const Id original_value{LoadU64(ctx, pointer_1, pointer_2)};
338 const Id result{ctx.OpBitwiseXor(ctx.U64, value, original_value)};
339 StoreResult(ctx, pointer_1, pointer_2, result);
340 return original_value;
341} 258}
342 259
343Id EmitStorageAtomicExchange64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 260Id EmitStorageAtomicExchange64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
344 Id value) { 261 Id value) {
345 const Id pointer_1{GetStoragePointer(ctx, binding, offset)};
346 if (ctx.profile.support_int64_atomics) { 262 if (ctx.profile.support_int64_atomics) {
347 const auto [scope, semantics]{GetAtomicArgs(ctx)}; 263 const Id pointer{StoragePointer(ctx, ctx.storage_types.U64, &StorageDefinitions::U64,
348 return ctx.OpAtomicExchange(ctx.U64, pointer_1, scope, semantics, value); 264 binding, offset, sizeof(u64))};
265 const auto [scope, semantics]{AtomicArgs(ctx)};
266 return ctx.OpAtomicExchange(ctx.U64, pointer, scope, semantics, value);
349 } 267 }
350 // LOG_WARNING(Render_Vulkan, "Int64 Atomics not supported, fallback to non-atomic"); 268 // LOG_WARNING(..., "Int64 Atomics not supported, fallback to non-atomic");
351 const Id pointer_2{GetStoragePointer(ctx, binding, offset, 1)}; 269 const Id pointer{StoragePointer(ctx, ctx.storage_types.U32x2, &StorageDefinitions::U32x2,
352 const Id original_value{LoadU64(ctx, pointer_1, pointer_2)}; 270 binding, offset, sizeof(u32[2]))};
353 StoreResult(ctx, pointer_1, pointer_2, value); 271 const Id original{ctx.OpBitcast(ctx.U64, ctx.OpLoad(ctx.U32[2], pointer))};
354 return original_value; 272 ctx.OpStore(pointer, value);
273 return original;
355} 274}
356 275
357Id EmitStorageAtomicAddF32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 276Id EmitStorageAtomicAddF32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
358 Id value) { 277 Id value) {
359 const Id ssbo{ctx.ssbos[binding.U32()]}; 278 const Id ssbo{ctx.ssbos[binding.U32()].U32};
360 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))}; 279 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))};
361 return ctx.OpFunctionCall(ctx.F32[1], ctx.f32_add_cas, base_index, value, ssbo); 280 return ctx.OpFunctionCall(ctx.F32[1], ctx.f32_add_cas, base_index, value, ssbo);
362} 281}
363 282
364Id EmitStorageAtomicAddF16x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 283Id EmitStorageAtomicAddF16x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
365 Id value) { 284 Id value) {
366 const Id ssbo{ctx.ssbos[binding.U32()]}; 285 const Id ssbo{ctx.ssbos[binding.U32()].U32};
367 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))}; 286 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))};
368 const Id result{ctx.OpFunctionCall(ctx.F16[2], ctx.f16x2_add_cas, base_index, value, ssbo)}; 287 const Id result{ctx.OpFunctionCall(ctx.F16[2], ctx.f16x2_add_cas, base_index, value, ssbo)};
369 return ctx.OpBitcast(ctx.U32[1], result); 288 return ctx.OpBitcast(ctx.U32[1], result);
@@ -371,7 +290,7 @@ Id EmitStorageAtomicAddF16x2(EmitContext& ctx, const IR::Value& binding, const I
371 290
372Id EmitStorageAtomicAddF32x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 291Id EmitStorageAtomicAddF32x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
373 Id value) { 292 Id value) {
374 const Id ssbo{ctx.ssbos[binding.U32()]}; 293 const Id ssbo{ctx.ssbos[binding.U32()].U32};
375 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))}; 294 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))};
376 const Id result{ctx.OpFunctionCall(ctx.F32[2], ctx.f32x2_add_cas, base_index, value, ssbo)}; 295 const Id result{ctx.OpFunctionCall(ctx.F32[2], ctx.f32x2_add_cas, base_index, value, ssbo)};
377 return ctx.OpPackHalf2x16(ctx.U32[1], result); 296 return ctx.OpPackHalf2x16(ctx.U32[1], result);
@@ -379,7 +298,7 @@ Id EmitStorageAtomicAddF32x2(EmitContext& ctx, const IR::Value& binding, const I
379 298
380Id EmitStorageAtomicMinF16x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 299Id EmitStorageAtomicMinF16x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
381 Id value) { 300 Id value) {
382 const Id ssbo{ctx.ssbos[binding.U32()]}; 301 const Id ssbo{ctx.ssbos[binding.U32()].U32};
383 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))}; 302 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))};
384 const Id result{ctx.OpFunctionCall(ctx.F16[2], ctx.f16x2_min_cas, base_index, value, ssbo)}; 303 const Id result{ctx.OpFunctionCall(ctx.F16[2], ctx.f16x2_min_cas, base_index, value, ssbo)};
385 return ctx.OpBitcast(ctx.U32[1], result); 304 return ctx.OpBitcast(ctx.U32[1], result);
@@ -387,7 +306,7 @@ Id EmitStorageAtomicMinF16x2(EmitContext& ctx, const IR::Value& binding, const I
387 306
388Id EmitStorageAtomicMinF32x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 307Id EmitStorageAtomicMinF32x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
389 Id value) { 308 Id value) {
390 const Id ssbo{ctx.ssbos[binding.U32()]}; 309 const Id ssbo{ctx.ssbos[binding.U32()].U32};
391 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))}; 310 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))};
392 const Id result{ctx.OpFunctionCall(ctx.F32[2], ctx.f32x2_min_cas, base_index, value, ssbo)}; 311 const Id result{ctx.OpFunctionCall(ctx.F32[2], ctx.f32x2_min_cas, base_index, value, ssbo)};
393 return ctx.OpPackHalf2x16(ctx.U32[1], result); 312 return ctx.OpPackHalf2x16(ctx.U32[1], result);
@@ -395,7 +314,7 @@ Id EmitStorageAtomicMinF32x2(EmitContext& ctx, const IR::Value& binding, const I
395 314
396Id EmitStorageAtomicMaxF16x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 315Id EmitStorageAtomicMaxF16x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
397 Id value) { 316 Id value) {
398 const Id ssbo{ctx.ssbos[binding.U32()]}; 317 const Id ssbo{ctx.ssbos[binding.U32()].U32};
399 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))}; 318 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))};
400 const Id result{ctx.OpFunctionCall(ctx.F16[2], ctx.f16x2_max_cas, base_index, value, ssbo)}; 319 const Id result{ctx.OpFunctionCall(ctx.F16[2], ctx.f16x2_max_cas, base_index, value, ssbo)};
401 return ctx.OpBitcast(ctx.U32[1], result); 320 return ctx.OpBitcast(ctx.U32[1], result);
@@ -403,7 +322,7 @@ Id EmitStorageAtomicMaxF16x2(EmitContext& ctx, const IR::Value& binding, const I
403 322
404Id EmitStorageAtomicMaxF32x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 323Id EmitStorageAtomicMaxF32x2(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
405 Id value) { 324 Id value) {
406 const Id ssbo{ctx.ssbos[binding.U32()]}; 325 const Id ssbo{ctx.ssbos[binding.U32()].U32};
407 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))}; 326 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))};
408 const Id result{ctx.OpFunctionCall(ctx.F32[2], ctx.f32x2_max_cas, base_index, value, ssbo)}; 327 const Id result{ctx.OpFunctionCall(ctx.F32[2], ctx.f32x2_max_cas, base_index, value, ssbo)};
409 return ctx.OpPackHalf2x16(ctx.U32[1], result); 328 return ctx.OpPackHalf2x16(ctx.U32[1], result);
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp
index 088bd3059..a8f2ea5a0 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp
@@ -22,29 +22,29 @@ Id StorageIndex(EmitContext& ctx, const IR::Value& offset, size_t element_size)
22 return ctx.OpShiftRightLogical(ctx.U32[1], index, shift_id); 22 return ctx.OpShiftRightLogical(ctx.U32[1], index, shift_id);
23} 23}
24 24
25Id EmitLoadStorage(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 25Id StoragePointer(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
26 u32 num_components) { 26 const StorageTypeDefinition& type_def, size_t element_size,
27 // TODO: Support reinterpreting bindings, guaranteed to be aligned 27 Id StorageDefinitions::*member_ptr) {
28 if (!binding.IsImmediate()) { 28 if (!binding.IsImmediate()) {
29 throw NotImplementedException("Dynamic storage buffer indexing"); 29 throw NotImplementedException("Dynamic storage buffer indexing");
30 } 30 }
31 const Id ssbo{ctx.ssbos[binding.U32()]}; 31 const Id ssbo{ctx.ssbos[binding.U32()].*member_ptr};
32 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))}; 32 const Id index{StorageIndex(ctx, offset, element_size)};
33 std::array<Id, 4> components; 33 return ctx.OpAccessChain(type_def.element, ssbo, ctx.u32_zero_value, index);
34 for (u32 element = 0; element < num_components; ++element) { 34}
35 Id index{base_index}; 35
36 if (element > 0) { 36Id LoadStorage(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, Id result_type,
37 index = ctx.OpIAdd(ctx.U32[1], base_index, ctx.Constant(ctx.U32[1], element)); 37 const StorageTypeDefinition& type_def, size_t element_size,
38 } 38 Id StorageDefinitions::*member_ptr) {
39 const Id pointer{ctx.OpAccessChain(ctx.storage_u32, ssbo, ctx.u32_zero_value, index)}; 39 const Id pointer{StoragePointer(ctx, binding, offset, type_def, element_size, member_ptr)};
40 components[element] = ctx.OpLoad(ctx.U32[1], pointer); 40 return ctx.OpLoad(result_type, pointer);
41 } 41}
42 if (num_components == 1) { 42
43 return components[0]; 43void WriteStorage(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, Id value,
44 } else { 44 const StorageTypeDefinition& type_def, size_t element_size,
45 const std::span components_span(components.data(), num_components); 45 Id StorageDefinitions::*member_ptr) {
46 return ctx.OpCompositeConstruct(ctx.U32[num_components], components_span); 46 const Id pointer{StoragePointer(ctx, binding, offset, type_def, element_size, member_ptr)};
47 } 47 ctx.OpStore(pointer, value);
48} 48}
49} // Anonymous namespace 49} // Anonymous namespace
50 50
@@ -104,92 +104,85 @@ void EmitWriteGlobal128(EmitContext&) {
104 throw NotImplementedException("SPIR-V Instruction"); 104 throw NotImplementedException("SPIR-V Instruction");
105} 105}
106 106
107void EmitLoadStorageU8(EmitContext&) { 107Id EmitLoadStorageU8(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset) {
108 throw NotImplementedException("SPIR-V Instruction"); 108 return ctx.OpUConvert(ctx.U32[1],
109 LoadStorage(ctx, binding, offset, ctx.U8, ctx.storage_types.U8,
110 sizeof(u8), &StorageDefinitions::U8));
109} 111}
110 112
111void EmitLoadStorageS8(EmitContext&) { 113Id EmitLoadStorageS8(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset) {
112 throw NotImplementedException("SPIR-V Instruction"); 114 return ctx.OpSConvert(ctx.U32[1],
115 LoadStorage(ctx, binding, offset, ctx.S8, ctx.storage_types.S8,
116 sizeof(s8), &StorageDefinitions::S8));
113} 117}
114 118
115void EmitLoadStorageU16(EmitContext&) { 119Id EmitLoadStorageU16(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset) {
116 throw NotImplementedException("SPIR-V Instruction"); 120 return ctx.OpUConvert(ctx.U32[1],
121 LoadStorage(ctx, binding, offset, ctx.U16, ctx.storage_types.U16,
122 sizeof(u16), &StorageDefinitions::U16));
117} 123}
118 124
119void EmitLoadStorageS16(EmitContext&) { 125Id EmitLoadStorageS16(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset) {
120 throw NotImplementedException("SPIR-V Instruction"); 126 return ctx.OpSConvert(ctx.U32[1],
127 LoadStorage(ctx, binding, offset, ctx.S16, ctx.storage_types.S16,
128 sizeof(s16), &StorageDefinitions::S16));
121} 129}
122 130
123Id EmitLoadStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset) { 131Id EmitLoadStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset) {
124 return EmitLoadStorage(ctx, binding, offset, 1); 132 return LoadStorage(ctx, binding, offset, ctx.U32[1], ctx.storage_types.U32, sizeof(u32),
133 &StorageDefinitions::U32);
125} 134}
126 135
127Id EmitLoadStorage64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset) { 136Id EmitLoadStorage64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset) {
128 return EmitLoadStorage(ctx, binding, offset, 2); 137 return LoadStorage(ctx, binding, offset, ctx.U32[2], ctx.storage_types.U32x2, sizeof(u32[2]),
138 &StorageDefinitions::U32x2);
129} 139}
130 140
131Id EmitLoadStorage128(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset) { 141Id EmitLoadStorage128(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset) {
132 return EmitLoadStorage(ctx, binding, offset, 4); 142 return LoadStorage(ctx, binding, offset, ctx.U32[4], ctx.storage_types.U32x4, sizeof(u32[4]),
143 &StorageDefinitions::U32x4);
133} 144}
134 145
135void EmitWriteStorageU8(EmitContext&) { 146void EmitWriteStorageU8(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
136 throw NotImplementedException("SPIR-V Instruction"); 147 Id value) {
148 WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.U8, value), ctx.storage_types.U8,
149 sizeof(u8), &StorageDefinitions::U8);
137} 150}
138 151
139void EmitWriteStorageS8(EmitContext&) { 152void EmitWriteStorageS8(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
140 throw NotImplementedException("SPIR-V Instruction"); 153 Id value) {
154 WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.S8, value), ctx.storage_types.S8,
155 sizeof(s8), &StorageDefinitions::S8);
141} 156}
142 157
143void EmitWriteStorageU16(EmitContext&) { 158void EmitWriteStorageU16(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
144 throw NotImplementedException("SPIR-V Instruction"); 159 Id value) {
160 WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.U16, value), ctx.storage_types.U16,
161 sizeof(u16), &StorageDefinitions::U16);
145} 162}
146 163
147void EmitWriteStorageS16(EmitContext&) { 164void EmitWriteStorageS16(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
148 throw NotImplementedException("SPIR-V Instruction"); 165 Id value) {
166 WriteStorage(ctx, binding, offset, ctx.OpSConvert(ctx.S16, value), ctx.storage_types.S16,
167 sizeof(s16), &StorageDefinitions::S16);
149} 168}
150 169
151void EmitWriteStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 170void EmitWriteStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
152 Id value) { 171 Id value) {
153 if (!binding.IsImmediate()) { 172 WriteStorage(ctx, binding, offset, value, ctx.storage_types.U32, sizeof(u32),
154 throw NotImplementedException("Dynamic storage buffer indexing"); 173 &StorageDefinitions::U32);
155 }
156 const Id ssbo{ctx.ssbos[binding.U32()]};
157 const Id index{StorageIndex(ctx, offset, sizeof(u32))};
158 const Id pointer{ctx.OpAccessChain(ctx.storage_u32, ssbo, ctx.u32_zero_value, index)};
159 ctx.OpStore(pointer, value);
160} 174}
161 175
162void EmitWriteStorage64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 176void EmitWriteStorage64(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
163 Id value) { 177 Id value) {
164 if (!binding.IsImmediate()) { 178 WriteStorage(ctx, binding, offset, value, ctx.storage_types.U32x2, sizeof(u32[2]),
165 throw NotImplementedException("Dynamic storage buffer indexing"); 179 &StorageDefinitions::U32x2);
166 }
167 // TODO: Support reinterpreting bindings, guaranteed to be aligned
168 const Id ssbo{ctx.ssbos[binding.U32()]};
169 const Id low_index{StorageIndex(ctx, offset, sizeof(u32))};
170 const Id high_index{ctx.OpIAdd(ctx.U32[1], low_index, ctx.Constant(ctx.U32[1], 1U))};
171 const Id low_pointer{ctx.OpAccessChain(ctx.storage_u32, ssbo, ctx.u32_zero_value, low_index)};
172 const Id high_pointer{ctx.OpAccessChain(ctx.storage_u32, ssbo, ctx.u32_zero_value, high_index)};
173 ctx.OpStore(low_pointer, ctx.OpCompositeExtract(ctx.U32[1], value, 0U));
174 ctx.OpStore(high_pointer, ctx.OpCompositeExtract(ctx.U32[1], value, 1U));
175} 180}
176 181
177void EmitWriteStorage128(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset, 182void EmitWriteStorage128(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
178 Id value) { 183 Id value) {
179 if (!binding.IsImmediate()) { 184 WriteStorage(ctx, binding, offset, value, ctx.storage_types.U32x4, sizeof(u32[4]),
180 throw NotImplementedException("Dynamic storage buffer indexing"); 185 &StorageDefinitions::U32x4);
181 }
182 // TODO: Support reinterpreting bindings, guaranteed to be aligned
183 const Id ssbo{ctx.ssbos[binding.U32()]};
184 const Id base_index{StorageIndex(ctx, offset, sizeof(u32))};
185 for (u32 element = 0; element < 4; ++element) {
186 Id index = base_index;
187 if (element > 0) {
188 index = ctx.OpIAdd(ctx.U32[1], base_index, ctx.Constant(ctx.U32[1], element));
189 }
190 const Id pointer{ctx.OpAccessChain(ctx.storage_u32, ssbo, ctx.u32_zero_value, index)};
191 ctx.OpStore(pointer, ctx.OpCompositeExtract(ctx.U32[1], value, element));
192 }
193} 186}
194 187
195} // namespace Shader::Backend::SPIRV 188} // namespace Shader::Backend::SPIRV
diff --git a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
index ab529e86d..116d93c1c 100644
--- a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
+++ b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
@@ -315,6 +315,23 @@ void VisitUsages(Info& info, IR::Inst& inst) {
315 case IR::Opcode::ConvertF32U64: 315 case IR::Opcode::ConvertF32U64:
316 case IR::Opcode::ConvertF64U64: 316 case IR::Opcode::ConvertF64U64:
317 case IR::Opcode::SharedAtomicExchange64: 317 case IR::Opcode::SharedAtomicExchange64:
318 case IR::Opcode::GlobalAtomicIAdd64:
319 case IR::Opcode::GlobalAtomicSMin64:
320 case IR::Opcode::GlobalAtomicUMin64:
321 case IR::Opcode::GlobalAtomicSMax64:
322 case IR::Opcode::GlobalAtomicUMax64:
323 case IR::Opcode::GlobalAtomicAnd64:
324 case IR::Opcode::GlobalAtomicOr64:
325 case IR::Opcode::GlobalAtomicXor64:
326 case IR::Opcode::GlobalAtomicExchange64:
327 case IR::Opcode::StorageAtomicIAdd64:
328 case IR::Opcode::StorageAtomicSMin64:
329 case IR::Opcode::StorageAtomicUMin64:
330 case IR::Opcode::StorageAtomicSMax64:
331 case IR::Opcode::StorageAtomicUMax64:
332 case IR::Opcode::StorageAtomicAnd64:
333 case IR::Opcode::StorageAtomicOr64:
334 case IR::Opcode::StorageAtomicXor64:
318 info.uses_int64 = true; 335 info.uses_int64 = true;
319 break; 336 break;
320 default: 337 default:
@@ -457,46 +474,91 @@ void VisitUsages(Info& info, IR::Inst& inst) {
457 case IR::Opcode::FSwizzleAdd: 474 case IR::Opcode::FSwizzleAdd:
458 info.uses_fswzadd = true; 475 info.uses_fswzadd = true;
459 break; 476 break;
477 case IR::Opcode::LoadStorageU8:
478 case IR::Opcode::LoadStorageS8:
479 case IR::Opcode::WriteStorageU8:
480 case IR::Opcode::WriteStorageS8:
481 info.used_storage_buffer_types |= IR::Type::U8;
482 break;
483 case IR::Opcode::LoadStorageU16:
484 case IR::Opcode::LoadStorageS16:
485 case IR::Opcode::WriteStorageU16:
486 case IR::Opcode::WriteStorageS16:
487 info.used_storage_buffer_types |= IR::Type::U16;
488 break;
489 case IR::Opcode::LoadStorage32:
490 case IR::Opcode::WriteStorage32:
491 case IR::Opcode::StorageAtomicIAdd32:
492 case IR::Opcode::StorageAtomicSMin32:
493 case IR::Opcode::StorageAtomicUMin32:
494 case IR::Opcode::StorageAtomicSMax32:
495 case IR::Opcode::StorageAtomicUMax32:
496 case IR::Opcode::StorageAtomicAnd32:
497 case IR::Opcode::StorageAtomicOr32:
498 case IR::Opcode::StorageAtomicXor32:
499 case IR::Opcode::StorageAtomicExchange32:
500 info.used_storage_buffer_types |= IR::Type::U32;
501 break;
502 case IR::Opcode::LoadStorage64:
503 case IR::Opcode::WriteStorage64:
504 info.used_storage_buffer_types |= IR::Type::U32x2;
505 break;
506 case IR::Opcode::LoadStorage128:
507 case IR::Opcode::WriteStorage128:
508 info.used_storage_buffer_types |= IR::Type::U32x4;
509 break;
460 case IR::Opcode::SharedAtomicInc32: 510 case IR::Opcode::SharedAtomicInc32:
461 info.uses_shared_increment = true; 511 info.uses_shared_increment = true;
462 break; 512 break;
463 case IR::Opcode::SharedAtomicDec32: 513 case IR::Opcode::SharedAtomicDec32:
464 info.uses_shared_decrement = true; 514 info.uses_shared_decrement = true;
465 break; 515 break;
516 case IR::Opcode::SharedAtomicExchange64:
517 info.uses_int64_bit_atomics = true;
518 break;
466 case IR::Opcode::GlobalAtomicInc32: 519 case IR::Opcode::GlobalAtomicInc32:
467 case IR::Opcode::StorageAtomicInc32: 520 case IR::Opcode::StorageAtomicInc32:
521 info.used_storage_buffer_types |= IR::Type::U32;
468 info.uses_global_increment = true; 522 info.uses_global_increment = true;
469 break; 523 break;
470 case IR::Opcode::GlobalAtomicDec32: 524 case IR::Opcode::GlobalAtomicDec32:
471 case IR::Opcode::StorageAtomicDec32: 525 case IR::Opcode::StorageAtomicDec32:
526 info.used_storage_buffer_types |= IR::Type::U32;
472 info.uses_global_decrement = true; 527 info.uses_global_decrement = true;
473 break; 528 break;
474 case IR::Opcode::GlobalAtomicAddF32: 529 case IR::Opcode::GlobalAtomicAddF32:
475 case IR::Opcode::StorageAtomicAddF32: 530 case IR::Opcode::StorageAtomicAddF32:
531 info.used_storage_buffer_types |= IR::Type::U32;
476 info.uses_atomic_f32_add = true; 532 info.uses_atomic_f32_add = true;
477 break; 533 break;
478 case IR::Opcode::GlobalAtomicAddF16x2: 534 case IR::Opcode::GlobalAtomicAddF16x2:
479 case IR::Opcode::StorageAtomicAddF16x2: 535 case IR::Opcode::StorageAtomicAddF16x2:
536 info.used_storage_buffer_types |= IR::Type::U32;
480 info.uses_atomic_f16x2_add = true; 537 info.uses_atomic_f16x2_add = true;
481 break; 538 break;
482 case IR::Opcode::GlobalAtomicAddF32x2: 539 case IR::Opcode::GlobalAtomicAddF32x2:
483 case IR::Opcode::StorageAtomicAddF32x2: 540 case IR::Opcode::StorageAtomicAddF32x2:
541 info.used_storage_buffer_types |= IR::Type::U32;
484 info.uses_atomic_f32x2_add = true; 542 info.uses_atomic_f32x2_add = true;
485 break; 543 break;
486 case IR::Opcode::GlobalAtomicMinF16x2: 544 case IR::Opcode::GlobalAtomicMinF16x2:
487 case IR::Opcode::StorageAtomicMinF16x2: 545 case IR::Opcode::StorageAtomicMinF16x2:
546 info.used_storage_buffer_types |= IR::Type::U32;
488 info.uses_atomic_f16x2_min = true; 547 info.uses_atomic_f16x2_min = true;
489 break; 548 break;
490 case IR::Opcode::GlobalAtomicMinF32x2: 549 case IR::Opcode::GlobalAtomicMinF32x2:
491 case IR::Opcode::StorageAtomicMinF32x2: 550 case IR::Opcode::StorageAtomicMinF32x2:
551 info.used_storage_buffer_types |= IR::Type::U32;
492 info.uses_atomic_f32x2_min = true; 552 info.uses_atomic_f32x2_min = true;
493 break; 553 break;
494 case IR::Opcode::GlobalAtomicMaxF16x2: 554 case IR::Opcode::GlobalAtomicMaxF16x2:
495 case IR::Opcode::StorageAtomicMaxF16x2: 555 case IR::Opcode::StorageAtomicMaxF16x2:
556 info.used_storage_buffer_types |= IR::Type::U32;
496 info.uses_atomic_f16x2_max = true; 557 info.uses_atomic_f16x2_max = true;
497 break; 558 break;
498 case IR::Opcode::GlobalAtomicMaxF32x2: 559 case IR::Opcode::GlobalAtomicMaxF32x2:
499 case IR::Opcode::StorageAtomicMaxF32x2: 560 case IR::Opcode::StorageAtomicMaxF32x2:
561 info.used_storage_buffer_types |= IR::Type::U32;
500 info.uses_atomic_f32x2_max = true; 562 info.uses_atomic_f32x2_max = true;
501 break; 563 break;
502 case IR::Opcode::GlobalAtomicIAdd64: 564 case IR::Opcode::GlobalAtomicIAdd64:
@@ -516,11 +578,8 @@ void VisitUsages(Info& info, IR::Inst& inst) {
516 case IR::Opcode::StorageAtomicAnd64: 578 case IR::Opcode::StorageAtomicAnd64:
517 case IR::Opcode::StorageAtomicOr64: 579 case IR::Opcode::StorageAtomicOr64:
518 case IR::Opcode::StorageAtomicXor64: 580 case IR::Opcode::StorageAtomicXor64:
519 info.uses_64_bit_atomics = true; 581 info.used_storage_buffer_types |= IR::Type::U64;
520 break; 582 info.uses_int64_bit_atomics = true;
521 case IR::Opcode::SharedAtomicExchange64:
522 info.uses_64_bit_atomics = true;
523 info.uses_shared_memory_u32x2 = true;
524 break; 583 break;
525 default: 584 default:
526 break; 585 break;
diff --git a/src/shader_recompiler/shader_info.h b/src/shader_recompiler/shader_info.h
index 6a51aabb5..15cf09c3d 100644
--- a/src/shader_recompiler/shader_info.h
+++ b/src/shader_recompiler/shader_info.h
@@ -141,10 +141,10 @@ struct Info {
141 bool uses_atomic_f32x2_add{}; 141 bool uses_atomic_f32x2_add{};
142 bool uses_atomic_f32x2_min{}; 142 bool uses_atomic_f32x2_min{};
143 bool uses_atomic_f32x2_max{}; 143 bool uses_atomic_f32x2_max{};
144 bool uses_64_bit_atomics{}; 144 bool uses_int64_bit_atomics{};
145 bool uses_shared_memory_u32x2{};
146 145
147 IR::Type used_constant_buffer_types{}; 146 IR::Type used_constant_buffer_types{};
147 IR::Type used_storage_buffer_types{};
148 148
149 u32 constant_buffer_mask{}; 149 u32 constant_buffer_mask{};
150 150
diff --git a/src/video_core/vulkan_common/vulkan_device.cpp b/src/video_core/vulkan_common/vulkan_device.cpp
index 911dfed44..87cfe6312 100644
--- a/src/video_core/vulkan_common/vulkan_device.cpp
+++ b/src/video_core/vulkan_common/vulkan_device.cpp
@@ -44,6 +44,7 @@ constexpr std::array REQUIRED_EXTENSIONS{
44 VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME, 44 VK_KHR_TIMELINE_SEMAPHORE_EXTENSION_NAME,
45 VK_KHR_SAMPLER_MIRROR_CLAMP_TO_EDGE_EXTENSION_NAME, 45 VK_KHR_SAMPLER_MIRROR_CLAMP_TO_EDGE_EXTENSION_NAME,
46 VK_KHR_SHADER_FLOAT_CONTROLS_EXTENSION_NAME, 46 VK_KHR_SHADER_FLOAT_CONTROLS_EXTENSION_NAME,
47 VK_KHR_VARIABLE_POINTERS_EXTENSION_NAME,
47 VK_EXT_VERTEX_ATTRIBUTE_DIVISOR_EXTENSION_NAME, 48 VK_EXT_VERTEX_ATTRIBUTE_DIVISOR_EXTENSION_NAME,
48 VK_EXT_SHADER_SUBGROUP_BALLOT_EXTENSION_NAME, 49 VK_EXT_SHADER_SUBGROUP_BALLOT_EXTENSION_NAME,
49 VK_EXT_SHADER_SUBGROUP_VOTE_EXTENSION_NAME, 50 VK_EXT_SHADER_SUBGROUP_VOTE_EXTENSION_NAME,
@@ -313,6 +314,14 @@ Device::Device(VkInstance instance_, vk::PhysicalDevice physical_, VkSurfaceKHR
313 }; 314 };
314 SetNext(next, host_query_reset); 315 SetNext(next, host_query_reset);
315 316
317 VkPhysicalDeviceVariablePointerFeaturesKHR variable_pointers{
318 .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VARIABLE_POINTERS_FEATURES_KHR,
319 .pNext = nullptr,
320 .variablePointersStorageBuffer = VK_TRUE,
321 .variablePointers = VK_TRUE,
322 };
323 SetNext(next, variable_pointers);
324
316 VkPhysicalDeviceShaderDemoteToHelperInvocationFeaturesEXT demote{ 325 VkPhysicalDeviceShaderDemoteToHelperInvocationFeaturesEXT demote{
317 .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_DEMOTE_TO_HELPER_INVOCATION_FEATURES_EXT, 326 .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_DEMOTE_TO_HELPER_INVOCATION_FEATURES_EXT,
318 .pNext = nullptr, 327 .pNext = nullptr,
@@ -399,6 +408,17 @@ Device::Device(VkInstance instance_, vk::PhysicalDevice physical_, VkSurfaceKHR
399 LOG_INFO(Render_Vulkan, "Device doesn't support extended dynamic state"); 408 LOG_INFO(Render_Vulkan, "Device doesn't support extended dynamic state");
400 } 409 }
401 410
411 VkPhysicalDeviceShaderAtomicInt64FeaturesKHR atomic_int64;
412 if (ext_shader_atomic_int64) {
413 atomic_int64 = {
414 .sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_ATOMIC_INT64_FEATURES_KHR,
415 .pNext = nullptr,
416 .shaderBufferInt64Atomics = VK_TRUE,
417 .shaderSharedInt64Atomics = VK_TRUE,
418 };
419 SetNext(next, atomic_int64);
420 }
421
402 VkPhysicalDeviceWorkgroupMemoryExplicitLayoutFeaturesKHR workgroup_layout; 422 VkPhysicalDeviceWorkgroupMemoryExplicitLayoutFeaturesKHR workgroup_layout;
403 if (khr_workgroup_memory_explicit_layout) { 423 if (khr_workgroup_memory_explicit_layout) {
404 workgroup_layout = { 424 workgroup_layout = {
@@ -624,9 +644,13 @@ void Device::CheckSuitability(bool requires_swapchain) const {
624 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_DEMOTE_TO_HELPER_INVOCATION_FEATURES_EXT; 644 VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_DEMOTE_TO_HELPER_INVOCATION_FEATURES_EXT;
625 demote.pNext = nullptr; 645 demote.pNext = nullptr;
626 646
647 VkPhysicalDeviceVariablePointerFeaturesKHR variable_pointers{};
648 variable_pointers.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VARIABLE_POINTERS_FEATURES_KHR;
649 variable_pointers.pNext = &demote;
650
627 VkPhysicalDeviceRobustness2FeaturesEXT robustness2{}; 651 VkPhysicalDeviceRobustness2FeaturesEXT robustness2{};
628 robustness2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ROBUSTNESS_2_FEATURES_EXT; 652 robustness2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_ROBUSTNESS_2_FEATURES_EXT;
629 robustness2.pNext = &demote; 653 robustness2.pNext = &variable_pointers;
630 654
631 VkPhysicalDeviceFeatures2KHR features2{}; 655 VkPhysicalDeviceFeatures2KHR features2{};
632 features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; 656 features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
@@ -654,6 +678,9 @@ void Device::CheckSuitability(bool requires_swapchain) const {
654 std::make_pair(features.shaderStorageImageWriteWithoutFormat, 678 std::make_pair(features.shaderStorageImageWriteWithoutFormat,
655 "shaderStorageImageWriteWithoutFormat"), 679 "shaderStorageImageWriteWithoutFormat"),
656 std::make_pair(demote.shaderDemoteToHelperInvocation, "shaderDemoteToHelperInvocation"), 680 std::make_pair(demote.shaderDemoteToHelperInvocation, "shaderDemoteToHelperInvocation"),
681 std::make_pair(variable_pointers.variablePointers, "variablePointers"),
682 std::make_pair(variable_pointers.variablePointersStorageBuffer,
683 "variablePointersStorageBuffer"),
657 std::make_pair(robustness2.robustBufferAccess2, "robustBufferAccess2"), 684 std::make_pair(robustness2.robustBufferAccess2, "robustBufferAccess2"),
658 std::make_pair(robustness2.robustImageAccess2, "robustImageAccess2"), 685 std::make_pair(robustness2.robustImageAccess2, "robustImageAccess2"),
659 std::make_pair(robustness2.nullDescriptor, "nullDescriptor"), 686 std::make_pair(robustness2.nullDescriptor, "nullDescriptor"),