summaryrefslogtreecommitdiff
path: root/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/shader_recompiler/ir_opt/constant_propagation_pass.cpp')
-rw-r--r--src/shader_recompiler/ir_opt/constant_propagation_pass.cpp42
1 files changed, 41 insertions, 1 deletions
diff --git a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
index 7da4d50ef..15e16956e 100644
--- a/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
+++ b/src/shader_recompiler/ir_opt/constant_propagation_pass.cpp
@@ -3,9 +3,9 @@
3// Refer to the license.txt file included. 3// Refer to the license.txt file included.
4 4
5#include <algorithm> 5#include <algorithm>
6#include <ranges>
6#include <tuple> 7#include <tuple>
7#include <type_traits> 8#include <type_traits>
8#include <ranges>
9 9
10#include "common/bit_cast.h" 10#include "common/bit_cast.h"
11#include "common/bit_util.h" 11#include "common/bit_util.h"
@@ -332,6 +332,18 @@ void FoldBitCast(IR::Inst& inst, IR::Opcode reverse) {
332 } 332 }
333} 333}
334 334
335void FoldInverseFunc(IR::Inst& inst, IR::Opcode reverse) {
336 const IR::Value value{inst.Arg(0)};
337 if (value.IsImmediate()) {
338 return;
339 }
340 IR::Inst* const arg_inst{value.InstRecursive()};
341 if (arg_inst->Opcode() == reverse) {
342 inst.ReplaceUsesWith(arg_inst->Arg(0));
343 return;
344 }
345}
346
335template <typename Func, size_t... I> 347template <typename Func, size_t... I>
336IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<I...>) { 348IR::Value EvalImmediates(const IR::Inst& inst, Func&& func, std::index_sequence<I...>) {
337 using Traits = LambdaTraits<decltype(func)>; 349 using Traits = LambdaTraits<decltype(func)>;
@@ -372,6 +384,10 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
372 return FoldBitCast<IR::Opcode::BitCastU32F32, u32, f32>(inst, IR::Opcode::BitCastF32U32); 384 return FoldBitCast<IR::Opcode::BitCastU32F32, u32, f32>(inst, IR::Opcode::BitCastF32U32);
373 case IR::Opcode::IAdd64: 385 case IR::Opcode::IAdd64:
374 return FoldAdd<u64>(block, inst); 386 return FoldAdd<u64>(block, inst);
387 case IR::Opcode::PackHalf2x16:
388 return FoldInverseFunc(inst, IR::Opcode::UnpackHalf2x16);
389 case IR::Opcode::UnpackHalf2x16:
390 return FoldInverseFunc(inst, IR::Opcode::PackHalf2x16);
375 case IR::Opcode::SelectU1: 391 case IR::Opcode::SelectU1:
376 case IR::Opcode::SelectU8: 392 case IR::Opcode::SelectU8:
377 case IR::Opcode::SelectU16: 393 case IR::Opcode::SelectU16:
@@ -395,6 +411,30 @@ void ConstantPropagation(IR::Block& block, IR::Inst& inst) {
395 case IR::Opcode::ULessThan: 411 case IR::Opcode::ULessThan:
396 FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; }); 412 FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a < b; });
397 return; 413 return;
414 case IR::Opcode::SLessThanEqual:
415 FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a <= b; });
416 return;
417 case IR::Opcode::ULessThanEqual:
418 FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a <= b; });
419 return;
420 case IR::Opcode::SGreaterThan:
421 FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a > b; });
422 return;
423 case IR::Opcode::UGreaterThan:
424 FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a > b; });
425 return;
426 case IR::Opcode::SGreaterThanEqual:
427 FoldWhenAllImmediates(inst, [](s32 a, s32 b) { return a >= b; });
428 return;
429 case IR::Opcode::UGreaterThanEqual:
430 FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a >= b; });
431 return;
432 case IR::Opcode::IEqual:
433 FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a == b; });
434 return;
435 case IR::Opcode::INotEqual:
436 FoldWhenAllImmediates(inst, [](u32 a, u32 b) { return a != b; });
437 return;
398 case IR::Opcode::BitFieldUExtract: 438 case IR::Opcode::BitFieldUExtract:
399 FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) { 439 FoldWhenAllImmediates(inst, [](u32 base, u32 shift, u32 count) {
400 if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) { 440 if (static_cast<size_t>(shift) + static_cast<size_t>(count) > Common::BitSize<u32>()) {