Added the new version for dummynet.
[ipfw.git] / dummynet / ip_fw2.c
index 50e8701..21d1b41 100644 (file)
@@ -70,11 +70,6 @@ __FBSDID("$FreeBSD: src/sys/netinet/ip_fw2.c,v 1.175.2.13 2008/10/30 16:29:04 bz
 #include <net/pf_mtag.h>
 #include <net/vnet.h>
 
-#ifdef linux
-#define INP_LOCK_ASSERT                /* define before missing.h otherwise ? */
-#include "missing.h"
-#endif
-
 #define        IPFW_INTERNAL   /* Access to protected data structures in ip_fw.h. */
 
 #include <netinet/in.h>
@@ -135,21 +130,6 @@ static int default_to_accept;
 #endif
 static uma_zone_t ipfw_dyn_rule_zone;
 
-struct ip_fw *ip_fw_default_rule;
-
-/*
- * Data structure to cache our ucred related
- * information. This structure only gets used if
- * the user specified UID/GID based constraints in
- * a firewall rule.
- */
-struct ip_fw_ugid {
-       gid_t           fw_groups[NGROUPS];
-       int             fw_ngroups;
-       uid_t           fw_uid;
-       int             fw_prid;
-};
-
 /*
  * list of rules for layer 3
  */
@@ -194,11 +174,13 @@ SYSCTL_VNET_INT(_net_inet_ip_fw, OID_AUTO, verbose,
 SYSCTL_VNET_INT(_net_inet_ip_fw, OID_AUTO, verbose_limit,
     CTLFLAG_RW, &VNET_NAME(verbose_limit), 0,
     "Set upper limit of matches of ipfw rules logged");
+unsigned int dummy_default_rule = IPFW_DEFAULT_RULE;
 SYSCTL_UINT(_net_inet_ip_fw, OID_AUTO, default_rule, CTLFLAG_RD,
-    NULL, IPFW_DEFAULT_RULE,
+    &dummy_default_rule, IPFW_DEFAULT_RULE,
     "The default/max possible rule number.");
+unsigned int dummy_tables_max = IPFW_TABLES_MAX;
 SYSCTL_UINT(_net_inet_ip_fw, OID_AUTO, tables_max, CTLFLAG_RD,
-    NULL, IPFW_TABLES_MAX,
+    &dummy_tables_max, IPFW_TABLES_MAX,
     "The maximum number of tables.");
 SYSCTL_INT(_net_inet_ip_fw, OID_AUTO, default_to_accept, CTLFLAG_RDTUN,
     &default_to_accept, 0,
@@ -218,9 +200,6 @@ SYSCTL_VNET_INT(_net_inet6_ip6_fw, OID_AUTO, deny_unknown_exthdrs,
 
 #endif /* SYSCTL_NODE */
 
-#ifndef IPFW_NEWTABLES_MAX
-#define IPFW_NEWTABLES_MAX     256
-#endif
 /*
  * Description of dynamic rules.
  *
@@ -1915,7 +1894,7 @@ lookup_next_rule(struct ip_fw *me, u_int32_t tablearg)
        struct ip_fw *rule = NULL;
        ipfw_insn *cmd;
        u_int16_t       rulenum;
-printf("%s called\n", __FUNCTION__);
+
        /* look for action, in case it is a skipto */
        cmd = ACTION_PTR(me);
        if (cmd->opcode == O_LOG)
@@ -1936,44 +1915,12 @@ printf("%s called\n", __FUNCTION__);
                        }
                }
        }
-       if (rule == NULL)                       /* failure or not a skipto */
+       if (rule == NULL)               /* failure or not a skipto */
                rule = me->next;
        me->next_rule = rule;
        return rule;
 }
 
-#ifdef IPFW_HAVE_SKIPTO_TABLE
-struct ip_fw *lookup_skipto_table(struct ip_fw_chain *chain, uint16_t num);
-
-struct ip_fw *
-lookup_skipto_table(struct ip_fw_chain *chain, uint16_t num)
-{
-       struct ip_fw *f;
-
-       printf("--%s called\n", __FUNCTION__);
-       if (1)
-               return NULL;
-       if (chain->skipto_pointers[num].id == chain->id) {
-               printf("-- %s pointer ok, return it\n", __FUNCTION__);
-               return chain->skipto_pointers[num].rule;
-       }
-       printf("-- %s search pointer\n", __FUNCTION__);
-
-       for (f = chain->rules; f ; f = f->next) {
-               if (f->rulenum == num) {
-                       chain->skipto_pointers[num].id = chain->id;
-                       chain->skipto_pointers[num].rule = f;
-                       printf("-- %s found, set and return\n", __FUNCTION__);
-                       return f;
-               }
-       }
-       printf("-- %s NOT found return NULL\n", __FUNCTION__);
-
-       return NULL;
-}
-#endif /* IPFW_HAVE_SKIPTO_TABLE */
-
-#ifdef radix
 static int
 add_table_entry(struct ip_fw_chain *ch, uint16_t tbl, in_addr_t addr,
     uint8_t mlen, uint32_t value)
@@ -1989,7 +1936,16 @@ add_table_entry(struct ip_fw_chain *ch, uint16_t tbl, in_addr_t addr,
        if (ent == NULL)
                return (ENOMEM);
        ent->value = value;
+#ifdef linux
+       /* there is no sin_len on linux, and the code assumes the first
+        * byte in the sockaddr to contain the length in bits.
+        * So we just dump the number right there
+        */
+       *((uint8_t *)&(ent->addr)) = 8;
+       *((uint8_t *)&(ent->mask)) = 8;
+#else
        ent->addr.sin_len = ent->mask.sin_len = 8;
+#endif
        ent->mask.sin_addr.s_addr = htonl(mlen ? ~((1 << (32 - mlen)) - 1) : 0);
        ent->addr.sin_addr.s_addr = addr & ent->mask.sin_addr.s_addr;
        IPFW_WLOCK(ch);
@@ -2014,7 +1970,13 @@ del_table_entry(struct ip_fw_chain *ch, uint16_t tbl, in_addr_t addr,
        if (tbl >= IPFW_TABLES_MAX)
                return (EINVAL);
        rnh = ch->tables[tbl];
+#ifdef linux
+       /* there is no sin_len on linux, see above */
+       *((uint8_t *)&sa) = 8;
+       *((uint8_t *)&mask) = 8;
+#else
        sa.sin_len = mask.sin_len = 8;
+#endif
        mask.sin_addr.s_addr = htonl(mlen ? ~((1 << (32 - mlen)) - 1) : 0);
        sa.sin_addr.s_addr = addr & mask.sin_addr.s_addr;
        IPFW_WLOCK(ch);
@@ -2055,18 +2017,6 @@ flush_table(struct ip_fw_chain *ch, uint16_t tbl)
        rnh->rnh_walktree(rnh, flush_table_entry, rnh);
        return (0);
 }
-#else
-extern int add_table_entry(struct ip_fw_chain *ch, uint16_t tbl,
-    in_addr_t addr, uint8_t mlen, uint32_t value);
-extern int del_table_entry(struct ip_fw_chain *ch, uint16_t tbl,
-    in_addr_t addr, uint8_t mlen);
-extern int flush_table(struct ip_fw_chain *ch, uint16_t tbl);
-extern int count_table(struct ip_fw_chain *ch, uint32_t tbl, uint32_t *cnt);
-extern int dump_table(struct ip_fw_chain *ch, ipfw_table *tbl);
-extern int lookup_table(struct ip_fw_chain *ch, uint16_t tbl, in_addr_t addr,
-    uint32_t *val);
-extern int init_tables(struct ip_fw_chain *ch);
-#endif
 
 static void
 flush_tables(struct ip_fw_chain *ch)
@@ -2075,11 +2025,10 @@ flush_tables(struct ip_fw_chain *ch)
 
        IPFW_WLOCK_ASSERT(ch);
 
-       for (tbl = IPFW_TABLES_MAX -1; tbl < IPFW_NEWTABLES_MAX; tbl++)
+       for (tbl = 0; tbl < IPFW_TABLES_MAX; tbl++)
                flush_table(ch, tbl);
 }
 
-#ifdef radix
 static int
 init_tables(struct ip_fw_chain *ch)
 { 
@@ -2108,7 +2057,12 @@ lookup_table(struct ip_fw_chain *ch, uint16_t tbl, in_addr_t addr,
        if (tbl >= IPFW_TABLES_MAX)
                return (0);
        rnh = ch->tables[tbl];
+#ifdef linux
+       /* there is no sin_len on linux, see above */
+       *((uint8_t *)&sa) = 8;
+#else
        sa.sin_len = 8;
+#endif
        sa.sin_addr.s_addr = addr;
        ent = (struct table_entry *)(rnh->rnh_lookup(&sa, NULL, rnh));
        if (ent != NULL) {
@@ -2117,9 +2071,7 @@ lookup_table(struct ip_fw_chain *ch, uint16_t tbl, in_addr_t addr,
        }
        return (0);
 }
-#endif
 
-#ifdef radix
 static int
 count_table_entry(struct radix_node *rn, void *arg)
 {
@@ -2175,52 +2127,37 @@ dump_table(struct ip_fw_chain *ch, ipfw_table *tbl)
        rnh->rnh_walktree(rnh, dump_table_entry, tbl);
        return (0);
 }
-#endif
-
-#ifndef linux /* FreeBSD */
-static void
-fill_ugid_cache(struct inpcb *inp, struct ip_fw_ugid *ugp)
-{
-       struct ucred *cr;
-
-       cr = inp->inp_cred;
-       ugp->fw_prid = jailed(cr) ? cr->cr_prison->pr_id : -1;
-       ugp->fw_uid = cr->cr_uid;
-       ugp->fw_ngroups = cr->cr_ngroups;
-       bcopy(cr->cr_groups, ugp->fw_groups, sizeof(ugp->fw_groups));
-}
-#endif
 
 static int
 check_uidgid(ipfw_insn_u32 *insn, int proto, struct ifnet *oif,
     struct in_addr dst_ip, u_int16_t dst_port, struct in_addr src_ip,
-    u_int16_t src_port, struct ip_fw_ugid *ugp, int *ugid_lookupp,
+    u_int16_t src_port, struct ucred **uc, int *ugid_lookup,
     struct inpcb *inp)
 {
 #ifdef linux
        int match = 0;
        struct sk_buff *skb = ((struct mbuf *)inp)->m_skb;
+       struct bsd_ucred *u = (struct bsd_ucred *)uc;
 
-       if (*ugid_lookupp == 0) {       /* actively lookup and copy in cache */
-
+       if (*ugid_lookup == 0) {        /* actively lookup and copy in cache */
                /* returns null if any element of the chain up to file is null.
                 * if sk != NULL then we also have a reference 
                 */
-               *ugid_lookupp = linux_lookup(proto,
+               *ugid_lookup = linux_lookup(proto,
                        src_ip.s_addr, htons(src_port),
                        dst_ip.s_addr, htons(dst_port),
-                       skb, oif ? 1 : 0, ugp);
+                       skb, oif ? 1 : 0, u);
 
        }
-       if (*ugid_lookupp < 0)
+       if (*ugid_lookup < 0)
                return 0;
 
        if (insn->o.opcode == O_UID)
-               match = (ugp->fw_uid == (uid_t)insn->d[0]);
+               match = (u->uid == (uid_t)insn->d[0]);
        else if (insn->o.opcode == O_JAIL)
-               match = (ugp->fw_groups[1] == (uid_t)insn->d[0]);
+               match = (u->xid == (uid_t)insn->d[0]);
        else if (insn->o.opcode == O_GID)
-               match = (ugp->fw_groups[0] == (uid_t)insn->d[0]);
+               match = (u->gid == (uid_t)insn->d[0]);
 
        return match;
 
@@ -2230,7 +2167,6 @@ check_uidgid(ipfw_insn_u32 *insn, int proto, struct ifnet *oif,
        int wildcard;
        struct inpcb *pcb;
        int match;
-       gid_t *gp;
 
        /*
         * Check to see if the UDP or TCP stack supplied us with
@@ -2240,7 +2176,7 @@ check_uidgid(ipfw_insn_u32 *insn, int proto, struct ifnet *oif,
        if (inp && *ugid_lookupp == 0) {
                INP_LOCK_ASSERT(inp);
                if (inp->inp_socket != NULL) {
-                       fill_ugid_cache(inp, ugp);
+                       *uc = crhold(inp->inp_cred);
                        *ugid_lookupp = 1;
                } else
                        *ugid_lookupp = -1;
@@ -2273,7 +2209,7 @@ check_uidgid(ipfw_insn_u32 *insn, int proto, struct ifnet *oif,
                                dst_ip, htons(dst_port),
                                wildcard, NULL);
                if (pcb != NULL) {
-                       fill_ugid_cache(pcb, ugp);
+                       *uc = crhold(pcb->inp_cred);
                        *ugid_lookupp = 1;
                }
                INP_INFO_RUNLOCK(pi);
@@ -2289,16 +2225,11 @@ check_uidgid(ipfw_insn_u32 *insn, int proto, struct ifnet *oif,
                }
        } 
        if (insn->o.opcode == O_UID)
-               match = (ugp->fw_uid == (uid_t)insn->d[0]);
-       else if (insn->o.opcode == O_GID) {
-               for (gp = ugp->fw_groups;
-                       gp < &ugp->fw_groups[ugp->fw_ngroups]; gp++)
-                       if (*gp == (gid_t)insn->d[0]) {
-                               match = 1;
-                               break;
-                       }
-       } else if (insn->o.opcode == O_JAIL)
-               match = (ugp->fw_prid == (int)insn->d[0]);
+               match = ((*uc)->cr_uid == (uid_t)insn->d[0]);
+       else if (insn->o.opcode == O_GID)
+               match = groupmember((gid_t)insn->d[0], *uc);
+       else if (insn->o.opcode == O_JAIL)
+               match = ((*uc)->cr_prison->pr_id == (int)insn->d[0]);
        return match;
 #endif
 }
@@ -2375,8 +2306,8 @@ ipfw_chk(struct ip_fw_args *args)
         * these types of constraints, as well as decrease contention
         * on pcb related locks.
         */
-       struct ip_fw_ugid fw_ugid_cache;
-       int ugid_lookup = 0;
+       struct bsd_ucred ucred_cache;
+       int ucred_lookup = 0;
 
        /*
         * divinput_flags       If non-zero, set to the IP_FW_DIVERT_*_FLAG
@@ -2735,8 +2666,17 @@ do {                                                                     \
                        IPFW_RUNLOCK(chain);
                        return (IP_FW_PASS);
                }
+               if (chain->id != args->chain_id) {
+                       for (f = chain->rules; f != NULL; f = f->next)
+                               if (f == args->rule && f->id == args->rule_id)
+                                       break;
 
-               f = args->rule->next_rule;
+                       if (f != NULL)
+                               f = f->next_rule;
+                       else
+                               f = chain->default_rule;
+               } else
+                       f = args->rule->next_rule;
 
                if (f == NULL)
                        f = lookup_next_rule(args->rule, 0);
@@ -2753,12 +2693,9 @@ do {                                                                     \
                                IPFW_RUNLOCK(chain);
                                return (IP_FW_DENY); /* invalid */
                        }
+//                     f = rule2ptr(chain, skipto+1);
                        while (f && f->rulenum <= skipto)
                                f = f->next;
-                       if (f == NULL) {        /* drop packet */
-                               IPFW_RUNLOCK(chain);
-                               return (IP_FW_DENY);
-                       }
                }
        }
        /* reset divert rule to avoid confusion later */
@@ -2857,8 +2794,8 @@ do {                                                                      \
                                                    (ipfw_insn_u32 *)cmd,
                                                    proto, oif,
                                                    dst_ip, dst_port,
-                                                   src_ip, src_port, &fw_ugid_cache,
-                                                   &ugid_lookup, (struct inpcb *)args->m);
+                                                   src_ip, src_port, (struct ucred **)&ucred_cache,
+                                                   &ucred_lookup, (struct inpcb *)args->m);
                                break;
 
                        case O_RECV:
@@ -2959,19 +2896,24 @@ do {                                                                    \
                                            a = dst_port;
                                        else if (v == 3)
                                            a = src_port;
-                                       else if (v >= 4 && v <= 6) {
+                                       else if (v == 4 || v == 5) {
                                            check_uidgid(
                                                    (ipfw_insn_u32 *)cmd,
                                                    proto, oif,
                                                    dst_ip, dst_port,
-                                                   src_ip, src_port, &fw_ugid_cache,
-                                                   &ugid_lookup, (struct inpcb *)args->m);
+                                                   src_ip, src_port, (struct ucred **)&ucred_cache,
+                                                   &ucred_lookup, (struct inpcb *)args->m);
+#ifdef linux
                                            if (v ==4 /* O_UID */)
-                                               a = fw_ugid_cache.fw_uid;
-                                           else if (v == 5 /* O_GID */)
-                                               a = fw_ugid_cache.fw_groups[0];
-                                           else if (v == 6 /* O_JAIL */)
-                                               a = fw_ugid_cache.fw_groups[1];
+                                               a = ucred_cache.uid;
+                                           else if (v == 5 /* O_JAIL */)
+                                               a = ucred_cache.xid;
+#else
+                                           if (v ==4 /* O_UID */)
+                                               a = (*uc)->cr_uid;
+                                           else if (v == 5 /* O_JAIL */)
+                                               a = (*uc)->cr_prison->pr_id;
+#endif
                                        } else
                                            break;
                                    }
@@ -3555,37 +3497,13 @@ do {                                                                    \
                                        break;
                                }
                                /* handle skipto */
-#ifdef IPFW_HAVE_SKIPTO_TABLE
-                               /* NOTE: lookup_skipto_table can return NULL
-                                *       if the rule isn't found, so the
-                                *       standard lookup function must be
-                                *       called XXX
-                                */
-                               if (cmd->arg1 == IP_FW_TABLEARG) {
-                                       f = lookup_skipto_table(chain,
-                                                                tablearg);
-                                       if (f == NULL)
-                                               f = lookup_next_rule(f, tablearg);
-                               }
-                                else {
-                                       f = lookup_skipto_table(chain,
-                                                                cmd->arg1);
-                                       if (f == NULL) {
-                                               if (f->next_rule == NULL)
-                                                       lookup_next_rule(f, 0);
-                                               f = f->next_rule;
-                                       }
-                                }
-
-#else
                                if (cmd->arg1 == IP_FW_TABLEARG) {
                                        f = lookup_next_rule(f, tablearg);
-                               } else {
+                               } else { // XXX ?
                                        if (f->next_rule == NULL)
                                                lookup_next_rule(f, 0);
                                        f = f->next_rule;
                                }
-#endif
                                /*
                                 * Skip disabled rules, and
                                 * re-enter the inner loop
@@ -3809,6 +3727,10 @@ do {                                                                     \
                printf("ipfw: ouch!, skip past end of rules, denying packet\n");
        }
        IPFW_RUNLOCK(chain);
+#ifdef __FreeBSD__
+       if (ucred_cache != NULL)
+               crfree(ucred_cache);
+#endif
        return (retval);
 
 pullup_failed:
@@ -3953,12 +3875,6 @@ remove_rule(struct ip_fw_chain *chain, struct ip_fw *rule,
 }
 
 /*
- * Hook for cleaning up dummynet when an ipfw rule is deleted.
- * Set/cleared when dummynet module is loaded/unloaded.
- */
-void   (*ip_dn_ruledel_ptr)(void *) = NULL;
-
-/**
  * Reclaim storage associated with a list of rules.  This is
  * typically the list created using remove_rule.
  * A NULL pointer on input is handled correctly.
@@ -3970,8 +3886,6 @@ reap_rules(struct ip_fw *head)
 
        while ((rule = head) != NULL) {
                head = head->next;
-               if (ip_dn_ruledel_ptr)
-                       ip_dn_ruledel_ptr(rule);
                free(rule, M_IPFW);
        }
 }
@@ -3988,6 +3902,7 @@ free_chain(struct ip_fw_chain *chain, int kill_default)
 
        IPFW_WLOCK_ASSERT(chain);
 
+       chain->reap = NULL;
        flush_rule_ptrs(chain); /* more efficient to do outside the loop */
        for (prev = NULL, rule = chain->rules; rule ; )
                if (kill_default || rule->set != RESVD_SET)
@@ -4115,10 +4030,8 @@ del_entry(struct ip_fw_chain *chain, u_int32_t arg)
         * avoid a LOR with dummynet.
         */
        rule = chain->reap;
-       chain->reap = NULL;
        IPFW_WUNLOCK(chain);
-       if (rule)
-               reap_rules(rule);
+       reap_rules(rule);
        return 0;
 }
 
@@ -4531,7 +4444,7 @@ ipfw_getrules(struct ip_fw_chain *chain, void *buf, size_t space)
        int i;
        time_t  boot_seconds;
 
-        boot_seconds = boottime.tv_sec;
+       boot_seconds = boottime.tv_sec;
        /* XXX this can take a long time and locking will block packet flow */
        IPFW_RLOCK(chain);
        for (rule = chain->rules; rule ; rule = rule->next) {
@@ -4619,7 +4532,6 @@ ipfw_getdynrules(struct ip_fw_chain *chain, void *buf, size_t space)
                if (last != NULL) /* mark last dynamic rule */
                        bzero(&last->next, sizeof(last));
        }
-
        return (bp - (char *)buf);
 }
 
@@ -4706,13 +4618,10 @@ ipfw_ctl(struct sockopt *sopt)
                 */
 
                IPFW_WLOCK(&V_layer3_chain);
-               V_layer3_chain.reap = NULL;
                free_chain(&V_layer3_chain, 0 /* keep default rule */);
                rule = V_layer3_chain.reap;
-               V_layer3_chain.reap = NULL;
                IPFW_WUNLOCK(&V_layer3_chain);
-               if (rule != NULL)
-                       reap_rules(rule);
+               reap_rules(rule);
                break;
 
        case IP_FW_ADD:
@@ -4905,14 +4814,6 @@ ipfw_ctl(struct sockopt *sopt)
 #undef RULE_MAXSIZE
 }
 
-/**
- * dummynet needs a reference to the default rule, because rules can be
- * deleted while packets hold a reference to them. When this happens,
- * dummynet changes the reference to the default rule (it could well be a
- * NULL pointer, but this way we do not need to check for the special
- * case, plus here he have info on the default behaviour).
- */
-//struct ip_fw *ip_fw_default_rule;
 
 /*
  * This procedure is only used to handle keepalives. It is invoked
@@ -5010,7 +4911,7 @@ ipfw_tick(void * vnetx)
 #endif
 done:
        callout_reset(&V_ipfw_timeout, V_dyn_keepalive_period*hz,
-               ipfw_tick, NULL);
+               ipfw_tick, vnetx);
        CURVNET_RESTORE();
 }
 
@@ -5096,17 +4997,12 @@ ipfw_destroy(void)
        IPFW_WUNLOCK(&V_layer3_chain);
        if (reap != NULL)
                reap_rules(reap);
-       IPFW_DYN_LOCK_DESTROY();
        uma_zdestroy(ipfw_dyn_rule_zone);
+       IPFW_DYN_LOCK_DESTROY();
        if (V_ipfw_dyn_v != NULL)
                free(V_ipfw_dyn_v, M_IPFW);
        IPFW_LOCK_DESTROY(&V_layer3_chain);
 
-#ifdef INET6
-       /* Free IPv6 fw sysctl tree. */
-       sysctl_ctx_free(&ip6_fw_sysctl_ctx);
-#endif
-
        printf("IP firewall unloaded\n");
 }
 
@@ -5132,12 +5028,6 @@ vnet_ipfw_init(const void *unused)
        if (error) {
                panic("init_tables"); /* XXX Marko fix this ! */
        }
-
-#ifdef IPFW_HAVE_SKIPTO_TABLE
-//     for (error = 0; error < 64*1024; error++)
-//             V_layer3_chain.skipto_pointers[error].id = -1;
-#endif /* IPFW_HAVE_SKIPTO_TABLE */
-
 #ifdef IPFIREWALL_NAT
        LIST_INIT(&V_layer3_chain.nat);
 #endif
@@ -5184,7 +5074,7 @@ vnet_ipfw_init(const void *unused)
                return (error);
        }
 
-       ip_fw_default_rule = V_layer3_chain.rules;
+       V_layer3_chain.default_rule = V_layer3_chain.rules;
 
        /* curvnet is NULL in the !VIMAGE case */
        callout_reset(&V_ipfw_timeout, hz, ipfw_tick, curvnet);