// Copyright (c) 2012 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "net/dns/address_sorter_posix.h"

#if defined(OS_STARBOARD)
#include "starboard/client_porting/poem/inet_poem.h"
#else
#include <netinet/in.h>
#endif

#if defined(OS_MACOSX) || defined(OS_BSD)
#include <sys/socket.h>  // Must be included before ifaddrs.h.
#include <ifaddrs.h>
#include <net/if.h>
#include <netinet/in_var.h>
#include <string.h>
#include <sys/ioctl.h>
#endif

#include <algorithm>

#include "base/logging.h"
#include "base/memory/scoped_vector.h"
#include "base/posix/eintr_wrapper.h"
#include "net/socket/client_socket_factory.h"
#include "net/udp/datagram_client_socket.h"

#if defined(OS_LINUX)
#include "net/base/address_tracker_linux.h"
#endif

namespace net {

namespace {

// Address sorting is performed according to RFC3484 with revisions.
// http://tools.ietf.org/html/draft-ietf-6man-rfc3484bis-06
// Precedence and label are separate to support override through /etc/gai.conf.

// Returns true if |p1| should precede |p2| in the table.
// Sorts table by decreasing prefix size to allow longest prefix matching.
bool ComparePolicy(const AddressSorterPosix::PolicyEntry& p1,
                   const AddressSorterPosix::PolicyEntry& p2) {
  return p1.prefix_length > p2.prefix_length;
}

// Creates sorted PolicyTable from |table| with |size| entries.
AddressSorterPosix::PolicyTable LoadPolicy(
    AddressSorterPosix::PolicyEntry* table,
    size_t size) {
  AddressSorterPosix::PolicyTable result(table, table + size);
  std::sort(result.begin(), result.end(), ComparePolicy);
  return result;
}

// Search |table| for matching prefix of |address|. |table| must be sorted by
// descending prefix (prefix of another prefix must be later in table).
unsigned GetPolicyValue(const AddressSorterPosix::PolicyTable& table,
                        const IPAddressNumber& address) {
  if (address.size() == kIPv4AddressSize)
    return GetPolicyValue(table, ConvertIPv4NumberToIPv6Number(address));
  for (unsigned i = 0; i < table.size(); ++i) {
    const AddressSorterPosix::PolicyEntry& entry = table[i];
    IPAddressNumber prefix(entry.prefix, entry.prefix + kIPv6AddressSize);
    if (IPNumberMatchesPrefix(address, prefix, entry.prefix_length))
      return entry.value;
  }
  NOTREACHED();
  // The last entry is the least restrictive, so assume it's default.
  return table.back().value;
}

bool IsIPv6Multicast(const IPAddressNumber& address) {
  DCHECK_EQ(kIPv6AddressSize, address.size());
  return address[0] == 0xFF;
}

AddressSorterPosix::AddressScope GetIPv6MulticastScope(
    const IPAddressNumber& address) {
  DCHECK_EQ(kIPv6AddressSize, address.size());
  return static_cast<AddressSorterPosix::AddressScope>(address[1] & 0x0F);
}

bool IsIPv6Loopback(const IPAddressNumber& address) {
  DCHECK_EQ(kIPv6AddressSize, address.size());
  // IN6_IS_ADDR_LOOPBACK
  unsigned char kLoopback[kIPv6AddressSize] = {
    0, 0, 0, 0, 0, 0, 0, 0,
    0, 0, 0, 0, 0, 0, 0, 1,
  };
  return address == IPAddressNumber(kLoopback, kLoopback + kIPv6AddressSize);
}

bool IsIPv6LinkLocal(const IPAddressNumber& address) {
  DCHECK_EQ(kIPv6AddressSize, address.size());
  // IN6_IS_ADDR_LINKLOCAL
  return (address[0] == 0xFE) && ((address[1] & 0xC0) == 0x80);
}

bool IsIPv6SiteLocal(const IPAddressNumber& address) {
  DCHECK_EQ(kIPv6AddressSize, address.size());
  // IN6_IS_ADDR_SITELOCAL
  return (address[0] == 0xFE) && ((address[1] & 0xC0) == 0xC0);
}

AddressSorterPosix::AddressScope GetScope(
    const AddressSorterPosix::PolicyTable& ipv4_scope_table,
    const IPAddressNumber& address) {
  if (address.size() == kIPv6AddressSize) {
    if (IsIPv6Multicast(address)) {
      return GetIPv6MulticastScope(address);
    } else if (IsIPv6Loopback(address) || IsIPv6LinkLocal(address)) {
      return AddressSorterPosix::SCOPE_LINKLOCAL;
    } else if (IsIPv6SiteLocal(address)) {
      return AddressSorterPosix::SCOPE_SITELOCAL;
    } else {
      return AddressSorterPosix::SCOPE_GLOBAL;
    }
  } else if (address.size() == kIPv4AddressSize) {
    return static_cast<AddressSorterPosix::AddressScope>(
        GetPolicyValue(ipv4_scope_table, address));
  } else {
    NOTREACHED();
    return AddressSorterPosix::SCOPE_NODELOCAL;
  }
}

// Default policy table. RFC 3484, Section 2.1.
AddressSorterPosix::PolicyEntry kDefaultPrecedenceTable[] = {
  // ::1/128 -- loopback
  { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 128, 50 },
  // ::/0 -- any
  { { }, 0, 40 },
  // ::ffff:0:0/96 -- IPv4 mapped
  { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF }, 96, 35 },
  // 2002::/16 -- 6to4
  { { 0x20, 0x02, }, 16, 30 },
  // 2001::/32 -- Teredo
  { { 0x20, 0x01, 0, 0 }, 32, 5 },
  // fc00::/7 -- unique local address
  { { 0xFC }, 7, 3 },
  // ::/96 -- IPv4 compatible
  { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 96, 1 },
  // fec0::/10 -- site-local expanded scope
  { { 0xFE, 0xC0 }, 10, 1 },
  // 3ffe::/16 -- 6bone
  { { 0x3F, 0xFE }, 16, 1 },
};

AddressSorterPosix::PolicyEntry kDefaultLabelTable[] = {
  // ::1/128 -- loopback
  { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1 }, 128, 0 },
  // ::/0 -- any
  { { }, 0, 1 },
  // ::ffff:0:0/96 -- IPv4 mapped
  { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF }, 96, 4 },
  // 2002::/16 -- 6to4
  { { 0x20, 0x02, }, 16, 2 },
  // 2001::/32 -- Teredo
  { { 0x20, 0x01, 0, 0 }, 32, 5 },
  // fc00::/7 -- unique local address
  { { 0xFC }, 7, 13 },
  // ::/96 -- IPv4 compatible
  { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }, 96, 3 },
  // fec0::/10 -- site-local expanded scope
  { { 0xFE, 0xC0 }, 10, 11 },
  // 3ffe::/16 -- 6bone
  { { 0x3F, 0xFE }, 16, 12 },
};

// Default mapping of IPv4 addresses to scope.
AddressSorterPosix::PolicyEntry kDefaultIPv4ScopeTable[] = {
  { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0x7F }, 104,
      AddressSorterPosix::SCOPE_LINKLOCAL },
  { { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xFF, 0xFF, 0xA9, 0xFE }, 112,
      AddressSorterPosix::SCOPE_LINKLOCAL },
  { { }, 0, AddressSorterPosix::SCOPE_GLOBAL },
};

// Returns number of matching initial bits between the addresses |a1| and |a2|.
unsigned CommonPrefixLength(const IPAddressNumber& a1,
                            const IPAddressNumber& a2) {
  DCHECK_EQ(a1.size(), a2.size());
  for (size_t i = 0; i < a1.size(); ++i) {
    unsigned diff = a1[i] ^ a2[i];
    if (!diff)
      continue;
    for (unsigned j = 0; j < CHAR_BIT; ++j) {
      if (diff & (1 << (CHAR_BIT - 1)))
        return i * CHAR_BIT + j;
      diff <<= 1;
    }
    NOTREACHED();
  }
  return a1.size() * CHAR_BIT;
}

// Computes the number of leading 1-bits in |mask|.
unsigned MaskPrefixLength(const IPAddressNumber& mask) {
  IPAddressNumber all_ones(mask.size(), 0xFF);
  return CommonPrefixLength(mask, all_ones);
}

struct DestinationInfo {
  IPAddressNumber address;
  AddressSorterPosix::AddressScope scope;
  unsigned precedence;
  unsigned label;
  const AddressSorterPosix::SourceAddressInfo* src;
  unsigned common_prefix_length;
};

// Returns true iff |dst_a| should precede |dst_b| in the address list.
// RFC 3484, section 6.
bool CompareDestinations(const DestinationInfo* dst_a,
                         const DestinationInfo* dst_b) {
  // Rule 1: Avoid unusable destinations.
  // Unusable destinations are already filtered out.
  DCHECK(dst_a->src);
  DCHECK(dst_b->src);

  // Rule 2: Prefer matching scope.
  bool scope_match1 = (dst_a->src->scope == dst_a->scope);
  bool scope_match2 = (dst_b->src->scope == dst_b->scope);
  if (scope_match1 != scope_match2)
    return scope_match1;

  // Rule 3: Avoid deprecated addresses.
  if (dst_a->src->deprecated != dst_b->src->deprecated)
    return !dst_a->src->deprecated;

  // Rule 4: Prefer home addresses.
  if (dst_a->src->home != dst_b->src->home)
    return dst_a->src->home;

  // Rule 5: Prefer matching label.
  bool label_match1 = (dst_a->src->label == dst_a->label);
  bool label_match2 = (dst_b->src->label == dst_b->label);
  if (label_match1 != label_match2)
    return label_match1;

  // Rule 6: Prefer higher precedence.
  if (dst_a->precedence != dst_b->precedence)
    return dst_a->precedence > dst_b->precedence;

  // Rule 7: Prefer native transport.
  if (dst_a->src->native != dst_b->src->native)
    return dst_a->src->native;

  // Rule 8: Prefer smaller scope.
  if (dst_a->scope != dst_b->scope)
    return dst_a->scope < dst_b->scope;

  // Rule 9: Use longest matching prefix. Only for matching address families.
  if (dst_a->address.size() == dst_b->address.size()) {
    if (dst_a->common_prefix_length != dst_b->common_prefix_length)
      return dst_a->common_prefix_length > dst_b->common_prefix_length;
  }

  // Rule 10: Leave the order unchanged.
  // stable_sort takes care of that.
  return false;
}

}  // namespace

AddressSorterPosix::AddressSorterPosix(ClientSocketFactory* socket_factory)
    : socket_factory_(socket_factory),
      precedence_table_(LoadPolicy(kDefaultPrecedenceTable,
                                   arraysize(kDefaultPrecedenceTable))),
      label_table_(LoadPolicy(kDefaultLabelTable,
                              arraysize(kDefaultLabelTable))),
      ipv4_scope_table_(LoadPolicy(kDefaultIPv4ScopeTable,
                              arraysize(kDefaultIPv4ScopeTable))) {
  NetworkChangeNotifier::AddIPAddressObserver(this);
  OnIPAddressChanged();
}

AddressSorterPosix::~AddressSorterPosix() {
  NetworkChangeNotifier::RemoveIPAddressObserver(this);
}

void AddressSorterPosix::Sort(const AddressList& list,
                              const CallbackType& callback) const {
  DCHECK(CalledOnValidThread());
  ScopedVector<DestinationInfo> sort_list;

  for (size_t i = 0; i < list.size(); ++i) {
    scoped_ptr<DestinationInfo> info(new DestinationInfo());
    info->address = list[i].address();
    info->scope = GetScope(ipv4_scope_table_, info->address);
    info->precedence = GetPolicyValue(precedence_table_, info->address);
    info->label = GetPolicyValue(label_table_, info->address);

    // Each socket can only be bound once.
    scoped_ptr<DatagramClientSocket> socket(
        socket_factory_->CreateDatagramClientSocket(
            DatagramSocket::DEFAULT_BIND,
            RandIntCallback(),
            NULL /* NetLog */,
            NetLog::Source()));

    // Even though no packets are sent, cannot use port 0 in Connect.
    IPEndPoint dest(info->address, 80 /* port */);
    int rv = socket->Connect(dest);
    if (rv != OK) {
      LOG(WARNING) << "Could not connect to " << dest.ToStringWithoutPort()
                   << " reason " << rv;
      continue;
    }
    // Filter out unusable destinations.
    IPEndPoint src;
    rv = socket->GetLocalAddress(&src);
    if (rv != OK) {
      LOG(WARNING) << "Could not get local address for "
                   << src.ToStringWithoutPort() << " reason " << rv;
      continue;
    }

    SourceAddressInfo& src_info = source_map_[src.address()];
    if (src_info.scope == SCOPE_UNDEFINED) {
      // If |source_info_| is out of date, |src| might be missing, but we still
      // want to sort, even though the HostCache will be cleared soon.
      FillPolicy(src.address(), &src_info);
    }
    info->src = &src_info;

    if (info->address.size() == src.address().size()) {
      info->common_prefix_length = std::min(
          CommonPrefixLength(info->address, src.address()),
          info->src->prefix_length);
    }
    sort_list.push_back(info.release());
  }

  std::stable_sort(sort_list.begin(), sort_list.end(), CompareDestinations);

  AddressList result;
  for (size_t i = 0; i < sort_list.size(); ++i)
    result.push_back(IPEndPoint(sort_list[i]->address, 0 /* port */));

  callback.Run(true, result);
}

void AddressSorterPosix::OnIPAddressChanged() {
  DCHECK(CalledOnValidThread());
  source_map_.clear();
#if defined(OS_LINUX)
  const internal::AddressTrackerLinux* tracker =
      NetworkChangeNotifier::GetAddressTracker();
  if (!tracker)
    return;
  typedef internal::AddressTrackerLinux::AddressMap AddressMap;
  AddressMap map = tracker->GetAddressMap();
  for (AddressMap::const_iterator it = map.begin(); it != map.end(); ++it) {
    const IPAddressNumber& address = it->first;
    const struct ifaddrmsg& msg = it->second;
    SourceAddressInfo& info = source_map_[address];
    info.native = false;  // TODO(szym): obtain this via netlink.
    info.deprecated = msg.ifa_flags & IFA_F_DEPRECATED;
    info.home = msg.ifa_flags & IFA_F_HOMEADDRESS;
    info.prefix_length = msg.ifa_prefixlen;
    FillPolicy(address, &info);
  }
#elif defined(OS_MACOSX) || defined(OS_BSD)
  // It's not clear we will receive notification when deprecated flag changes.
  // Socket for ioctl.
  int ioctl_socket = socket(AF_INET6, SOCK_DGRAM, 0);
  if (ioctl_socket < 0)
    return;

  struct ifaddrs* addrs;
  int rv = getifaddrs(&addrs);
  if (rv < 0) {
    LOG(WARNING) << "getifaddrs failed " << rv;
    close(ioctl_socket);
    return;
  }

  for (struct ifaddrs* ifa = addrs; ifa != NULL; ifa = ifa->ifa_next) {
    IPEndPoint src;
    if (!src.FromSockAddr(ifa->ifa_addr, ifa->ifa_addr->sa_len)) {
      LOG(WARNING) << "FromSockAddr failed";
      continue;
    }
    SourceAddressInfo& info = source_map_[src.address()];
    // Note: no known way to fill in |native| and |home|.
    info.native = info.home = info.deprecated = false;
    if (ifa->ifa_addr->sa_family == AF_INET6) {
      struct in6_ifreq ifr = {};
      strncpy(ifr.ifr_name, ifa->ifa_name, sizeof(ifr.ifr_name) - 1);
      DCHECK_LE(ifa->ifa_addr->sa_len, sizeof(ifr.ifr_ifru.ifru_addr));
      memcpy(&ifr.ifr_ifru.ifru_addr, ifa->ifa_addr, ifa->ifa_addr->sa_len);
      int rv = ioctl(ioctl_socket, SIOCGIFAFLAG_IN6, &ifr);
      if (rv > 0) {
        info.deprecated = ifr.ifr_ifru.ifru_flags & IN6_IFF_DEPRECATED;
      } else {
        LOG(WARNING) << "SIOCGIFAFLAG_IN6 failed " << rv;
      }
    }
    if (ifa->ifa_netmask) {
      IPEndPoint netmask;
      if (netmask.FromSockAddr(ifa->ifa_netmask, ifa->ifa_addr->sa_len)) {
        info.prefix_length = MaskPrefixLength(netmask.address());
      } else {
        LOG(WARNING) << "FromSockAddr failed on netmask";
      }
    }
    FillPolicy(src.address(), &info);
  }
  freeifaddrs(addrs);
  close(ioctl_socket);
#endif
}

void AddressSorterPosix::FillPolicy(const IPAddressNumber& address,
                                    SourceAddressInfo* info) const {
  DCHECK(CalledOnValidThread());
  info->scope = GetScope(ipv4_scope_table_, address);
  info->label = GetPolicyValue(label_table_, address);
}

// static
scoped_ptr<AddressSorter> AddressSorter::CreateAddressSorter() {
  return scoped_ptr<AddressSorter>(
      new AddressSorterPosix(ClientSocketFactory::GetDefaultFactory()));
}

}  // namespace net
