Grab the lock before reading uid/gid related structure, this will
[ipfw.git] / dummynet / ipfw2_mod.c
index 0765718..ee4eeba 100644 (file)
@@ -490,10 +490,9 @@ linux_lookup(const int proto, const __be32 saddr, const __be16 sport,
 {
        struct sock *sk;
        int ret = -1;   /* default return value */
-       int uid = -1;   /* user id */
        int st = -1;    /* state */
 
-       if (proto != IPPROTO_TCP)
+       if (proto != IPPROTO_TCP)       /* XXX extend for UDP */
                return -1;
 
        if ((dir ? (void *)skb->dst : (void *)skb->dev) == NULL) {
@@ -501,11 +500,17 @@ linux_lookup(const int proto, const __be32 saddr, const __be16 sport,
                return -1;
        }
 
-       /*
-        * inet_lookup above 2.6.24 has an additional 'net' parameter
-        * so we use a macro to conditionally supply it.
-        * Also we need to switch dst and src depending on the direction.
-        */
+       if (skb->sk) {
+               sk = skb->sk;
+       } else {
+               /*
+                * Try a lookup. On a match, sk has a refcount that we must
+                * release on exit (we know it because skb->sk = NULL).
+                *
+                * inet_lookup above 2.6.24 has an additional 'net' parameter
+                * so we use a macro to conditionally supply it.
+                * swap dst and src depending on the direction.
+                */
 #if LINUX_VERSION_CODE <= KERNEL_VERSION(2,6,24)
 #define _OPT_NET_ARG
 #else
@@ -516,35 +521,34 @@ linux_lookup(const int proto, const __be32 saddr, const __be16 sport,
 #define _OPT_NET_ARG dev_net(skb->dev),
 #endif
 #endif
-
-       if (0 && skb->sk) {
-               sk=skb->sk;
-       } else {
-       sk =  (dir) ?
-               inet_lookup(_OPT_NET_ARG &tcp_hashinfo,
-                       daddr, dport, saddr, sport,     // matches outgoing for server sockets
+               sk =  (dir) ? /* dir != 0 on output */
+                   inet_lookup(_OPT_NET_ARG &tcp_hashinfo,
+                       daddr, dport, saddr, sport,     // match outgoing
                        inet_iif(skb)) :
-               inet_lookup(_OPT_NET_ARG &tcp_hashinfo,
-                       saddr, sport, daddr, dport,     // matches incoming for server sockets
+                   inet_lookup(_OPT_NET_ARG &tcp_hashinfo,
+                       saddr, sport, daddr, dport,     // match incoming
                        skb->dev->ifindex);
-       }
-
 #undef _OPT_NET_ARG
-       /* no match, nothing to be done */
-       if (sk == NULL)
-               return -1;
 
+               if (sk == NULL) /* no match, nothing to be done */
+                       return -1;
+       }
+       ret = 1;        /* retrying won't make things better */
+       st = sk->sk_state;
+#ifdef CONFIG_VSERVER
+       ugp->fw_groups[1] = sk->sk_xid;
+       ugp->fw_groups[2] = sk->sk_nid;
+#else
+       ugp->fw_groups[1] = ugp->fw_groups[2] = 0;
+#endif
        /*
-        * On a match, sk is returned with a refcount.
-        * In tcp some states reference a valid struct sock
-        * which is what we want, otherwise the struct sock
-        * referenced can be invalid, as in the case of the
-        * TCP_TIME_WAIT state, when it references a
-        * struct inet_timewait_sock which does not point to credentials.
-        * To be safe we exclude TCP_CLOSE and TCP_LAST_ACK states too.
+        * Exclude tcp states where sk points to a inet_timewait_sock which
+        * has no sk_socket field (surely TCP_TIME_WAIT, perhaps more).
+        * To be safe, use a whitelist and not a blacklist.
+        * Before dereferencing sk_socket grab a lock on sk_callback_lock.
         *
         * Once again we need conditional code because the UID and GID
-        * location changes between the two kernels.
+        * location changes between kernels.
         */
 #if LINUX_VERSION_CODE <= KERNEL_VERSION(2,6,28)
 /* use the current's real uid/gid */
@@ -555,24 +559,37 @@ linux_lookup(const int proto, const __be32 saddr, const __be16 sport,
 #define _CURR_UID f_cred->fsuid
 #define _CURR_GID f_cred->fsgid
 #endif
-       st = sk->sk_state;
-       if (st != TCP_TIME_WAIT && st != TCP_CLOSE && st != TCP_LAST_ACK &&
-                       sk->sk_socket && sk->sk_socket->file) {
-               ugp->fw_uid = sk->sk_socket->file->_CURR_UID;
-               uid = ugp->fw_uid;
-               ugp->fw_groups[0] = sk->sk_socket->file->_CURR_GID;
+
 #ifdef CONFIG_VSERVER
-               ugp->fw_groups[1] = sk->sk_xid;
-               ugp->fw_groups[2] = sk->sk_nid;
+       ugp->fw_groups[1] = sk->sk_xid;
+       ugp->fw_groups[2] = sk->sk_nid;
+#else
+       ugp->fw_groups[1] =
+       ugp->fw_groups[2] = 0;
 #endif
-               ret = 1;
+       ret = 1;
+
+#define GOOD_STATES (  \
+       (1<<TCP_LISTEN) | (1<<TCP_SYN_RECV)   | (1<<TCP_SYN_SENT)   | \
+       (1<<TCP_ESTABLISHED)  | (1<<TCP_FIN_WAIT1) | (1<<TCP_FIN_WAIT2) )
+       // surely exclude TCP_CLOSE, TCP_TIME_WAIT, TCP_LAST_ACK
+       // uncertain TCP_CLOSE_WAIT and TCP_CLOSING
+
+       if ((1<<st) & GOOD_STATES) {
+               read_lock_bh(&sk->sk_callback_lock);
+                       if (sk->sk_socket && sk->sk_socket->file) {
+                               ugp->fw_uid = sk->sk_socket->file->_CURR_UID;
+                               ugp->fw_groups[0] = sk->sk_socket->file->_CURR_GID;
+                       }
+               read_unlock_bh(&sk->sk_callback_lock);
+       } else {
+               ugp->fw_uid = ugp->fw_groups[0] = 0;
        }
-       if (1 || !skb->sk) /* the reference came from the lookup */
+       if (!skb->sk) /* return the reference that came from the lookup */
                sock_put(sk);
+#undef GOOD_STATES
 #undef _CURR_UID
 #undef _CURR_GID
-
-       //printf("%s dir %d sb>dst %p sb>dev %p ret %d id %d st%d\n", __FUNCTION__, dir, skb->dst, skb->dev, ret, uid, st);
        return ret;
 }