blob: 789643945ae9f16d3153e1c743b34489610d6846 [file] [log] [blame]
//
// Copyright 2002 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//
// The ArrayReturnValueToOutParameter function changes return values of an array type to out
// parameters in function definitions, prototypes, and call sites.
#include "compiler/translator/tree_ops/ArrayReturnValueToOutParameter.h"
#include <map>
#include "compiler/translator/StaticType.h"
#include "compiler/translator/SymbolTable.h"
#include "compiler/translator/tree_util/IntermNode_util.h"
#include "compiler/translator/tree_util/IntermTraverse.h"
namespace sh
{
namespace
{
constexpr const ImmutableString kReturnValueVariableName("angle_return");
class ArrayReturnValueToOutParameterTraverser : private TIntermTraverser
{
public:
ANGLE_NO_DISCARD static bool apply(TCompiler *compiler,
TIntermNode *root,
TSymbolTable *symbolTable);
private:
ArrayReturnValueToOutParameterTraverser(TSymbolTable *symbolTable);
void visitFunctionPrototype(TIntermFunctionPrototype *node) override;
bool visitFunctionDefinition(Visit visit, TIntermFunctionDefinition *node) override;
bool visitAggregate(Visit visit, TIntermAggregate *node) override;
bool visitBranch(Visit visit, TIntermBranch *node) override;
bool visitBinary(Visit visit, TIntermBinary *node) override;
TIntermAggregate *createReplacementCall(TIntermAggregate *originalCall,
TIntermTyped *returnValueTarget);
// Set when traversal is inside a function with array return value.
TIntermFunctionDefinition *mFunctionWithArrayReturnValue;
struct ChangedFunction
{
const TVariable *returnValueVariable;
const TFunction *func;
};
// Map from function symbol ids to the changed function.
std::map<int, ChangedFunction> mChangedFunctions;
};
TIntermAggregate *ArrayReturnValueToOutParameterTraverser::createReplacementCall(
TIntermAggregate *originalCall,
TIntermTyped *returnValueTarget)
{
TIntermSequence *replacementArguments = new TIntermSequence();
TIntermSequence *originalArguments = originalCall->getSequence();
for (auto &arg : *originalArguments)
{
replacementArguments->push_back(arg);
}
replacementArguments->push_back(returnValueTarget);
ASSERT(originalCall->getFunction());
const TSymbolUniqueId &originalId = originalCall->getFunction()->uniqueId();
TIntermAggregate *replacementCall = TIntermAggregate::CreateFunctionCall(
*mChangedFunctions[originalId.get()].func, replacementArguments);
replacementCall->setLine(originalCall->getLine());
return replacementCall;
}
bool ArrayReturnValueToOutParameterTraverser::apply(TCompiler *compiler,
TIntermNode *root,
TSymbolTable *symbolTable)
{
ArrayReturnValueToOutParameterTraverser arrayReturnValueToOutParam(symbolTable);
root->traverse(&arrayReturnValueToOutParam);
return arrayReturnValueToOutParam.updateTree(compiler, root);
}
ArrayReturnValueToOutParameterTraverser::ArrayReturnValueToOutParameterTraverser(
TSymbolTable *symbolTable)
: TIntermTraverser(true, false, true, symbolTable), mFunctionWithArrayReturnValue(nullptr)
{}
bool ArrayReturnValueToOutParameterTraverser::visitFunctionDefinition(
Visit visit,
TIntermFunctionDefinition *node)
{
if (node->getFunctionPrototype()->isArray() && visit == PreVisit)
{
// Replacing the function header is done on visitFunctionPrototype().
mFunctionWithArrayReturnValue = node;
}
if (visit == PostVisit)
{
mFunctionWithArrayReturnValue = nullptr;
}
return true;
}
void ArrayReturnValueToOutParameterTraverser::visitFunctionPrototype(TIntermFunctionPrototype *node)
{
if (node->isArray())
{
// Replace the whole prototype node with another node that has the out parameter
// added. Also set the function to return void.
const TSymbolUniqueId &functionId = node->getFunction()->uniqueId();
if (mChangedFunctions.find(functionId.get()) == mChangedFunctions.end())
{
TType *returnValueVariableType = new TType(node->getType());
returnValueVariableType->setQualifier(EvqOut);
ChangedFunction changedFunction;
changedFunction.returnValueVariable =
new TVariable(mSymbolTable, kReturnValueVariableName, returnValueVariableType,
SymbolType::AngleInternal);
TFunction *func = new TFunction(mSymbolTable, node->getFunction()->name(),
node->getFunction()->symbolType(),
StaticType::GetBasic<EbtVoid>(), false);
for (size_t i = 0; i < node->getFunction()->getParamCount(); ++i)
{
func->addParameter(node->getFunction()->getParam(i));
}
func->addParameter(changedFunction.returnValueVariable);
changedFunction.func = func;
mChangedFunctions[functionId.get()] = changedFunction;
}
TIntermFunctionPrototype *replacement =
new TIntermFunctionPrototype(mChangedFunctions[functionId.get()].func);
replacement->setLine(node->getLine());
queueReplacement(replacement, OriginalNode::IS_DROPPED);
}
}
bool ArrayReturnValueToOutParameterTraverser::visitAggregate(Visit visit, TIntermAggregate *node)
{
ASSERT(!node->isArray() || node->getOp() != EOpCallInternalRawFunction);
if (visit == PreVisit && node->isArray() && node->getOp() == EOpCallFunctionInAST)
{
// Handle call sites where the returned array is not assigned.
// Examples where f() is a function returning an array:
// 1. f();
// 2. another_array == f();
// 3. another_function(f());
// 4. return f();
// Cases 2 to 4 are already converted to simpler cases by
// SeparateExpressionsReturningArrays, so we only need to worry about the case where a
// function call returning an array forms an expression by itself.
TIntermBlock *parentBlock = getParentNode()->getAsBlock();
if (parentBlock)
{
// replace
// f();
// with
// type s0[size]; f(s0);
TIntermSequence replacements;
// type s0[size];
TIntermDeclaration *returnValueDeclaration = nullptr;
TVariable *returnValue = DeclareTempVariable(mSymbolTable, new TType(node->getType()),
EvqTemporary, &returnValueDeclaration);
replacements.push_back(returnValueDeclaration);
// f(s0);
TIntermSymbol *returnValueSymbol = CreateTempSymbolNode(returnValue);
replacements.push_back(createReplacementCall(node, returnValueSymbol));
mMultiReplacements.push_back(
NodeReplaceWithMultipleEntry(parentBlock, node, replacements));
}
return false;
}
return true;
}
bool ArrayReturnValueToOutParameterTraverser::visitBranch(Visit visit, TIntermBranch *node)
{
if (mFunctionWithArrayReturnValue && node->getFlowOp() == EOpReturn)
{
// Instead of returning a value, assign to the out parameter and then return.
TIntermSequence replacements;
TIntermTyped *expression = node->getExpression();
ASSERT(expression != nullptr);
const TSymbolUniqueId &functionId =
mFunctionWithArrayReturnValue->getFunction()->uniqueId();
ASSERT(mChangedFunctions.find(functionId.get()) != mChangedFunctions.end());
TIntermSymbol *returnValueSymbol =
new TIntermSymbol(mChangedFunctions[functionId.get()].returnValueVariable);
TIntermBinary *replacementAssignment =
new TIntermBinary(EOpAssign, returnValueSymbol, expression);
replacementAssignment->setLine(expression->getLine());
replacements.push_back(replacementAssignment);
TIntermBranch *replacementBranch = new TIntermBranch(EOpReturn, nullptr);
replacementBranch->setLine(node->getLine());
replacements.push_back(replacementBranch);
mMultiReplacements.push_back(
NodeReplaceWithMultipleEntry(getParentNode()->getAsBlock(), node, replacements));
}
return false;
}
bool ArrayReturnValueToOutParameterTraverser::visitBinary(Visit visit, TIntermBinary *node)
{
if (node->getOp() == EOpAssign && node->getLeft()->isArray())
{
TIntermAggregate *rightAgg = node->getRight()->getAsAggregate();
ASSERT(rightAgg == nullptr || rightAgg->getOp() != EOpCallInternalRawFunction);
if (rightAgg != nullptr && rightAgg->getOp() == EOpCallFunctionInAST)
{
TIntermAggregate *replacementCall = createReplacementCall(rightAgg, node->getLeft());
queueReplacement(replacementCall, OriginalNode::IS_DROPPED);
}
}
return false;
}
} // namespace
bool ArrayReturnValueToOutParameter(TCompiler *compiler,
TIntermNode *root,
TSymbolTable *symbolTable)
{
return ArrayReturnValueToOutParameterTraverser::apply(compiler, root, symbolTable);
}
} // namespace sh