/*
 * kgraft_patch_bsc1064392
 *
 * Fix for CVE-2017-15649, bsc#1064392
 *
 *  Upstream commits:
 *  4971613c1639 ("packet: in packet_do_bind, test fanout with bind_lock held")
 *  008ba2a13f2d ("packet: hold bind lock when rebinding to fanout hook")
 *
 *  SLE12(-SP1) commits:
 *  cdca61fab3721a489f206121e7a4d539ca7db032
 *  10a43e3ca8b6c4c85708a7c0796397be2c8cb1f3
 *  e8d983aa9457e32040adc164be8c58f08fa28644
 *
 *  SLE12-SP2 commits:
 *  d360124864db51fcbcc831114031fa0cc5689f1a
 *  205d967847640d7dcf13152063fef0ca88172b6d
 *
 *  SLE12-SP3 commits:
 *  fbf37987471f9f37e9c1b0a7463ea5613ee5d92d
 *  4741c8495f017586043f65b83cf88aa8579befac
 *
 *  Copyright (c) 2017 SUSE
 *  Author: Nicolai Stange <nstange@suse.de>
 *
 *  Based on the original Linux kernel code. Other copyrights apply.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, see <http://www.gnu.org/licenses/>.
 */

#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/kallsyms.h>
#include <linux/if_packet.h>
#include <net/sock.h>
#include "kgr_patch_bsc1064392.h"

#if !IS_MODULE(CONFIG_PACKET)
#error "KGR patch supports only CONFIG_PACKET=m."
#endif

#define KGR_PATCHED_MODULE "af_packet"

static struct mutex *kgr_fanout_mutex;
static struct list_head *kgr_fanout_list;

struct packet_mclist;
struct packet_sock;

static int (*kgr_packet_dev_mc)(struct net_device *dev, struct packet_mclist *i,
				int what);
static int (*kgr_packet_set_ring)(struct sock *sk, union tpacket_req_u *req_u,
				  int closing, int tx_ring);
static void (*kgr__fanout_link)(struct sock *sk, struct packet_sock *po);
static void (*kgr__unregister_prot_hook)(struct sock *sk, bool sync);
static int (*kgr_packet_rcv_fanout)(struct sk_buff *skb, struct net_device *dev,
				    struct packet_type *pt,
				    struct net_device *orig_dev);
static bool (*kgr_match_fanout_group)(struct packet_type *ptype,
				      struct sock * sk);

static struct {
	char *name;
	void **addr;
} kgr_funcs[] = {
	{ "af_packet:fanout_mutex", (void *)&kgr_fanout_mutex },
	{ "af_packet:fanout_list", (void *)&kgr_fanout_list },
	{ "af_packet:packet_dev_mc", (void *)&kgr_packet_dev_mc },
	{ "af_packet:packet_set_ring", (void *)&kgr_packet_set_ring },
	{ "af_packet:__fanout_link", (void *)&kgr__fanout_link },
	{ "af_packet:__unregister_prot_hook",
		(void *)&kgr__unregister_prot_hook },
	{ "af_packet:packet_rcv_fanout", (void *)&kgr_packet_rcv_fanout },
	{ "af_packet:match_fanout_group", (void *)&kgr_match_fanout_group },
};


/* from net/packet/internal.h */
struct packet_mclist {
	struct packet_mclist	*next;
	int			ifindex;
	int			count;
	unsigned short		type;
	unsigned short		alen;
	unsigned char		addr[MAX_ADDR_LEN];
};

struct tpacket_kbdq_core {
	struct pgv	*pkbdq;
	unsigned int	feature_req_word;
	unsigned int	hdrlen;
	unsigned char	reset_pending_on_curr_blk;
	unsigned char   delete_blk_timer;
	unsigned short	kactive_blk_num;
	unsigned short	blk_sizeof_priv;

	/* last_kactive_blk_num:
	 * trick to see if user-space has caught up
	 * in order to avoid refreshing timer when every single pkt arrives.
	 */
	unsigned short	last_kactive_blk_num;

	char		*pkblk_start;
	char		*pkblk_end;
	int		kblk_size;
	unsigned int	max_frame_len;
	unsigned int	knum_blocks;
	uint64_t	knxt_seq_num;
	char		*prev;
	char		*nxt_offset;
	struct sk_buff	*skb;

	atomic_t	blk_fill_in_prog;

	/* Default is set to 8ms */
#define DEFAULT_PRB_RETIRE_TOV	(8)

	unsigned short  retire_blk_tov;
	unsigned short  version;
	unsigned long	tov_in_jiffies;

	/* timer to retire an outstanding block */
	struct timer_list retire_blk_timer;
};

struct pgv {
	char *buffer;
};

struct packet_ring_buffer {
	struct pgv		*pg_vec;

	unsigned int		head;
	unsigned int		frames_per_block;
	unsigned int		frame_size;
	unsigned int		frame_max;

	unsigned int		pg_vec_order;
	unsigned int		pg_vec_pages;
	unsigned int		pg_vec_len;

	atomic_t		pending;

	struct tpacket_kbdq_core	prb_bdqc;
};

#define PACKET_FANOUT_MAX	256

struct packet_fanout {
#ifdef CONFIG_NET_NS
	struct net		*net;
#endif
	unsigned int		num_members;
	u16			id;
	u8			type;
	u8			flags;
	atomic_t		rr_cur;
	struct list_head	list;
	struct sock		*arr[PACKET_FANOUT_MAX];
	int			next[PACKET_FANOUT_MAX];
	spinlock_t		lock;
	atomic_t		sk_ref;
	struct packet_type	prot_hook ____cacheline_aligned_in_smp;
};

struct packet_sock {
	/* struct sock has to be the first member of packet_sock */
	struct sock		sk;
	struct packet_fanout	*fanout;
	union  tpacket_stats_u	stats;
	struct packet_ring_buffer	rx_ring;
	struct packet_ring_buffer	tx_ring;
	int			copy_thresh;
	spinlock_t		bind_lock;
	struct mutex		pg_vec_lock;
	unsigned int		running:1,	/* prot_hook is attached*/
				auxdata:1,
				origdev:1,
				has_vnet_hdr:1;
	int			ifindex;	/* bound device		*/
	__be16			num;
	struct packet_mclist	*mclist;
	atomic_t		mapped;
	enum tpacket_versions	tp_version;
	unsigned int		tp_hdrlen;
	unsigned int		tp_reserve;
	unsigned int		tp_loss:1;
	unsigned int		tp_tx_has_off:1;
	unsigned int		tp_tstamp;
	struct net_device __rcu	*cached_dev;
	struct packet_type	prot_hook ____cacheline_aligned_in_smp;
};

static struct packet_sock *kgr_pkt_sk(struct sock *sk)
{
	return (struct packet_sock *)sk;
}

/* from net/packet/af_packet.c */
struct packet_mreq_max {
	int		mr_ifindex;
	unsigned short	mr_type;
	unsigned short	mr_alen;
	unsigned char	mr_address[MAX_ADDR_LEN];
};

/* inlined */
static void kgr_packet_cached_dev_assign(struct packet_sock *po,
					 struct net_device *dev)
{
	rcu_assign_pointer(po->cached_dev, dev);
}

/* inlined */
static void kgr_register_prot_hook(struct sock *sk)
{
	struct packet_sock *po = kgr_pkt_sk(sk);

	if (!po->running) {
		if (po->fanout)
			kgr__fanout_link(sk, po);
		else
			dev_add_pack(&po->prot_hook);

		sock_hold(sk);
		po->running = 1;
	}
}

/* inlined */
static void kgr_unregister_prot_hook(struct sock *sk, bool sync)
{
	struct packet_sock *po = kgr_pkt_sk(sk);

	if (po->running)
		kgr__unregister_prot_hook(sk, sync);
}

/* inlined */
static int kgr_packet_mc_add(struct sock *sk, struct packet_mreq_max *mreq)
{
	struct packet_sock *po = kgr_pkt_sk(sk);
	struct packet_mclist *ml, *i;
	struct net_device *dev;
	int err;

	rtnl_lock();

	err = -ENODEV;
	dev = __dev_get_by_index(sock_net(sk), mreq->mr_ifindex);
	if (!dev)
		goto done;

	err = -EINVAL;
	if (mreq->mr_alen > dev->addr_len)
		goto done;

	err = -ENOBUFS;
	i = kmalloc(sizeof(*i), GFP_KERNEL);
	if (i == NULL)
		goto done;

	err = 0;
	for (ml = po->mclist; ml; ml = ml->next) {
		if (ml->ifindex == mreq->mr_ifindex &&
		    ml->type == mreq->mr_type &&
		    ml->alen == mreq->mr_alen &&
		    memcmp(ml->addr, mreq->mr_address, ml->alen) == 0) {
			ml->count++;
			/* Free the new element ... */
			kfree(i);
			goto done;
		}
	}

	i->type = mreq->mr_type;
	i->ifindex = mreq->mr_ifindex;
	i->alen = mreq->mr_alen;
	memcpy(i->addr, mreq->mr_address, i->alen);
	memset(i->addr + i->alen, 0, sizeof(i->addr) - i->alen);
	i->count = 1;
	i->next = po->mclist;
	po->mclist = i;
	err = kgr_packet_dev_mc(dev, i, 1);
	if (err) {
		po->mclist = i->next;
		kfree(i);
	}

done:
	rtnl_unlock();
	return err;
}

/* inlined */
static int kgr_packet_mc_drop(struct sock *sk, struct packet_mreq_max *mreq)
{
	struct packet_mclist *ml, **mlp;

	rtnl_lock();

	for (mlp = &kgr_pkt_sk(sk)->mclist; (ml = *mlp) != NULL; mlp = &ml->next) {
		if (ml->ifindex == mreq->mr_ifindex &&
		    ml->type == mreq->mr_type &&
		    ml->alen == mreq->mr_alen &&
		    memcmp(ml->addr, mreq->mr_address, ml->alen) == 0) {
			if (--ml->count == 0) {
				struct net_device *dev;
				*mlp = ml->next;
				dev = __dev_get_by_index(sock_net(sk), ml->ifindex);
				if (dev)
					kgr_packet_dev_mc(dev, ml, -1);
				kfree(ml);
			}
			rtnl_unlock();
			return 0;
		}
	}
	rtnl_unlock();
	return -EADDRNOTAVAIL;
}


/* patched, inlined from packet_setsockopt() */
static int kgr_fanout_add(struct sock *sk, u16 id, u16 type_flags)
{
	struct packet_sock *po = kgr_pkt_sk(sk);
	struct packet_fanout *f, *match;
	u8 type = type_flags & 0xff;
	u8 flags = type_flags >> 8;
	int err;

	switch (type) {
	case PACKET_FANOUT_ROLLOVER:
		if (type_flags & PACKET_FANOUT_FLAG_ROLLOVER)
			return -EINVAL;
	case PACKET_FANOUT_HASH:
	case PACKET_FANOUT_LB:
	case PACKET_FANOUT_CPU:
	case PACKET_FANOUT_RND:
		break;
	default:
		return -EINVAL;
	}

	mutex_lock(kgr_fanout_mutex);

	/*
	 * Fix CVE-2017-15649
	 *  -4 lines
	 */
	err = -EALREADY;
	if (po->fanout)
		goto out;

	match = NULL;
	list_for_each_entry(f, kgr_fanout_list, list) {
		if (f->id == id &&
		    read_pnet(&f->net) == sock_net(sk)) {
			match = f;
			break;
		}
	}
	err = -EINVAL;
	if (match && match->flags != flags)
		goto out;
	if (!match) {
		err = -ENOMEM;
		match = kzalloc(sizeof(*match), GFP_KERNEL);
		if (!match)
			goto out;
		write_pnet(&match->net, sock_net(sk));
		match->id = id;
		match->type = type;
		match->flags = flags;
		atomic_set(&match->rr_cur, 0);
		INIT_LIST_HEAD(&match->list);
		spin_lock_init(&match->lock);
		atomic_set(&match->sk_ref, 0);
		match->prot_hook.type = po->prot_hook.type;
		match->prot_hook.dev = po->prot_hook.dev;
		match->prot_hook.func = kgr_packet_rcv_fanout;
		match->prot_hook.af_packet_priv = match;
		match->prot_hook.id_match = kgr_match_fanout_group;
		list_add(&match->list, kgr_fanout_list);
	}
	err = -EINVAL;
	/*
	 * Fix CVE-2017-15649
	 *  -1 line, +3 lines
	 */
	spin_lock(&po->bind_lock);
	if (po->running &&
	    match->type == type &&
	    match->prot_hook.type == po->prot_hook.type &&
	    match->prot_hook.dev == po->prot_hook.dev) {
		err = -ENOSPC;
		if (atomic_read(&match->sk_ref) < PACKET_FANOUT_MAX) {
			__dev_remove_pack(&po->prot_hook);
			po->fanout = match;
			atomic_inc(&match->sk_ref);
			kgr__fanout_link(sk, po);
			err = 0;
		}
	}
	/*
	 * Fix CVE-2017-15649
	 *  +7 lines
	 */
	spin_unlock(&po->bind_lock);

	if (err && !atomic_read(&match->sk_ref)) {
		list_del(&match->list);
		kfree(match);
	}

out:
	mutex_unlock(kgr_fanout_mutex);
	return err;
}

/* patched, inlined from packet_bind_spkt() and packet_bind() */
static int kgr_packet_do_bind(struct sock *sk, struct net_device *dev, __be16 protocol)
{
	struct packet_sock *po = kgr_pkt_sk(sk);
	/*
	 * Fix CVE-2017-15649
	 *  +4 lines
	 */
	int ret = 0;

	lock_sock(sk);
	spin_lock(&po->bind_lock);
	if (po->fanout) {
		if (dev)
			dev_put(dev);

		/*
		 * Fix CVE-2017-15649
		 *  -1 line, +2 lines
		 */
		ret = -EINVAL;
		goto out_unlock;
	}

	/*
	 * Fix CVE-2017-15649
	 *  -3 lines
	 */
	kgr_unregister_prot_hook(sk, true);

	po->num = protocol;
	po->prot_hook.type = protocol;
	if (po->prot_hook.dev)
		dev_put(po->prot_hook.dev);

	po->prot_hook.dev = dev;
	po->ifindex = dev ? dev->ifindex : 0;

	kgr_packet_cached_dev_assign(po, dev);

	if (protocol == 0)
		goto out_unlock;

	if (!dev || (dev->flags & IFF_UP)) {
		kgr_register_prot_hook(sk);
	} else {
		sk->sk_err = ENETDOWN;
		if (!sock_flag(sk, SOCK_DEAD))
			sk->sk_error_report(sk);
	}

out_unlock:
	spin_unlock(&po->bind_lock);
	release_sock(sk);
	/*
	 * Fix CVE-2017-15649
	 *  -1 line, +1 line
	 */
	return ret;
}

/* patched caller of packet_do_bind() */
int kgr_packet_bind_spkt(struct socket *sock, struct sockaddr *uaddr,
			 int addr_len)
{
	struct sock *sk = sock->sk;
	char name[sizeof(uaddr->sa_data) + 1];
	struct net_device *dev;
	int err = -ENODEV;

	/*
	 *	Check legality
	 */

	if (addr_len != sizeof(struct sockaddr))
		return -EINVAL;
	/* uaddr->sa_data comes from the userspace, it's not guaranteed to be
	 * zero-terminated.
	 */
	memcpy(name, uaddr->sa_data, sizeof(uaddr->sa_data));
	name[sizeof(uaddr->sa_data)] = 0;

	dev = dev_get_by_name(sock_net(sk), name);
	if (dev)
		err = kgr_packet_do_bind(sk, dev, kgr_pkt_sk(sk)->num);
	return err;
}

/* patched caller of packet_do_bind() */
int kgr_packet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
{
	struct sockaddr_ll *sll = (struct sockaddr_ll *)uaddr;
	struct sock *sk = sock->sk;
	struct net_device *dev = NULL;
	int err;


	/*
	 *	Check legality
	 */

	if (addr_len < sizeof(struct sockaddr_ll))
		return -EINVAL;
	if (sll->sll_family != AF_PACKET)
		return -EINVAL;

	if (sll->sll_ifindex) {
		err = -ENODEV;
		dev = dev_get_by_index(sock_net(sk), sll->sll_ifindex);
		if (dev == NULL)
			goto out;
	}
	err = kgr_packet_do_bind(sk, dev, sll->sll_protocol ? : kgr_pkt_sk(sk)->num);

out:
	return err;
}

/* patched caller of fanout_add() */
int kgr_packet_setsockopt(struct socket *sock, int level, int optname,
			  char __user *optval, unsigned int optlen)
{
	struct sock *sk = sock->sk;
	struct packet_sock *po = kgr_pkt_sk(sk);
	int ret;

	if (level != SOL_PACKET)
		return -ENOPROTOOPT;

	switch (optname) {
	case PACKET_ADD_MEMBERSHIP:
	case PACKET_DROP_MEMBERSHIP:
	{
		struct packet_mreq_max mreq;
		int len = optlen;
		memset(&mreq, 0, sizeof(mreq));
		if (len < sizeof(struct packet_mreq))
			return -EINVAL;
		if (len > sizeof(mreq))
			len = sizeof(mreq);
		if (copy_from_user(&mreq, optval, len))
			return -EFAULT;
		if (len < (mreq.mr_alen + offsetof(struct packet_mreq, mr_address)))
			return -EINVAL;
		if (optname == PACKET_ADD_MEMBERSHIP)
			ret = kgr_packet_mc_add(sk, &mreq);
		else
			ret = kgr_packet_mc_drop(sk, &mreq);
		return ret;
	}

	case PACKET_RX_RING:
	case PACKET_TX_RING:
	{
		union tpacket_req_u req_u;
		int len;

		switch (po->tp_version) {
		case TPACKET_V1:
		case TPACKET_V2:
			len = sizeof(req_u.req);
			break;
		case TPACKET_V3:
		default:
			len = sizeof(req_u.req3);
			break;
		}
		if (optlen < len)
			return -EINVAL;
		if (kgr_pkt_sk(sk)->has_vnet_hdr)
			return -EINVAL;
		if (copy_from_user(&req_u.req, optval, len))
			return -EFAULT;
		return kgr_packet_set_ring(sk, &req_u, 0,
			optname == PACKET_TX_RING);
	}
	case PACKET_COPY_THRESH:
	{
		int val;

		if (optlen != sizeof(val))
			return -EINVAL;
		if (copy_from_user(&val, optval, sizeof(val)))
			return -EFAULT;

		kgr_pkt_sk(sk)->copy_thresh = val;
		return 0;
	}
	case PACKET_VERSION:
	{
		int val;

		if (optlen != sizeof(val))
			return -EINVAL;
		if (copy_from_user(&val, optval, sizeof(val)))
			return -EFAULT;
		switch (val) {
		case TPACKET_V1:
		case TPACKET_V2:
		case TPACKET_V3:
			break;
		default:
			return -EINVAL;
		}
		lock_sock(sk);
		if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) {
			ret = -EBUSY;
		} else {
			po->tp_version = val;
			ret = 0;
		}
		release_sock(sk);
		return ret;
	}
	case PACKET_RESERVE:
	{
		unsigned int val;

		if (optlen != sizeof(val))
			return -EINVAL;
		if (copy_from_user(&val, optval, sizeof(val)))
			return -EFAULT;
		if (val > INT_MAX)
			return -EINVAL;
		lock_sock(sk);
		if (po->rx_ring.pg_vec || po->tx_ring.pg_vec)
			ret = -EBUSY;
		else {
			po->tp_reserve = val;
			ret = 0;
		}
		release_sock(sk);
		return ret;
	}
	case PACKET_LOSS:
	{
		unsigned int val;

		if (optlen != sizeof(val))
			return -EINVAL;
		if (po->rx_ring.pg_vec || po->tx_ring.pg_vec)
			return -EBUSY;
		if (copy_from_user(&val, optval, sizeof(val)))
			return -EFAULT;
		po->tp_loss = !!val;
		return 0;
	}
	case PACKET_AUXDATA:
	{
		int val;

		if (optlen < sizeof(val))
			return -EINVAL;
		if (copy_from_user(&val, optval, sizeof(val)))
			return -EFAULT;

		po->auxdata = !!val;
		return 0;
	}
	case PACKET_ORIGDEV:
	{
		int val;

		if (optlen < sizeof(val))
			return -EINVAL;
		if (copy_from_user(&val, optval, sizeof(val)))
			return -EFAULT;

		po->origdev = !!val;
		return 0;
	}
	case PACKET_VNET_HDR:
	{
		int val;

		if (sock->type != SOCK_RAW)
			return -EINVAL;
		if (po->rx_ring.pg_vec || po->tx_ring.pg_vec)
			return -EBUSY;
		if (optlen < sizeof(val))
			return -EINVAL;
		if (copy_from_user(&val, optval, sizeof(val)))
			return -EFAULT;

		po->has_vnet_hdr = !!val;
		return 0;
	}
	case PACKET_TIMESTAMP:
	{
		int val;

		if (optlen != sizeof(val))
			return -EINVAL;
		if (copy_from_user(&val, optval, sizeof(val)))
			return -EFAULT;

		po->tp_tstamp = val;
		return 0;
	}
	case PACKET_FANOUT:
	{
		int val;

		if (optlen != sizeof(val))
			return -EINVAL;
		if (copy_from_user(&val, optval, sizeof(val)))
			return -EFAULT;

		return kgr_fanout_add(sk, val & 0xffff, val >> 16);
	}
	case PACKET_TX_HAS_OFF:
	{
		unsigned int val;

		if (optlen != sizeof(val))
			return -EINVAL;
		if (po->rx_ring.pg_vec || po->tx_ring.pg_vec)
			return -EBUSY;
		if (copy_from_user(&val, optval, sizeof(val)))
			return -EFAULT;
		po->tp_tx_has_off = !!val;
		return 0;
	}
	default:
		return -ENOPROTOOPT;
	}
}


static int kgr_patch_bsc1064392_kallsyms(void)
{
	unsigned long addr;
	int i;

	for (i = 0; i < ARRAY_SIZE(kgr_funcs); i++) {
		/* mod_find_symname would be nice, but it is not exported */
		addr = kallsyms_lookup_name(kgr_funcs[i].name);
		if (!addr) {
			pr_err("kgraft-patch: symbol %s not resolved\n",
				kgr_funcs[i].name);
			return -ENOENT;
		}

		*(kgr_funcs[i].addr) = (void *)addr;
	}

	return 0;
}

static int kgr_patch_bsc1064392_module_notify(struct notifier_block *nb,
	unsigned long action, void *data)
{
	struct module *mod = data;
	int ret;

	if (action != MODULE_STATE_COMING || strcmp(mod->name, KGR_PATCHED_MODULE))
		return 0;

	ret = kgr_patch_bsc1064392_kallsyms();
	WARN(ret, "kgraft-patch: delayed kallsyms lookup failed. System is broken and can crash.\n");

	return ret;
}

static struct notifier_block kgr_patch_bsc1064392_module_nb = {
	.notifier_call = kgr_patch_bsc1064392_module_notify,
	.priority = INT_MIN+1,
};

int kgr_patch_bsc1064392_init(void)
{
	int ret;

	mutex_lock(&module_mutex);
	if (find_module(KGR_PATCHED_MODULE)) {
		ret = kgr_patch_bsc1064392_kallsyms();
		if (ret)
			goto out;
	}

	ret = register_module_notifier(&kgr_patch_bsc1064392_module_nb);
out:
	mutex_unlock(&module_mutex);
	return ret;
}

void kgr_patch_bsc1064392_cleanup(void)
{
	unregister_module_notifier(&kgr_patch_bsc1064392_module_nb);
}
