summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to '')
m---------externals/sirit0
-rw-r--r--src/shader_recompiler/CMakeLists.txt4
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.cpp160
-rw-r--r--src/shader_recompiler/backend/spirv/emit_context.h67
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv.cpp189
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv.h84
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp4
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_composite.cpp2
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp20
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp26
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp18
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp16
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp36
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp4
-rw-r--r--src/shader_recompiler/frontend/ir/basic_block.h16
-rw-r--r--src/shader_recompiler/frontend/ir/program.h2
-rw-r--r--src/shader_recompiler/frontend/maxwell/program.cpp7
-rw-r--r--src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp81
-rw-r--r--src/shader_recompiler/ir_opt/constant_propagation_pass.cpp76
-rw-r--r--src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp110
-rw-r--r--src/shader_recompiler/ir_opt/passes.h4
-rw-r--r--src/shader_recompiler/main.cpp4
-rw-r--r--src/shader_recompiler/shader_info.h33
23 files changed, 671 insertions, 292 deletions
diff --git a/externals/sirit b/externals/sirit
Subproject f819ade0efe925a782090dea9e1bf300fedffb3 Subproject 200310e8faa756b9869dd6dfc902c255246ac74
diff --git a/src/shader_recompiler/CMakeLists.txt b/src/shader_recompiler/CMakeLists.txt
index e1f4276a1..84be94a8d 100644
--- a/src/shader_recompiler/CMakeLists.txt
+++ b/src/shader_recompiler/CMakeLists.txt
@@ -1,4 +1,6 @@
1add_executable(shader_recompiler 1add_executable(shader_recompiler
2 backend/spirv/emit_context.cpp
3 backend/spirv/emit_context.h
2 backend/spirv/emit_spirv.cpp 4 backend/spirv/emit_spirv.cpp
3 backend/spirv/emit_spirv.h 5 backend/spirv/emit_spirv.h
4 backend/spirv/emit_spirv_bitwise_conversion.cpp 6 backend/spirv/emit_spirv_bitwise_conversion.cpp
@@ -75,6 +77,7 @@ add_executable(shader_recompiler
75 frontend/maxwell/translate/impl/move_special_register.cpp 77 frontend/maxwell/translate/impl/move_special_register.cpp
76 frontend/maxwell/translate/translate.cpp 78 frontend/maxwell/translate/translate.cpp
77 frontend/maxwell/translate/translate.h 79 frontend/maxwell/translate/translate.h
80 ir_opt/collect_shader_info_pass.cpp
78 ir_opt/constant_propagation_pass.cpp 81 ir_opt/constant_propagation_pass.cpp
79 ir_opt/dead_code_elimination_pass.cpp 82 ir_opt/dead_code_elimination_pass.cpp
80 ir_opt/global_memory_to_storage_buffer_pass.cpp 83 ir_opt/global_memory_to_storage_buffer_pass.cpp
@@ -84,6 +87,7 @@ add_executable(shader_recompiler
84 ir_opt/verification_pass.cpp 87 ir_opt/verification_pass.cpp
85 main.cpp 88 main.cpp
86 object_pool.h 89 object_pool.h
90 shader_info.h
87) 91)
88 92
89target_include_directories(video_core PRIVATE sirit) 93target_include_directories(video_core PRIVATE sirit)
diff --git a/src/shader_recompiler/backend/spirv/emit_context.cpp b/src/shader_recompiler/backend/spirv/emit_context.cpp
new file mode 100644
index 000000000..1c985aff8
--- /dev/null
+++ b/src/shader_recompiler/backend/spirv/emit_context.cpp
@@ -0,0 +1,160 @@
1// Copyright 2021 yuzu Emulator Project
2// Licensed under GPLv2 or any later version
3// Refer to the license.txt file included.
4
5#include <algorithm>
6#include <array>
7#include <string_view>
8
9#include <fmt/format.h>
10
11#include "common/common_types.h"
12#include "shader_recompiler/backend/spirv/emit_context.h"
13
14namespace Shader::Backend::SPIRV {
15
16void VectorTypes::Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) {
17 defs[0] = sirit_ctx.Name(base_type, name);
18
19 std::array<char, 6> def_name;
20 for (int i = 1; i < 4; ++i) {
21 const std::string_view def_name_view(
22 def_name.data(),
23 fmt::format_to_n(def_name.data(), def_name.size(), "{}x{}", name, i + 1).size);
24 defs[i] = sirit_ctx.Name(sirit_ctx.TypeVector(base_type, i + 1), def_name_view);
25 }
26}
27
28EmitContext::EmitContext(IR::Program& program) : Sirit::Module(0x00010000) {
29 AddCapability(spv::Capability::Shader);
30 DefineCommonTypes(program.info);
31 DefineCommonConstants();
32 DefineSpecialVariables(program.info);
33 DefineConstantBuffers(program.info);
34 DefineStorageBuffers(program.info);
35 DefineLabels(program);
36}
37
38EmitContext::~EmitContext() = default;
39
40Id EmitContext::Def(const IR::Value& value) {
41 if (!value.IsImmediate()) {
42 return value.Inst()->Definition<Id>();
43 }
44 switch (value.Type()) {
45 case IR::Type::U1:
46 return value.U1() ? true_value : false_value;
47 case IR::Type::U32:
48 return Constant(U32[1], value.U32());
49 case IR::Type::F32:
50 return Constant(F32[1], value.F32());
51 default:
52 throw NotImplementedException("Immediate type {}", value.Type());
53 }
54}
55
56void EmitContext::DefineCommonTypes(const Info& info) {
57 void_id = TypeVoid();
58
59 U1 = Name(TypeBool(), "u1");
60
61 F32.Define(*this, TypeFloat(32), "f32");
62 U32.Define(*this, TypeInt(32, false), "u32");
63
64 if (info.uses_fp16) {
65 AddCapability(spv::Capability::Float16);
66 F16.Define(*this, TypeFloat(16), "f16");
67 }
68 if (info.uses_fp64) {
69 AddCapability(spv::Capability::Float64);
70 F64.Define(*this, TypeFloat(64), "f64");
71 }
72}
73
74void EmitContext::DefineCommonConstants() {
75 true_value = ConstantTrue(U1);
76 false_value = ConstantFalse(U1);
77 u32_zero_value = Constant(U32[1], 0U);
78}
79
80void EmitContext::DefineSpecialVariables(const Info& info) {
81 const auto define{[this](Id type, spv::BuiltIn builtin, spv::StorageClass storage_class) {
82 const Id pointer_type{TypePointer(storage_class, type)};
83 const Id id{AddGlobalVariable(pointer_type, spv::StorageClass::Input)};
84 Decorate(id, spv::Decoration::BuiltIn, builtin);
85 return id;
86 }};
87 using namespace std::placeholders;
88 const auto define_input{std::bind(define, _1, _2, spv::StorageClass::Input)};
89
90 if (info.uses_workgroup_id) {
91 workgroup_id = define_input(U32[3], spv::BuiltIn::WorkgroupId);
92 }
93 if (info.uses_local_invocation_id) {
94 local_invocation_id = define_input(U32[3], spv::BuiltIn::LocalInvocationId);
95 }
96}
97
98void EmitContext::DefineConstantBuffers(const Info& info) {
99 if (info.constant_buffer_descriptors.empty()) {
100 return;
101 }
102 const Id array_type{TypeArray(U32[1], Constant(U32[1], 4096))};
103 Decorate(array_type, spv::Decoration::ArrayStride, 16U);
104
105 const Id struct_type{TypeStruct(array_type)};
106 Name(struct_type, "cbuf_block");
107 Decorate(struct_type, spv::Decoration::Block);
108 MemberName(struct_type, 0, "data");
109 MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U);
110
111 const Id uniform_type{TypePointer(spv::StorageClass::Uniform, struct_type)};
112 uniform_u32 = TypePointer(spv::StorageClass::Uniform, U32[1]);
113
114 u32 binding{};
115 for (const Info::ConstantBufferDescriptor& desc : info.constant_buffer_descriptors) {
116 const Id id{AddGlobalVariable(uniform_type, spv::StorageClass::Uniform)};
117 Decorate(id, spv::Decoration::Binding, binding);
118 Name(id, fmt::format("c{}", desc.index));
119 std::fill_n(cbufs.data() + desc.index, desc.count, id);
120 binding += desc.count;
121 }
122}
123
124void EmitContext::DefineStorageBuffers(const Info& info) {
125 if (info.storage_buffers_descriptors.empty()) {
126 return;
127 }
128 AddExtension("SPV_KHR_storage_buffer_storage_class");
129
130 const Id array_type{TypeRuntimeArray(U32[1])};
131 Decorate(array_type, spv::Decoration::ArrayStride, 4U);
132
133 const Id struct_type{TypeStruct(array_type)};
134 Name(struct_type, "ssbo_block");
135 Decorate(struct_type, spv::Decoration::Block);
136 MemberName(struct_type, 0, "data");
137 MemberDecorate(struct_type, 0, spv::Decoration::Offset, 0U);
138
139 const Id storage_type{TypePointer(spv::StorageClass::StorageBuffer, struct_type)};
140 storage_u32 = TypePointer(spv::StorageClass::StorageBuffer, U32[1]);
141
142 u32 binding{};
143 for (const Info::StorageBufferDescriptor& desc : info.storage_buffers_descriptors) {
144 const Id id{AddGlobalVariable(storage_type, spv::StorageClass::StorageBuffer)};
145 Decorate(id, spv::Decoration::Binding, binding);
146 Name(id, fmt::format("ssbo{}", binding));
147 std::fill_n(ssbos.data() + binding, desc.count, id);
148 binding += desc.count;
149 }
150}
151
152void EmitContext::DefineLabels(IR::Program& program) {
153 for (const IR::Function& function : program.functions) {
154 for (IR::Block* const block : function.blocks) {
155 block->SetDefinition(OpLabel());
156 }
157 }
158}
159
160} // namespace Shader::Backend::SPIRV
diff --git a/src/shader_recompiler/backend/spirv/emit_context.h b/src/shader_recompiler/backend/spirv/emit_context.h
new file mode 100644
index 000000000..c4b84759d
--- /dev/null
+++ b/src/shader_recompiler/backend/spirv/emit_context.h
@@ -0,0 +1,67 @@
1// Copyright 2021 yuzu Emulator Project
2// Licensed under GPLv2 or any later version
3// Refer to the license.txt file included.
4
5#pragma once
6
7#include <array>
8#include <string_view>
9
10#include <sirit/sirit.h>
11
12#include "shader_recompiler/frontend/ir/program.h"
13#include "shader_recompiler/shader_info.h"
14
15namespace Shader::Backend::SPIRV {
16
17using Sirit::Id;
18
19class VectorTypes {
20public:
21 void Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name);
22
23 [[nodiscard]] Id operator[](size_t size) const noexcept {
24 return defs[size - 1];
25 }
26
27private:
28 std::array<Id, 4> defs{};
29};
30
31class EmitContext final : public Sirit::Module {
32public:
33 explicit EmitContext(IR::Program& program);
34 ~EmitContext();
35
36 [[nodiscard]] Id Def(const IR::Value& value);
37
38 Id void_id{};
39 Id U1{};
40 VectorTypes F32;
41 VectorTypes U32;
42 VectorTypes F16;
43 VectorTypes F64;
44
45 Id true_value{};
46 Id false_value{};
47 Id u32_zero_value{};
48
49 Id uniform_u32{};
50 Id storage_u32{};
51
52 std::array<Id, Info::MAX_CBUFS> cbufs{};
53 std::array<Id, Info::MAX_SSBOS> ssbos{};
54
55 Id workgroup_id{};
56 Id local_invocation_id{};
57
58private:
59 void DefineCommonTypes(const Info& info);
60 void DefineCommonConstants();
61 void DefineSpecialVariables(const Info& info);
62 void DefineConstantBuffers(const Info& info);
63 void DefineStorageBuffers(const Info& info);
64 void DefineLabels(IR::Program& program);
65};
66
67} // namespace Shader::Backend::SPIRV
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
index 0895414b4..c79c09774 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
@@ -12,31 +12,83 @@
12#include "shader_recompiler/frontend/ir/program.h" 12#include "shader_recompiler/frontend/ir/program.h"
13 13
14namespace Shader::Backend::SPIRV { 14namespace Shader::Backend::SPIRV {
15namespace {
16template <class Func>
17struct FuncTraits : FuncTraits<decltype(&Func::operator())> {};
15 18
16EmitContext::EmitContext(IR::Program& program) { 19template <class ClassType, class ReturnType_, class... Args>
17 AddCapability(spv::Capability::Shader); 20struct FuncTraits<ReturnType_ (ClassType::*)(Args...)> {
18 AddCapability(spv::Capability::Float16); 21 using ReturnType = ReturnType_;
19 AddCapability(spv::Capability::Float64);
20 void_id = TypeVoid();
21 22
22 u1 = Name(TypeBool(), "u1"); 23 static constexpr size_t NUM_ARGS = sizeof...(Args);
23 f32.Define(*this, TypeFloat(32), "f32");
24 u32.Define(*this, TypeInt(32, false), "u32");
25 f16.Define(*this, TypeFloat(16), "f16");
26 f64.Define(*this, TypeFloat(64), "f64");
27 24
28 true_value = ConstantTrue(u1); 25 template <size_t I>
29 false_value = ConstantFalse(u1); 26 using ArgType = std::tuple_element_t<I, std::tuple<Args...>>;
27};
30 28
31 for (const IR::Function& function : program.functions) { 29template <auto method, typename... Args>
32 for (IR::Block* const block : function.blocks) { 30void SetDefinition(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst, Args... args) {
33 block_label_map.emplace_back(block, OpLabel()); 31 const Id forward_id{inst->Definition<Id>()};
32 const bool has_forward_id{Sirit::ValidId(forward_id)};
33 Id current_id{};
34 if (has_forward_id) {
35 current_id = ctx.ExchangeCurrentId(forward_id);
36 }
37 const Id new_id{(emit.*method)(ctx, std::forward<Args>(args)...)};
38 if (has_forward_id) {
39 ctx.ExchangeCurrentId(current_id);
40 } else {
41 inst->SetDefinition<Id>(new_id);
42 }
43}
44
45template <typename ArgType>
46ArgType Arg(EmitContext& ctx, const IR::Value& arg) {
47 if constexpr (std::is_same_v<ArgType, Id>) {
48 return ctx.Def(arg);
49 } else if constexpr (std::is_same_v<ArgType, const IR::Value&>) {
50 return arg;
51 } else if constexpr (std::is_same_v<ArgType, u32>) {
52 return arg.U32();
53 } else if constexpr (std::is_same_v<ArgType, IR::Block*>) {
54 return arg.Label();
55 }
56}
57
58template <auto method, bool is_first_arg_inst, size_t... I>
59void Invoke(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst, std::index_sequence<I...>) {
60 using Traits = FuncTraits<decltype(method)>;
61 if constexpr (std::is_same_v<Traits::ReturnType, Id>) {
62 if constexpr (is_first_arg_inst) {
63 SetDefinition<method>(emit, ctx, inst, inst,
64 Arg<Traits::ArgType<I + 2>>(ctx, inst->Arg(I))...);
65 } else {
66 SetDefinition<method>(emit, ctx, inst,
67 Arg<Traits::ArgType<I + 1>>(ctx, inst->Arg(I))...);
68 }
69 } else {
70 if constexpr (is_first_arg_inst) {
71 (emit.*method)(ctx, inst, Arg<Traits::ArgType<I + 2>>(ctx, inst->Arg(I))...);
72 } else {
73 (emit.*method)(ctx, Arg<Traits::ArgType<I + 1>>(ctx, inst->Arg(I))...);
34 } 74 }
35 } 75 }
36 std::ranges::sort(block_label_map, {}, &std::pair<IR::Block*, Id>::first);
37} 76}
38 77
39EmitContext::~EmitContext() = default; 78template <auto method>
79void Invoke(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst) {
80 using Traits = FuncTraits<decltype(method)>;
81 static_assert(Traits::NUM_ARGS >= 1, "Insufficient arguments");
82 if constexpr (Traits::NUM_ARGS == 1) {
83 Invoke<method, false>(emit, ctx, inst, std::make_index_sequence<0>{});
84 } else {
85 using FirstArgType = typename Traits::template ArgType<1>;
86 static constexpr bool is_first_arg_inst = std::is_same_v<FirstArgType, IR::Inst*>;
87 using Indices = std::make_index_sequence<Traits::NUM_ARGS - (is_first_arg_inst ? 2 : 1)>;
88 Invoke<method, is_first_arg_inst>(emit, ctx, inst, Indices{});
89 }
90}
91} // Anonymous namespace
40 92
41EmitSPIRV::EmitSPIRV(IR::Program& program) { 93EmitSPIRV::EmitSPIRV(IR::Program& program) {
42 EmitContext ctx{program}; 94 EmitContext ctx{program};
@@ -46,74 +98,32 @@ EmitSPIRV::EmitSPIRV(IR::Program& program) {
46 for (IR::Function& function : program.functions) { 98 for (IR::Function& function : program.functions) {
47 func = ctx.OpFunction(ctx.void_id, spv::FunctionControlMask::MaskNone, void_function); 99 func = ctx.OpFunction(ctx.void_id, spv::FunctionControlMask::MaskNone, void_function);
48 for (IR::Block* const block : function.blocks) { 100 for (IR::Block* const block : function.blocks) {
49 ctx.AddLabel(ctx.BlockLabel(block)); 101 ctx.AddLabel(block->Definition<Id>());
50 for (IR::Inst& inst : block->Instructions()) { 102 for (IR::Inst& inst : block->Instructions()) {
51 EmitInst(ctx, &inst); 103 EmitInst(ctx, &inst);
52 } 104 }
53 } 105 }
54 ctx.OpFunctionEnd(); 106 ctx.OpFunctionEnd();
55 } 107 }
56 ctx.AddEntryPoint(spv::ExecutionModel::GLCompute, func, "main"); 108 boost::container::small_vector<Id, 32> interfaces;
109 if (program.info.uses_workgroup_id) {
110 interfaces.push_back(ctx.workgroup_id);
111 }
112 if (program.info.uses_local_invocation_id) {
113 interfaces.push_back(ctx.local_invocation_id);
114 }
115
116 const std::span interfaces_span(interfaces.data(), interfaces.size());
117 ctx.AddEntryPoint(spv::ExecutionModel::Fragment, func, "main", interfaces_span);
118 ctx.AddExecutionMode(func, spv::ExecutionMode::OriginUpperLeft);
57 119
58 std::vector<u32> result{ctx.Assemble()}; 120 std::vector<u32> result{ctx.Assemble()};
59 std::FILE* file{std::fopen("shader.spv", "wb")}; 121 std::FILE* file{std::fopen("D:\\shader.spv", "wb")};
60 std::fwrite(result.data(), sizeof(u32), result.size(), file); 122 std::fwrite(result.data(), sizeof(u32), result.size(), file);
61 std::fclose(file); 123 std::fclose(file);
62 std::system("spirv-dis shader.spv"); 124 std::system("spirv-dis D:\\shader.spv") == 0 &&
63 std::system("spirv-val shader.spv"); 125 std::system("spirv-val --uniform-buffer-standard-layout D:\\shader.spv") == 0 &&
64 std::system("spirv-cross shader.spv"); 126 std::system("spirv-cross -V D:\\shader.spv") == 0;
65}
66
67template <auto method, typename... Args>
68static void SetDefinition(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst, Args... args) {
69 const Id forward_id{inst->Definition<Id>()};
70 const bool has_forward_id{Sirit::ValidId(forward_id)};
71 Id current_id{};
72 if (has_forward_id) {
73 current_id = ctx.ExchangeCurrentId(forward_id);
74 }
75 const Id new_id{(emit.*method)(ctx, std::forward<Args>(args)...)};
76 if (has_forward_id) {
77 ctx.ExchangeCurrentId(current_id);
78 } else {
79 inst->SetDefinition<Id>(new_id);
80 }
81}
82
83template <auto method>
84static void Invoke(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst) {
85 using M = decltype(method);
86 using std::is_invocable_r_v;
87 if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&>) {
88 SetDefinition<method>(emit, ctx, inst);
89 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, Id>) {
90 SetDefinition<method>(emit, ctx, inst, ctx.Def(inst->Arg(0)));
91 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, Id, Id>) {
92 SetDefinition<method>(emit, ctx, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1)));
93 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, Id, Id, Id>) {
94 SetDefinition<method>(emit, ctx, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1)),
95 ctx.Def(inst->Arg(2)));
96 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, IR::Inst*>) {
97 SetDefinition<method>(emit, ctx, inst, inst);
98 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, IR::Inst*, Id, Id>) {
99 SetDefinition<method>(emit, ctx, inst, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1)));
100 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, IR::Inst*, Id, Id, Id>) {
101 SetDefinition<method>(emit, ctx, inst, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1)),
102 ctx.Def(inst->Arg(2)));
103 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, Id, u32>) {
104 SetDefinition<method>(emit, ctx, inst, ctx.Def(inst->Arg(0)), inst->Arg(1).U32());
105 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, const IR::Value&>) {
106 SetDefinition<method>(emit, ctx, inst, inst->Arg(0));
107 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, const IR::Value&,
108 const IR::Value&>) {
109 SetDefinition<method>(emit, ctx, inst, inst->Arg(0), inst->Arg(1));
110 } else if constexpr (is_invocable_r_v<void, M, EmitSPIRV&, EmitContext&, IR::Inst*>) {
111 (emit.*method)(ctx, inst);
112 } else if constexpr (is_invocable_r_v<void, M, EmitSPIRV&, EmitContext&>) {
113 (emit.*method)(ctx);
114 } else {
115 static_assert(false, "Bad format");
116 }
117} 127}
118 128
119void EmitSPIRV::EmitInst(EmitContext& ctx, IR::Inst* inst) { 129void EmitSPIRV::EmitInst(EmitContext& ctx, IR::Inst* inst) {
@@ -130,9 +140,9 @@ void EmitSPIRV::EmitInst(EmitContext& ctx, IR::Inst* inst) {
130static Id TypeId(const EmitContext& ctx, IR::Type type) { 140static Id TypeId(const EmitContext& ctx, IR::Type type) {
131 switch (type) { 141 switch (type) {
132 case IR::Type::U1: 142 case IR::Type::U1:
133 return ctx.u1; 143 return ctx.U1;
134 case IR::Type::U32: 144 case IR::Type::U32:
135 return ctx.u32[1]; 145 return ctx.U32[1];
136 default: 146 default:
137 throw NotImplementedException("Phi node type {}", type); 147 throw NotImplementedException("Phi node type {}", type);
138 } 148 }
@@ -162,7 +172,7 @@ Id EmitSPIRV::EmitPhi(EmitContext& ctx, IR::Inst* inst) {
162 } 172 }
163 IR::Block* const phi_block{inst->PhiBlock(index)}; 173 IR::Block* const phi_block{inst->PhiBlock(index)};
164 operands.push_back(def); 174 operands.push_back(def);
165 operands.push_back(ctx.BlockLabel(phi_block)); 175 operands.push_back(phi_block->Definition<Id>());
166 } 176 }
167 const Id result_type{TypeId(ctx, inst->Arg(0).Type())}; 177 const Id result_type{TypeId(ctx, inst->Arg(0).Type())};
168 return ctx.OpPhi(result_type, std::span(operands.data(), operands.size())); 178 return ctx.OpPhi(result_type, std::span(operands.data(), operands.size()));
@@ -174,29 +184,6 @@ void EmitSPIRV::EmitIdentity(EmitContext&) {
174 throw NotImplementedException("SPIR-V Instruction"); 184 throw NotImplementedException("SPIR-V Instruction");
175} 185}
176 186
177// FIXME: Move to its own file
178void EmitSPIRV::EmitBranch(EmitContext& ctx, IR::Inst* inst) {
179 ctx.OpBranch(ctx.BlockLabel(inst->Arg(0).Label()));
180}
181
182void EmitSPIRV::EmitBranchConditional(EmitContext& ctx, IR::Inst* inst) {
183 ctx.OpBranchConditional(ctx.Def(inst->Arg(0)), ctx.BlockLabel(inst->Arg(1).Label()),
184 ctx.BlockLabel(inst->Arg(2).Label()));
185}
186
187void EmitSPIRV::EmitLoopMerge(EmitContext& ctx, IR::Inst* inst) {
188 ctx.OpLoopMerge(ctx.BlockLabel(inst->Arg(0).Label()), ctx.BlockLabel(inst->Arg(1).Label()),
189 spv::LoopControlMask::MaskNone);
190}
191
192void EmitSPIRV::EmitSelectionMerge(EmitContext& ctx, IR::Inst* inst) {
193 ctx.OpSelectionMerge(ctx.BlockLabel(inst->Arg(0).Label()), spv::SelectionControlMask::MaskNone);
194}
195
196void EmitSPIRV::EmitReturn(EmitContext& ctx) {
197 ctx.OpReturn();
198}
199
200void EmitSPIRV::EmitGetZeroFromOp(EmitContext&) { 187void EmitSPIRV::EmitGetZeroFromOp(EmitContext&) {
201 throw LogicError("Unreachable instruction"); 188 throw LogicError("Unreachable instruction");
202} 189}
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.h b/src/shader_recompiler/backend/spirv/emit_spirv.h
index 7d76377b5..a5d0e1ec0 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.h
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.h
@@ -7,82 +7,12 @@
7#include <sirit/sirit.h> 7#include <sirit/sirit.h>
8 8
9#include "common/common_types.h" 9#include "common/common_types.h"
10#include "shader_recompiler/backend/spirv/emit_context.h"
10#include "shader_recompiler/frontend/ir/microinstruction.h" 11#include "shader_recompiler/frontend/ir/microinstruction.h"
11#include "shader_recompiler/frontend/ir/program.h" 12#include "shader_recompiler/frontend/ir/program.h"
12 13
13namespace Shader::Backend::SPIRV { 14namespace Shader::Backend::SPIRV {
14 15
15using Sirit::Id;
16
17class VectorTypes {
18public:
19 void Define(Sirit::Module& sirit_ctx, Id base_type, std::string_view name) {
20 defs[0] = sirit_ctx.Name(base_type, name);
21
22 std::array<char, 6> def_name;
23 for (int i = 1; i < 4; ++i) {
24 const std::string_view def_name_view(
25 def_name.data(),
26 fmt::format_to_n(def_name.data(), def_name.size(), "{}x{}", name, i + 1).size);
27 defs[i] = sirit_ctx.Name(sirit_ctx.TypeVector(base_type, i + 1), def_name_view);
28 }
29 }
30
31 [[nodiscard]] Id operator[](size_t size) const noexcept {
32 return defs[size - 1];
33 }
34
35private:
36 std::array<Id, 4> defs;
37};
38
39class EmitContext final : public Sirit::Module {
40public:
41 explicit EmitContext(IR::Program& program);
42 ~EmitContext();
43
44 [[nodiscard]] Id Def(const IR::Value& value) {
45 if (!value.IsImmediate()) {
46 return value.Inst()->Definition<Id>();
47 }
48 switch (value.Type()) {
49 case IR::Type::U1:
50 return value.U1() ? true_value : false_value;
51 case IR::Type::U32:
52 return Constant(u32[1], value.U32());
53 case IR::Type::F32:
54 return Constant(f32[1], value.F32());
55 default:
56 throw NotImplementedException("Immediate type {}", value.Type());
57 }
58 }
59
60 [[nodiscard]] Id BlockLabel(IR::Block* block) const {
61 const auto it{std::ranges::lower_bound(block_label_map, block, {},
62 &std::pair<IR::Block*, Id>::first)};
63 if (it == block_label_map.end()) {
64 throw LogicError("Undefined block");
65 }
66 return it->second;
67 }
68
69 Id void_id{};
70 Id u1{};
71 VectorTypes f32;
72 VectorTypes u32;
73 VectorTypes f16;
74 VectorTypes f64;
75
76 Id true_value{};
77 Id false_value{};
78
79 Id workgroup_id{};
80 Id local_invocation_id{};
81
82private:
83 std::vector<std::pair<IR::Block*, Id>> block_label_map;
84};
85
86class EmitSPIRV { 16class EmitSPIRV {
87public: 17public:
88 explicit EmitSPIRV(IR::Program& program); 18 explicit EmitSPIRV(IR::Program& program);
@@ -94,10 +24,11 @@ private:
94 Id EmitPhi(EmitContext& ctx, IR::Inst* inst); 24 Id EmitPhi(EmitContext& ctx, IR::Inst* inst);
95 void EmitVoid(EmitContext& ctx); 25 void EmitVoid(EmitContext& ctx);
96 void EmitIdentity(EmitContext& ctx); 26 void EmitIdentity(EmitContext& ctx);
97 void EmitBranch(EmitContext& ctx, IR::Inst* inst); 27 void EmitBranch(EmitContext& ctx, IR::Block* label);
98 void EmitBranchConditional(EmitContext& ctx, IR::Inst* inst); 28 void EmitBranchConditional(EmitContext& ctx, Id condition, IR::Block* true_label,
99 void EmitLoopMerge(EmitContext& ctx, IR::Inst* inst); 29 IR::Block* false_label);
100 void EmitSelectionMerge(EmitContext& ctx, IR::Inst* inst); 30 void EmitLoopMerge(EmitContext& ctx, IR::Block* merge_label, IR::Block* continue_label);
31 void EmitSelectionMerge(EmitContext& ctx, IR::Block* merge_label);
101 void EmitReturn(EmitContext& ctx); 32 void EmitReturn(EmitContext& ctx);
102 void EmitGetRegister(EmitContext& ctx); 33 void EmitGetRegister(EmitContext& ctx);
103 void EmitSetRegister(EmitContext& ctx); 34 void EmitSetRegister(EmitContext& ctx);
@@ -150,7 +81,8 @@ private:
150 void EmitWriteStorageS8(EmitContext& ctx); 81 void EmitWriteStorageS8(EmitContext& ctx);
151 void EmitWriteStorageU16(EmitContext& ctx); 82 void EmitWriteStorageU16(EmitContext& ctx);
152 void EmitWriteStorageS16(EmitContext& ctx); 83 void EmitWriteStorageS16(EmitContext& ctx);
153 void EmitWriteStorage32(EmitContext& ctx); 84 void EmitWriteStorage32(EmitContext& ctx, const IR::Value& binding, const IR::Value& offset,
85 Id value);
154 void EmitWriteStorage64(EmitContext& ctx); 86 void EmitWriteStorage64(EmitContext& ctx);
155 void EmitWriteStorage128(EmitContext& ctx); 87 void EmitWriteStorage128(EmitContext& ctx);
156 void EmitCompositeConstructU32x2(EmitContext& ctx); 88 void EmitCompositeConstructU32x2(EmitContext& ctx);
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp
index 447df5b8c..af82df99c 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_bitwise_conversion.cpp
@@ -11,7 +11,7 @@ void EmitSPIRV::EmitBitCastU16F16(EmitContext&) {
11} 11}
12 12
13Id EmitSPIRV::EmitBitCastU32F32(EmitContext& ctx, Id value) { 13Id EmitSPIRV::EmitBitCastU32F32(EmitContext& ctx, Id value) {
14 return ctx.OpBitcast(ctx.u32[1], value); 14 return ctx.OpBitcast(ctx.U32[1], value);
15} 15}
16 16
17void EmitSPIRV::EmitBitCastU64F64(EmitContext&) { 17void EmitSPIRV::EmitBitCastU64F64(EmitContext&) {
@@ -23,7 +23,7 @@ void EmitSPIRV::EmitBitCastF16U16(EmitContext&) {
23} 23}
24 24
25Id EmitSPIRV::EmitBitCastF32U32(EmitContext& ctx, Id value) { 25Id EmitSPIRV::EmitBitCastF32U32(EmitContext& ctx, Id value) {
26 return ctx.OpBitcast(ctx.f32[1], value); 26 return ctx.OpBitcast(ctx.F32[1], value);
27} 27}
28 28
29void EmitSPIRV::EmitBitCastF64U64(EmitContext&) { 29void EmitSPIRV::EmitBitCastF64U64(EmitContext&) {
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_composite.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_composite.cpp
index b190cf876..a7374c89d 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_composite.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_composite.cpp
@@ -23,7 +23,7 @@ void EmitSPIRV::EmitCompositeExtractU32x2(EmitContext&) {
23} 23}
24 24
25Id EmitSPIRV::EmitCompositeExtractU32x3(EmitContext& ctx, Id vector, u32 index) { 25Id EmitSPIRV::EmitCompositeExtractU32x3(EmitContext& ctx, Id vector, u32 index) {
26 return ctx.OpCompositeExtract(ctx.u32[1], vector, index); 26 return ctx.OpCompositeExtract(ctx.U32[1], vector, index);
27} 27}
28 28
29void EmitSPIRV::EmitCompositeExtractU32x4(EmitContext&) { 29void EmitSPIRV::EmitCompositeExtractU32x4(EmitContext&) {
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
index 1eab739ed..f4c9970eb 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_context_get_set.cpp
@@ -37,7 +37,10 @@ Id EmitSPIRV::EmitGetCbuf(EmitContext& ctx, const IR::Value& binding, const IR::
37 if (!offset.IsImmediate()) { 37 if (!offset.IsImmediate()) {
38 throw NotImplementedException("Variable constant buffer offset"); 38 throw NotImplementedException("Variable constant buffer offset");
39 } 39 }
40 return ctx.Name(ctx.OpUndef(ctx.u32[1]), "unimplemented_cbuf"); 40 const Id imm_offset{ctx.Constant(ctx.U32[1], offset.U32() / 4)};
41 const Id cbuf{ctx.cbufs[binding.U32()]};
42 const Id access_chain{ctx.OpAccessChain(ctx.uniform_u32, cbuf, ctx.u32_zero_value, imm_offset)};
43 return ctx.OpLoad(ctx.U32[1], access_chain);
41} 44}
42 45
43void EmitSPIRV::EmitGetAttribute(EmitContext&) { 46void EmitSPIRV::EmitGetAttribute(EmitContext&) {
@@ -89,22 +92,11 @@ void EmitSPIRV::EmitSetOFlag(EmitContext&) {
89} 92}
90 93
91Id EmitSPIRV::EmitWorkgroupId(EmitContext& ctx) { 94Id EmitSPIRV::EmitWorkgroupId(EmitContext& ctx) {
92 if (ctx.workgroup_id.value == 0) { 95 return ctx.OpLoad(ctx.U32[3], ctx.workgroup_id);
93 ctx.workgroup_id = ctx.AddGlobalVariable(
94 ctx.TypePointer(spv::StorageClass::Input, ctx.u32[3]), spv::StorageClass::Input);
95 ctx.Decorate(ctx.workgroup_id, spv::Decoration::BuiltIn, spv::BuiltIn::WorkgroupId);
96 }
97 return ctx.OpLoad(ctx.u32[3], ctx.workgroup_id);
98} 96}
99 97
100Id EmitSPIRV::EmitLocalInvocationId(EmitContext& ctx) { 98Id EmitSPIRV::EmitLocalInvocationId(EmitContext& ctx) {
101 if (ctx.local_invocation_id.value == 0) { 99 return ctx.OpLoad(ctx.U32[3], ctx.local_invocation_id);
102 ctx.local_invocation_id = ctx.AddGlobalVariable(
103 ctx.TypePointer(spv::StorageClass::Input, ctx.u32[3]), spv::StorageClass::Input);
104 ctx.Decorate(ctx.local_invocation_id, spv::Decoration::BuiltIn,
105 spv::BuiltIn::LocalInvocationId);
106 }
107 return ctx.OpLoad(ctx.u32[3], ctx.local_invocation_id);
108} 100}
109 101
110} // namespace Shader::Backend::SPIRV 102} // namespace Shader::Backend::SPIRV
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp
index 66ce6c8c5..549c1907a 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_control_flow.cpp
@@ -3,3 +3,29 @@
3// Refer to the license.txt file included. 3// Refer to the license.txt file included.
4 4
5#include "shader_recompiler/backend/spirv/emit_spirv.h" 5#include "shader_recompiler/backend/spirv/emit_spirv.h"
6
7namespace Shader::Backend::SPIRV {
8
9void EmitSPIRV::EmitBranch(EmitContext& ctx, IR::Block* label) {
10 ctx.OpBranch(label->Definition<Id>());
11}
12
13void EmitSPIRV::EmitBranchConditional(EmitContext& ctx, Id condition, IR::Block* true_label,
14 IR::Block* false_label) {
15 ctx.OpBranchConditional(condition, true_label->Definition<Id>(), false_label->Definition<Id>());
16}
17
18void EmitSPIRV::EmitLoopMerge(EmitContext& ctx, IR::Block* merge_label, IR::Block* continue_label) {
19 ctx.OpLoopMerge(merge_label->Definition<Id>(), continue_label->Definition<Id>(),
20 spv::LoopControlMask::MaskNone);
21}
22
23void EmitSPIRV::EmitSelectionMerge(EmitContext& ctx, IR::Block* merge_label) {
24 ctx.OpSelectionMerge(merge_label->Definition<Id>(), spv::SelectionControlMask::MaskNone);
25}
26
27void EmitSPIRV::EmitReturn(EmitContext& ctx) {
28 ctx.OpReturn();
29}
30
31} // namespace Shader::Backend::SPIRV
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp
index 9c39537e2..c9bc121f8 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_floating_point.cpp
@@ -46,27 +46,27 @@ void EmitSPIRV::EmitFPAbs64(EmitContext&) {
46} 46}
47 47
48Id EmitSPIRV::EmitFPAdd16(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { 48Id EmitSPIRV::EmitFPAdd16(EmitContext& ctx, IR::Inst* inst, Id a, Id b) {
49 return Decorate(ctx, inst, ctx.OpFAdd(ctx.f16[1], a, b)); 49 return Decorate(ctx, inst, ctx.OpFAdd(ctx.F16[1], a, b));
50} 50}
51 51
52Id EmitSPIRV::EmitFPAdd32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { 52Id EmitSPIRV::EmitFPAdd32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) {
53 return Decorate(ctx, inst, ctx.OpFAdd(ctx.f32[1], a, b)); 53 return Decorate(ctx, inst, ctx.OpFAdd(ctx.F32[1], a, b));
54} 54}
55 55
56Id EmitSPIRV::EmitFPAdd64(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { 56Id EmitSPIRV::EmitFPAdd64(EmitContext& ctx, IR::Inst* inst, Id a, Id b) {
57 return Decorate(ctx, inst, ctx.OpFAdd(ctx.f64[1], a, b)); 57 return Decorate(ctx, inst, ctx.OpFAdd(ctx.F64[1], a, b));
58} 58}
59 59
60Id EmitSPIRV::EmitFPFma16(EmitContext& ctx, IR::Inst* inst, Id a, Id b, Id c) { 60Id EmitSPIRV::EmitFPFma16(EmitContext& ctx, IR::Inst* inst, Id a, Id b, Id c) {
61 return Decorate(ctx, inst, ctx.OpFma(ctx.f16[1], a, b, c)); 61 return Decorate(ctx, inst, ctx.OpFma(ctx.F16[1], a, b, c));
62} 62}
63 63
64Id EmitSPIRV::EmitFPFma32(EmitContext& ctx, IR::Inst* inst, Id a, Id b, Id c) { 64Id EmitSPIRV::EmitFPFma32(EmitContext& ctx, IR::Inst* inst, Id a, Id b, Id c) {
65 return Decorate(ctx, inst, ctx.OpFma(ctx.f32[1], a, b, c)); 65 return Decorate(ctx, inst, ctx.OpFma(ctx.F32[1], a, b, c));
66} 66}
67 67
68Id EmitSPIRV::EmitFPFma64(EmitContext& ctx, IR::Inst* inst, Id a, Id b, Id c) { 68Id EmitSPIRV::EmitFPFma64(EmitContext& ctx, IR::Inst* inst, Id a, Id b, Id c) {
69 return Decorate(ctx, inst, ctx.OpFma(ctx.f64[1], a, b, c)); 69 return Decorate(ctx, inst, ctx.OpFma(ctx.F64[1], a, b, c));
70} 70}
71 71
72void EmitSPIRV::EmitFPMax32(EmitContext&) { 72void EmitSPIRV::EmitFPMax32(EmitContext&) {
@@ -86,15 +86,15 @@ void EmitSPIRV::EmitFPMin64(EmitContext&) {
86} 86}
87 87
88Id EmitSPIRV::EmitFPMul16(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { 88Id EmitSPIRV::EmitFPMul16(EmitContext& ctx, IR::Inst* inst, Id a, Id b) {
89 return Decorate(ctx, inst, ctx.OpFMul(ctx.f16[1], a, b)); 89 return Decorate(ctx, inst, ctx.OpFMul(ctx.F16[1], a, b));
90} 90}
91 91
92Id EmitSPIRV::EmitFPMul32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { 92Id EmitSPIRV::EmitFPMul32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) {
93 return Decorate(ctx, inst, ctx.OpFMul(ctx.f32[1], a, b)); 93 return Decorate(ctx, inst, ctx.OpFMul(ctx.F32[1], a, b));
94} 94}
95 95
96Id EmitSPIRV::EmitFPMul64(EmitContext& ctx, IR::Inst* inst, Id a, Id b) { 96Id EmitSPIRV::EmitFPMul64(EmitContext& ctx, IR::Inst* inst, Id a, Id b) {
97 return Decorate(ctx, inst, ctx.OpFMul(ctx.f64[1], a, b)); 97 return Decorate(ctx, inst, ctx.OpFMul(ctx.F64[1], a, b));
98} 98}
99 99
100void EmitSPIRV::EmitFPNeg16(EmitContext&) { 100void EmitSPIRV::EmitFPNeg16(EmitContext&) {
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp
index e811a63ab..32af94a73 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_integer.cpp
@@ -10,7 +10,7 @@ Id EmitSPIRV::EmitIAdd32(EmitContext& ctx, IR::Inst* inst, Id a, Id b) {
10 if (inst->HasAssociatedPseudoOperation()) { 10 if (inst->HasAssociatedPseudoOperation()) {
11 throw NotImplementedException("Pseudo-operations on IAdd32"); 11 throw NotImplementedException("Pseudo-operations on IAdd32");
12 } 12 }
13 return ctx.OpIAdd(ctx.u32[1], a, b); 13 return ctx.OpIAdd(ctx.U32[1], a, b);
14} 14}
15 15
16void EmitSPIRV::EmitIAdd64(EmitContext&) { 16void EmitSPIRV::EmitIAdd64(EmitContext&) {
@@ -18,7 +18,7 @@ void EmitSPIRV::EmitIAdd64(EmitContext&) {
18} 18}
19 19
20Id EmitSPIRV::EmitISub32(EmitContext& ctx, Id a, Id b) { 20Id EmitSPIRV::EmitISub32(EmitContext& ctx, Id a, Id b) {
21 return ctx.OpISub(ctx.u32[1], a, b); 21 return ctx.OpISub(ctx.U32[1], a, b);
22} 22}
23 23
24void EmitSPIRV::EmitISub64(EmitContext&) { 24void EmitSPIRV::EmitISub64(EmitContext&) {
@@ -26,7 +26,7 @@ void EmitSPIRV::EmitISub64(EmitContext&) {
26} 26}
27 27
28Id EmitSPIRV::EmitIMul32(EmitContext& ctx, Id a, Id b) { 28Id EmitSPIRV::EmitIMul32(EmitContext& ctx, Id a, Id b) {
29 return ctx.OpIMul(ctx.u32[1], a, b); 29 return ctx.OpIMul(ctx.U32[1], a, b);
30} 30}
31 31
32void EmitSPIRV::EmitINeg32(EmitContext&) { 32void EmitSPIRV::EmitINeg32(EmitContext&) {
@@ -38,7 +38,7 @@ void EmitSPIRV::EmitIAbs32(EmitContext&) {
38} 38}
39 39
40Id EmitSPIRV::EmitShiftLeftLogical32(EmitContext& ctx, Id base, Id shift) { 40Id EmitSPIRV::EmitShiftLeftLogical32(EmitContext& ctx, Id base, Id shift) {
41 return ctx.OpShiftLeftLogical(ctx.u32[1], base, shift); 41 return ctx.OpShiftLeftLogical(ctx.U32[1], base, shift);
42} 42}
43 43
44void EmitSPIRV::EmitShiftRightLogical32(EmitContext&) { 44void EmitSPIRV::EmitShiftRightLogical32(EmitContext&) {
@@ -70,11 +70,11 @@ void EmitSPIRV::EmitBitFieldSExtract(EmitContext&) {
70} 70}
71 71
72Id EmitSPIRV::EmitBitFieldUExtract(EmitContext& ctx, Id base, Id offset, Id count) { 72Id EmitSPIRV::EmitBitFieldUExtract(EmitContext& ctx, Id base, Id offset, Id count) {
73 return ctx.OpBitFieldUExtract(ctx.u32[1], base, offset, count); 73 return ctx.OpBitFieldUExtract(ctx.U32[1], base, offset, count);
74} 74}
75 75
76Id EmitSPIRV::EmitSLessThan(EmitContext& ctx, Id lhs, Id rhs) { 76Id EmitSPIRV::EmitSLessThan(EmitContext& ctx, Id lhs, Id rhs) {
77 return ctx.OpSLessThan(ctx.u1, lhs, rhs); 77 return ctx.OpSLessThan(ctx.U1, lhs, rhs);
78} 78}
79 79
80void EmitSPIRV::EmitULessThan(EmitContext&) { 80void EmitSPIRV::EmitULessThan(EmitContext&) {
@@ -94,7 +94,7 @@ void EmitSPIRV::EmitULessThanEqual(EmitContext&) {
94} 94}
95 95
96Id EmitSPIRV::EmitSGreaterThan(EmitContext& ctx, Id lhs, Id rhs) { 96Id EmitSPIRV::EmitSGreaterThan(EmitContext& ctx, Id lhs, Id rhs) {
97 return ctx.OpSGreaterThan(ctx.u1, lhs, rhs); 97 return ctx.OpSGreaterThan(ctx.U1, lhs, rhs);
98} 98}
99 99
100void EmitSPIRV::EmitUGreaterThan(EmitContext&) { 100void EmitSPIRV::EmitUGreaterThan(EmitContext&) {
@@ -110,7 +110,7 @@ void EmitSPIRV::EmitSGreaterThanEqual(EmitContext&) {
110} 110}
111 111
112Id EmitSPIRV::EmitUGreaterThanEqual(EmitContext& ctx, Id lhs, Id rhs) { 112Id EmitSPIRV::EmitUGreaterThanEqual(EmitContext& ctx, Id lhs, Id rhs) {
113 return ctx.OpUGreaterThanEqual(ctx.u1, lhs, rhs); 113 return ctx.OpUGreaterThanEqual(ctx.U1, lhs, rhs);
114} 114}
115 115
116void EmitSPIRV::EmitLogicalOr(EmitContext&) { 116void EmitSPIRV::EmitLogicalOr(EmitContext&) {
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp
index 21a0d72fa..5769a3c95 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_memory.cpp
@@ -2,10 +2,26 @@
2// Licensed under GPLv2 or any later version 2// Licensed under GPLv2 or any later version
3// Refer to the license.txt file included. 3// Refer to the license.txt file included.
4 4
5#include <bit>
6
5#include "shader_recompiler/backend/spirv/emit_spirv.h" 7#include "shader_recompiler/backend/spirv/emit_spirv.h"
6 8
7namespace Shader::Backend::SPIRV { 9namespace Shader::Backend::SPIRV {
8 10
11static Id StorageIndex(EmitContext& ctx, const IR::Value& offset, size_t element_size) {
12 if (offset.IsImmediate()) {
13 const u32 imm_offset{static_cast<u32>(offset.U32() / element_size)};
14 return ctx.Constant(ctx.U32[1], imm_offset);
15 }
16 const u32 shift{static_cast<u32>(std::countr_zero(element_size))};
17 const Id index{ctx.Def(offset)};
18 if (shift == 0) {
19 return index;
20 }
21 const Id shift_id{ctx.Constant(ctx.U32[1], shift)};
22 return ctx.OpShiftRightLogical(ctx.U32[1], index, shift_id);
23}
24
9void EmitSPIRV::EmitLoadGlobalU8(EmitContext&) { 25void EmitSPIRV::EmitLoadGlobalU8(EmitContext&) {
10 throw NotImplementedException("SPIR-V Instruction"); 26 throw NotImplementedException("SPIR-V Instruction");
11} 27}
@@ -79,11 +95,14 @@ void EmitSPIRV::EmitLoadStorageS16(EmitContext&) {
79} 95}
80 96
81Id EmitSPIRV::EmitLoadStorage32(EmitContext& ctx, const IR::Value& binding, 97Id EmitSPIRV::EmitLoadStorage32(EmitContext& ctx, const IR::Value& binding,
82 [[maybe_unused]] const IR::Value& offset) { 98 const IR::Value& offset) {
83 if (!binding.IsImmediate()) { 99 if (!binding.IsImmediate()) {
84 throw NotImplementedException("Storage buffer indexing"); 100 throw NotImplementedException("Dynamic storage buffer indexing");
85 } 101 }
86 return ctx.Name(ctx.OpUndef(ctx.u32[1]), "unimplemented_sbuf"); 102 const Id ssbo{ctx.ssbos[binding.U32()]};
103 const Id index{StorageIndex(ctx, offset, sizeof(u32))};
104 const Id pointer{ctx.OpAccessChain(ctx.storage_u32, ssbo, ctx.u32_zero_value, index)};
105 return ctx.OpLoad(ctx.U32[1], pointer);
87} 106}
88 107
89void EmitSPIRV::EmitLoadStorage64(EmitContext&) { 108void EmitSPIRV::EmitLoadStorage64(EmitContext&) {
@@ -110,8 +129,15 @@ void EmitSPIRV::EmitWriteStorageS16(EmitContext&) {
110 throw NotImplementedException("SPIR-V Instruction"); 129 throw NotImplementedException("SPIR-V Instruction");
111} 130}
112 131
113void EmitSPIRV::EmitWriteStorage32(EmitContext& ctx) { 132void EmitSPIRV::EmitWriteStorage32(EmitContext& ctx, const IR::Value& binding,
114 ctx.Name(ctx.OpUndef(ctx.u32[1]), "unimplemented_sbuf_store"); 133 const IR::Value& offset, Id value) {
134 if (!binding.IsImmediate()) {
135 throw NotImplementedException("Dynamic storage buffer indexing");
136 }
137 const Id ssbo{ctx.ssbos[binding.U32()]};
138 const Id index{StorageIndex(ctx, offset, sizeof(u32))};
139 const Id pointer{ctx.OpAccessChain(ctx.storage_u32, ssbo, ctx.u32_zero_value, index)};
140 ctx.OpStore(pointer, value);
115} 141}
116 142
117void EmitSPIRV::EmitWriteStorage64(EmitContext&) { 143void EmitSPIRV::EmitWriteStorage64(EmitContext&) {
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp b/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp
index a6f542360..c1ed8f281 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv_undefined.cpp
@@ -7,7 +7,7 @@
7namespace Shader::Backend::SPIRV { 7namespace Shader::Backend::SPIRV {
8 8
9Id EmitSPIRV::EmitUndefU1(EmitContext& ctx) { 9Id EmitSPIRV::EmitUndefU1(EmitContext& ctx) {
10 return ctx.OpUndef(ctx.u1); 10 return ctx.OpUndef(ctx.U1);
11} 11}
12 12
13Id EmitSPIRV::EmitUndefU8(EmitContext&) { 13Id EmitSPIRV::EmitUndefU8(EmitContext&) {
@@ -19,7 +19,7 @@ Id EmitSPIRV::EmitUndefU16(EmitContext&) {
19} 19}
20 20
21Id EmitSPIRV::EmitUndefU32(EmitContext& ctx) { 21Id EmitSPIRV::EmitUndefU32(EmitContext& ctx) {
22 return ctx.OpUndef(ctx.u32[1]); 22 return ctx.OpUndef(ctx.U32[1]);
23} 23}
24 24
25Id EmitSPIRV::EmitUndefU64(EmitContext&) { 25Id EmitSPIRV::EmitUndefU64(EmitContext&) {
diff --git a/src/shader_recompiler/frontend/ir/basic_block.h b/src/shader_recompiler/frontend/ir/basic_block.h
index 778b32e43..b14a35ec5 100644
--- a/src/shader_recompiler/frontend/ir/basic_block.h
+++ b/src/shader_recompiler/frontend/ir/basic_block.h
@@ -11,6 +11,7 @@
11 11
12#include <boost/intrusive/list.hpp> 12#include <boost/intrusive/list.hpp>
13 13
14#include "common/bit_cast.h"
14#include "shader_recompiler/frontend/ir/condition.h" 15#include "shader_recompiler/frontend/ir/condition.h"
15#include "shader_recompiler/frontend/ir/microinstruction.h" 16#include "shader_recompiler/frontend/ir/microinstruction.h"
16#include "shader_recompiler/frontend/ir/value.h" 17#include "shader_recompiler/frontend/ir/value.h"
@@ -68,6 +69,18 @@ public:
68 /// Gets an immutable span to the immediate predecessors. 69 /// Gets an immutable span to the immediate predecessors.
69 [[nodiscard]] std::span<Block* const> ImmediatePredecessors() const noexcept; 70 [[nodiscard]] std::span<Block* const> ImmediatePredecessors() const noexcept;
70 71
72 /// Intrusively store the host definition of this instruction.
73 template <typename DefinitionType>
74 void SetDefinition(DefinitionType def) {
75 definition = Common::BitCast<u32>(def);
76 }
77
78 /// Return the intrusively stored host definition of this instruction.
79 template <typename DefinitionType>
80 [[nodiscard]] DefinitionType Definition() const noexcept {
81 return Common::BitCast<DefinitionType>(definition);
82 }
83
71 [[nodiscard]] Condition BranchCondition() const noexcept { 84 [[nodiscard]] Condition BranchCondition() const noexcept {
72 return branch_cond; 85 return branch_cond;
73 } 86 }
@@ -161,6 +174,9 @@ private:
161 Block* branch_false{nullptr}; 174 Block* branch_false{nullptr};
162 /// Block immediate predecessors 175 /// Block immediate predecessors
163 std::vector<Block*> imm_predecessors; 176 std::vector<Block*> imm_predecessors;
177
178 /// Intrusively stored host definition of this block.
179 u32 definition{};
164}; 180};
165 181
166using BlockList = std::vector<Block*>; 182using BlockList = std::vector<Block*>;
diff --git a/src/shader_recompiler/frontend/ir/program.h b/src/shader_recompiler/frontend/ir/program.h
index efaf1aa1e..98aab2dc6 100644
--- a/src/shader_recompiler/frontend/ir/program.h
+++ b/src/shader_recompiler/frontend/ir/program.h
@@ -9,11 +9,13 @@
9#include <boost/container/small_vector.hpp> 9#include <boost/container/small_vector.hpp>
10 10
11#include "shader_recompiler/frontend/ir/function.h" 11#include "shader_recompiler/frontend/ir/function.h"
12#include "shader_recompiler/shader_info.h"
12 13
13namespace Shader::IR { 14namespace Shader::IR {
14 15
15struct Program { 16struct Program {
16 boost::container::small_vector<Function, 1> functions; 17 boost::container::small_vector<Function, 1> functions;
18 Info info;
17}; 19};
18 20
19[[nodiscard]] std::string DumpProgram(const Program& program); 21[[nodiscard]] std::string DumpProgram(const Program& program);
diff --git a/src/shader_recompiler/frontend/maxwell/program.cpp b/src/shader_recompiler/frontend/maxwell/program.cpp
index dab6d68c0..8331d576c 100644
--- a/src/shader_recompiler/frontend/maxwell/program.cpp
+++ b/src/shader_recompiler/frontend/maxwell/program.cpp
@@ -53,21 +53,22 @@ IR::Program TranslateProgram(ObjectPool<IR::Inst>& inst_pool, ObjectPool<IR::Blo
53 for (Flow::Function& cfg_function : cfg.Functions()) { 53 for (Flow::Function& cfg_function : cfg.Functions()) {
54 functions.push_back(IR::Function{ 54 functions.push_back(IR::Function{
55 .blocks{TranslateCode(inst_pool, block_pool, env, cfg_function)}, 55 .blocks{TranslateCode(inst_pool, block_pool, env, cfg_function)},
56 .post_order_blocks{},
56 }); 57 });
57 } 58 }
58
59 fmt::print(stdout, "No optimizations: {}", IR::DumpProgram(program));
60 for (IR::Function& function : functions) { 59 for (IR::Function& function : functions) {
61 function.post_order_blocks = PostOrder(function.blocks); 60 function.post_order_blocks = PostOrder(function.blocks);
62 Optimization::SsaRewritePass(function.post_order_blocks); 61 Optimization::SsaRewritePass(function.post_order_blocks);
63 } 62 }
63 fmt::print(stdout, "{}\n", IR::DumpProgram(program));
64 Optimization::GlobalMemoryToStorageBufferPass(program);
64 for (IR::Function& function : functions) { 65 for (IR::Function& function : functions) {
65 Optimization::PostOrderInvoke(Optimization::GlobalMemoryToStorageBufferPass, function);
66 Optimization::PostOrderInvoke(Optimization::ConstantPropagationPass, function); 66 Optimization::PostOrderInvoke(Optimization::ConstantPropagationPass, function);
67 Optimization::PostOrderInvoke(Optimization::DeadCodeEliminationPass, function); 67 Optimization::PostOrderInvoke(Optimization::DeadCodeEliminationPass, function);
68 Optimization::IdentityRemovalPass(function); 68 Optimization::IdentityRemovalPass(function);
69 Optimization::VerificationPass(function); 69 Optimization::VerificationPass(function);
70 } 70 }
71 Optimization::CollectShaderInfoPass(program);
71 //*/ 72 //*/
72 return program; 73 return program;
73} 74}
diff --git a/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
new file mode 100644
index 000000000..f2326dea1
--- /dev/null
+++ b/src/shader_recompiler/ir_opt/collect_shader_info_pass.cpp
@@ -0,0 +1,81 @@
1// Copyright 2021 yuzu Emulator Project
2// Licensed under GPLv2 or any later version
3// Refer to the license.txt file included.
4
5#include "shader_recompiler/frontend/ir/program.h"
6#include "shader_recompiler/shader_info.h"
7
8namespace Shader::Optimization {
9namespace {
10void AddConstantBufferDescriptor(Info& info, u32 index) {
11 auto& descriptor{info.constant_buffers.at(index)};
12 if (descriptor) {
13 return;
14 }
15 descriptor = &info.constant_buffer_descriptors.emplace_back(Info::ConstantBufferDescriptor{
16 .index{index},
17 .count{1},
18 });
19}
20
21void Visit(Info& info, IR::Inst& inst) {
22 switch (inst.Opcode()) {
23 case IR::Opcode::WorkgroupId:
24 info.uses_workgroup_id = true;
25 break;
26 case IR::Opcode::LocalInvocationId:
27 info.uses_local_invocation_id = true;
28 break;
29 case IR::Opcode::FPAbs16:
30 case IR::Opcode::FPAdd16:
31 case IR::Opcode::FPCeil16:
32 case IR::Opcode::FPFloor16:
33 case IR::Opcode::FPFma16:
34 case IR::Opcode::FPMul16:
35 case IR::Opcode::FPNeg16:
36 case IR::Opcode::FPRoundEven16:
37 case IR::Opcode::FPSaturate16:
38 case IR::Opcode::FPTrunc16:
39 info.uses_fp16;
40 break;
41 case IR::Opcode::FPAbs64:
42 case IR::Opcode::FPAdd64:
43 case IR::Opcode::FPCeil64:
44 case IR::Opcode::FPFloor64:
45 case IR::Opcode::FPFma64:
46 case IR::Opcode::FPMax64:
47 case IR::Opcode::FPMin64:
48 case IR::Opcode::FPMul64:
49 case IR::Opcode::FPNeg64:
50 case IR::Opcode::FPRecip64:
51 case IR::Opcode::FPRecipSqrt64:
52 case IR::Opcode::FPRoundEven64:
53 case IR::Opcode::FPSaturate64:
54 case IR::Opcode::FPTrunc64:
55 info.uses_fp64 = true;
56 break;
57 case IR::Opcode::GetCbuf:
58 if (const IR::Value index{inst.Arg(0)}; index.IsImmediate()) {
59 AddConstantBufferDescriptor(info, index.U32());
60 } else {
61 throw NotImplementedException("Constant buffer with non-immediate index");
62 }
63 break;
64 default:
65 break;
66 }
67}
68} // Anonymous namespace
69
70void CollectShaderInfoPass(IR::Program& program) {
71 Info& info{program.info};
72 for (IR::Function& function : program.functions) {
73 for (IR::Block* const block : function.post_order_blocks) {
74 for (IR::Inst& inst : block->Instructions()) {
75 Visit(info, inst);
76 }
77 }
78 }
79}
80
81} // namespace Shader::Optimization
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index cbde65b9b..f1ad16d60 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -77,6 +77,16 @@ bool FoldCommutative(IR::Inst& inst, ImmFn&& imm_fn) {
77 return true; 77 return true;
78} 78}
79 79
80template <typename Func>
81bool FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
82 if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) {
83 return false;
84 }
85 using Indices = std::make_index_sequence<LambdaTraits<decltype(func)>::NUM_ARGS>;
86 inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{}));
87 return true;
88}
89
80void FoldGetRegister(IR::Inst& inst) { 90void FoldGetRegister(IR::Inst& inst) {
81 if (inst.Arg(0).Reg() == IR::Reg::RZ) { 91 if (inst.Arg(0).Reg() == IR::Reg::RZ) {
82 inst.ReplaceUsesWith(IR::Value{u32{0}}); 92 inst.ReplaceUsesWith(IR::Value{u32{0}});
@@ -103,6 +113,52 @@ void FoldAdd(IR::Inst& inst) {
103 } 113 }
104} 114}
105 115
116void FoldISub32(IR::Inst& inst) {
117 if (FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a - b; })) {
118 return;
119 }
120 if (inst.Arg(0).IsImmediate() || inst.Arg(1).IsImmediate()) {
121 return;
122 }
123 // ISub32 is generally used to subtract two constant buffers, compare and replace this with
124 // zero if they equal.
125 const auto equal_cbuf{[](IR::Inst* a, IR::Inst* b) {
126 return a->Opcode() == IR::Opcode::GetCbuf && b->Opcode() == IR::Opcode::GetCbuf &&
127 a->Arg(0) == b->Arg(0) && a->Arg(1) == b->Arg(1);
128 }};
129 IR::Inst* op_a{inst.Arg(0).InstRecursive()};
130 IR::Inst* op_b{inst.Arg(1).InstRecursive()};
131 if (equal_cbuf(op_a, op_b)) {
132 inst.ReplaceUsesWith(IR::Value{u32{0}});
133 return;
134 }
135 // It's also possible a value is being added to a cbuf and then subtracted
136 if (op_b->Opcode() == IR::Opcode::IAdd32) {
137 // Canonicalize local variables to simplify the following logic
138 std::swap(op_a, op_b);
139 }
140 if (op_b->Opcode() != IR::Opcode::GetCbuf) {
141 return;
142 }
143 IR::Inst* const inst_cbuf{op_b};
144 if (op_a->Opcode() != IR::Opcode::IAdd32) {
145 return;
146 }
147 IR::Value add_op_a{op_a->Arg(0)};
148 IR::Value add_op_b{op_a->Arg(1)};
149 if (add_op_b.IsImmediate()) {
150 // Canonicalize
151 std::swap(add_op_a, add_op_b);
152 }
153 if (add_op_b.IsImmediate()) {
154 return;
155 }
156 IR::Inst* const add_cbuf{add_op_b.InstRecursive()};
157 if (equal_cbuf(add_cbuf, inst_cbuf)) {
158 inst.ReplaceUsesWith(add_op_a);
159 }
160}
161
106template <typename T> 162template <typename T>
107void FoldSelect(IR::Inst& inst) { 163void FoldSelect(IR::Inst& inst) {
108 const IR::Value cond{inst.Arg(0)}; 164 const IR::Value cond{inst.Arg(0)};
@@ -170,15 +226,6 @@ IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<
170 return IR::Value{func(Arg<Traits::ArgType<I>>(inst.Arg(I))...)}; 226 return IR::Value{func(Arg<Traits::ArgType<I>>(inst.Arg(I))...)};
171} 227}
172 228
173template <typename Func>
174void FoldWhenAllImmediates(IR::Inst& inst, Func&& func) {
175 if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) {
176 return;
177 }
178 using Indices = std::make_index_sequence<LambdaTraits<decltype(func)>::NUM_ARGS>;
179 inst.ReplaceUsesWith(EvalImmediates(inst, func, Indices{}));
180}
181
182void FoldBranchConditional(IR::Inst& inst) { 229void FoldBranchConditional(IR::Inst& inst) {
183 const IR::U1 cond{inst.Arg(0)}; 230 const IR::U1 cond{inst.Arg(0)};
184 if (cond.IsImmediate()) { 231 if (cond.IsImmediate()) {
@@ -205,6 +252,8 @@ void ConstantPropagation(IR::Inst& inst) {
205 return FoldGetPred(inst); 252 return FoldGetPred(inst);
206 case IR::Opcode::IAdd32: 253 case IR::Opcode::IAdd32:
207 return FoldAdd<u32>(inst); 254 return FoldAdd<u32>(inst);
255 case IR::Opcode::ISub32:
256 return FoldISub32(inst);
208 case IR::Opcode::BitCastF32U32: 257 case IR::Opcode::BitCastF32U32:
209 return FoldBitCast<f32, u32>(inst, IR::Opcode::BitCastU32F32); 258 return FoldBitCast<f32, u32>(inst, IR::Opcode::BitCastU32F32);
210 case IR::Opcode::BitCastU32F32: 259 case IR::Opcode::BitCastU32F32:
@@ -220,17 +269,20 @@ void ConstantPropagation(IR::Inst& inst) {
220 case IR::Opcode::LogicalNot: 269 case IR::Opcode::LogicalNot:
221 return FoldLogicalNot(inst); 270 return FoldLogicalNot(inst);
222 case IR::Opcode::SLessThan: 271 case IR::Opcode::SLessThan:
223 return FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; }); 272 FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a < b; });
273 return;
224 case IR::Opcode::ULessThan: 274 case IR::Opcode::ULessThan:
225 return FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); 275 FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; });
276 return;
226 case IR::Opcode::BitFieldUExtract: 277 case IR::Opcode::BitFieldUExtract:
227 return FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) { 278 FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) {
228 if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) { 279 if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) {
229 throw LogicError("Undefined result in {}({}, {}, {})", IR::Opcode::BitFieldUExtract, 280 throw LogicError("Undefined result in {}({}, {}, {})", IR::Opcode::BitFieldUExtract,
230 base, shift, count); 281 base, shift, count);
231 } 282 }
232 return (base >> shift) & ((1U << count) - 1); 283 return (base >> shift) & ((1U << count) - 1);
233 }); 284 });
285 return;
234 case IR::Opcode::BranchConditional: 286 case IR::Opcode::BranchConditional:
235 return FoldBranchConditional(inst); 287 return FoldBranchConditional(inst);
236 default: 288 default:
diff --git a/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp b/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp
index b40c0c57b..bf230a850 100644
--- a/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp
+++ b/src/shader_recompiler/ir_opt/global_memory_to_storage_buffer_pass.cpp
@@ -28,7 +28,8 @@ struct StorageBufferAddr {
28/// Block iterator to a global memory instruction and the storage buffer it uses 28/// Block iterator to a global memory instruction and the storage buffer it uses
29struct StorageInst { 29struct StorageInst {
30 StorageBufferAddr storage_buffer; 30 StorageBufferAddr storage_buffer;
31 IR::Block::iterator inst; 31 IR::Inst* inst;
32 IR::Block* block;
32}; 33};
33 34
34/// Bias towards a certain range of constant buffers when looking for storage buffers 35/// Bias towards a certain range of constant buffers when looking for storage buffers
@@ -41,7 +42,7 @@ struct Bias {
41using StorageBufferSet = 42using StorageBufferSet =
42 boost::container::flat_set<StorageBufferAddr, std::less<StorageBufferAddr>, 43 boost::container::flat_set<StorageBufferAddr, std::less<StorageBufferAddr>,
43 boost::container::small_vector<StorageBufferAddr, 16>>; 44 boost::container::small_vector<StorageBufferAddr, 16>>;
44using StorageInstVector = boost::container::small_vector<StorageInst, 32>; 45using StorageInstVector = boost::container::small_vector<StorageInst, 24>;
45 46
46/// Returns true when the instruction is a global memory instruction 47/// Returns true when the instruction is a global memory instruction
47bool IsGlobalMemory(const IR::Inst& inst) { 48bool IsGlobalMemory(const IR::Inst& inst) {
@@ -109,23 +110,22 @@ bool MeetsBias(const StorageBufferAddr& storage_buffer, const Bias& bias) noexce
109} 110}
110 111
111/// Discards a global memory operation, reads return zero and writes are ignored 112/// Discards a global memory operation, reads return zero and writes are ignored
112void DiscardGlobalMemory(IR::Block& block, IR::Block::iterator inst) { 113void DiscardGlobalMemory(IR::Block& block, IR::Inst& inst) {
114 IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
113 const IR::Value zero{u32{0}}; 115 const IR::Value zero{u32{0}};
114 switch (inst->Opcode()) { 116 switch (inst.Opcode()) {
115 case IR::Opcode::LoadGlobalS8: 117 case IR::Opcode::LoadGlobalS8:
116 case IR::Opcode::LoadGlobalU8: 118 case IR::Opcode::LoadGlobalU8:
117 case IR::Opcode::LoadGlobalS16: 119 case IR::Opcode::LoadGlobalS16:
118 case IR::Opcode::LoadGlobalU16: 120 case IR::Opcode::LoadGlobalU16:
119 case IR::Opcode::LoadGlobal32: 121 case IR::Opcode::LoadGlobal32:
120 inst->ReplaceUsesWith(zero); 122 inst.ReplaceUsesWith(zero);
121 break; 123 break;
122 case IR::Opcode::LoadGlobal64: 124 case IR::Opcode::LoadGlobal64:
123 inst->ReplaceUsesWith(IR::Value{ 125 inst.ReplaceUsesWith(IR::Value{ir.CompositeConstruct(zero, zero)});
124 &*block.PrependNewInst(inst, IR::Opcode::CompositeConstructU32x2, {zero, zero})});
125 break; 126 break;
126 case IR::Opcode::LoadGlobal128: 127 case IR::Opcode::LoadGlobal128:
127 inst->ReplaceUsesWith(IR::Value{&*block.PrependNewInst( 128 inst.ReplaceUsesWith(IR::Value{ir.CompositeConstruct(zero, zero, zero, zero)});
128 inst, IR::Opcode::CompositeConstructU32x4, {zero, zero, zero, zero})});
129 break; 129 break;
130 case IR::Opcode::WriteGlobalS8: 130 case IR::Opcode::WriteGlobalS8:
131 case IR::Opcode::WriteGlobalU8: 131 case IR::Opcode::WriteGlobalU8:
@@ -134,11 +134,10 @@ void DiscardGlobalMemory(IR::Block& block, IR::Block::iterator inst) {
134 case IR::Opcode::WriteGlobal32: 134 case IR::Opcode::WriteGlobal32:
135 case IR::Opcode::WriteGlobal64: 135 case IR::Opcode::WriteGlobal64:
136 case IR::Opcode::WriteGlobal128: 136 case IR::Opcode::WriteGlobal128:
137 inst->Invalidate(); 137 inst.Invalidate();
138 break; 138 break;
139 default: 139 default:
140 throw LogicError("Invalid opcode to discard its global memory operation {}", 140 throw LogicError("Invalid opcode to discard its global memory operation {}", inst.Opcode());
141 inst->Opcode());
142 } 141 }
143} 142}
144 143
@@ -232,8 +231,8 @@ std::optional<StorageBufferAddr> Track(const IR::Value& value, const Bias* bias)
232} 231}
233 232
234/// Collects the storage buffer used by a global memory instruction and the instruction itself 233/// Collects the storage buffer used by a global memory instruction and the instruction itself
235void CollectStorageBuffers(IR::Block& block, IR::Block::iterator inst, 234void CollectStorageBuffers(IR::Block& block, IR::Inst& inst, StorageBufferSet& storage_buffer_set,
236 StorageBufferSet& storage_buffer_set, StorageInstVector& to_replace) { 235 StorageInstVector& to_replace) {
237 // NVN puts storage buffers in a specific range, we have to bias towards these addresses to 236 // NVN puts storage buffers in a specific range, we have to bias towards these addresses to
238 // avoid getting false positives 237 // avoid getting false positives
239 static constexpr Bias nvn_bias{ 238 static constexpr Bias nvn_bias{
@@ -241,19 +240,13 @@ void CollectStorageBuffers(IR::Block& block, IR::Block::iterator inst,
241 .offset_begin{0x110}, 240 .offset_begin{0x110},
242 .offset_end{0x610}, 241 .offset_end{0x610},
243 }; 242 };
244 // First try to find storage buffers in the NVN address
245 const IR::U64 addr{inst->Arg(0)};
246 if (addr.IsImmediate()) {
247 // Immediate addresses can't be lowered to a storage buffer
248 DiscardGlobalMemory(block, inst);
249 return;
250 }
251 // Track the low address of the instruction 243 // Track the low address of the instruction
252 const std::optional<LowAddrInfo> low_addr_info{TrackLowAddress(addr.InstRecursive())}; 244 const std::optional<LowAddrInfo> low_addr_info{TrackLowAddress(&inst)};
253 if (!low_addr_info) { 245 if (!low_addr_info) {
254 DiscardGlobalMemory(block, inst); 246 DiscardGlobalMemory(block, inst);
255 return; 247 return;
256 } 248 }
249 // First try to find storage buffers in the NVN address
257 const IR::U32 low_addr{low_addr_info->value}; 250 const IR::U32 low_addr{low_addr_info->value};
258 std::optional<StorageBufferAddr> storage_buffer{Track(low_addr, &nvn_bias)}; 251 std::optional<StorageBufferAddr> storage_buffer{Track(low_addr, &nvn_bias)};
259 if (!storage_buffer) { 252 if (!storage_buffer) {
@@ -269,21 +262,22 @@ void CollectStorageBuffers(IR::Block& block, IR::Block::iterator inst,
269 storage_buffer_set.insert(*storage_buffer); 262 storage_buffer_set.insert(*storage_buffer);
270 to_replace.push_back(StorageInst{ 263 to_replace.push_back(StorageInst{
271 .storage_buffer{*storage_buffer}, 264 .storage_buffer{*storage_buffer},
272 .inst{inst}, 265 .inst{&inst},
266 .block{&block},
273 }); 267 });
274} 268}
275 269
276/// Returns the offset in indices (not bytes) for an equivalent storage instruction 270/// Returns the offset in indices (not bytes) for an equivalent storage instruction
277IR::U32 StorageOffset(IR::Block& block, IR::Block::iterator inst, StorageBufferAddr buffer) { 271IR::U32 StorageOffset(IR::Block& block, IR::Inst& inst, StorageBufferAddr buffer) {
278 IR::IREmitter ir{block, inst}; 272 IR::IREmitter ir{block, IR::Block::InstructionList::s_iterator_to(inst)};
279 IR::U32 offset; 273 IR::U32 offset;
280 if (const std::optional<LowAddrInfo> low_addr{TrackLowAddress(&*inst)}) { 274 if (const std::optional<LowAddrInfo> low_addr{TrackLowAddress(&inst)}) {
281 offset = low_addr->value; 275 offset = low_addr->value;
282 if (low_addr->imm_offset != 0) { 276 if (low_addr->imm_offset != 0) {
283 offset = ir.IAdd(offset, ir.Imm32(low_addr->imm_offset)); 277 offset = ir.IAdd(offset, ir.Imm32(low_addr->imm_offset));
284 } 278 }
285 } else { 279 } else {
286 offset = ir.ConvertU(32, IR::U64{inst->Arg(0)}); 280 offset = ir.ConvertU(32, IR::U64{inst.Arg(0)});
287 } 281 }
288 // Subtract the least significant 32 bits from the guest offset. The result is the storage 282 // Subtract the least significant 32 bits from the guest offset. The result is the storage
289 // buffer offset in bytes. 283 // buffer offset in bytes.
@@ -292,25 +286,27 @@ IR::U32 StorageOffset(IR::Block& block, IR::Block::iterator inst, StorageBufferA
292} 286}
293 287
294/// Replace a global memory load instruction with its storage buffer equivalent 288/// Replace a global memory load instruction with its storage buffer equivalent
295void ReplaceLoad(IR::Block& block, IR::Block::iterator inst, const IR::U32& storage_index, 289void ReplaceLoad(IR::Block& block, IR::Inst& inst, const IR::U32& storage_index,
296 const IR::U32& offset) { 290 const IR::U32& offset) {
297 const IR::Opcode new_opcode{GlobalToStorage(inst->Opcode())}; 291 const IR::Opcode new_opcode{GlobalToStorage(inst.Opcode())};
298 const IR::Value value{&*block.PrependNewInst(inst, new_opcode, {storage_index, offset})}; 292 const auto it{IR::Block::InstructionList::s_iterator_to(inst)};
299 inst->ReplaceUsesWith(value); 293 const IR::Value value{&*block.PrependNewInst(it, new_opcode, {storage_index, offset})};
294 inst.ReplaceUsesWith(value);
300} 295}
301 296
302/// Replace a global memory write instruction with its storage buffer equivalent 297/// Replace a global memory write instruction with its storage buffer equivalent
303void ReplaceWrite(IR::Block& block, IR::Block::iterator inst, const IR::U32& storage_index, 298void ReplaceWrite(IR::Block& block, IR::Inst& inst, const IR::U32& storage_index,
304 const IR::U32& offset) { 299 const IR::U32& offset) {
305 const IR::Opcode new_opcode{GlobalToStorage(inst->Opcode())}; 300 const IR::Opcode new_opcode{GlobalToStorage(inst.Opcode())};
306 block.PrependNewInst(inst, new_opcode, {storage_index, offset, inst->Arg(1)}); 301 const auto it{IR::Block::InstructionList::s_iterator_to(inst)};
307 inst->Invalidate(); 302 block.PrependNewInst(it, new_opcode, {storage_index, offset, inst.Arg(1)});
303 inst.Invalidate();
308} 304}
309 305
310/// Replace a global memory instruction with its storage buffer equivalent 306/// Replace a global memory instruction with its storage buffer equivalent
311void Replace(IR::Block& block, IR::Block::iterator inst, const IR::U32& storage_index, 307void Replace(IR::Block& block, IR::Inst& inst, const IR::U32& storage_index,
312 const IR::U32& offset) { 308 const IR::U32& offset) {
313 switch (inst->Opcode()) { 309 switch (inst.Opcode()) {
314 case IR::Opcode::LoadGlobalS8: 310 case IR::Opcode::LoadGlobalS8:
315 case IR::Opcode::LoadGlobalU8: 311 case IR::Opcode::LoadGlobalU8:
316 case IR::Opcode::LoadGlobalS16: 312 case IR::Opcode::LoadGlobalS16:
@@ -328,26 +324,44 @@ void Replace(IR::Block& block, IR::Block::iterator inst, const IR::U32& storage_
328 case IR::Opcode::WriteGlobal128: 324 case IR::Opcode::WriteGlobal128:
329 return ReplaceWrite(block, inst, storage_index, offset); 325 return ReplaceWrite(block, inst, storage_index, offset);
330 default: 326 default:
331 throw InvalidArgument("Invalid global memory opcode {}", inst->Opcode()); 327 throw InvalidArgument("Invalid global memory opcode {}", inst.Opcode());
332 } 328 }
333} 329}
334} // Anonymous namespace 330} // Anonymous namespace
335 331
336void GlobalMemoryToStorageBufferPass(IR::Block& block) { 332void GlobalMemoryToStorageBufferPass(IR::Program& program) {
337 StorageBufferSet storage_buffers; 333 StorageBufferSet storage_buffers;
338 StorageInstVector to_replace; 334 StorageInstVector to_replace;
339 335
340 for (IR::Block::iterator inst{block.begin()}; inst != block.end(); ++inst) { 336 for (IR::Function& function : program.functions) {
341 if (!IsGlobalMemory(*inst)) { 337 for (IR::Block* const block : function.post_order_blocks) {
342 continue; 338 for (IR::Inst& inst : block->Instructions()) {
339 if (!IsGlobalMemory(inst)) {
340 continue;
341 }
342 CollectStorageBuffers(*block, inst, storage_buffers, to_replace);
343 }
343 } 344 }
344 CollectStorageBuffers(block, inst, storage_buffers, to_replace);
345 } 345 }
346 for (const auto [storage_buffer, inst] : to_replace) { 346 Info& info{program.info};
347 const auto it{storage_buffers.find(storage_buffer)}; 347 u32 storage_index{};
348 const IR::U32 storage_index{IR::Value{static_cast<u32>(storage_buffers.index_of(it))}}; 348 for (const StorageBufferAddr& storage_buffer : storage_buffers) {
349 const IR::U32 offset{StorageOffset(block, inst, storage_buffer)}; 349 info.storage_buffers_descriptors.push_back({
350 Replace(block, inst, storage_index, offset); 350 .cbuf_index{storage_buffer.index},
351 .cbuf_offset{storage_buffer.offset},
352 .count{1},
353 });
354 info.storage_buffers[storage_index] = &info.storage_buffers_descriptors.back();
355 ++storage_index;
356 }
357 for (const StorageInst& storage_inst : to_replace) {
358 const StorageBufferAddr storage_buffer{storage_inst.storage_buffer};
359 const auto it{storage_buffers.find(storage_inst.storage_buffer)};
360 const IR::U32 index{IR::Value{static_cast<u32>(storage_buffers.index_of(it))}};
361 IR::Block* const block{storage_inst.block};
362 IR::Inst* const inst{storage_inst.inst};
363 const IR::U32 offset{StorageOffset(*block, *inst, storage_buffer)};
364 Replace(*block, *inst, index, offset);
351 } 365 }
352} 366}
353 367
diff --git a/src/shader_recompiler/ir_opt/passes.h b/src/shader_recompiler/ir_opt/passes.h
index 30eb31588..89e5811d3 100644
--- a/src/shader_recompiler/ir_opt/passes.h
+++ b/src/shader_recompiler/ir_opt/passes.h
@@ -8,6 +8,7 @@
8 8
9#include "shader_recompiler/frontend/ir/basic_block.h" 9#include "shader_recompiler/frontend/ir/basic_block.h"
10#include "shader_recompiler/frontend/ir/function.h" 10#include "shader_recompiler/frontend/ir/function.h"
11#include "shader_recompiler/frontend/ir/program.h"
11 12
12namespace Shader::Optimization { 13namespace Shader::Optimization {
13 14
@@ -18,9 +19,10 @@ void PostOrderInvoke(Func&& func, IR::Function& function) {
18 } 19 }
19} 20}
20 21
22void CollectShaderInfoPass(IR::Program& program);
21void ConstantPropagationPass(IR::Block& block); 23void ConstantPropagationPass(IR::Block& block);
22void DeadCodeEliminationPass(IR::Block& block); 24void DeadCodeEliminationPass(IR::Block& block);
23void GlobalMemoryToStorageBufferPass(IR::Block& block); 25void GlobalMemoryToStorageBufferPass(IR::Program& program);
24void IdentityRemovalPass(IR::Function& function); 26void IdentityRemovalPass(IR::Function& function);
25void SsaRewritePass(std::span<IR::Block* const> post_order_blocks); 27void SsaRewritePass(std::span<IR::Block* const> post_order_blocks);
26void VerificationPass(const IR::Function& function); 28void VerificationPass(const IR::Function& function);
diff --git a/src/shader_recompiler/main.cpp b/src/shader_recompiler/main.cpp
index 216345e91..1610bb34e 100644
--- a/src/shader_recompiler/main.cpp
+++ b/src/shader_recompiler/main.cpp
@@ -67,8 +67,8 @@ int main() {
67 ObjectPool<IR::Inst> inst_pool; 67 ObjectPool<IR::Inst> inst_pool;
68 ObjectPool<IR::Block> block_pool; 68 ObjectPool<IR::Block> block_pool;
69 69
70 // FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"}; 70 FileEnvironment env{"D:\\Shaders\\Database\\Oninaki\\CS8F146B41DB6BD826.bin"};
71 FileEnvironment env{"D:\\Shaders\\shader.bin"}; 71 // FileEnvironment env{"D:\\Shaders\\shader.bin"};
72 block_pool.ReleaseContents(); 72 block_pool.ReleaseContents();
73 inst_pool.ReleaseContents(); 73 inst_pool.ReleaseContents();
74 flow_block_pool.ReleaseContents(); 74 flow_block_pool.ReleaseContents();
diff --git a/src/shader_recompiler/shader_info.h b/src/shader_recompiler/shader_info.h
index 1760bf4a9..f49a79368 100644
--- a/src/shader_recompiler/shader_info.h
+++ b/src/shader_recompiler/shader_info.h
@@ -6,23 +6,40 @@
6 6
7#include <array> 7#include <array>
8 8
9#include "common/common_types.h"
10
9#include <boost/container/static_vector.hpp> 11#include <boost/container/static_vector.hpp>
10 12
11namespace Shader { 13namespace Shader {
12 14
13struct Info { 15struct Info {
14 struct ConstantBuffer { 16 static constexpr size_t MAX_CBUFS{18};
17 static constexpr size_t MAX_SSBOS{16};
18
19 struct ConstantBufferDescriptor {
20 u32 index;
21 u32 count;
22 };
15 23
24 struct StorageBufferDescriptor {
25 u32 cbuf_index;
26 u32 cbuf_offset;
27 u32 count;
16 }; 28 };
17 29
18 struct { 30 bool uses_workgroup_id{};
19 bool workgroup_id{}; 31 bool uses_local_invocation_id{};
20 bool local_invocation_id{}; 32 bool uses_fp16{};
21 bool fp16{}; 33 bool uses_fp64{};
22 bool fp64{}; 34
23 } uses; 35 u32 constant_buffer_mask{};
36
37 std::array<ConstantBufferDescriptor*, MAX_CBUFS> constant_buffers{};
38 boost::container::static_vector<ConstantBufferDescriptor, MAX_CBUFS>
39 constant_buffer_descriptors;
24 40
25 std::array<18 41 std::array<StorageBufferDescriptor*, MAX_SSBOS> storage_buffers{};
42 boost::container::static_vector<StorageBufferDescriptor, MAX_SSBOS> storage_buffers_descriptors;
26}; 43};
27 44
28} // namespace Shader 45} // namespace Shader