blob: b9b33fcfb4eb0adfc7ad04ab8cf8fa37480968a3 [file] [log] [blame]
// Copyright 2015 The Cobalt Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "starboard/nplb/socket_helpers.h"
#include "starboard/common/scoped_ptr.h"
#include "starboard/common/socket.h"
#include "starboard/once.h"
#include "starboard/socket_waiter.h"
#include "starboard/thread.h"
#include "starboard/time.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace starboard {
namespace nplb {
namespace {
int port_number_for_tests = 0;
SbOnceControl valid_port_once_control = SB_ONCE_INITIALIZER;
void InitializePortNumberForTests() {
// Create a listening socket. Let the system choose a port for us.
SbSocket socket = CreateListeningTcpSocket(kSbSocketAddressTypeIpv4, 0);
SB_DCHECK(socket != kSbSocketInvalid);
// Query which port this socket was bound to and save it to valid_port_number.
SbSocketAddress socket_address = {0};
bool result = SbSocketGetLocalAddress(socket, &socket_address);
SB_DCHECK(result);
port_number_for_tests = socket_address.port;
// Clean up the socket.
result = SbSocketDestroy(socket);
SB_DCHECK(result);
}
} // namespace
int GetPortNumberForTests() {
#if defined(SB_SOCKET_OVERRIDE_PORT_FOR_TESTS)
return SB_SOCKET_OVERRIDE_PORT_FOR_TESTS;
#else
SbOnce(&valid_port_once_control, &InitializePortNumberForTests);
return port_number_for_tests;
#endif
}
bool IsUnspecified(const SbSocketAddress* address) {
// Look at each piece of memory and make sure too many of them aren't zero.
int components = (address->type == kSbSocketAddressTypeIpv4 ? 4 : 16);
int zero_count = 0;
for (int i = 0; i < components; ++i) {
if (address->address[i] == 0) {
++zero_count;
}
}
return components == zero_count;
}
bool IsLocalhost(const SbSocketAddress* address) {
if (address->type == kSbSocketAddressTypeIpv4) {
return address->address[0] == 127;
}
if (address->type == kSbSocketAddressTypeIpv6) {
bool may_be_localhost = true;
for (int i = 0; i < 15; ++i) {
may_be_localhost &= (address->address[i] == 0);
}
return (may_be_localhost && address->address[15] == 1);
}
return false;
}
SbSocket CreateServerTcpSocket(SbSocketAddressType address_type) {
SbSocket server_socket = SbSocketCreate(address_type, kSbSocketProtocolTcp);
if (!SbSocketIsValid(server_socket)) {
ADD_FAILURE() << "SbSocketCreate failed";
return kSbSocketInvalid;
}
if (!SbSocketSetReuseAddress(server_socket, true)) {
ADD_FAILURE() << "SbSocketSetReuseAddress failed";
SbSocketDestroy(server_socket);
return kSbSocketInvalid;
}
return server_socket;
}
scoped_ptr<Socket> CreateServerTcpSocketWrapped(
SbSocketAddressType address_type) {
scoped_ptr<Socket> server_socket =
make_scoped_ptr(new Socket(address_type, kSbSocketProtocolTcp));
if (!server_socket->IsValid()) {
ADD_FAILURE() << "SbSocketCreate failed";
return scoped_ptr<Socket>().Pass();
}
if (!server_socket->SetReuseAddress(true)) {
ADD_FAILURE() << "SbSocketSetReuseAddress failed";
return scoped_ptr<Socket>().Pass();
}
return server_socket.Pass();
}
SbSocket CreateBoundTcpSocket(SbSocketAddressType address_type, int port) {
SbSocket server_socket = CreateServerTcpSocket(address_type);
if (!SbSocketIsValid(server_socket)) {
return kSbSocketInvalid;
}
SbSocketAddress address = GetUnspecifiedAddress(address_type, port);
SbSocketError result = SbSocketBind(server_socket, &address);
if (result != kSbSocketOk) {
ADD_FAILURE() << "SbSocketBind to " << port << " failed: " << result;
SbSocketDestroy(server_socket);
return kSbSocketInvalid;
}
return server_socket;
}
scoped_ptr<Socket> CreateBoundTcpSocketWrapped(SbSocketAddressType address_type,
int port) {
scoped_ptr<Socket> server_socket = CreateServerTcpSocketWrapped(address_type);
if (!server_socket) {
return scoped_ptr<Socket>().Pass();
}
SbSocketAddress address = GetUnspecifiedAddress(address_type, port);
SbSocketError result = server_socket->Bind(&address);
if (result != kSbSocketOk) {
ADD_FAILURE() << "SbSocketBind to " << port << " failed: " << result;
return scoped_ptr<Socket>().Pass();
}
return server_socket.Pass();
}
SbSocket CreateListeningTcpSocket(SbSocketAddressType address_type, int port) {
SbSocket server_socket = CreateBoundTcpSocket(address_type, port);
if (!SbSocketIsValid(server_socket)) {
return kSbSocketInvalid;
}
SbSocketError result = SbSocketListen(server_socket);
if (result != kSbSocketOk) {
ADD_FAILURE() << "SbSocketListen failed: " << result;
SbSocketDestroy(server_socket);
return kSbSocketInvalid;
}
return server_socket;
}
scoped_ptr<Socket> CreateListeningTcpSocketWrapped(
SbSocketAddressType address_type,
int port) {
scoped_ptr<Socket> server_socket =
CreateBoundTcpSocketWrapped(address_type, port);
if (!server_socket) {
return scoped_ptr<Socket>().Pass();
}
SbSocketError result = server_socket->Listen();
if (result != kSbSocketOk) {
ADD_FAILURE() << "SbSocketListen failed: " << result;
return scoped_ptr<Socket>().Pass();
}
return server_socket.Pass();
}
namespace {
SbSocket CreateConnectingTcpSocket(SbSocketAddressType address_type, int port) {
SbSocket client_socket = SbSocketCreate(address_type, kSbSocketProtocolTcp);
if (!SbSocketIsValid(client_socket)) {
ADD_FAILURE() << "SbSocketCreate failed";
return kSbSocketInvalid;
}
// Connect to localhost:<port>.
SbSocketAddress address = {};
bool success = GetLocalhostAddress(address_type, port, &address);
if (!success) {
ADD_FAILURE() << "GetLocalhostAddress failed";
return kSbSocketInvalid;
}
// This connect will probably return pending, but we'll assume it will connect
// eventually.
SbSocketError result = SbSocketConnect(client_socket, &address);
if (result != kSbSocketOk && result != kSbSocketPending) {
ADD_FAILURE() << "SbSocketConnect failed: " << result;
SbSocketDestroy(client_socket);
return kSbSocketInvalid;
}
return client_socket;
}
scoped_ptr<Socket> CreateConnectingTcpSocketWrapped(
SbSocketAddressType address_type,
int port) {
scoped_ptr<Socket> client_socket =
make_scoped_ptr(new Socket(address_type, kSbSocketProtocolTcp));
if (!client_socket->IsValid()) {
ADD_FAILURE() << "SbSocketCreate failed";
return scoped_ptr<Socket>().Pass();
}
// Connect to localhost:<port>.
SbSocketAddress address = {};
bool success = GetLocalhostAddress(address_type, port, &address);
if (!success) {
ADD_FAILURE() << "GetLocalhostAddress failed";
return scoped_ptr<Socket>().Pass();
}
// This connect will probably return pending, but we'll assume it will connect
// eventually.
SbSocketError result = client_socket->Connect(&address);
if (result != kSbSocketOk && result != kSbSocketPending) {
ADD_FAILURE() << "SbSocketConnect failed: " << result;
return scoped_ptr<Socket>().Pass();
}
return client_socket.Pass();
}
} // namespace
SbSocket AcceptBySpinning(SbSocket server_socket, SbTime timeout) {
SbTimeMonotonic start = SbTimeGetMonotonicNow();
while (true) {
SbSocket accepted_socket = SbSocketAccept(server_socket);
if (SbSocketIsValid(accepted_socket)) {
return accepted_socket;
}
// If we didn't get a socket, it should be pending.
EXPECT_EQ(kSbSocketPending, SbSocketGetLastError(server_socket));
// Check if we have passed our timeout.
if (SbTimeGetMonotonicNow() - start >= timeout) {
break;
}
// Just being polite.
SbThreadYield();
}
return kSbSocketInvalid;
}
scoped_ptr<Socket> AcceptBySpinning(Socket* server_socket, SbTime timeout) {
SbTimeMonotonic start = SbTimeGetMonotonicNow();
while (true) {
Socket* accepted_socket = server_socket->Accept();
if (accepted_socket && accepted_socket->IsValid()) {
return make_scoped_ptr(accepted_socket);
}
// If we didn't get a socket, it should be pending.
EXPECT_TRUE(server_socket->IsPending());
// Check if we have passed our timeout.
if (SbTimeGetMonotonicNow() - start >= timeout) {
break;
}
// Just being polite.
SbThreadYield();
}
return scoped_ptr<Socket>().Pass();
}
bool WriteBySpinning(SbSocket socket,
const char* data,
int data_size,
SbTime timeout) {
SbTimeMonotonic start = SbTimeGetMonotonicNow();
int total = 0;
while (total < data_size) {
int sent = SbSocketSendTo(socket, data + total, data_size - total, NULL);
if (sent >= 0) {
total += sent;
continue;
}
if (SbSocketGetLastError(socket) != kSbSocketPending) {
return false;
}
if (SbTimeGetMonotonicNow() - start >= timeout) {
return false;
}
SbThreadYield();
}
return true;
}
bool WriteBySpinning(Socket* socket,
const char* data,
int data_size,
SbTime timeout) {
SbTimeMonotonic start = SbTimeGetMonotonicNow();
int total = 0;
while (total < data_size) {
int sent = socket->SendTo(data + total, data_size - total, NULL);
if (sent >= 0) {
total += sent;
continue;
}
if (!socket->IsPending()) {
return false;
}
if (SbTimeGetMonotonicNow() - start >= timeout) {
return false;
}
SbThreadYield();
}
return true;
}
bool ReadBySpinning(SbSocket socket,
char* out_data,
int data_size,
SbTime timeout) {
SbTimeMonotonic start = SbTimeGetMonotonicNow();
int total = 0;
while (total < data_size) {
int received =
SbSocketReceiveFrom(socket, out_data + total, data_size - total, NULL);
if (received >= 0) {
total += received;
continue;
}
if (SbSocketGetLastError(socket) != kSbSocketPending) {
return false;
}
if (SbTimeGetMonotonicNow() - start >= timeout) {
return false;
}
SbThreadYield();
}
return true;
}
bool ReadBySpinning(Socket* socket,
char* out_data,
int data_size,
SbTime timeout) {
SbTimeMonotonic start = SbTimeGetMonotonicNow();
int total = 0;
while (total < data_size) {
int received =
socket->ReceiveFrom(out_data + total, data_size - total, NULL);
if (received >= 0) {
total += received;
continue;
}
if (!socket->IsPending()) {
return false;
}
if (SbTimeGetMonotonicNow() - start >= timeout) {
return false;
}
SbThreadYield();
}
return true;
}
int Transfer(SbSocket receive_socket,
char* out_data,
SbSocket send_socket,
const char* send_data,
int size) {
int send_total = 0;
int receive_total = 0;
while (receive_total < size) {
if (send_total < size) {
int bytes_sent = SbSocketSendTo(send_socket, send_data + send_total,
size - send_total, NULL);
if (bytes_sent < 0) {
if (SbSocketGetLastError(send_socket) != kSbSocketPending) {
return -1;
}
bytes_sent = 0;
}
send_total += bytes_sent;
}
int bytes_received = SbSocketReceiveFrom(
receive_socket, out_data + receive_total, size - receive_total, NULL);
if (bytes_received < 0) {
if (SbSocketGetLastError(receive_socket) != kSbSocketPending) {
return -1;
}
bytes_received = 0;
}
receive_total += bytes_received;
}
return size;
}
int Transfer(Socket* receive_socket,
char* out_data,
Socket* send_socket,
const char* send_data,
int size) {
int send_total = 0;
int receive_total = 0;
while (receive_total < size) {
if (send_total < size) {
int bytes_sent =
send_socket->SendTo(send_data + send_total, size - send_total, NULL);
if (bytes_sent < 0) {
if (!send_socket->IsPending()) {
return -1;
}
bytes_sent = 0;
}
send_total += bytes_sent;
}
int bytes_received = receive_socket->ReceiveFrom(
out_data + receive_total, size - receive_total, NULL);
if (bytes_received < 0) {
if (!receive_socket->IsPending()) {
return -1;
}
bytes_received = 0;
}
receive_total += bytes_received;
}
return size;
}
ConnectedTrio CreateAndConnect(SbSocketAddressType server_address_type,
SbSocketAddressType client_address_type,
int port,
SbTime timeout) {
// Verify the listening socket.
SbSocket listen_socket = CreateListeningTcpSocket(server_address_type, port);
if (!SbSocketIsValid(listen_socket)) {
ADD_FAILURE() << "Could not create listen socket.";
return ConnectedTrio();
}
// Verify the socket to connect to the listening socket.
SbSocket client_socket = CreateConnectingTcpSocket(client_address_type, port);
if (!SbSocketIsValid(client_socket)) {
ADD_FAILURE() << "Could not create client socket.";
EXPECT_TRUE(SbSocketDestroy(listen_socket));
return ConnectedTrio();
}
// Spin until the accept happens (or we get impatient).
SbTimeMonotonic start = SbTimeGetMonotonicNow();
SbSocket server_socket = AcceptBySpinning(listen_socket, timeout);
if (!SbSocketIsValid(server_socket)) {
ADD_FAILURE() << "Failed to accept within " << timeout;
EXPECT_TRUE(SbSocketDestroy(listen_socket));
EXPECT_TRUE(SbSocketDestroy(client_socket));
return ConnectedTrio();
}
return ConnectedTrio(listen_socket, client_socket, server_socket);
}
scoped_ptr<ConnectedTrioWrapped> CreateAndConnectWrapped(
SbSocketAddressType server_address_type,
SbSocketAddressType client_address_type,
int port,
SbTime timeout) {
// Verify the listening socket.
scoped_ptr<Socket> listen_socket =
CreateListeningTcpSocketWrapped(server_address_type, port);
if (!listen_socket || !listen_socket->IsValid()) {
ADD_FAILURE() << "Could not create listen socket.";
return scoped_ptr<ConnectedTrioWrapped>().Pass();
}
// Verify the socket to connect to the listening socket.
scoped_ptr<Socket> client_socket =
CreateConnectingTcpSocketWrapped(client_address_type, port);
if (!client_socket || !client_socket->IsValid()) {
ADD_FAILURE() << "Could not create client socket.";
return scoped_ptr<ConnectedTrioWrapped>().Pass();
}
// Spin until the accept happens (or we get impatient).
SbTimeMonotonic start = SbTimeGetMonotonicNow();
scoped_ptr<Socket> server_socket =
AcceptBySpinning(listen_socket.get(), timeout);
if (!server_socket || !server_socket->IsValid()) {
ADD_FAILURE() << "Failed to accept within " << timeout;
return scoped_ptr<ConnectedTrioWrapped>().Pass();
}
return make_scoped_ptr(new ConnectedTrioWrapped(
listen_socket.Pass(), client_socket.Pass(), server_socket.Pass()));
}
SbTimeMonotonic TimedWait(SbSocketWaiter waiter) {
SbTimeMonotonic start = SbTimeGetMonotonicNow();
SbSocketWaiterWait(waiter);
return SbTimeGetMonotonicNow() - start;
}
// Waits on the given waiter, and returns the elapsed time.
SbTimeMonotonic TimedWaitTimed(SbSocketWaiter waiter, SbTime timeout) {
SbTimeMonotonic start = SbTimeGetMonotonicNow();
SbSocketWaiterWaitTimed(waiter, timeout);
return SbTimeGetMonotonicNow() - start;
}
} // namespace nplb
} // namespace starboard