// SPDX-License-Identifier: GPL-2.0
/* Multipath TCP
*
* Copyright (c) 2020, Red Hat, Inc.
*/
#define pr_fmt(fmt) "MPTCP: " fmt
#include <linux/inet.h>
#include <linux/kernel.h>
#include <net/tcp.h>
#include <net/netns/generic.h>
#include <net/mptcp.h>
#include <net/genetlink.h>
#include <uapi/linux/mptcp.h>
#include "protocol.h"
/* forward declaration */
static struct genl_family mptcp_genl_family;
static int pm_nl_pernet_id;
struct mptcp_pm_addr_entry {
struct list_head list;
unsigned int flags;
int ifindex;
struct mptcp_addr_info addr;
struct rcu_head rcu;
};
struct pm_nl_pernet {
/* protects pernet updates */
spinlock_t lock;
struct list_head local_addr_list;
unsigned int addrs;
unsigned int add_addr_signal_max;
unsigned int add_addr_accept_max;
unsigned int local_addr_max;
unsigned int subflows_max;
unsigned int next_id;
};
#define MPTCP_PM_ADDR_MAX 8
static bool addresses_equal(const struct mptcp_addr_info *a,
struct mptcp_addr_info *b, bool use_port)
{
bool addr_equals = false;
if (a->family != b->family)
return false;
if (a->family == AF_INET)
addr_equals = a->addr.s_addr == b->addr.s_addr;
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
else
addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
#endif
if (!addr_equals)
return false;
if (!use_port)
return true;
return a->port == b->port;
}
static void local_address(const struct sock_common *skc,
struct mptcp_addr_info *addr)
{
addr->port = 0;
addr->family = skc->skc_family;
if (addr->family == AF_INET)
addr->addr.s_addr = skc->skc_rcv_saddr;
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
else if (addr->family == AF_INET6)
addr->addr6 = skc->skc_v6_rcv_saddr;
#endif
}
static void remote_address(const struct sock_common *skc,
struct mptcp_addr_info *addr)
{
addr->family = skc->skc_family;
addr->port = skc->skc_dport;
if (addr->family == AF_INET)
addr->addr.s_addr = skc->skc_daddr;
#if IS_ENABLED(CONFIG_MPTCP_IPV6)
else if (addr->family == AF_INET6)
addr->addr6 = skc->skc_v6_daddr;
#endif
}
static bool lookup_subflow_by_saddr(const struct list_head *list,
struct mptcp_addr_info *saddr)
{
struct mptcp_subflow_context *subflow;
struct mptcp_addr_info cur;
struct sock_common *skc;
list_for_each_entry(subflow, list, node) {
skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
local_address(skc, &cur);
if (addresses_equal(&cur, saddr, false))
return true;
}
return false;
}
static struct mptcp_pm_addr_entry *
select_local_address(const struct pm_nl_pernet *pernet,
struct mptcp_sock *msk)
{
struct mptcp_pm_addr_entry *entry, *ret = NULL;
rcu_read_lock();
spin_lock_bh(&msk->join_list_lock);
list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
continue;
/* avoid any address already in use by subflows and
* pending join
*/
if (entry->addr.family == ((struct sock *)msk)->sk_family &&
!lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) &&
!lookup_subflow_by_saddr(&msk->join_list, &entry->addr)) {
ret = entry;
break;
}
}
spin_unlock_bh(&msk->join_list_lock);
rcu_read_unlock();
return ret;
}
static struct mptcp_pm_addr_entry *
select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos)
{
struct mptcp_pm_addr_entry *entry, *ret = NULL;
int i = 0;
rcu_read_lock();
/* do not keep any additional per socket state, just signal
* the address list in order.
* Note: removal from the local address list during the msk life-cycle
* can lead to additional addresses not being announced.
*/
list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
if (!(entry->flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
continue;
if (i++ == pos) {
ret = entry;
break;
}
}
rcu_read_unlock();
return ret;
}
static void check_work_pending(struct mptcp_sock *msk)
{
if (msk->pm.add_addr_signaled == msk->pm.add_addr_signal_max &&
(msk->pm.local_addr_used == msk->pm.local_addr_max ||
msk->pm.subflows == msk->pm.subflows_max))
WRITE_ONCE(msk->pm.work_pending, false);
}
static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
{
struct sock