summaryrefslogtreecommitdiff
path: root/src/shader_recompiler/backend/spirv/emit_spirv.cpp
diff options
context:
space:
mode:
authorGravatar ReinUsesLisp2021-02-16 04:10:22 -0300
committerGravatar ameerj2021-07-22 21:51:22 -0400
commitb5d7279d878211654b4abb165d94af763a365f47 (patch)
tree9b3a7b6e9d7d2b8945fe87d27ff75f1712ef06aa /src/shader_recompiler/backend/spirv/emit_spirv.cpp
parentshader: Improve object pool (diff)
downloadyuzu-b5d7279d878211654b4abb165d94af763a365f47.tar.gz
yuzu-b5d7279d878211654b4abb165d94af763a365f47.tar.xz
yuzu-b5d7279d878211654b4abb165d94af763a365f47.zip
spirv: Initial bindings support
Diffstat (limited to 'src/shader_recompiler/backend/spirv/emit_spirv.cpp')
-rw-r--r--src/shader_recompiler/backend/spirv/emit_spirv.cpp189
1 files changed, 88 insertions, 101 deletions
diff --git a/src/shader_recompiler/backend/spirv/emit_spirv.cpp b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
index 0895414b4..c79c09774 100644
--- a/src/shader_recompiler/backend/spirv/emit_spirv.cpp
+++ b/src/shader_recompiler/backend/spirv/emit_spirv.cpp
@@ -12,31 +12,83 @@
12#include "shader_recompiler/frontend/ir/program.h" 12#include "shader_recompiler/frontend/ir/program.h"
13 13
14namespace Shader::Backend::SPIRV { 14namespace Shader::Backend::SPIRV {
15namespace {
16template <class Func>
17struct FuncTraits : FuncTraits<decltype(&Func::operator())> {};
15 18
16EmitContext::EmitContext(IR::Program& program) { 19template <class ClassType, class ReturnType_, class... Args>
17 AddCapability(spv::Capability::Shader); 20struct FuncTraits<ReturnType_ (ClassType::*)(Args...)> {
18 AddCapability(spv::Capability::Float16); 21 using ReturnType = ReturnType_;
19 AddCapability(spv::Capability::Float64);
20 void_id = TypeVoid();
21 22
22 u1 = Name(TypeBool(), "u1"); 23 static constexpr size_t NUM_ARGS = sizeof...(Args);
23 f32.Define(*this, TypeFloat(32), "f32");
24 u32.Define(*this, TypeInt(32, false), "u32");
25 f16.Define(*this, TypeFloat(16), "f16");
26 f64.Define(*this, TypeFloat(64), "f64");
27 24
28 true_value = ConstantTrue(u1); 25 template <size_t I>
29 false_value = ConstantFalse(u1); 26 using ArgType = std::tuple_element_t<I, std::tuple<Args...>>;
27};
30 28
31 for (const IR::Function& function : program.functions) { 29template <auto method, typename... Args>
32 for (IR::Block* const block : function.blocks) { 30void SetDefinition(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst, Args... args) {
33 block_label_map.emplace_back(block, OpLabel()); 31 const Id forward_id{inst->Definition<Id>()};
32 const bool has_forward_id{Sirit::ValidId(forward_id)};
33 Id current_id{};
34 if (has_forward_id) {
35 current_id = ctx.ExchangeCurrentId(forward_id);
36 }
37 const Id new_id{(emit.*method)(ctx, std::forward<Args>(args)...)};
38 if (has_forward_id) {
39 ctx.ExchangeCurrentId(current_id);
40 } else {
41 inst->SetDefinition<Id>(new_id);
42 }
43}
44
45template <typename ArgType>
46ArgType Arg(EmitContext& ctx, const IR::Value& arg) {
47 if constexpr (std::is_same_v<ArgType, Id>) {
48 return ctx.Def(arg);
49 } else if constexpr (std::is_same_v<ArgType, const IR::Value&>) {
50 return arg;
51 } else if constexpr (std::is_same_v<ArgType, u32>) {
52 return arg.U32();
53 } else if constexpr (std::is_same_v<ArgType, IR::Block*>) {
54 return arg.Label();
55 }
56}
57
58template <auto method, bool is_first_arg_inst, size_t... I>
59void Invoke(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst, std::index_sequence<I...>) {
60 using Traits = FuncTraits<decltype(method)>;
61 if constexpr (std::is_same_v<Traits::ReturnType, Id>) {
62 if constexpr (is_first_arg_inst) {
63 SetDefinition<method>(emit, ctx, inst, inst,
64 Arg<Traits::ArgType<I + 2>>(ctx, inst->Arg(I))...);
65 } else {
66 SetDefinition<method>(emit, ctx, inst,
67 Arg<Traits::ArgType<I + 1>>(ctx, inst->Arg(I))...);
68 }
69 } else {
70 if constexpr (is_first_arg_inst) {
71 (emit.*method)(ctx, inst, Arg<Traits::ArgType<I + 2>>(ctx, inst->Arg(I))...);
72 } else {
73 (emit.*method)(ctx, Arg<Traits::ArgType<I + 1>>(ctx, inst->Arg(I))...);
34 } 74 }
35 } 75 }
36 std::ranges::sort(block_label_map, {}, &std::pair<IR::Block*, Id>::first);
37} 76}
38 77
39EmitContext::~EmitContext() = default; 78template <auto method>
79void Invoke(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst) {
80 using Traits = FuncTraits<decltype(method)>;
81 static_assert(Traits::NUM_ARGS >= 1, "Insufficient arguments");
82 if constexpr (Traits::NUM_ARGS == 1) {
83 Invoke<method, false>(emit, ctx, inst, std::make_index_sequence<0>{});
84 } else {
85 using FirstArgType = typename Traits::template ArgType<1>;
86 static constexpr bool is_first_arg_inst = std::is_same_v<FirstArgType, IR::Inst*>;
87 using Indices = std::make_index_sequence<Traits::NUM_ARGS - (is_first_arg_inst ? 2 : 1)>;
88 Invoke<method, is_first_arg_inst>(emit, ctx, inst, Indices{});
89 }
90}
91} // Anonymous namespace
40 92
41EmitSPIRV::EmitSPIRV(IR::Program& program) { 93EmitSPIRV::EmitSPIRV(IR::Program& program) {
42 EmitContext ctx{program}; 94 EmitContext ctx{program};
@@ -46,74 +98,32 @@ EmitSPIRV::EmitSPIRV(IR::Program& program) {
46 for (IR::Function& function : program.functions) { 98 for (IR::Function& function : program.functions) {
47 func = ctx.OpFunction(ctx.void_id, spv::FunctionControlMask::MaskNone, void_function); 99 func = ctx.OpFunction(ctx.void_id, spv::FunctionControlMask::MaskNone, void_function);
48 for (IR::Block* const block : function.blocks) { 100 for (IR::Block* const block : function.blocks) {
49 ctx.AddLabel(ctx.BlockLabel(block)); 101 ctx.AddLabel(block->Definition<Id>());
50 for (IR::Inst& inst : block->Instructions()) { 102 for (IR::Inst& inst : block->Instructions()) {
51 EmitInst(ctx, &inst); 103 EmitInst(ctx, &inst);
52 } 104 }
53 } 105 }
54 ctx.OpFunctionEnd(); 106 ctx.OpFunctionEnd();
55 } 107 }
56 ctx.AddEntryPoint(spv::ExecutionModel::GLCompute, func, "main"); 108 boost::container::small_vector<Id, 32> interfaces;
109 if (program.info.uses_workgroup_id) {
110 interfaces.push_back(ctx.workgroup_id);
111 }
112 if (program.info.uses_local_invocation_id) {
113 interfaces.push_back(ctx.local_invocation_id);
114 }
115
116 const std::span interfaces_span(interfaces.data(), interfaces.size());
117 ctx.AddEntryPoint(spv::ExecutionModel::Fragment, func, "main", interfaces_span);
118 ctx.AddExecutionMode(func, spv::ExecutionMode::OriginUpperLeft);
57 119
58 std::vector<u32> result{ctx.Assemble()}; 120 std::vector<u32> result{ctx.Assemble()};
59 std::FILE* file{std::fopen("shader.spv", "wb")}; 121 std::FILE* file{std::fopen("D:\\shader.spv", "wb")};
60 std::fwrite(result.data(), sizeof(u32), result.size(), file); 122 std::fwrite(result.data(), sizeof(u32), result.size(), file);
61 std::fclose(file); 123 std::fclose(file);
62 std::system("spirv-dis shader.spv"); 124 std::system("spirv-dis D:\\shader.spv") == 0 &&
63 std::system("spirv-val shader.spv"); 125 std::system("spirv-val --uniform-buffer-standard-layout D:\\shader.spv") == 0 &&
64 std::system("spirv-cross shader.spv"); 126 std::system("spirv-cross -V D:\\shader.spv") == 0;
65}
66
67template <auto method, typename... Args>
68static void SetDefinition(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst, Args... args) {
69 const Id forward_id{inst->Definition<Id>()};
70 const bool has_forward_id{Sirit::ValidId(forward_id)};
71 Id current_id{};
72 if (has_forward_id) {
73 current_id = ctx.ExchangeCurrentId(forward_id);
74 }
75 const Id new_id{(emit.*method)(ctx, std::forward<Args>(args)...)};
76 if (has_forward_id) {
77 ctx.ExchangeCurrentId(current_id);
78 } else {
79 inst->SetDefinition<Id>(new_id);
80 }
81}
82
83template <auto method>
84static void Invoke(EmitSPIRV& emit, EmitContext& ctx, IR::Inst* inst) {
85 using M = decltype(method);
86 using std::is_invocable_r_v;
87 if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&>) {
88 SetDefinition<method>(emit, ctx, inst);
89 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, Id>) {
90 SetDefinition<method>(emit, ctx, inst, ctx.Def(inst->Arg(0)));
91 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, Id, Id>) {
92 SetDefinition<method>(emit, ctx, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1)));
93 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, Id, Id, Id>) {
94 SetDefinition<method>(emit, ctx, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1)),
95 ctx.Def(inst->Arg(2)));
96 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, IR::Inst*>) {
97 SetDefinition<method>(emit, ctx, inst, inst);
98 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, IR::Inst*, Id, Id>) {
99 SetDefinition<method>(emit, ctx, inst, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1)));
100 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, IR::Inst*, Id, Id, Id>) {
101 SetDefinition<method>(emit, ctx, inst, inst, ctx.Def(inst->Arg(0)), ctx.Def(inst->Arg(1)),
102 ctx.Def(inst->Arg(2)));
103 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, Id, u32>) {
104 SetDefinition<method>(emit, ctx, inst, ctx.Def(inst->Arg(0)), inst->Arg(1).U32());
105 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, const IR::Value&>) {
106 SetDefinition<method>(emit, ctx, inst, inst->Arg(0));
107 } else if constexpr (is_invocable_r_v<Id, M, EmitSPIRV&, EmitContext&, const IR::Value&,
108 const IR::Value&>) {
109 SetDefinition<method>(emit, ctx, inst, inst->Arg(0), inst->Arg(1));
110 } else if constexpr (is_invocable_r_v<void, M, EmitSPIRV&, EmitContext&, IR::Inst*>) {
111 (emit.*method)(ctx, inst);
112 } else if constexpr (is_invocable_r_v<void, M, EmitSPIRV&, EmitContext&>) {
113 (emit.*method)(ctx);
114 } else {
115 static_assert(false, "Bad format");
116 }
117} 127}
118 128
119void EmitSPIRV::EmitInst(EmitContext& ctx, IR::Inst* inst) { 129void EmitSPIRV::EmitInst(EmitContext& ctx, IR::Inst* inst) {
@@ -130,9 +140,9 @@ void EmitSPIRV::EmitInst(EmitContext& ctx, IR::Inst* inst) {
130static Id TypeId(const EmitContext& ctx, IR::Type type) { 140static Id TypeId(const EmitContext& ctx, IR::Type type) {
131 switch (type) { 141 switch (type) {
132 case IR::Type::U1: 142 case IR::Type::U1:
133 return ctx.u1; 143 return ctx.U1;
134 case IR::Type::U32: 144 case IR::Type::U32:
135 return ctx.u32[1]; 145 return ctx.U32[1];
136 default: 146 default:
137 throw NotImplementedException("Phi node type {}", type); 147 throw NotImplementedException("Phi node type {}", type);
138 } 148 }
@@ -162,7 +172,7 @@ Id EmitSPIRV::EmitPhi(EmitContext& ctx, IR::Inst* inst) {
162 } 172 }
163 IR::Block* const phi_block{inst->PhiBlock(index)}; 173 IR::Block* const phi_block{inst->PhiBlock(index)};
164 operands.push_back(def); 174 operands.push_back(def);
165 operands.push_back(ctx.BlockLabel(phi_block)); 175 operands.push_back(phi_block->Definition<Id>());
166 } 176 }
167 const Id result_type{TypeId(ctx, inst->Arg(0).Type())}; 177 const Id result_type{TypeId(ctx, inst->Arg(0).Type())};
168 return ctx.OpPhi(result_type, std::span(operands.data(), operands.size())); 178 return ctx.OpPhi(result_type, std::span(operands.data(), operands.size()));
@@ -174,29 +184,6 @@ void EmitSPIRV::EmitIdentity(EmitContext&) {
174 throw NotImplementedException("SPIR-V Instruction"); 184 throw NotImplementedException("SPIR-V Instruction");
175} 185}
176 186
177// FIXME: Move to its own file
178void EmitSPIRV::EmitBranch(EmitContext& ctx, IR::Inst* inst) {
179 ctx.OpBranch(ctx.BlockLabel(inst->Arg(0).Label()));
180}
181
182void EmitSPIRV::EmitBranchConditional(EmitContext& ctx, IR::Inst* inst) {
183 ctx.OpBranchConditional(ctx.Def(inst->Arg(0)), ctx.BlockLabel(inst->Arg(1).Label()),
184 ctx.BlockLabel(inst->Arg(2).Label()));
185}
186
187void EmitSPIRV::EmitLoopMerge(EmitContext& ctx, IR::Inst* inst) {
188 ctx.OpLoopMerge(ctx.BlockLabel(inst->Arg(0).Label()), ctx.BlockLabel(inst->Arg(1).Label()),
189 spv::LoopControlMask::MaskNone);
190}
191
192void EmitSPIRV::EmitSelectionMerge(EmitContext& ctx, IR::Inst* inst) {
193 ctx.OpSelectionMerge(ctx.BlockLabel(inst->Arg(0).Label()), spv::SelectionControlMask::MaskNone);
194}
195
196void EmitSPIRV::EmitReturn(EmitContext& ctx) {
197 ctx.OpReturn();
198}
199
200void EmitSPIRV::EmitGetZeroFromOp(EmitContext&) { 187void EmitSPIRV::EmitGetZeroFromOp(EmitContext&) {
201 throw LogicError("Unreachable instruction"); 188 throw LogicError("Unreachable instruction");
202} 189}