blob: 5d136a327b8a0d8f28ca20224524f8c14f6e8580 [file] [log] [blame]
//
// Copyright (c) 2002-2015 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// RemoveDynamicIndexing is an AST traverser to remove dynamic indexing of vectors and matrices,
// replacing them with calls to functions that choose which component to return or write.
//
#include "compiler/translator/RemoveDynamicIndexing.h"
#include "compiler/translator/InfoSink.h"
#include "compiler/translator/IntermNode.h"
#include "compiler/translator/IntermNodePatternMatcher.h"
#include "compiler/translator/SymbolTable.h"
namespace sh
{
namespace
{
std::string GetIndexFunctionName(const TType &type, bool write)
{
TInfoSinkBase nameSink;
nameSink << "dyn_index_";
if (write)
{
nameSink << "write_";
}
if (type.isMatrix())
{
nameSink << "mat" << type.getCols() << "x" << type.getRows();
}
else
{
switch (type.getBasicType())
{
case EbtInt:
nameSink << "ivec";
break;
case EbtBool:
nameSink << "bvec";
break;
case EbtUInt:
nameSink << "uvec";
break;
case EbtFloat:
nameSink << "vec";
break;
default:
UNREACHABLE();
}
nameSink << type.getNominalSize();
}
return nameSink.str();
}
TIntermSymbol *CreateBaseSymbol(const TType &type, TQualifier qualifier)
{
TIntermSymbol *symbol = new TIntermSymbol(0, "base", type);
symbol->setInternal(true);
symbol->getTypePointer()->setQualifier(qualifier);
return symbol;
}
TIntermSymbol *CreateIndexSymbol()
{
TIntermSymbol *symbol = new TIntermSymbol(0, "index", TType(EbtInt, EbpHigh));
symbol->setInternal(true);
symbol->getTypePointer()->setQualifier(EvqIn);
return symbol;
}
TIntermSymbol *CreateValueSymbol(const TType &type)
{
TIntermSymbol *symbol = new TIntermSymbol(0, "value", type);
symbol->setInternal(true);
symbol->getTypePointer()->setQualifier(EvqIn);
return symbol;
}
TIntermConstantUnion *CreateIntConstantNode(int i)
{
TConstantUnion *constant = new TConstantUnion();
constant->setIConst(i);
return new TIntermConstantUnion(constant, TType(EbtInt, EbpHigh));
}
TIntermBinary *CreateIndexDirectBaseSymbolNode(const TType &indexedType,
const TType &fieldType,
const int index,
TQualifier baseQualifier)
{
TIntermSymbol *baseSymbol = CreateBaseSymbol(indexedType, baseQualifier);
TIntermBinary *indexNode =
new TIntermBinary(EOpIndexDirect, baseSymbol, TIntermTyped::CreateIndexNode(index));
return indexNode;
}
TIntermBinary *CreateAssignValueSymbolNode(TIntermTyped *targetNode, const TType &assignedValueType)
{
return new TIntermBinary(EOpAssign, targetNode, CreateValueSymbol(assignedValueType));
}
TIntermTyped *EnsureSignedInt(TIntermTyped *node)
{
if (node->getBasicType() == EbtInt)
return node;
TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(node);
return TIntermAggregate::CreateConstructor(TType(EbtInt), arguments);
}
TType GetFieldType(const TType &indexedType)
{
if (indexedType.isMatrix())
{
TType fieldType = TType(indexedType.getBasicType(), indexedType.getPrecision());
fieldType.setPrimarySize(static_cast<unsigned char>(indexedType.getRows()));
return fieldType;
}
else
{
return TType(indexedType.getBasicType(), indexedType.getPrecision());
}
}
// Generate a read or write function for one field in a vector/matrix.
// Out-of-range indices are clamped. This is consistent with how ANGLE handles out-of-range
// indices in other places.
// Note that indices can be either int or uint. We create only int versions of the functions,
// and convert uint indices to int at the call site.
// read function example:
// float dyn_index_vec2(in vec2 base, in int index)
// {
// switch(index)
// {
// case (0):
// return base[0];
// case (1):
// return base[1];
// default:
// break;
// }
// if (index < 0)
// return base[0];
// return base[1];
// }
// write function example:
// void dyn_index_write_vec2(inout vec2 base, in int index, in float value)
// {
// switch(index)
// {
// case (0):
// base[0] = value;
// return;
// case (1):
// base[1] = value;
// return;
// default:
// break;
// }
// if (index < 0)
// {
// base[0] = value;
// return;
// }
// base[1] = value;
// }
// Note that else is not used in above functions to avoid the RewriteElseBlocks transformation.
TIntermFunctionDefinition *GetIndexFunctionDefinition(TType type,
bool write,
const TSymbolUniqueId &functionId)
{
ASSERT(!type.isArray());
// Conservatively use highp here, even if the indexed type is not highp. That way the code can't
// end up using mediump version of an indexing function for a highp value, if both mediump and
// highp values are being indexed in the shader. For HLSL precision doesn't matter, but in
// principle this code could be used with multiple backends.
type.setPrecision(EbpHigh);
TType fieldType = GetFieldType(type);
int numCases = 0;
if (type.isMatrix())
{
numCases = type.getCols();
}
else
{
numCases = type.getNominalSize();
}
TType returnType(EbtVoid);
if (!write)
{
returnType = fieldType;
}
std::string functionName = GetIndexFunctionName(type, write);
TIntermFunctionPrototype *prototypeNode = TIntermTraverser::CreateInternalFunctionPrototypeNode(
returnType, functionName.c_str(), functionId);
TQualifier baseQualifier = EvqInOut;
if (!write)
baseQualifier = EvqIn;
TIntermSymbol *baseParam = CreateBaseSymbol(type, baseQualifier);
prototypeNode->getSequence()->push_back(baseParam);
TIntermSymbol *indexParam = CreateIndexSymbol();
prototypeNode->getSequence()->push_back(indexParam);
if (write)
{
TIntermSymbol *valueParam = CreateValueSymbol(fieldType);
prototypeNode->getSequence()->push_back(valueParam);
}
TIntermBlock *statementList = new TIntermBlock();
for (int i = 0; i < numCases; ++i)
{
TIntermCase *caseNode = new TIntermCase(CreateIntConstantNode(i));
statementList->getSequence()->push_back(caseNode);
TIntermBinary *indexNode =
CreateIndexDirectBaseSymbolNode(type, fieldType, i, baseQualifier);
if (write)
{
TIntermBinary *assignNode = CreateAssignValueSymbolNode(indexNode, fieldType);
statementList->getSequence()->push_back(assignNode);
TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
statementList->getSequence()->push_back(returnNode);
}
else
{
TIntermBranch *returnNode = new TIntermBranch(EOpReturn, indexNode);
statementList->getSequence()->push_back(returnNode);
}
}
// Default case
TIntermCase *defaultNode = new TIntermCase(nullptr);
statementList->getSequence()->push_back(defaultNode);
TIntermBranch *breakNode = new TIntermBranch(EOpBreak, nullptr);
statementList->getSequence()->push_back(breakNode);
TIntermSwitch *switchNode = new TIntermSwitch(CreateIndexSymbol(), statementList);
TIntermBlock *bodyNode = new TIntermBlock();
bodyNode->getSequence()->push_back(switchNode);
TIntermBinary *cond =
new TIntermBinary(EOpLessThan, CreateIndexSymbol(), CreateIntConstantNode(0));
cond->setType(TType(EbtBool, EbpUndefined));
// Two blocks: one accesses (either reads or writes) the first element and returns,
// the other accesses the last element.
TIntermBlock *useFirstBlock = new TIntermBlock();
TIntermBlock *useLastBlock = new TIntermBlock();
TIntermBinary *indexFirstNode =
CreateIndexDirectBaseSymbolNode(type, fieldType, 0, baseQualifier);
TIntermBinary *indexLastNode =
CreateIndexDirectBaseSymbolNode(type, fieldType, numCases - 1, baseQualifier);
if (write)
{
TIntermBinary *assignFirstNode = CreateAssignValueSymbolNode(indexFirstNode, fieldType);
useFirstBlock->getSequence()->push_back(assignFirstNode);
TIntermBranch *returnNode = new TIntermBranch(EOpReturn, nullptr);
useFirstBlock->getSequence()->push_back(returnNode);
TIntermBinary *assignLastNode = CreateAssignValueSymbolNode(indexLastNode, fieldType);
useLastBlock->getSequence()->push_back(assignLastNode);
}
else
{
TIntermBranch *returnFirstNode = new TIntermBranch(EOpReturn, indexFirstNode);
useFirstBlock->getSequence()->push_back(returnFirstNode);
TIntermBranch *returnLastNode = new TIntermBranch(EOpReturn, indexLastNode);
useLastBlock->getSequence()->push_back(returnLastNode);
}
TIntermIfElse *ifNode = new TIntermIfElse(cond, useFirstBlock, nullptr);
bodyNode->getSequence()->push_back(ifNode);
bodyNode->getSequence()->push_back(useLastBlock);
TIntermFunctionDefinition *indexingFunction =
new TIntermFunctionDefinition(prototypeNode, bodyNode);
return indexingFunction;
}
class RemoveDynamicIndexingTraverser : public TLValueTrackingTraverser
{
public:
RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable, int shaderVersion);
bool visitBinary(Visit visit, TIntermBinary *node) override;
void insertHelperDefinitions(TIntermNode *root);
void nextIteration();
bool usedTreeInsertion() const { return mUsedTreeInsertion; }
protected:
// Maps of types that are indexed to the indexing function ids used for them. Note that these
// can not store multiple variants of the same type with different precisions - only one
// precision gets stored.
std::map<TType, TSymbolUniqueId> mIndexedVecAndMatrixTypes;
std::map<TType, TSymbolUniqueId> mWrittenVecAndMatrixTypes;
bool mUsedTreeInsertion;
// When true, the traverser will remove side effects from any indexing expression.
// This is done so that in code like
// V[j++][i]++.
// where V is an array of vectors, j++ will only be evaluated once.
bool mRemoveIndexSideEffectsInSubtree;
};
RemoveDynamicIndexingTraverser::RemoveDynamicIndexingTraverser(const TSymbolTable &symbolTable,
int shaderVersion)
: TLValueTrackingTraverser(true, false, false, symbolTable, shaderVersion),
mUsedTreeInsertion(false),
mRemoveIndexSideEffectsInSubtree(false)
{
}
void RemoveDynamicIndexingTraverser::insertHelperDefinitions(TIntermNode *root)
{
TIntermBlock *rootBlock = root->getAsBlock();
ASSERT(rootBlock != nullptr);
TIntermSequence insertions;
for (auto &type : mIndexedVecAndMatrixTypes)
{
insertions.push_back(GetIndexFunctionDefinition(type.first, false, type.second));
}
for (auto &type : mWrittenVecAndMatrixTypes)
{
insertions.push_back(GetIndexFunctionDefinition(type.first, true, type.second));
}
mInsertions.push_back(NodeInsertMultipleEntry(rootBlock, 0, insertions, TIntermSequence()));
}
// Create a call to dyn_index_*() based on an indirect indexing op node
TIntermAggregate *CreateIndexFunctionCall(TIntermBinary *node,
TIntermTyped *index,
const TSymbolUniqueId &functionId)
{
ASSERT(node->getOp() == EOpIndexIndirect);
TIntermSequence *arguments = new TIntermSequence();
arguments->push_back(node->getLeft());
arguments->push_back(index);
TType fieldType = GetFieldType(node->getLeft()->getType());
std::string functionName = GetIndexFunctionName(node->getLeft()->getType(), false);
TIntermAggregate *indexingCall = TIntermTraverser::CreateInternalFunctionCallNode(
fieldType, functionName.c_str(), functionId, arguments);
indexingCall->setLine(node->getLine());
return indexingCall;
}
TIntermAggregate *CreateIndexedWriteFunctionCall(TIntermBinary *node,
TIntermTyped *index,
TIntermTyped *writtenValue,
const TSymbolUniqueId &functionId)
{
ASSERT(node->getOp() == EOpIndexIndirect);
TIntermSequence *arguments = new TIntermSequence();
// Deep copy the child nodes so that two pointers to the same node don't end up in the tree.
arguments->push_back(node->getLeft()->deepCopy());
arguments->push_back(index->deepCopy());
arguments->push_back(writtenValue);
std::string functionName = GetIndexFunctionName(node->getLeft()->getType(), true);
TIntermAggregate *indexedWriteCall = TIntermTraverser::CreateInternalFunctionCallNode(
TType(EbtVoid), functionName.c_str(), functionId, arguments);
indexedWriteCall->setLine(node->getLine());
return indexedWriteCall;
}
bool RemoveDynamicIndexingTraverser::visitBinary(Visit visit, TIntermBinary *node)
{
if (mUsedTreeInsertion)
return false;
if (node->getOp() == EOpIndexIndirect)
{
if (mRemoveIndexSideEffectsInSubtree)
{
ASSERT(node->getRight()->hasSideEffects());
// In case we're just removing index side effects, convert
// v_expr[index_expr]
// to this:
// int s0 = index_expr; v_expr[s0];
// Now v_expr[s0] can be safely executed several times without unintended side effects.
// Init the temp variable holding the index
TIntermDeclaration *initIndex = createTempInitDeclaration(node->getRight());
insertStatementInParentBlock(initIndex);
mUsedTreeInsertion = true;
// Replace the index with the temp variable
TIntermSymbol *tempIndex = createTempSymbol(node->getRight()->getType());
queueReplacementWithParent(node, node->getRight(), tempIndex, OriginalNode::IS_DROPPED);
}
else if (IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(node))
{
bool write = isLValueRequiredHere();
#if defined(ANGLE_ENABLE_ASSERTS)
// Make sure that IntermNodePatternMatcher is consistent with the slightly differently
// implemented checks in this traverser.
IntermNodePatternMatcher matcher(
IntermNodePatternMatcher::kDynamicIndexingOfVectorOrMatrixInLValue);
ASSERT(matcher.match(node, getParentNode(), isLValueRequiredHere()) == write);
#endif
const TType &type = node->getLeft()->getType();
TSymbolUniqueId indexingFunctionId;
if (mIndexedVecAndMatrixTypes.find(type) == mIndexedVecAndMatrixTypes.end())
{
mIndexedVecAndMatrixTypes[type] = indexingFunctionId;
}
else
{
indexingFunctionId = mIndexedVecAndMatrixTypes[type];
}
if (write)
{
// Convert:
// v_expr[index_expr]++;
// to this:
// int s0 = index_expr; float s1 = dyn_index(v_expr, s0); s1++;
// dyn_index_write(v_expr, s0, s1);
// This works even if index_expr has some side effects.
if (node->getLeft()->hasSideEffects())
{
// If v_expr has side effects, those need to be removed before proceeding.
// Otherwise the side effects of v_expr would be evaluated twice.
// The only case where an l-value can have side effects is when it is
// indexing. For example, it can be V[j++] where V is an array of vectors.
mRemoveIndexSideEffectsInSubtree = true;
return true;
}
TIntermBinary *leftBinary = node->getLeft()->getAsBinaryNode();
if (leftBinary != nullptr &&
IntermNodePatternMatcher::IsDynamicIndexingOfVectorOrMatrix(leftBinary))
{
// This is a case like:
// mat2 m;
// m[a][b]++;
// Process the child node m[a] first.
return true;
}
// TODO(oetuaho@nvidia.com): This is not optimal if the expression using the value
// only writes it and doesn't need the previous value. http://anglebug.com/1116
TSymbolUniqueId indexedWriteFunctionId;
if (mWrittenVecAndMatrixTypes.find(type) == mWrittenVecAndMatrixTypes.end())
{
mWrittenVecAndMatrixTypes[type] = indexedWriteFunctionId;
}
else
{
indexedWriteFunctionId = mWrittenVecAndMatrixTypes[type];
}
TType fieldType = GetFieldType(type);
TIntermSequence insertionsBefore;
TIntermSequence insertionsAfter;
// Store the index in a temporary signed int variable.
TIntermTyped *indexInitializer = EnsureSignedInt(node->getRight());
TIntermDeclaration *initIndex = createTempInitDeclaration(indexInitializer);
initIndex->setLine(node->getLine());
insertionsBefore.push_back(initIndex);
// Create a node for referring to the index after the nextTemporaryIndex() call
// below.
TIntermSymbol *tempIndex = createTempSymbol(indexInitializer->getType());
TIntermAggregate *indexingCall =
CreateIndexFunctionCall(node, tempIndex, indexingFunctionId);
nextTemporaryIndex(); // From now on, creating temporary symbols that refer to the
// field value.
insertionsBefore.push_back(createTempInitDeclaration(indexingCall));
TIntermAggregate *indexedWriteCall = CreateIndexedWriteFunctionCall(
node, tempIndex, createTempSymbol(fieldType), indexedWriteFunctionId);
insertionsAfter.push_back(indexedWriteCall);
insertStatementsInParentBlock(insertionsBefore, insertionsAfter);
queueReplacement(node, createTempSymbol(fieldType), OriginalNode::IS_DROPPED);
mUsedTreeInsertion = true;
}
else
{
// The indexed value is not being written, so we can simply convert
// v_expr[index_expr]
// into
// dyn_index(v_expr, index_expr)
// If the index_expr is unsigned, we'll convert it to signed.
ASSERT(!mRemoveIndexSideEffectsInSubtree);
TIntermAggregate *indexingCall = CreateIndexFunctionCall(
node, EnsureSignedInt(node->getRight()), indexingFunctionId);
queueReplacement(node, indexingCall, OriginalNode::IS_DROPPED);
}
}
}
return !mUsedTreeInsertion;
}
void RemoveDynamicIndexingTraverser::nextIteration()
{
mUsedTreeInsertion = false;
mRemoveIndexSideEffectsInSubtree = false;
nextTemporaryIndex();
}
} // namespace
void RemoveDynamicIndexing(TIntermNode *root,
unsigned int *temporaryIndex,
const TSymbolTable &symbolTable,
int shaderVersion)
{
RemoveDynamicIndexingTraverser traverser(symbolTable, shaderVersion);
ASSERT(temporaryIndex != nullptr);
traverser.useTemporaryIndex(temporaryIndex);
do
{
traverser.nextIteration();
root->traverse(&traverser);
traverser.updateTree();
} while (traverser.usedTreeInsertion());
// TOOD(oetuaho@nvidia.com): It might be nicer to add the helper definitions also in the middle
// of traversal. Now the tree ends up in an inconsistent state in the middle, since there are
// function call nodes with no corresponding definition nodes. This needs special handling in
// TIntermLValueTrackingTraverser, and creates intricacies that are not easily apparent from a
// superficial reading of the code.
traverser.insertHelperDefinitions(root);
traverser.updateTree();
}
} // namespace sh