summaryrefslogtreecommitdiff
path: root/src/video_core/shader/ast.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/video_core/shader/ast.cpp')
-rw-r--r--src/video_core/shader/ast.cpp122
1 files changed, 47 insertions, 75 deletions
diff --git a/src/video_core/shader/ast.cpp b/src/video_core/shader/ast.cpp
index 2eb065c3d..436d45f4b 100644
--- a/src/video_core/shader/ast.cpp
+++ b/src/video_core/shader/ast.cpp
@@ -17,6 +17,7 @@ void ASTZipper::Init(const ASTNode new_first, const ASTNode parent) {
17 ASSERT(new_first->manager == nullptr); 17 ASSERT(new_first->manager == nullptr);
18 first = new_first; 18 first = new_first;
19 last = new_first; 19 last = new_first;
20
20 ASTNode current = first; 21 ASTNode current = first;
21 while (current) { 22 while (current) {
22 current->manager = this; 23 current->manager = this;
@@ -92,7 +93,7 @@ void ASTZipper::InsertBefore(const ASTNode new_node, const ASTNode at_node) {
92 new_node->manager = this; 93 new_node->manager = this;
93} 94}
94 95
95void ASTZipper::DetachTail(const ASTNode node) { 96void ASTZipper::DetachTail(ASTNode node) {
96 ASSERT(node->manager == this); 97 ASSERT(node->manager == this);
97 if (node == first) { 98 if (node == first) {
98 first.reset(); 99 first.reset();
@@ -103,7 +104,8 @@ void ASTZipper::DetachTail(const ASTNode node) {
103 last = node->previous; 104 last = node->previous;
104 last->next.reset(); 105 last->next.reset();
105 node->previous.reset(); 106 node->previous.reset();
106 ASTNode current = node; 107
108 ASTNode current = std::move(node);
107 while (current) { 109 while (current) {
108 current->manager = nullptr; 110 current->manager = nullptr;
109 current->parent.reset(); 111 current->parent.reset();
@@ -185,9 +187,7 @@ void ASTZipper::Remove(const ASTNode node) {
185 187
186class ExprPrinter final { 188class ExprPrinter final {
187public: 189public:
188 ExprPrinter() = default; 190 void operator()(const ExprAnd& expr) {
189
190 void operator()(ExprAnd const& expr) {
191 inner += "( "; 191 inner += "( ";
192 std::visit(*this, *expr.operand1); 192 std::visit(*this, *expr.operand1);
193 inner += " && "; 193 inner += " && ";
@@ -195,7 +195,7 @@ public:
195 inner += ')'; 195 inner += ')';
196 } 196 }
197 197
198 void operator()(ExprOr const& expr) { 198 void operator()(const ExprOr& expr) {
199 inner += "( "; 199 inner += "( ";
200 std::visit(*this, *expr.operand1); 200 std::visit(*this, *expr.operand1);
201 inner += " || "; 201 inner += " || ";
@@ -203,29 +203,29 @@ public:
203 inner += ')'; 203 inner += ')';
204 } 204 }
205 205
206 void operator()(ExprNot const& expr) { 206 void operator()(const ExprNot& expr) {
207 inner += "!"; 207 inner += "!";
208 std::visit(*this, *expr.operand1); 208 std::visit(*this, *expr.operand1);
209 } 209 }
210 210
211 void operator()(ExprPredicate const& expr) { 211 void operator()(const ExprPredicate& expr) {
212 inner += "P" + std::to_string(expr.predicate); 212 inner += "P" + std::to_string(expr.predicate);
213 } 213 }
214 214
215 void operator()(ExprCondCode const& expr) { 215 void operator()(const ExprCondCode& expr) {
216 u32 cc = static_cast<u32>(expr.cc); 216 u32 cc = static_cast<u32>(expr.cc);
217 inner += "CC" + std::to_string(cc); 217 inner += "CC" + std::to_string(cc);
218 } 218 }
219 219
220 void operator()(ExprVar const& expr) { 220 void operator()(const ExprVar& expr) {
221 inner += "V" + std::to_string(expr.var_index); 221 inner += "V" + std::to_string(expr.var_index);
222 } 222 }
223 223
224 void operator()(ExprBoolean const& expr) { 224 void operator()(const ExprBoolean& expr) {
225 inner += expr.value ? "true" : "false"; 225 inner += expr.value ? "true" : "false";
226 } 226 }
227 227
228 std::string& GetResult() { 228 const std::string& GetResult() const {
229 return inner; 229 return inner;
230 } 230 }
231 231
@@ -234,9 +234,7 @@ public:
234 234
235class ASTPrinter { 235class ASTPrinter {
236public: 236public:
237 ASTPrinter() = default; 237 void operator()(const ASTProgram& ast) {
238
239 void operator()(ASTProgram& ast) {
240 scope++; 238 scope++;
241 inner += "program {\n"; 239 inner += "program {\n";
242 ASTNode current = ast.nodes.GetFirst(); 240 ASTNode current = ast.nodes.GetFirst();
@@ -248,7 +246,7 @@ public:
248 scope--; 246 scope--;
249 } 247 }
250 248
251 void operator()(ASTIfThen& ast) { 249 void operator()(const ASTIfThen& ast) {
252 ExprPrinter expr_parser{}; 250 ExprPrinter expr_parser{};
253 std::visit(expr_parser, *ast.condition); 251 std::visit(expr_parser, *ast.condition);
254 inner += Ident() + "if (" + expr_parser.GetResult() + ") {\n"; 252 inner += Ident() + "if (" + expr_parser.GetResult() + ") {\n";
@@ -262,7 +260,7 @@ public:
262 inner += Ident() + "}\n"; 260 inner += Ident() + "}\n";
263 } 261 }
264 262
265 void operator()(ASTIfElse& ast) { 263 void operator()(const ASTIfElse& ast) {
266 inner += Ident() + "else {\n"; 264 inner += Ident() + "else {\n";
267 scope++; 265 scope++;
268 ASTNode current = ast.nodes.GetFirst(); 266 ASTNode current = ast.nodes.GetFirst();
@@ -274,34 +272,34 @@ public:
274 inner += Ident() + "}\n"; 272 inner += Ident() + "}\n";
275 } 273 }
276 274
277 void operator()(ASTBlockEncoded& ast) { 275 void operator()(const ASTBlockEncoded& ast) {
278 inner += Ident() + "Block(" + std::to_string(ast.start) + ", " + std::to_string(ast.end) + 276 inner += Ident() + "Block(" + std::to_string(ast.start) + ", " + std::to_string(ast.end) +
279 ");\n"; 277 ");\n";
280 } 278 }
281 279
282 void operator()(ASTBlockDecoded& ast) { 280 void operator()(const ASTBlockDecoded& ast) {
283 inner += Ident() + "Block;\n"; 281 inner += Ident() + "Block;\n";
284 } 282 }
285 283
286 void operator()(ASTVarSet& ast) { 284 void operator()(const ASTVarSet& ast) {
287 ExprPrinter expr_parser{}; 285 ExprPrinter expr_parser{};
288 std::visit(expr_parser, *ast.condition); 286 std::visit(expr_parser, *ast.condition);
289 inner += 287 inner +=
290 Ident() + "V" + std::to_string(ast.index) + " := " + expr_parser.GetResult() + ";\n"; 288 Ident() + "V" + std::to_string(ast.index) + " := " + expr_parser.GetResult() + ";\n";
291 } 289 }
292 290
293 void operator()(ASTLabel& ast) { 291 void operator()(const ASTLabel& ast) {
294 inner += "Label_" + std::to_string(ast.index) + ":\n"; 292 inner += "Label_" + std::to_string(ast.index) + ":\n";
295 } 293 }
296 294
297 void operator()(ASTGoto& ast) { 295 void operator()(const ASTGoto& ast) {
298 ExprPrinter expr_parser{}; 296 ExprPrinter expr_parser{};
299 std::visit(expr_parser, *ast.condition); 297 std::visit(expr_parser, *ast.condition);
300 inner += Ident() + "(" + expr_parser.GetResult() + ") -> goto Label_" + 298 inner += Ident() + "(" + expr_parser.GetResult() + ") -> goto Label_" +
301 std::to_string(ast.label) + ";\n"; 299 std::to_string(ast.label) + ";\n";
302 } 300 }
303 301
304 void operator()(ASTDoWhile& ast) { 302 void operator()(const ASTDoWhile& ast) {
305 ExprPrinter expr_parser{}; 303 ExprPrinter expr_parser{};
306 std::visit(expr_parser, *ast.condition); 304 std::visit(expr_parser, *ast.condition);
307 inner += Ident() + "do {\n"; 305 inner += Ident() + "do {\n";
@@ -315,14 +313,14 @@ public:
315 inner += Ident() + "} while (" + expr_parser.GetResult() + ");\n"; 313 inner += Ident() + "} while (" + expr_parser.GetResult() + ");\n";
316 } 314 }
317 315
318 void operator()(ASTReturn& ast) { 316 void operator()(const ASTReturn& ast) {
319 ExprPrinter expr_parser{}; 317 ExprPrinter expr_parser{};
320 std::visit(expr_parser, *ast.condition); 318 std::visit(expr_parser, *ast.condition);
321 inner += Ident() + "(" + expr_parser.GetResult() + ") -> " + 319 inner += Ident() + "(" + expr_parser.GetResult() + ") -> " +
322 (ast.kills ? "discard" : "exit") + ";\n"; 320 (ast.kills ? "discard" : "exit") + ";\n";
323 } 321 }
324 322
325 void operator()(ASTBreak& ast) { 323 void operator()(const ASTBreak& ast) {
326 ExprPrinter expr_parser{}; 324 ExprPrinter expr_parser{};
327 std::visit(expr_parser, *ast.condition); 325 std::visit(expr_parser, *ast.condition);
328 inner += Ident() + "(" + expr_parser.GetResult() + ") -> break;\n"; 326 inner += Ident() + "(" + expr_parser.GetResult() + ") -> break;\n";
@@ -341,7 +339,7 @@ public:
341 std::visit(*this, *node->GetInnerData()); 339 std::visit(*this, *node->GetInnerData());
342 } 340 }
343 341
344 std::string& GetResult() { 342 const std::string& GetResult() const {
345 return inner; 343 return inner;
346 } 344 }
347 345
@@ -352,11 +350,9 @@ private:
352 std::string tabs_memo{}; 350 std::string tabs_memo{};
353 u32 memo_scope{}; 351 u32 memo_scope{};
354 352
355 static std::string tabs; 353 static constexpr std::string_view tabs{" "};
356}; 354};
357 355
358std::string ASTPrinter::tabs = " ";
359
360std::string ASTManager::Print() { 356std::string ASTManager::Print() {
361 ASTPrinter printer{}; 357 ASTPrinter printer{};
362 printer.Visit(main_node); 358 printer.Visit(main_node);
@@ -376,30 +372,6 @@ void ASTManager::Init() {
376 false_condition = MakeExpr<ExprBoolean>(false); 372 false_condition = MakeExpr<ExprBoolean>(false);
377} 373}
378 374
379ASTManager::ASTManager(ASTManager&& other) noexcept
380 : labels_map(std::move(other.labels_map)), labels_count{other.labels_count},
381 gotos(std::move(other.gotos)), labels(std::move(other.labels)), variables{other.variables},
382 program{other.program}, main_node{other.main_node}, false_condition{other.false_condition},
383 disable_else_derivation{other.disable_else_derivation} {
384 other.main_node.reset();
385}
386
387ASTManager& ASTManager::operator=(ASTManager&& other) noexcept {
388 full_decompile = other.full_decompile;
389 labels_map = std::move(other.labels_map);
390 labels_count = other.labels_count;
391 gotos = std::move(other.gotos);
392 labels = std::move(other.labels);
393 variables = other.variables;
394 program = other.program;
395 main_node = other.main_node;
396 false_condition = other.false_condition;
397 disable_else_derivation = other.disable_else_derivation;
398
399 other.main_node.reset();
400 return *this;
401}
402
403void ASTManager::DeclareLabel(u32 address) { 375void ASTManager::DeclareLabel(u32 address) {
404 const auto pair = labels_map.emplace(address, labels_count); 376 const auto pair = labels_map.emplace(address, labels_count);
405 if (pair.second) { 377 if (pair.second) {
@@ -417,19 +389,19 @@ void ASTManager::InsertLabel(u32 address) {
417 389
418void ASTManager::InsertGoto(Expr condition, u32 address) { 390void ASTManager::InsertGoto(Expr condition, u32 address) {
419 const u32 index = labels_map[address]; 391 const u32 index = labels_map[address];
420 const ASTNode goto_node = ASTBase::Make<ASTGoto>(main_node, condition, index); 392 const ASTNode goto_node = ASTBase::Make<ASTGoto>(main_node, std::move(condition), index);
421 gotos.push_back(goto_node); 393 gotos.push_back(goto_node);
422 program->nodes.PushBack(goto_node); 394 program->nodes.PushBack(goto_node);
423} 395}
424 396
425void ASTManager::InsertBlock(u32 start_address, u32 end_address) { 397void ASTManager::InsertBlock(u32 start_address, u32 end_address) {
426 const ASTNode block = ASTBase::Make<ASTBlockEncoded>(main_node, start_address, end_address); 398 ASTNode block = ASTBase::Make<ASTBlockEncoded>(main_node, start_address, end_address);
427 program->nodes.PushBack(block); 399 program->nodes.PushBack(std::move(block));
428} 400}
429 401
430void ASTManager::InsertReturn(Expr condition, bool kills) { 402void ASTManager::InsertReturn(Expr condition, bool kills) {
431 const ASTNode node = ASTBase::Make<ASTReturn>(main_node, condition, kills); 403 ASTNode node = ASTBase::Make<ASTReturn>(main_node, std::move(condition), kills);
432 program->nodes.PushBack(node); 404 program->nodes.PushBack(std::move(node));
433} 405}
434 406
435// The decompile algorithm is based on 407// The decompile algorithm is based on
@@ -496,10 +468,10 @@ void ASTManager::Decompile() {
496 } 468 }
497 labels.clear(); 469 labels.clear();
498 } else { 470 } else {
499 auto it = labels.begin(); 471 auto label_it = labels.begin();
500 while (it != labels.end()) { 472 while (label_it != labels.end()) {
501 bool can_remove = true; 473 bool can_remove = true;
502 ASTNode label = *it; 474 ASTNode label = *label_it;
503 for (const ASTNode& goto_node : gotos) { 475 for (const ASTNode& goto_node : gotos) {
504 const auto label_index = goto_node->GetGotoLabel(); 476 const auto label_index = goto_node->GetGotoLabel();
505 if (!label_index) { 477 if (!label_index) {
@@ -543,11 +515,11 @@ bool ASTManager::IsBackwardsJump(ASTNode goto_node, ASTNode label_node) const {
543 return false; 515 return false;
544} 516}
545 517
546bool ASTManager::IndirectlyRelated(ASTNode first, ASTNode second) { 518bool ASTManager::IndirectlyRelated(const ASTNode& first, const ASTNode& second) const {
547 return !(first->GetParent() == second->GetParent() || DirectlyRelated(first, second)); 519 return !(first->GetParent() == second->GetParent() || DirectlyRelated(first, second));
548} 520}
549 521
550bool ASTManager::DirectlyRelated(ASTNode first, ASTNode second) { 522bool ASTManager::DirectlyRelated(const ASTNode& first, const ASTNode& second) const {
551 if (first->GetParent() == second->GetParent()) { 523 if (first->GetParent() == second->GetParent()) {
552 return false; 524 return false;
553 } 525 }
@@ -577,7 +549,7 @@ bool ASTManager::DirectlyRelated(ASTNode first, ASTNode second) {
577 return min->GetParent() == max->GetParent(); 549 return min->GetParent() == max->GetParent();
578} 550}
579 551
580void ASTManager::ShowCurrentState(std::string state) { 552void ASTManager::ShowCurrentState(std::string_view state) {
581 LOG_CRITICAL(HW_GPU, "\nState {}:\n\n{}\n", state, Print()); 553 LOG_CRITICAL(HW_GPU, "\nState {}:\n\n{}\n", state, Print());
582 SanityCheck(); 554 SanityCheck();
583} 555}
@@ -696,7 +668,7 @@ class ASTClearer {
696public: 668public:
697 ASTClearer() = default; 669 ASTClearer() = default;
698 670
699 void operator()(ASTProgram& ast) { 671 void operator()(const ASTProgram& ast) {
700 ASTNode current = ast.nodes.GetFirst(); 672 ASTNode current = ast.nodes.GetFirst();
701 while (current) { 673 while (current) {
702 Visit(current); 674 Visit(current);
@@ -704,7 +676,7 @@ public:
704 } 676 }
705 } 677 }
706 678
707 void operator()(ASTIfThen& ast) { 679 void operator()(const ASTIfThen& ast) {
708 ASTNode current = ast.nodes.GetFirst(); 680 ASTNode current = ast.nodes.GetFirst();
709 while (current) { 681 while (current) {
710 Visit(current); 682 Visit(current);
@@ -712,7 +684,7 @@ public:
712 } 684 }
713 } 685 }
714 686
715 void operator()(ASTIfElse& ast) { 687 void operator()(const ASTIfElse& ast) {
716 ASTNode current = ast.nodes.GetFirst(); 688 ASTNode current = ast.nodes.GetFirst();
717 while (current) { 689 while (current) {
718 Visit(current); 690 Visit(current);
@@ -720,19 +692,19 @@ public:
720 } 692 }
721 } 693 }
722 694
723 void operator()(ASTBlockEncoded& ast) {} 695 void operator()([[maybe_unused]] const ASTBlockEncoded& ast) {}
724 696
725 void operator()(ASTBlockDecoded& ast) { 697 void operator()(ASTBlockDecoded& ast) {
726 ast.nodes.clear(); 698 ast.nodes.clear();
727 } 699 }
728 700
729 void operator()(ASTVarSet& ast) {} 701 void operator()([[maybe_unused]] const ASTVarSet& ast) {}
730 702
731 void operator()(ASTLabel& ast) {} 703 void operator()([[maybe_unused]] const ASTLabel& ast) {}
732 704
733 void operator()(ASTGoto& ast) {} 705 void operator()([[maybe_unused]] const ASTGoto& ast) {}
734 706
735 void operator()(ASTDoWhile& ast) { 707 void operator()(const ASTDoWhile& ast) {
736 ASTNode current = ast.nodes.GetFirst(); 708 ASTNode current = ast.nodes.GetFirst();
737 while (current) { 709 while (current) {
738 Visit(current); 710 Visit(current);
@@ -740,11 +712,11 @@ public:
740 } 712 }
741 } 713 }
742 714
743 void operator()(ASTReturn& ast) {} 715 void operator()([[maybe_unused]] const ASTReturn& ast) {}
744 716
745 void operator()(ASTBreak& ast) {} 717 void operator()([[maybe_unused]] const ASTBreak& ast) {}
746 718
747 void Visit(ASTNode& node) { 719 void Visit(const ASTNode& node) {
748 std::visit(*this, *node->GetInnerData()); 720 std::visit(*this, *node->GetInnerData());
749 node->Clear(); 721 node->Clear();
750 } 722 }