Merge git://git.kernel.org/pub/scm/linux/kernel/git/davem/net
[linux-drm-fsl-dcu.git] / net / ipv4 / udp.c
index 0ca44df51ee94a2875427e5c9d3b1c77434f5f40..5944d7d668dd91da21e945eac748bbbbbb11d67a 100644 (file)
 #include <linux/seq_file.h>
 #include <net/net_namespace.h>
 #include <net/icmp.h>
+#include <net/inet_hashtables.h>
 #include <net/route.h>
 #include <net/checksum.h>
 #include <net/xfrm.h>
@@ -219,7 +220,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
                unsigned short first, last;
                DECLARE_BITMAP(bitmap, PORTS_PER_CHAIN);
 
-               inet_get_local_port_range(&low, &high);
+               inet_get_local_port_range(net, &low, &high);
                remaining = (high - low) + 1;
 
                rand = net_random();
@@ -406,6 +407,18 @@ static inline int compute_score2(struct sock *sk, struct net *net,
        return score;
 }
 
+static unsigned int udp_ehashfn(struct net *net, const __be32 laddr,
+                                const __u16 lport, const __be32 faddr,
+                                const __be16 fport)
+{
+       static u32 udp_ehash_secret __read_mostly;
+
+       net_get_random_once(&udp_ehash_secret, sizeof(udp_ehash_secret));
+
+       return __inet_ehashfn(laddr, lport, faddr, fport,
+                             udp_ehash_secret + net_hash_mix(net));
+}
+
 
 /* called with read_rcu_lock() */
 static struct sock *udp4_lib_lookup2(struct net *net,
@@ -429,8 +442,8 @@ begin:
                        badness = score;
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
-                               hash = inet_ehashfn(net, daddr, hnum,
-                                                   saddr, sport);
+                               hash = udp_ehashfn(net, daddr, hnum,
+                                                  saddr, sport);
                                matches = 1;
                        }
                } else if (score == badness && reuseport) {
@@ -510,8 +523,8 @@ begin:
                        badness = score;
                        reuseport = sk->sk_reuseport;
                        if (reuseport) {
-                               hash = inet_ehashfn(net, daddr, hnum,
-                                                   saddr, sport);
+                               hash = udp_ehashfn(net, daddr, hnum,
+                                                  saddr, sport);
                                matches = 1;
                        }
                } else if (score == badness && reuseport) {
@@ -565,6 +578,26 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
 }
 EXPORT_SYMBOL_GPL(udp4_lib_lookup);
 
+static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
+                                      __be16 loc_port, __be32 loc_addr,
+                                      __be16 rmt_port, __be32 rmt_addr,
+                                      int dif, unsigned short hnum)
+{
+       struct inet_sock *inet = inet_sk(sk);
+
+       if (!net_eq(sock_net(sk), net) ||
+           udp_sk(sk)->udp_port_hash != hnum ||
+           (inet->inet_daddr && inet->inet_daddr != rmt_addr) ||
+           (inet->inet_dport != rmt_port && inet->inet_dport) ||
+           (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) ||
+           ipv6_only_sock(sk) ||
+           (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif))
+               return false;
+       if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif))
+               return false;
+       return true;
+}
+
 static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk,
                                             __be16 loc_port, __be32 loc_addr,
                                             __be16 rmt_port, __be32 rmt_addr,
@@ -575,20 +608,11 @@ static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk,
        unsigned short hnum = ntohs(loc_port);
 
        sk_nulls_for_each_from(s, node) {
-               struct inet_sock *inet = inet_sk(s);
-
-               if (!net_eq(sock_net(s), net) ||
-                   udp_sk(s)->udp_port_hash != hnum ||
-                   (inet->inet_daddr && inet->inet_daddr != rmt_addr) ||
-                   (inet->inet_dport != rmt_port && inet->inet_dport) ||
-                   (inet->inet_rcv_saddr &&
-                    inet->inet_rcv_saddr != loc_addr) ||
-                   ipv6_only_sock(s) ||
-                   (s->sk_bound_dev_if && s->sk_bound_dev_if != dif))
-                       continue;
-               if (!ip_mc_sf_allow(s, loc_addr, rmt_addr, dif))
-                       continue;
-               goto found;
+               if (__udp_is_mcast_sock(net, s,
+                                       loc_port, loc_addr,
+                                       rmt_port, rmt_addr,
+                                       dif, hnum))
+                       goto found;
        }
        s = NULL;
 found:
@@ -855,6 +879,8 @@ int udp_sendmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
 
        ipc.opt = NULL;
        ipc.tx_flags = 0;
+       ipc.ttl = 0;
+       ipc.tos = -1;
 
        getfrag = is_udplite ? udplite_getfrag : ip_generic_getfrag;
 
@@ -938,7 +964,7 @@ int udp_sendmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
                faddr = ipc.opt->opt.faddr;
                connected = 0;
        }
-       tos = RT_TOS(inet->tos);
+       tos = get_rttos(&ipc, inet);
        if (sock_flag(sk, SOCK_LOCALROUTE) ||
            (msg->msg_flags & MSG_DONTROUTE) ||
            (ipc.opt && ipc.opt->opt.is_strictroute)) {
@@ -1209,12 +1235,6 @@ int udp_recvmsg(struct kiocb *iocb, struct sock *sk, struct msghdr *msg,
        int is_udplite = IS_UDPLITE(sk);
        bool slow;
 
-       /*
-        *      Check any passed addresses
-        */
-       if (addr_len)
-               *addr_len = sizeof(*sin);
-
        if (flags & MSG_ERRQUEUE)
                return ip_recv_error(sk, msg, len);
 
@@ -1276,6 +1296,7 @@ try_again:
                sin->sin_port = udp_hdr(skb)->source;
                sin->sin_addr.s_addr = ip_hdr(skb)->saddr;
                memset(sin->sin_zero, 0, sizeof(sin->sin_zero));
+               *addr_len = sizeof(*sin);
        }
        if (inet->cmsg_flags)
                ip_cmsg_recv(msg, skb);
@@ -1403,8 +1424,10 @@ static int __udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
 {
        int rc;
 
-       if (inet_sk(sk)->inet_daddr)
+       if (inet_sk(sk)->inet_daddr) {
                sock_rps_save_rxhash(sk, skb);
+               sk_mark_napi_id(sk, skb);
+       }
 
        rc = sock_queue_rcv_skb(sk, skb);
        if (rc < 0) {
@@ -1528,7 +1551,7 @@ int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
 
        rc = 0;
 
-       ipv4_pktinfo_prepare(skb);
+       ipv4_pktinfo_prepare(sk, skb);
        bh_lock_sock(sk);
        if (!sock_owned_by_user(sk))
                rc = __udp_queue_rcv_skb(sk, skb);
@@ -1577,6 +1600,14 @@ static void flush_stack(struct sock **stack, unsigned int count,
                kfree_skb(skb1);
 }
 
+static void udp_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb)
+{
+       struct dst_entry *dst = skb_dst(skb);
+
+       dst_hold(dst);
+       sk->sk_rx_dst = dst;
+}
+
 /*
  *     Multicasts and broadcasts go to each listener.
  *
@@ -1705,16 +1736,32 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
        if (udp4_csum_init(skb, uh, proto))
                goto csum_error;
 
-       if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST))
-               return __udp4_lib_mcast_deliver(net, skb, uh,
-                               saddr, daddr, udptable);
+       if (skb->sk) {
+               int ret;
+               sk = skb->sk;
+
+               if (unlikely(sk->sk_rx_dst == NULL))
+                       udp_sk_rx_dst_set(sk, skb);
 
-       sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
+               ret = udp_queue_rcv_skb(sk, skb);
+
+               /* a return value > 0 means to resubmit the input, but
+                * it wants the return to be -protocol, or 0
+                */
+               if (ret > 0)
+                       return -ret;
+               return 0;
+       } else {
+               if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST))
+                       return __udp4_lib_mcast_deliver(net, skb, uh,
+                                       saddr, daddr, udptable);
+
+               sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
+       }
 
        if (sk != NULL) {
                int ret;
 
-               sk_mark_napi_id(sk, skb);
                ret = udp_queue_rcv_skb(sk, skb);
                sock_put(sk);
 
@@ -1768,6 +1815,135 @@ drop:
        return 0;
 }
 
+/* We can only early demux multicast if there is a single matching socket.
+ * If more than one socket found returns NULL
+ */
+static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
+                                                 __be16 loc_port, __be32 loc_addr,
+                                                 __be16 rmt_port, __be32 rmt_addr,
+                                                 int dif)
+{
+       struct sock *sk, *result;
+       struct hlist_nulls_node *node;
+       unsigned short hnum = ntohs(loc_port);
+       unsigned int count, slot = udp_hashfn(net, hnum, udp_table.mask);
+       struct udp_hslot *hslot = &udp_table.hash[slot];
+
+       rcu_read_lock();
+begin:
+       count = 0;
+       result = NULL;
+       sk_nulls_for_each_rcu(sk, node, &hslot->head) {
+               if (__udp_is_mcast_sock(net, sk,
+                                       loc_port, loc_addr,
+                                       rmt_port, rmt_addr,
+                                       dif, hnum)) {
+                       result = sk;
+                       ++count;
+               }
+       }
+       /*
+        * if the nulls value we got at the end of this lookup is
+        * not the expected one, we must restart lookup.
+        * We probably met an item that was moved to another chain.
+        */
+       if (get_nulls_value(node) != slot)
+               goto begin;
+
+       if (result) {
+               if (count != 1 ||
+                   unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
+                       result = NULL;
+               else if (unlikely(!__udp_is_mcast_sock(net, result,
+                                                      loc_port, loc_addr,
+                                                      rmt_port, rmt_addr,
+                                                      dif, hnum))) {
+                       sock_put(result);
+                       result = NULL;
+               }
+       }
+       rcu_read_unlock();
+       return result;
+}
+
+/* For unicast we should only early demux connected sockets or we can
+ * break forwarding setups.  The chains here can be long so only check
+ * if the first socket is an exact match and if not move on.
+ */
+static struct sock *__udp4_lib_demux_lookup(struct net *net,
+                                           __be16 loc_port, __be32 loc_addr,
+                                           __be16 rmt_port, __be32 rmt_addr,
+                                           int dif)
+{
+       struct sock *sk, *result;
+       struct hlist_nulls_node *node;
+       unsigned short hnum = ntohs(loc_port);
+       unsigned int hash2 = udp4_portaddr_hash(net, loc_addr, hnum);
+       unsigned int slot2 = hash2 & udp_table.mask;
+       struct udp_hslot *hslot2 = &udp_table.hash2[slot2];
+       INET_ADDR_COOKIE(acookie, rmt_addr, loc_addr)
+       const __portpair ports = INET_COMBINED_PORTS(rmt_port, hnum);
+
+       rcu_read_lock();
+       result = NULL;
+       udp_portaddr_for_each_entry_rcu(sk, node, &hslot2->head) {
+               if (INET_MATCH(sk, net, acookie,
+                              rmt_addr, loc_addr, ports, dif))
+                       result = sk;
+               /* Only check first socket in chain */
+               break;
+       }
+
+       if (result) {
+               if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
+                       result = NULL;
+               else if (unlikely(!INET_MATCH(sk, net, acookie,
+                                             rmt_addr, loc_addr,
+                                             ports, dif))) {
+                       sock_put(result);
+                       result = NULL;
+               }
+       }
+       rcu_read_unlock();
+       return result;
+}
+
+void udp_v4_early_demux(struct sk_buff *skb)
+{
+       const struct iphdr *iph = ip_hdr(skb);
+       const struct udphdr *uh = udp_hdr(skb);
+       struct sock *sk;
+       struct dst_entry *dst;
+       struct net *net = dev_net(skb->dev);
+       int dif = skb->dev->ifindex;
+
+       /* validate the packet */
+       if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr)))
+               return;
+
+       if (skb->pkt_type == PACKET_BROADCAST ||
+           skb->pkt_type == PACKET_MULTICAST)
+               sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr,
+                                                  uh->source, iph->saddr, dif);
+       else if (skb->pkt_type == PACKET_HOST)
+               sk = __udp4_lib_demux_lookup(net, uh->dest, iph->daddr,
+                                            uh->source, iph->saddr, dif);
+       else
+               return;
+
+       if (!sk)
+               return;
+
+       skb->sk = sk;
+       skb->destructor = sock_edemux;
+       dst = sk->sk_rx_dst;
+
+       if (dst)
+               dst = dst_check(dst, 0);
+       if (dst)
+               skb_dst_set_noref(skb, dst);
+}
+
 int udp_rcv(struct sk_buff *skb)
 {
        return __udp4_lib_rcv(skb, &udp_table, IPPROTO_UDP);
@@ -2150,7 +2326,7 @@ EXPORT_SYMBOL(udp_proc_unregister);
 
 /* ------------------------------------------------------------------------ */
 static void udp4_format_sock(struct sock *sp, struct seq_file *f,
-               int bucket, int *len)
+               int bucket)
 {
        struct inet_sock *inet = inet_sk(sp);
        __be32 dest = inet->inet_daddr;
@@ -2159,7 +2335,7 @@ static void udp4_format_sock(struct sock *sp, struct seq_file *f,
        __u16 srcp        = ntohs(inet->inet_sport);
 
        seq_printf(f, "%5d: %08X:%04X %08X:%04X"
-               " %02X %08X:%08X %02X:%08lX %08X %5u %8d %lu %d %pK %d%n",
+               " %02X %08X:%08X %02X:%08lX %08X %5u %8d %lu %d %pK %d",
                bucket, src, srcp, dest, destp, sp->sk_state,
                sk_wmem_alloc_get(sp),
                sk_rmem_alloc_get(sp),
@@ -2167,23 +2343,22 @@ static void udp4_format_sock(struct sock *sp, struct seq_file *f,
                from_kuid_munged(seq_user_ns(f), sock_i_uid(sp)),
                0, sock_i_ino(sp),
                atomic_read(&sp->sk_refcnt), sp,
-               atomic_read(&sp->sk_drops), len);
+               atomic_read(&sp->sk_drops));
 }
 
 int udp4_seq_show(struct seq_file *seq, void *v)
 {
+       seq_setwidth(seq, 127);
        if (v == SEQ_START_TOKEN)
-               seq_printf(seq, "%-127s\n",
-                          "  sl  local_address rem_address   st tx_queue "
+               seq_puts(seq, "  sl  local_address rem_address   st tx_queue "
                           "rx_queue tr tm->when retrnsmt   uid  timeout "
                           "inode ref pointer drops");
        else {
                struct udp_iter_state *state = seq->private;
-               int len;
 
-               udp4_format_sock(v, seq, state->bucket, &len);
-               seq_printf(seq, "%*s\n", 127 - len, "");
+               udp4_format_sock(v, seq, state->bucket);
        }
+       seq_pad(seq, '\n');
        return 0;
 }