summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Fernando Sahmkow2019-09-23 22:55:25 -0400
committerGravatar FernandoS272019-10-25 09:01:30 -0400
commit8909f52166bf9c27d52b5a722efbd46d1a11e876 (patch)
tree2e21bbc3c3f5422325d8003d72cdb4bb120a26e5 /src
parentShader_Cache: setup connection of ConstBufferLocker (diff)
downloadyuzu-8909f52166bf9c27d52b5a722efbd46d1a11e876.tar.gz
yuzu-8909f52166bf9c27d52b5a722efbd46d1a11e876.tar.xz
yuzu-8909f52166bf9c27d52b5a722efbd46d1a11e876.zip
Shader_IR: Implement Fast BRX and allow multi-branches in the CFG.
Diffstat (limited to 'src')
-rw-r--r--src/video_core/renderer_opengl/gl_shader_decompiler.cpp5
-rw-r--r--src/video_core/renderer_vulkan/vk_shader_decompiler.cpp7
-rw-r--r--src/video_core/shader/ast.cpp4
-rw-r--r--src/video_core/shader/control_flow.cpp262
-rw-r--r--src/video_core/shader/control_flow.h59
-rw-r--r--src/video_core/shader/decode.cpp34
-rw-r--r--src/video_core/shader/expr.h17
7 files changed, 258 insertions, 130 deletions
diff --git a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
index baec66ff0..71d7389cb 100644
--- a/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
+++ b/src/video_core/renderer_opengl/gl_shader_decompiler.cpp
@@ -2338,6 +2338,11 @@ public:
2338 inner += expr.value ? "true" : "false"; 2338 inner += expr.value ? "true" : "false";
2339 } 2339 }
2340 2340
2341 void operator()(VideoCommon::Shader::ExprGprEqual& expr) {
2342 inner +=
2343 "( ftou(" + decomp.GetRegister(expr.gpr) + ") == " + std::to_string(expr.value) + ')';
2344 }
2345
2341 const std::string& GetResult() const { 2346 const std::string& GetResult() const {
2342 return inner; 2347 return inner;
2343 } 2348 }
diff --git a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
index 0d943a826..42cf068b6 100644
--- a/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
+++ b/src/video_core/renderer_vulkan/vk_shader_decompiler.cpp
@@ -1704,6 +1704,13 @@ public:
1704 return expr.value ? decomp.v_true : decomp.v_false; 1704 return expr.value ? decomp.v_true : decomp.v_false;
1705 } 1705 }
1706 1706
1707 Id operator()(const ExprGprEqual& expr) {
1708 const Id target = decomp.Constant(decomp.t_uint, expr.value);
1709 const Id gpr = decomp.BitcastTo<Type::Uint>(
1710 decomp.Emit(decomp.OpLoad(decomp.t_float, decomp.registers.at(expr.gpr))));
1711 return decomp.Emit(decomp.OpLogicalEqual(decomp.t_uint, gpr, target));
1712 }
1713
1707 Id Visit(const Expr& node) { 1714 Id Visit(const Expr& node) {
1708 return std::visit(*this, *node); 1715 return std::visit(*this, *node);
1709 } 1716 }
diff --git a/src/video_core/shader/ast.cpp b/src/video_core/shader/ast.cpp
index e43aecc18..2fa3a3f7d 100644
--- a/src/video_core/shader/ast.cpp
+++ b/src/video_core/shader/ast.cpp
@@ -228,6 +228,10 @@ public:
228 inner += expr.value ? "true" : "false"; 228 inner += expr.value ? "true" : "false";
229 } 229 }
230 230
231 void operator()(ExprGprEqual const& expr) {
232 inner += "( gpr_" + std::to_string(expr.gpr) + " == " + std::to_string(expr.value) + ')';
233 }
234
231 const std::string& GetResult() const { 235 const std::string& GetResult() const {
232 return inner; 236 return inner;
233 } 237 }
diff --git a/src/video_core/shader/control_flow.cpp b/src/video_core/shader/control_flow.cpp
index dac2e4272..d1c269ea7 100644
--- a/src/video_core/shader/control_flow.cpp
+++ b/src/video_core/shader/control_flow.cpp
@@ -35,14 +35,24 @@ struct BlockStack {
35 std::stack<u32> pbk_stack{}; 35 std::stack<u32> pbk_stack{};
36}; 36};
37 37
38struct BlockBranchInfo { 38template <typename T, typename... Args>
39 Condition condition{}; 39BlockBranchInfo MakeBranchInfo(Args&&... args) {
40 s32 address{exit_branch}; 40 static_assert(std::is_convertible_v<T, BranchData>);
41 bool kill{}; 41 return std::make_shared<BranchData>(T(std::forward<Args>(args)...));
42 bool is_sync{}; 42}
43 bool is_brk{}; 43
44 bool ignore{}; 44bool BlockBranchInfoAreEqual(BlockBranchInfo first, BlockBranchInfo second) {
45}; 45 return false; //(*first) == (*second);
46}
47
48bool BlockBranchIsIgnored(BlockBranchInfo first) {
49 bool ignore = false;
50 if (std::holds_alternative<SingleBranch>(*first)) {
51 auto branch = std::get_if<SingleBranch>(first.get());
52 ignore = branch->ignore;
53 }
54 return ignore;
55}
46 56
47struct BlockInfo { 57struct BlockInfo {
48 u32 start{}; 58 u32 start{};
@@ -234,6 +244,7 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
234 u32 offset = static_cast<u32>(address); 244 u32 offset = static_cast<u32>(address);
235 const u32 end_address = static_cast<u32>(state.program_size / sizeof(Instruction)); 245 const u32 end_address = static_cast<u32>(state.program_size / sizeof(Instruction));
236 ParseInfo parse_info{}; 246 ParseInfo parse_info{};
247 SingleBranch single_branch{};
237 248
238 const auto insert_label = [](CFGRebuildState& state, u32 address) { 249 const auto insert_label = [](CFGRebuildState& state, u32 address) {
239 const auto pair = state.labels.emplace(address); 250 const auto pair = state.labels.emplace(address);
@@ -246,13 +257,14 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
246 if (offset >= end_address) { 257 if (offset >= end_address) {
247 // ASSERT_OR_EXECUTE can't be used, as it ignores the break 258 // ASSERT_OR_EXECUTE can't be used, as it ignores the break
248 ASSERT_MSG(false, "Shader passed the current limit!"); 259 ASSERT_MSG(false, "Shader passed the current limit!");
249 parse_info.branch_info.address = exit_branch; 260
250 parse_info.branch_info.ignore = false; 261 single_branch.address = exit_branch;
262 single_branch.ignore = false;
251 break; 263 break;
252 } 264 }
253 if (state.registered.count(offset) != 0) { 265 if (state.registered.count(offset) != 0) {
254 parse_info.branch_info.address = offset; 266 single_branch.address = offset;
255 parse_info.branch_info.ignore = true; 267 single_branch.ignore = true;
256 break; 268 break;
257 } 269 }
258 if (IsSchedInstruction(offset, state.start)) { 270 if (IsSchedInstruction(offset, state.start)) {
@@ -269,24 +281,26 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
269 switch (opcode->get().GetId()) { 281 switch (opcode->get().GetId()) {
270 case OpCode::Id::EXIT: { 282 case OpCode::Id::EXIT: {
271 const auto pred_index = static_cast<u32>(instr.pred.pred_index); 283 const auto pred_index = static_cast<u32>(instr.pred.pred_index);
272 parse_info.branch_info.condition.predicate = 284 single_branch.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0);
273 GetPredicate(pred_index, instr.negate_pred != 0); 285 if (single_branch.condition.predicate == Pred::NeverExecute) {
274 if (parse_info.branch_info.condition.predicate == Pred::NeverExecute) {
275 offset++; 286 offset++;
276 continue; 287 continue;
277 } 288 }
278 const ConditionCode cc = instr.flow_condition_code; 289 const ConditionCode cc = instr.flow_condition_code;
279 parse_info.branch_info.condition.cc = cc; 290 single_branch.condition.cc = cc;
280 if (cc == ConditionCode::F) { 291 if (cc == ConditionCode::F) {
281 offset++; 292 offset++;
282 continue; 293 continue;
283 } 294 }
284 parse_info.branch_info.address = exit_branch; 295 single_branch.address = exit_branch;
285 parse_info.branch_info.kill = false; 296 single_branch.kill = false;
286 parse_info.branch_info.is_sync = false; 297 single_branch.is_sync = false;
287 parse_info.branch_info.is_brk = false; 298 single_branch.is_brk = false;
288 parse_info.branch_info.ignore = false; 299 single_branch.ignore = false;
289 parse_info.end_address = offset; 300 parse_info.end_address = offset;
301 parse_info.branch_info = MakeBranchInfo<SingleBranch>(
302 single_branch.condition, single_branch.address, single_branch.kill,
303 single_branch.is_sync, single_branch.is_brk, single_branch.ignore);
290 304
291 return {ParseResult::ControlCaught, parse_info}; 305 return {ParseResult::ControlCaught, parse_info};
292 } 306 }
@@ -295,99 +309,107 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
295 return {ParseResult::AbnormalFlow, parse_info}; 309 return {ParseResult::AbnormalFlow, parse_info};
296 } 310 }
297 const auto pred_index = static_cast<u32>(instr.pred.pred_index); 311 const auto pred_index = static_cast<u32>(instr.pred.pred_index);
298 parse_info.branch_info.condition.predicate = 312 single_branch.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0);
299 GetPredicate(pred_index, instr.negate_pred != 0); 313 if (single_branch.condition.predicate == Pred::NeverExecute) {
300 if (parse_info.branch_info.condition.predicate == Pred::NeverExecute) {
301 offset++; 314 offset++;
302 continue; 315 continue;
303 } 316 }
304 const ConditionCode cc = instr.flow_condition_code; 317 const ConditionCode cc = instr.flow_condition_code;
305 parse_info.branch_info.condition.cc = cc; 318 single_branch.condition.cc = cc;
306 if (cc == ConditionCode::F) { 319 if (cc == ConditionCode::F) {
307 offset++; 320 offset++;
308 continue; 321 continue;
309 } 322 }
310 const u32 branch_offset = offset + instr.bra.GetBranchTarget(); 323 const u32 branch_offset = offset + instr.bra.GetBranchTarget();
311 if (branch_offset == 0) { 324 if (branch_offset == 0) {
312 parse_info.branch_info.address = exit_branch; 325 single_branch.address = exit_branch;
313 } else { 326 } else {
314 parse_info.branch_info.address = branch_offset; 327 single_branch.address = branch_offset;
315 } 328 }
316 insert_label(state, branch_offset); 329 insert_label(state, branch_offset);
317 parse_info.branch_info.kill = false; 330 single_branch.kill = false;
318 parse_info.branch_info.is_sync = false; 331 single_branch.is_sync = false;
319 parse_info.branch_info.is_brk = false; 332 single_branch.is_brk = false;
320 parse_info.branch_info.ignore = false; 333 single_branch.ignore = false;
321 parse_info.end_address = offset; 334 parse_info.end_address = offset;
335 parse_info.branch_info = MakeBranchInfo<SingleBranch>(
336 single_branch.condition, single_branch.address, single_branch.kill,
337 single_branch.is_sync, single_branch.is_brk, single_branch.ignore);
322 338
323 return {ParseResult::ControlCaught, parse_info}; 339 return {ParseResult::ControlCaught, parse_info};
324 } 340 }
325 case OpCode::Id::SYNC: { 341 case OpCode::Id::SYNC: {
326 const auto pred_index = static_cast<u32>(instr.pred.pred_index); 342 const auto pred_index = static_cast<u32>(instr.pred.pred_index);
327 parse_info.branch_info.condition.predicate = 343 single_branch.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0);
328 GetPredicate(pred_index, instr.negate_pred != 0); 344 if (single_branch.condition.predicate == Pred::NeverExecute) {
329 if (parse_info.branch_info.condition.predicate == Pred::NeverExecute) {
330 offset++; 345 offset++;
331 continue; 346 continue;
332 } 347 }
333 const ConditionCode cc = instr.flow_condition_code; 348 const ConditionCode cc = instr.flow_condition_code;
334 parse_info.branch_info.condition.cc = cc; 349 single_branch.condition.cc = cc;
335 if (cc == ConditionCode::F) { 350 if (cc == ConditionCode::F) {
336 offset++; 351 offset++;
337 continue; 352 continue;
338 } 353 }
339 parse_info.branch_info.address = unassigned_branch; 354 single_branch.address = unassigned_branch;
340 parse_info.branch_info.kill = false; 355 single_branch.kill = false;
341 parse_info.branch_info.is_sync = true; 356 single_branch.is_sync = true;
342 parse_info.branch_info.is_brk = false; 357 single_branch.is_brk = false;
343 parse_info.branch_info.ignore = false; 358 single_branch.ignore = false;
344 parse_info.end_address = offset; 359 parse_info.end_address = offset;
360 parse_info.branch_info = MakeBranchInfo<SingleBranch>(
361 single_branch.condition, single_branch.address, single_branch.kill,
362 single_branch.is_sync, single_branch.is_brk, single_branch.ignore);
345 363
346 return {ParseResult::ControlCaught, parse_info}; 364 return {ParseResult::ControlCaught, parse_info};
347 } 365 }
348 case OpCode::Id::BRK: { 366 case OpCode::Id::BRK: {
349 const auto pred_index = static_cast<u32>(instr.pred.pred_index); 367 const auto pred_index = static_cast<u32>(instr.pred.pred_index);
350 parse_info.branch_info.condition.predicate = 368 single_branch.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0);
351 GetPredicate(pred_index, instr.negate_pred != 0); 369 if (single_branch.condition.predicate == Pred::NeverExecute) {
352 if (parse_info.branch_info.condition.predicate == Pred::NeverExecute) {
353 offset++; 370 offset++;
354 continue; 371 continue;
355 } 372 }
356 const ConditionCode cc = instr.flow_condition_code; 373 const ConditionCode cc = instr.flow_condition_code;
357 parse_info.branch_info.condition.cc = cc; 374 single_branch.condition.cc = cc;
358 if (cc == ConditionCode::F) { 375 if (cc == ConditionCode::F) {
359 offset++; 376 offset++;
360 continue; 377 continue;
361 } 378 }
362 parse_info.branch_info.address = unassigned_branch; 379 single_branch.address = unassigned_branch;
363 parse_info.branch_info.kill = false; 380 single_branch.kill = false;
364 parse_info.branch_info.is_sync = false; 381 single_branch.is_sync = false;
365 parse_info.branch_info.is_brk = true; 382 single_branch.is_brk = true;
366 parse_info.branch_info.ignore = false; 383 single_branch.ignore = false;
367 parse_info.end_address = offset; 384 parse_info.end_address = offset;
385 parse_info.branch_info = MakeBranchInfo<SingleBranch>(
386 single_branch.condition, single_branch.address, single_branch.kill,
387 single_branch.is_sync, single_branch.is_brk, single_branch.ignore);
368 388
369 return {ParseResult::ControlCaught, parse_info}; 389 return {ParseResult::ControlCaught, parse_info};
370 } 390 }
371 case OpCode::Id::KIL: { 391 case OpCode::Id::KIL: {
372 const auto pred_index = static_cast<u32>(instr.pred.pred_index); 392 const auto pred_index = static_cast<u32>(instr.pred.pred_index);
373 parse_info.branch_info.condition.predicate = 393 single_branch.condition.predicate = GetPredicate(pred_index, instr.negate_pred != 0);
374 GetPredicate(pred_index, instr.negate_pred != 0); 394 if (single_branch.condition.predicate == Pred::NeverExecute) {
375 if (parse_info.branch_info.condition.predicate == Pred::NeverExecute) {
376 offset++; 395 offset++;
377 continue; 396 continue;
378 } 397 }
379 const ConditionCode cc = instr.flow_condition_code; 398 const ConditionCode cc = instr.flow_condition_code;
380 parse_info.branch_info.condition.cc = cc; 399 single_branch.condition.cc = cc;
381 if (cc == ConditionCode::F) { 400 if (cc == ConditionCode::F) {
382 offset++; 401 offset++;
383 continue; 402 continue;
384 } 403 }
385 parse_info.branch_info.address = exit_branch; 404 single_branch.address = exit_branch;
386 parse_info.branch_info.kill = true; 405 single_branch.kill = true;
387 parse_info.branch_info.is_sync = false; 406 single_branch.is_sync = false;
388 parse_info.branch_info.is_brk = false; 407 single_branch.is_brk = false;
389 parse_info.branch_info.ignore = false; 408 single_branch.ignore = false;
390 parse_info.end_address = offset; 409 parse_info.end_address = offset;
410 parse_info.branch_info = MakeBranchInfo<SingleBranch>(
411 single_branch.condition, single_branch.address, single_branch.kill,
412 single_branch.is_sync, single_branch.is_brk, single_branch.ignore);
391 413
392 return {ParseResult::ControlCaught, parse_info}; 414 return {ParseResult::ControlCaught, parse_info};
393 } 415 }
@@ -407,16 +429,25 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
407 auto tmp = TrackBranchIndirectInfo(state, address, offset); 429 auto tmp = TrackBranchIndirectInfo(state, address, offset);
408 if (tmp) { 430 if (tmp) {
409 auto result = *tmp; 431 auto result = *tmp;
410 std::string entries{}; 432 std::vector<CaseBranch> branches{};
433 s32 pc_target = offset + result.relative_position;
411 for (u32 i = 0; i < result.entries; i++) { 434 for (u32 i = 0; i < result.entries; i++) {
412 auto k = locker.ObtainKey(result.buffer, result.offset + i * 4); 435 auto k = state.locker.ObtainKey(result.buffer, result.offset + i * 4);
413 entries = entries + std::to_string(*k) + '\n'; 436 if (!k) {
437 return {ParseResult::AbnormalFlow, parse_info};
438 }
439 u32 value = *k;
440 u32 target = static_cast<u32>((value >> 3) + pc_target);
441 insert_label(state, target);
442 branches.emplace_back(value, target);
414 } 443 }
415 LOG_CRITICAL(HW_GPU, 444 parse_info.end_address = offset;
416 "Track Successful, BRX: buffer:{}, offset:{}, entries:{}, inner:\n{}", 445 parse_info.branch_info =
417 result.buffer, result.offset, result.entries, entries); 446 MakeBranchInfo<MultiBranch>(static_cast<u32>(instr.gpr8.Value()), branches);
447
448 return {ParseResult::ControlCaught, parse_info};
418 } else { 449 } else {
419 LOG_CRITICAL(HW_GPU, "Track Unsuccesful"); 450 LOG_WARNING(HW_GPU, "BRX Track Unsuccesful");
420 } 451 }
421 return {ParseResult::AbnormalFlow, parse_info}; 452 return {ParseResult::AbnormalFlow, parse_info};
422 } 453 }
@@ -426,10 +457,13 @@ std::pair<ParseResult, ParseInfo> ParseCode(CFGRebuildState& state, u32 address)
426 457
427 offset++; 458 offset++;
428 } 459 }
429 parse_info.branch_info.kill = false; 460 single_branch.kill = false;
430 parse_info.branch_info.is_sync = false; 461 single_branch.is_sync = false;
431 parse_info.branch_info.is_brk = false; 462 single_branch.is_brk = false;
432 parse_info.end_address = offset - 1; 463 parse_info.end_address = offset - 1;
464 parse_info.branch_info = MakeBranchInfo<SingleBranch>(
465 single_branch.condition, single_branch.address, single_branch.kill, single_branch.is_sync,
466 single_branch.is_brk, single_branch.ignore);
433 return {ParseResult::BlockEnd, parse_info}; 467 return {ParseResult::BlockEnd, parse_info};
434} 468}
435 469
@@ -453,9 +487,10 @@ bool TryInspectAddress(CFGRebuildState& state) {
453 BlockInfo& current_block = state.block_info[block_index]; 487 BlockInfo& current_block = state.block_info[block_index];
454 current_block.end = address - 1; 488 current_block.end = address - 1;
455 new_block.branch = current_block.branch; 489 new_block.branch = current_block.branch;
456 BlockBranchInfo forward_branch{}; 490 BlockBranchInfo forward_branch = MakeBranchInfo<SingleBranch>();
457 forward_branch.address = address; 491 auto branch = std::get_if<SingleBranch>(forward_branch.get());
458 forward_branch.ignore = true; 492 branch->address = address;
493 branch->ignore = true;
459 current_block.branch = forward_branch; 494 current_block.branch = forward_branch;
460 return true; 495 return true;
461 } 496 }
@@ -470,12 +505,15 @@ bool TryInspectAddress(CFGRebuildState& state) {
470 505
471 BlockInfo& block_info = CreateBlockInfo(state, address, parse_info.end_address); 506 BlockInfo& block_info = CreateBlockInfo(state, address, parse_info.end_address);
472 block_info.branch = parse_info.branch_info; 507 block_info.branch = parse_info.branch_info;
473 if (parse_info.branch_info.condition.IsUnconditional()) { 508 if (std::holds_alternative<SingleBranch>(*block_info.branch)) {
509 auto branch = std::get_if<SingleBranch>(block_info.branch.get());
510 if (branch->condition.IsUnconditional()) {
511 return true;
512 }
513 const u32 fallthrough_address = parse_info.end_address + 1;
514 state.inspect_queries.push_front(fallthrough_address);
474 return true; 515 return true;
475 } 516 }
476
477 const u32 fallthrough_address = parse_info.end_address + 1;
478 state.inspect_queries.push_front(fallthrough_address);
479 return true; 517 return true;
480} 518}
481 519
@@ -513,31 +551,41 @@ bool TryQuery(CFGRebuildState& state) {
513 state.queries.pop_front(); 551 state.queries.pop_front();
514 gather_labels(q2.ssy_stack, state.ssy_labels, block); 552 gather_labels(q2.ssy_stack, state.ssy_labels, block);
515 gather_labels(q2.pbk_stack, state.pbk_labels, block); 553 gather_labels(q2.pbk_stack, state.pbk_labels, block);
516 if (!block.branch.condition.IsUnconditional()) { 554 if (std::holds_alternative<SingleBranch>(*block.branch)) {
517 q2.address = block.end + 1; 555 auto branch = std::get_if<SingleBranch>(block.branch.get());
518 state.queries.push_back(q2); 556 if (!branch->condition.IsUnconditional()) {
519 } 557 q2.address = block.end + 1;
558 state.queries.push_back(q2);
559 }
520 560
521 Query conditional_query{q2}; 561 Query conditional_query{q2};
522 if (block.branch.is_sync) { 562 if (branch->is_sync) {
523 if (block.branch.address == unassigned_branch) { 563 if (branch->address == unassigned_branch) {
524 block.branch.address = conditional_query.ssy_stack.top(); 564 branch->address = conditional_query.ssy_stack.top();
565 }
566 conditional_query.ssy_stack.pop();
525 } 567 }
526 conditional_query.ssy_stack.pop(); 568 if (branch->is_brk) {
527 } 569 if (branch->address == unassigned_branch) {
528 if (block.branch.is_brk) { 570 branch->address = conditional_query.pbk_stack.top();
529 if (block.branch.address == unassigned_branch) { 571 }
530 block.branch.address = conditional_query.pbk_stack.top(); 572 conditional_query.pbk_stack.pop();
531 } 573 }
532 conditional_query.pbk_stack.pop(); 574 conditional_query.address = branch->address;
575 state.queries.push_back(std::move(conditional_query));
576 return true;
577 }
578 auto multi_branch = std::get_if<MultiBranch>(block.branch.get());
579 for (auto& branch_case : multi_branch->branches) {
580 Query conditional_query{q2};
581 conditional_query.address = branch_case.address;
582 state.queries.push_back(std::move(conditional_query));
533 } 583 }
534 conditional_query.address = block.branch.address;
535 state.queries.push_back(std::move(conditional_query));
536 return true; 584 return true;
537} 585}
538} // Anonymous namespace 586} // Anonymous namespace
539 587
540void InsertBranch(ASTManager& mm, const BlockBranchInfo& branch) { 588void InsertBranch(ASTManager& mm, const BlockBranchInfo& branch_info) {
541 const auto get_expr = ([&](const Condition& cond) -> Expr { 589 const auto get_expr = ([&](const Condition& cond) -> Expr {
542 Expr result{}; 590 Expr result{};
543 if (cond.cc != ConditionCode::T) { 591 if (cond.cc != ConditionCode::T) {
@@ -564,15 +612,24 @@ void InsertBranch(ASTManager& mm, const BlockBranchInfo& branch) {
564 } 612 }
565 return MakeExpr<ExprBoolean>(true); 613 return MakeExpr<ExprBoolean>(true);
566 }); 614 });
567 if (branch.address < 0) { 615 if (std::holds_alternative<SingleBranch>(*branch_info)) {
568 if (branch.kill) { 616 auto branch = std::get_if<SingleBranch>(branch_info.get());
569 mm.InsertReturn(get_expr(branch.condition), true); 617 if (branch->address < 0) {
618 if (branch->kill) {
619 mm.InsertReturn(get_expr(branch->condition), true);
620 return;
621 }
622 mm.InsertReturn(get_expr(branch->condition), false);
570 return; 623 return;
571 } 624 }
572 mm.InsertReturn(get_expr(branch.condition), false); 625 mm.InsertGoto(get_expr(branch->condition), branch->address);
573 return; 626 return;
574 } 627 }
575 mm.InsertGoto(get_expr(branch.condition), branch.address); 628 auto multi_branch = std::get_if<MultiBranch>(branch_info.get());
629 for (auto& branch_case : multi_branch->branches) {
630 mm.InsertGoto(MakeExpr<ExprGprEqual>(multi_branch->gpr, branch_case.cmp_value),
631 branch_case.address);
632 }
576} 633}
577 634
578void DecompileShader(CFGRebuildState& state) { 635void DecompileShader(CFGRebuildState& state) {
@@ -584,9 +641,10 @@ void DecompileShader(CFGRebuildState& state) {
584 if (state.labels.count(block.start) != 0) { 641 if (state.labels.count(block.start) != 0) {
585 state.manager->InsertLabel(block.start); 642 state.manager->InsertLabel(block.start);
586 } 643 }
587 u32 end = block.branch.ignore ? block.end + 1 : block.end; 644 const bool ignore = BlockBranchIsIgnored(block.branch);
645 u32 end = ignore ? block.end + 1 : block.end;
588 state.manager->InsertBlock(block.start, end); 646 state.manager->InsertBlock(block.start, end);
589 if (!block.branch.ignore) { 647 if (!ignore) {
590 InsertBranch(*state.manager, block.branch); 648 InsertBranch(*state.manager, block.branch);
591 } 649 }
592 } 650 }
@@ -668,11 +726,9 @@ std::unique_ptr<ShaderCharacteristics> ScanFlow(const ProgramCode& program_code,
668 ShaderBlock new_block{}; 726 ShaderBlock new_block{};
669 new_block.start = block.start; 727 new_block.start = block.start;
670 new_block.end = block.end; 728 new_block.end = block.end;
671 new_block.ignore_branch = block.branch.ignore; 729 new_block.ignore_branch = BlockBranchIsIgnored(block.branch);
672 if (!new_block.ignore_branch) { 730 if (!new_block.ignore_branch) {
673 new_block.branch.cond = block.branch.condition; 731 new_block.branch = block.branch;
674 new_block.branch.kills = block.branch.kill;
675 new_block.branch.address = block.branch.address;
676 } 732 }
677 result_out->end = std::max(result_out->end, block.end); 733 result_out->end = std::max(result_out->end, block.end);
678 result_out->blocks.push_back(new_block); 734 result_out->blocks.push_back(new_block);
diff --git a/src/video_core/shader/control_flow.h b/src/video_core/shader/control_flow.h
index 6d0e50d7c..369ca255b 100644
--- a/src/video_core/shader/control_flow.h
+++ b/src/video_core/shader/control_flow.h
@@ -7,6 +7,7 @@
7#include <list> 7#include <list>
8#include <optional> 8#include <optional>
9#include <set> 9#include <set>
10#include <variant>
10 11
11#include "video_core/engines/shader_bytecode.h" 12#include "video_core/engines/shader_bytecode.h"
12#include "video_core/shader/ast.h" 13#include "video_core/shader/ast.h"
@@ -37,29 +38,57 @@ struct Condition {
37 } 38 }
38}; 39};
39 40
40struct ShaderBlock { 41class SingleBranch {
41 struct Branch { 42public:
42 Condition cond{}; 43 SingleBranch() = default;
43 bool kills{}; 44 SingleBranch(Condition condition, s32 address, bool kill, bool is_sync, bool is_brk,
44 s32 address{}; 45 bool ignore)
46 : condition{condition}, address{address}, kill{kill}, is_sync{is_sync}, is_brk{is_brk},
47 ignore{ignore} {}
48
49 bool operator==(const SingleBranch& b) const {
50 return std::tie(condition, address, kill, is_sync, is_brk, ignore) ==
51 std::tie(b.condition, b.address, b.kill, b.is_sync, b.is_brk, b.ignore);
52 }
53
54 Condition condition{};
55 s32 address{exit_branch};
56 bool kill{};
57 bool is_sync{};
58 bool is_brk{};
59 bool ignore{};
60};
61
62struct CaseBranch {
63 CaseBranch(u32 cmp_value, u32 address) : cmp_value{cmp_value}, address{address} {}
64 u32 cmp_value;
65 u32 address;
66};
67
68class MultiBranch {
69public:
70 MultiBranch(u32 gpr, std::vector<CaseBranch>& branches)
71 : gpr{gpr}, branches{std::move(branches)} {}
45 72
46 bool operator==(const Branch& b) const { 73 u32 gpr{};
47 return std::tie(cond, kills, address) == std::tie(b.cond, b.kills, b.address); 74 std::vector<CaseBranch> branches{};
48 } 75};
49 76
50 bool operator!=(const Branch& b) const { 77using BranchData = std::variant<SingleBranch, MultiBranch>;
51 return !operator==(b); 78using BlockBranchInfo = std::shared_ptr<BranchData>;
52 }
53 };
54 79
80bool BlockBranchInfoAreEqual(BlockBranchInfo first, BlockBranchInfo second);
81
82struct ShaderBlock {
55 u32 start{}; 83 u32 start{};
56 u32 end{}; 84 u32 end{};
57 bool ignore_branch{}; 85 bool ignore_branch{};
58 Branch branch{}; 86 BlockBranchInfo branch{};
59 87
60 bool operator==(const ShaderBlock& sb) const { 88 bool operator==(const ShaderBlock& sb) const {
61 return std::tie(start, end, ignore_branch, branch) == 89 return std::tie(start, end, ignore_branch) ==
62 std::tie(sb.start, sb.end, sb.ignore_branch, sb.branch); 90 std::tie(sb.start, sb.end, sb.ignore_branch) &&
91 BlockBranchInfoAreEqual(branch, sb.branch);
63 } 92 }
64 93
65 bool operator!=(const ShaderBlock& sb) const { 94 bool operator!=(const ShaderBlock& sb) const {
diff --git a/src/video_core/shader/decode.cpp b/src/video_core/shader/decode.cpp
index 3f87b87ca..053241128 100644
--- a/src/video_core/shader/decode.cpp
+++ b/src/video_core/shader/decode.cpp
@@ -198,24 +198,38 @@ void ShaderIR::InsertControlFlow(NodeBlock& bb, const ShaderBlock& block) {
198 } 198 }
199 return result; 199 return result;
200 }; 200 };
201 if (block.branch.address < 0) { 201 if (std::holds_alternative<SingleBranch>(*block.branch)) {
202 if (block.branch.kills) { 202 auto branch = std::get_if<SingleBranch>(block.branch.get());
203 Node n = Operation(OperationCode::Discard); 203 if (branch->address < 0) {
204 n = apply_conditions(block.branch.cond, n); 204 if (branch->kill) {
205 Node n = Operation(OperationCode::Discard);
206 n = apply_conditions(branch->condition, n);
207 bb.push_back(n);
208 global_code.push_back(n);
209 return;
210 }
211 Node n = Operation(OperationCode::Exit);
212 n = apply_conditions(branch->condition, n);
205 bb.push_back(n); 213 bb.push_back(n);
206 global_code.push_back(n); 214 global_code.push_back(n);
207 return; 215 return;
208 } 216 }
209 Node n = Operation(OperationCode::Exit); 217 Node n = Operation(OperationCode::Branch, Immediate(branch->address));
210 n = apply_conditions(block.branch.cond, n); 218 n = apply_conditions(branch->condition, n);
211 bb.push_back(n); 219 bb.push_back(n);
212 global_code.push_back(n); 220 global_code.push_back(n);
213 return; 221 return;
214 } 222 }
215 Node n = Operation(OperationCode::Branch, Immediate(block.branch.address)); 223 auto multi_branch = std::get_if<MultiBranch>(block.branch.get());
216 n = apply_conditions(block.branch.cond, n); 224 Node op_a = GetRegister(multi_branch->gpr);
217 bb.push_back(n); 225 for (auto& branch_case : multi_branch->branches) {
218 global_code.push_back(n); 226 Node n = Operation(OperationCode::Branch, Immediate(branch_case.address));
227 Node op_b = Immediate(branch_case.cmp_value);
228 Node condition = GetPredicateComparisonInteger(Tegra::Shader::PredCondition::Equal, false, op_a, op_b);
229 auto result = Conditional(condition, {n});
230 bb.push_back(result);
231 global_code.push_back(result);
232 }
219} 233}
220 234
221u32 ShaderIR::DecodeInstr(NodeBlock& bb, u32 pc) { 235u32 ShaderIR::DecodeInstr(NodeBlock& bb, u32 pc) {
diff --git a/src/video_core/shader/expr.h b/src/video_core/shader/expr.h
index d3dcd00ec..e41d23e93 100644
--- a/src/video_core/shader/expr.h
+++ b/src/video_core/shader/expr.h
@@ -17,13 +17,14 @@ using Tegra::Shader::Pred;
17class ExprAnd; 17class ExprAnd;
18class ExprBoolean; 18class ExprBoolean;
19class ExprCondCode; 19class ExprCondCode;
20class ExprGprEqual;
20class ExprNot; 21class ExprNot;
21class ExprOr; 22class ExprOr;
22class ExprPredicate; 23class ExprPredicate;
23class ExprVar; 24class ExprVar;
24 25
25using ExprData = 26using ExprData = std::variant<ExprVar, ExprCondCode, ExprPredicate, ExprNot, ExprOr, ExprAnd,
26 std::variant<ExprVar, ExprCondCode, ExprPredicate, ExprNot, ExprOr, ExprAnd, ExprBoolean>; 27 ExprBoolean, ExprGprEqual>;
27using Expr = std::shared_ptr<ExprData>; 28using Expr = std::shared_ptr<ExprData>;
28 29
29class ExprAnd final { 30class ExprAnd final {
@@ -118,6 +119,18 @@ public:
118 bool value; 119 bool value;
119}; 120};
120 121
122class ExprGprEqual final {
123public:
124 ExprGprEqual(u32 gpr, u32 value) : gpr{gpr}, value{value} {}
125
126 bool operator==(const ExprGprEqual& b) const {
127 return gpr == b.gpr && value == b.value;
128 }
129
130 u32 gpr;
131 u32 value;
132};
133
121template <typename T, typename... Args> 134template <typename T, typename... Args>
122Expr MakeExpr(Args&&... args) { 135Expr MakeExpr(Args&&... args) {
123 static_assert(std::is_convertible_v<T, ExprData>); 136 static_assert(std::is_convertible_v<T, ExprData>);