blob: 3751128aa7d2edb0c7e13f3900a09f934d482f3f [file] [log] [blame]
//
// Copyright (c) 2002-2013 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// UnfoldShortCircuitToIf is an AST traverser to convert short-circuiting operators to if-else
// statements.
// The results are assigned to s# temporaries, which are used by the main translator instead of
// the original expression.
//
#include "compiler/translator/UnfoldShortCircuitToIf.h"
#include "compiler/translator/IntermNode.h"
#include "compiler/translator/IntermNodePatternMatcher.h"
namespace sh
{
namespace
{
// Traverser that unfolds one short-circuiting operation at a time.
class UnfoldShortCircuitTraverser : public TIntermTraverser
{
public:
UnfoldShortCircuitTraverser();
bool visitBinary(Visit visit, TIntermBinary *node) override;
bool visitTernary(Visit visit, TIntermTernary *node) override;
void nextIteration();
bool foundShortCircuit() const { return mFoundShortCircuit; }
protected:
// Marked to true once an operation that needs to be unfolded has been found.
// After that, no more unfolding is performed on that traversal.
bool mFoundShortCircuit;
IntermNodePatternMatcher mPatternToUnfoldMatcher;
};
UnfoldShortCircuitTraverser::UnfoldShortCircuitTraverser()
: TIntermTraverser(true, false, true),
mFoundShortCircuit(false),
mPatternToUnfoldMatcher(IntermNodePatternMatcher::kUnfoldedShortCircuitExpression)
{
}
bool UnfoldShortCircuitTraverser::visitBinary(Visit visit, TIntermBinary *node)
{
if (mFoundShortCircuit)
return false;
if (visit != PreVisit)
return true;
if (!mPatternToUnfoldMatcher.match(node, getParentNode()))
return true;
// If our right node doesn't have side effects, we know we don't need to unfold this
// expression: there will be no short-circuiting side effects to avoid
// (note: unfolding doesn't depend on the left node -- it will always be evaluated)
ASSERT(node->getRight()->hasSideEffects());
mFoundShortCircuit = true;
switch (node->getOp())
{
case EOpLogicalOr:
{
// "x || y" is equivalent to "x ? true : y", which unfolds to "bool s; if(x) s = true;
// else s = y;",
// and then further simplifies down to "bool s = x; if(!s) s = y;".
TIntermSequence insertions;
TType boolType(EbtBool, EbpUndefined, EvqTemporary);
ASSERT(node->getLeft()->getType() == boolType);
insertions.push_back(createTempInitDeclaration(node->getLeft()));
TIntermBlock *assignRightBlock = new TIntermBlock();
ASSERT(node->getRight()->getType() == boolType);
assignRightBlock->getSequence()->push_back(createTempAssignment(node->getRight()));
TIntermUnary *notTempSymbol =
new TIntermUnary(EOpLogicalNot, createTempSymbol(boolType));
TIntermIfElse *ifNode = new TIntermIfElse(notTempSymbol, assignRightBlock, nullptr);
insertions.push_back(ifNode);
insertStatementsInParentBlock(insertions);
queueReplacement(node, createTempSymbol(boolType), OriginalNode::IS_DROPPED);
return false;
}
case EOpLogicalAnd:
{
// "x && y" is equivalent to "x ? y : false", which unfolds to "bool s; if(x) s = y;
// else s = false;",
// and then further simplifies down to "bool s = x; if(s) s = y;".
TIntermSequence insertions;
TType boolType(EbtBool, EbpUndefined, EvqTemporary);
ASSERT(node->getLeft()->getType() == boolType);
insertions.push_back(createTempInitDeclaration(node->getLeft()));
TIntermBlock *assignRightBlock = new TIntermBlock();
ASSERT(node->getRight()->getType() == boolType);
assignRightBlock->getSequence()->push_back(createTempAssignment(node->getRight()));
TIntermIfElse *ifNode =
new TIntermIfElse(createTempSymbol(boolType), assignRightBlock, nullptr);
insertions.push_back(ifNode);
insertStatementsInParentBlock(insertions);
queueReplacement(node, createTempSymbol(boolType), OriginalNode::IS_DROPPED);
return false;
}
default:
UNREACHABLE();
return true;
}
}
bool UnfoldShortCircuitTraverser::visitTernary(Visit visit, TIntermTernary *node)
{
if (mFoundShortCircuit)
return false;
if (visit != PreVisit)
return true;
if (!mPatternToUnfoldMatcher.match(node))
return true;
mFoundShortCircuit = true;
// Unfold "b ? x : y" into "type s; if(b) s = x; else s = y;"
TIntermSequence insertions;
TIntermDeclaration *tempDeclaration = createTempDeclaration(node->getType());
insertions.push_back(tempDeclaration);
TIntermBlock *trueBlock = new TIntermBlock();
TIntermBinary *trueAssignment = createTempAssignment(node->getTrueExpression());
trueBlock->getSequence()->push_back(trueAssignment);
TIntermBlock *falseBlock = new TIntermBlock();
TIntermBinary *falseAssignment = createTempAssignment(node->getFalseExpression());
falseBlock->getSequence()->push_back(falseAssignment);
TIntermIfElse *ifNode =
new TIntermIfElse(node->getCondition()->getAsTyped(), trueBlock, falseBlock);
insertions.push_back(ifNode);
insertStatementsInParentBlock(insertions);
TIntermSymbol *ternaryResult = createTempSymbol(node->getType());
queueReplacement(node, ternaryResult, OriginalNode::IS_DROPPED);
return false;
}
void UnfoldShortCircuitTraverser::nextIteration()
{
mFoundShortCircuit = false;
nextTemporaryIndex();
}
} // namespace
void UnfoldShortCircuitToIf(TIntermNode *root, unsigned int *temporaryIndex)
{
UnfoldShortCircuitTraverser traverser;
ASSERT(temporaryIndex != nullptr);
traverser.useTemporaryIndex(temporaryIndex);
// Unfold one operator at a time, and reset the traverser between iterations.
do
{
traverser.nextIteration();
root->traverse(&traverser);
if (traverser.foundShortCircuit())
traverser.updateTree();
} while (traverser.foundShortCircuit());
}
} // namespace sh