This commit was generated by cvs2svn to compensate for changes in r759,
[codemux.git] / vnet_main.c
1 /*
2  * VServer IP isolation.
3  *
4  * This file implements netfilter hooks and AF_INET socket function
5  * overrides.
6  *
7  * Mark Huang <mlhuang@cs.princeton.edu>
8  * Copyright (C) 2004 The Trustees of Princeton University
9  *
10  * $Id: vnet_main.c,v 1.40 2007/03/08 15:46:07 mef Exp $
11  */
12
13 #include <linux/version.h>
14 #include <linux/types.h>
15 #include <linux/module.h>
16 #include <linux/ip.h>
17 #include <linux/netfilter.h>
18 #include <linux/netfilter_ipv4.h>
19 #include <linux/pkt_sched.h>
20 #include <linux/skbuff.h>
21 #include <linux/tcp.h>
22 #include <linux/udp.h>
23 #include <linux/icmp.h>
24 #include <linux/slab.h>
25 #include <net/sock.h>
26 #include <net/route.h>
27 #include <net/tcp.h>
28
29 #include <linux/netfilter_ipv4/ip_conntrack.h>
30 #include <linux/netfilter_ipv4/ip_conntrack_protocol.h>
31 #include <linux/netfilter_ipv4/ip_conntrack_core.h>
32 #include <linux/netfilter_ipv4/ip_tables.h>
33
34 #include "vnet_config.h"
35 #include "vnet.h"
36 #include "vnet_dbg.h"
37 #include "vnet_compat.h"
38
39 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,16)
40
41 #define HAVE_FUNCTIONALITY_REQUIRED_BY_DEMUX
42
43 #include <net/inet_hashtables.h>
44
45 static inline void
46 vnet_timewait_put(struct sock* sk)
47 {
48          inet_twsk_put((struct inet_timewait_sock *)sk);
49 }
50
51 static inline struct sock* 
52 vnet_tcp_lookup(u32 src_ip, u16 src_port, 
53                 u32 ip, u16 port, int dif)
54 {
55   return inet_lookup(&tcp_hashinfo, src_ip, src_port, ip, port, dif);
56 }
57
58 static inline int vnet_iif(const struct sk_buff *skb)
59 {
60   return inet_iif(skb);
61 }
62 #endif
63
64 #if LINUX_VERSION_CODE == KERNEL_VERSION(2,6,12)
65
66 #define HAVE_FUNCTIONALITY_REQUIRED_BY_DEMUX
67
68 static inline void 
69 vnet_timewait_put(struct sock* sk)
70 {
71   /* net/tcp.h */
72   tcp_tw_put((struct tcp_tw_bucket*)sk);
73 }
74
75 static inline struct sock* 
76 vnet_tcp_lookup(u32 saddr, u16 sport, u32 daddr,u16 dport, int dif)
77 {
78   extern struct sock *tcp_v4_lookup(u32, u16, u32, u16, int);
79   return tcp_v4_lookup(saddr, sport, daddr, dport, dif);
80 }
81
82 /* same as tcp_v4_iff() in net/ipv4/tcp_ipv4. */
83 static inline int vnet_iif(const struct sk_buff *skb)
84 {
85         return ((struct rtable *)skb->dst)->rt_iif;
86 }
87 #endif
88
89 #ifndef HAVE_FUNCTIONALITY_REQUIRED_BY_DEMUX
90 #warning DEMUX FUNCTIONALITY NOT SUPPORTED
91 #endif
92
93 int vnet_verbose = 1;
94
95 /* We subdivide the 1: major class into 15 minor subclasses 1:1, 1:2,
96  * etc. so that we can represent multiple bandwidth limits. The 1:1
97  * subclass has children named 1:1000, 1:1001, etc., one for each
98  * context (up to 4096). Similarly, the 1:2 subclass has children
99  * named 1:2000, 1:2001, etc. By default, the 1:1 subclass represents
100  * the node bandwidth cap and 1:1000 represents the root context's
101  * share of it. */
102 int vnet_root_class = TC_H_MAKE(1 << 16, 0x1000);
103
104 #define FILTER_VALID_HOOKS ((1 << NF_IP_LOCAL_IN) | \
105                             (1 << NF_IP_LOCAL_OUT) | \
106                             (1 << NF_IP_POST_ROUTING))
107
108 #if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,11)
109
110 /* Standard entry. */
111 struct ipt_standard
112 {
113         struct ipt_entry entry;
114         struct ipt_standard_target target;
115 };
116
117 struct ipt_error_target
118 {
119         struct ipt_entry_target target;
120         char errorname[IPT_FUNCTION_MAXNAMELEN];
121 };
122
123 struct ipt_error
124 {
125         struct ipt_entry entry;
126         struct ipt_error_target target;
127 };
128
129 #endif
130
131 static struct
132 {
133         struct ipt_replace repl;
134         struct ipt_standard entries[3];
135         struct ipt_error term;
136 } initial_table __initdata =
137 {
138         .repl =
139         {
140                 .name = "vnet",
141                 .valid_hooks = FILTER_VALID_HOOKS,
142                 .num_entries = 4,
143                 .size = sizeof(struct ipt_standard) * 3 + sizeof(struct ipt_error),
144                 .hook_entry = { [NF_IP_LOCAL_IN] = 0,
145                                 [NF_IP_LOCAL_OUT] = sizeof(struct ipt_standard),
146                                 [NF_IP_POST_ROUTING] = sizeof(struct ipt_standard) * 2, },
147                 .underflow = { [NF_IP_LOCAL_IN] = 0,
148                                [NF_IP_LOCAL_OUT] = sizeof(struct ipt_standard),
149                                [NF_IP_POST_ROUTING] = sizeof(struct ipt_standard) * 2, },
150         },
151
152         .entries =
153         {
154                 /* LOCAL_IN: currently unused */
155                 { .entry = { .target_offset = sizeof(struct ipt_entry),
156                              .next_offset = sizeof(struct ipt_standard), },
157                   .target = { .target = { .u = { .target_size = IPT_ALIGN(sizeof(struct ipt_standard_target)), }, },
158                               .verdict = -NF_ACCEPT - 1, },
159                 },
160
161                 /* LOCAL_OUT: used for logging */
162                 { .entry = { .target_offset = sizeof(struct ipt_entry),
163                              .next_offset = sizeof(struct ipt_standard), },
164                   .target = { .target = { .u = { .target_size = IPT_ALIGN(sizeof(struct ipt_standard_target)), }, },
165                               .verdict = -NF_ACCEPT - 1, },
166                 },
167
168                 /* POST_ROUTING: used for priority classification */
169                 { .entry = { .target_offset = sizeof(struct ipt_entry),
170                              .next_offset = sizeof(struct ipt_standard), },
171                   .target = { .target = { .u = { .target_size = IPT_ALIGN(sizeof(struct ipt_standard_target)), }, },
172                               .verdict = -NF_ACCEPT - 1, },
173                 },
174         },
175
176         /* ERROR */
177         .term =
178         {
179                 .entry = { .target_offset = sizeof(struct ipt_entry),
180                            .next_offset = sizeof(struct ipt_error), },
181                 .target = { .target = { .u = { .user = { .target_size = IPT_ALIGN(sizeof(struct ipt_error_target)),
182                                                          .name = IPT_ERROR_TARGET, }, }, },
183                             .errorname = "ERROR", },
184         },
185 };
186
187 static struct ipt_table vnet_table = {
188         .name           = "vnet",
189 #if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,11)
190         .table          = &initial_table.repl,
191 #endif
192         .valid_hooks    = FILTER_VALID_HOOKS,
193         .lock           = RW_LOCK_UNLOCKED,
194         .me             = THIS_MODULE,
195 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,16)
196         .af             = AF_INET,
197 #endif
198 };
199
200 static inline u_int16_t
201 get_dst_port(struct ip_conntrack_tuple *tuple)
202 {
203         switch (tuple->dst.protonum) {
204         case IPPROTO_GRE:
205                 /* XXX Truncate 32-bit GRE key to 16 bits */
206 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,11)                
207                 return tuple->dst.u.gre.key;
208 #else
209                 return htons(ntohl(tuple->dst.u.gre.key));
210 #endif
211         case IPPROTO_ICMP:
212                 /* Bind on ICMP echo ID */
213                 return tuple->src.u.icmp.id;
214         case IPPROTO_TCP:
215                 return tuple->dst.u.tcp.port;
216         case IPPROTO_UDP:
217                 return tuple->dst.u.udp.port;
218         default:
219                 return tuple->dst.u.all;
220         }
221 }
222
223 static inline u_int16_t
224 get_src_port(struct ip_conntrack_tuple *tuple)
225 {
226         switch (tuple->dst.protonum) {
227         case IPPROTO_GRE:
228                 /* XXX Truncate 32-bit GRE key to 16 bits */
229                 return htons(ntohl(tuple->src.u.gre.key));
230         case IPPROTO_ICMP:
231                 /* Bind on ICMP echo ID */
232                 return tuple->src.u.icmp.id;
233         case IPPROTO_TCP:
234                 return tuple->src.u.tcp.port;
235         case IPPROTO_UDP:
236                 return tuple->src.u.udp.port;
237         default:
238                 return tuple->src.u.all;
239         }
240 }
241
242
243
244 static unsigned int
245 vnet_hook(unsigned int hook,
246           struct sk_buff **pskb,
247           const struct net_device *in,
248           const struct net_device *out,
249           int (*okfn)(struct sk_buff *))
250 {
251         struct ip_conntrack *ct;
252         enum ip_conntrack_info ctinfo;
253         enum ip_conntrack_dir dir;
254         u_int8_t protocol;
255         u_int32_t ip;
256         u_int16_t port;
257         struct bind_key *key;
258         xid_t xid;
259         unsigned int verdict;
260         int priority;
261         struct sock *sk;
262         int need_to_free_sk = 0;
263
264         ct = ip_conntrack_get(*pskb, &ctinfo);
265         dir = CTINFO2DIR(ctinfo);
266
267         /* Default to marking packet with root context ID */
268         xid = 0;
269
270         switch (hook) {
271
272         case NF_IP_LOCAL_IN:
273                 /* Multicast to 224.0.0.1 is one example */
274                 if (!ct)
275                         break;
276
277                 /* Determine if the packet is destined for a bound port */
278                 protocol = ct->tuplehash[dir].tuple.dst.protonum;
279                 assert(ctinfo == IP_CT_RELATED ||
280                        ctinfo == (IP_CT_IS_REPLY + IP_CT_RELATED) ||
281                        protocol == (*pskb)->nh.iph->protocol);
282                 ip = ct->tuplehash[dir].tuple.dst.ip;
283                 port = get_dst_port(&ct->tuplehash[dir].tuple);
284
285                 /* Check if the port is bound */
286                 key = bind_get(protocol, ip, port, NULL);
287
288                 if (key && key->sk != NULL) {
289
290                         /* A new or established connection to a bound port */
291                         sk = key->sk;
292
293 #ifdef HAVE_FUNCTIONALITY_REQUIRED_BY_DEMUX
294                         /* If the bound socket is a real TCP socket, then the context that
295                          * bound the port could have re-assigned an established connection
296                          * socket to another context. See if this is the case.
297                          */
298                         if (protocol == IPPROTO_TCP && sk->sk_type == SOCK_STREAM) {
299                                 struct sock *tcp_sk;
300                                 u_int32_t src_ip = ct->tuplehash[dir].tuple.src.ip;
301                                 u_int16_t src_port = get_src_port(&ct->tuplehash[dir].tuple);
302
303                                 tcp_sk = vnet_tcp_lookup(src_ip, src_port, ip, port, vnet_iif(*pskb));
304                                 if (tcp_sk) {
305                                   if (tcp_sk->sk_state == TCP_TIME_WAIT) {
306                                      sock_put(tcp_sk);
307                                   } else {
308                                     dbg("vnet_in:%d: established TCP socket %u.%u.%u.%u:%u -> %u.%u.%u.%u:%u\n", 
309                                         get_sk_xid(tcp_sk), NIPQUAD(src_ip), ntohs(src_port), NIPQUAD(ip), ntohs(port));
310                                     sk = tcp_sk;
311                                     need_to_free_sk = 1;
312                                   }
313                                   /* Remember to sock_put()! */
314                                 }
315                         }
316 #endif
317
318                         /* Indicate to the stack that the packet was "expected", so that it does
319                          * not generate a TCP RST or ICMP Unreachable message. This requires a
320                          * kernel patch.
321                          */
322                         if (sk->sk_type == SOCK_RAW)
323                           (*pskb)->sk = sk;
324
325                         assert(sk);
326                         xid = get_sk_xid(sk);
327
328                         /* Steal the reply end of the connection */
329                         if (get_ct_xid(ct, !dir) != xid) {
330                                 dbg("vnet_in:%d: stealing %sbound %s connection %u.%u.%u.%u:%u -> %u.%u.%u.%u:%u from context %d\n", xid,
331                                     key ? "" : "un", print_protocol(protocol),
332                                     NIPQUAD(ip), ntohs(port),
333                                     NIPQUAD(ct->tuplehash[!dir].tuple.dst.ip), ntohs(ct->tuplehash[!dir].tuple.dst.u.all),
334                                     get_ct_xid(ct, !dir));
335                                 set_ct_xid(ct, !dir, xid);
336                         }
337
338                         /* Store the owner (if any) of the other side of the connection (if
339                          * localhost) in the peercred struct.
340                          */
341                         sk->sk_peercred.uid = sk->sk_peercred.gid = (__u32) get_ct_xid(ct, dir);
342
343                         if (ctinfo == IP_CT_NEW) {
344                                 dbg("vnet_in: %s port %u.%u.%u.%u:%u bound by context %d\n",
345                                     print_protocol(protocol), NIPQUAD(ip), ntohs(port), xid);
346                         }
347
348 #ifdef HAVE_FUNCTIONALITY_REQUIRED_BY_DEMUX
349                         if (need_to_free_sk) {
350                           /*
351                           if (sk->sk_state == TCP_TIME_WAIT)
352                             vnet_timewait_put(sk);
353                           else*/
354                           sock_put(sk);
355                           need_to_free_sk=0;
356                         }
357 #endif
358                         bind_put(key);
359
360                 } else if ((int) get_ct_xid(ct, !dir) == -1) {
361                         /* A new connection to an unbound port */
362                         if (ctinfo == IP_CT_NEW) {
363                                 dbg("vnet_in: %s port %u.%u.%u.%u:%u not bound\n",
364                                     print_protocol(protocol), NIPQUAD(ip), ntohs(port));
365                         }
366                 } else {
367                         /* A new or established connection to an unbound port that could be
368                          * associated with an active socket ("could be" because the socket
369                          * could be closed and the connection in a WAIT state). In any case,
370                          * give it to the last owner of the connection.
371                          */
372                         xid = get_ct_xid(ct, !dir);
373                 }
374
375                 break;
376
377         case NF_IP_LOCAL_OUT:
378                 /* Get the context ID of the sender */
379                 assert((*pskb)->sk);
380                 xid = get_sk_xid((*pskb)->sk);
381
382                 /* Default class */
383                 priority = vnet_root_class;
384
385                 if (ct) {
386                         protocol = ct->tuplehash[dir].tuple.dst.protonum;
387                         assert(ctinfo == IP_CT_RELATED ||
388                                ctinfo == (IP_CT_IS_REPLY + IP_CT_RELATED) ||
389                                protocol == (*pskb)->nh.iph->protocol);
390                         ip = ct->tuplehash[dir].tuple.src.ip;
391                         assert(ctinfo == IP_CT_RELATED ||
392                                ctinfo == (IP_CT_IS_REPLY + IP_CT_RELATED) ||
393                                ip == __constant_htonl(INADDR_ANY) || ip == (*pskb)->nh.iph->saddr);
394                         port = get_src_port(&ct->tuplehash[dir].tuple);
395                 } else {
396                         protocol = port = 0;
397                 }
398
399                 if (xid) {
400                         /* Multicast to 224.0.0.1 is one example */
401                         if (!ct) {
402                                 dbg("vnet_out:%d: dropping untrackable IP packet\n", xid);
403                                 return NF_DROP;
404                         }
405
406                         /* XXX Is this guaranteed? */
407                         if ((*pskb)->len < sizeof(struct iphdr)) {
408                                 dbg("vnet_out:%d: dropping runt IP packet\n", xid);
409                                 return NF_DROP;
410                         }
411
412                         /* Check source IP address */
413                         if (inet_addr_type(ip) != RTN_LOCAL) {
414                                 dbg("vnet_out:%d: non-local source IP address %u.%u.%u.%u not allowed\n", xid,
415                                     NIPQUAD(ip));
416                                 return NF_DROP;
417                         }
418
419                         /* Sending of ICMP error messages not allowed */
420                         if (protocol == IPPROTO_ICMP) {
421                                 struct icmphdr *icmph = (struct icmphdr *)((*pskb)->nh.raw + ((*pskb)->nh.iph->ihl * 4));
422
423                                 if ((unsigned char *) &icmph[1] > (*pskb)->tail) {
424                                         dbg("vnet_out:%d: dropping runt ICMP packet\n", xid);
425                                         return NF_DROP;
426                                 }
427                                 
428                                 switch (icmph->type) {
429                                 case ICMP_ECHOREPLY:
430                                 case ICMP_ECHO:
431                                 case ICMP_TIMESTAMP:
432                                 case ICMP_TIMESTAMPREPLY:
433                                 case ICMP_INFO_REQUEST:
434                                 case ICMP_INFO_REPLY:
435                                 case ICMP_ADDRESS:
436                                 case ICMP_ADDRESSREPLY:
437                                         /* Guaranteed by icmp_pkt_to_tuple() */
438                                         assert(port == icmph->un.echo.id);
439                                         break;
440                                 default:
441                                         dbg("vnet_out:%d: sending of ICMP error messages not allowed\n", xid);
442                                         return NF_DROP;
443                                 }
444                         }
445                 } else {
446                         /* Let root send anything it wants */
447                 }
448
449                 if (ct) {
450                         /* Check if the port is bound by someone else */
451                         key = bind_get(protocol, ip, port, NULL);
452                 } else {
453                         assert(xid == 0);
454                         key = NULL;
455                 }
456
457                 if (key && key->sk != NULL) {
458                         /* A new or established connection from a bound port */
459                         assert(ct);
460
461                         sk = key->sk;
462
463 #ifdef HAVE_FUNCTIONALITY_REQUIRED_BY_DEMUX
464                         /* If the bound socket is a real TCP socket, then the context that
465                          * bound the port could have re-assigned an established connection
466                          * socket to the sender's context. See if this is the case.
467                          */
468                         if (protocol == IPPROTO_TCP && sk->sk_type == SOCK_STREAM && get_sk_xid(sk) != xid) {
469                                 struct sock *tcp_sk;
470                                 u_int32_t dst_ip = ct->tuplehash[dir].tuple.dst.ip;
471                                 u_int16_t dst_port = get_dst_port(&ct->tuplehash[dir].tuple);
472
473                                 tcp_sk = vnet_tcp_lookup(dst_ip, dst_port, ip, port, vnet_iif(*pskb));
474                                 if (tcp_sk) {
475                                   if (tcp_sk->sk_state == TCP_TIME_WAIT) {
476                                     sock_put(tcp_sk);
477                                     //vnet_timewait_put(tcp_sk);
478                                   } else {
479                                     need_to_free_sk = 1;
480                                     sk = tcp_sk;
481                                     /* Remember to sock_put()! */
482                                   }
483                                 }
484                         }
485 #endif
486
487                         verdict = NF_ACCEPT;
488
489                         /* Stealing connections from established sockets is not allowed */
490                         assert(sk);
491                         if (get_sk_xid(sk) != xid) {
492                                 if (xid) {
493                                         dbg("vnet_out:%d: %s port %u.%u.%u.%u:%u already bound by context %d\n", xid,
494                                             print_protocol(protocol), NIPQUAD(ip), ntohs(port), get_sk_xid(sk));
495                                         verdict = NF_DROP;
496                                 } else {
497                                         /* Let root send whatever it wants but do not steal the packet or
498                                          * connection. Kernel sockets owned by root may send packets on
499                                          * behalf of bound sockets (for instance, TCP ACK in SYN_RECV or
500                                          * TIME_WAIT).
501                                          */
502                                         xid = get_sk_xid(sk);
503                                 }
504                         }
505
506 #ifdef HAVE_FUNCTIONALITY_REQUIRED_BY_DEMUX
507                         if (need_to_free_sk) {
508                         /*
509                           if (sk->sk_state == TCP_TIME_WAIT)
510                             vnet_timewait_put(sk);
511                           else */
512                           sock_put(sk);
513                           need_to_free_sk = 0;
514                         }
515 #endif
516                         bind_put(key);
517
518                         if (verdict == NF_DROP)
519                                 goto done;
520                 } else {
521                         /* A new or established or untrackable connection from an unbound port */
522
523                         /* Reserved ports must be bound. Usually only root is capable of
524                          * CAP_NET_BIND_SERVICE.
525                          */
526                         if (xid &&
527                             (protocol == IPPROTO_TCP || protocol == IPPROTO_UDP) &&
528                             ntohs(port) < PROT_SOCK) {
529                                 assert(ct);
530                                 dbg("vnet_out:%d: %s port %u is reserved\n", xid,
531                                     print_protocol(protocol), ntohs(port));
532                                 return NF_DROP;
533                         }
534                 }
535
536                 if (ct) {
537                         /* Steal the connection */
538                         if (get_ct_xid(ct, dir) != xid) {
539                                 dbg("vnet_out:%d: stealing %sbound %s connection %u.%u.%u.%u:%u -> %u.%u.%u.%u:%u from context %d\n", xid,
540                                     key ? "" : "un", print_protocol(protocol),
541                                     NIPQUAD(ip), ntohs(port),
542                                     NIPQUAD(ct->tuplehash[dir].tuple.dst.ip), ntohs(ct->tuplehash[dir].tuple.dst.u.all),
543                                     get_ct_xid(ct, dir));
544                                 set_ct_xid(ct, dir, xid);
545                         }
546
547                         /* Classify traffic once per connection */
548                         if (ct->priority == (u_int32_t) -1) {
549                                 /* The POSTROUTING chain should classify packets into a minor subclass
550                                  * (1:1000, 1:2000, etc.) with -j CLASSIFY --set-class. Set the packet
551                                  * MARK early so that rules can take xid into account. */
552                                 set_skb_xid(*pskb, xid);
553                                 (*pskb)->priority = priority;
554                                 (void) ipt_do_table(pskb, NF_IP_POST_ROUTING, in, out, &vnet_table, NULL);
555                                 priority = (*pskb)->priority | xid;
556                                 dbg("vnet_out:%d: %u.%u.%u.%u:%u -> %u.%u.%u.%u:%u class %x:%x\n", xid,
557                                     NIPQUAD(ip), ntohs(port),
558                                     NIPQUAD(ct->tuplehash[dir].tuple.dst.ip), ntohs(ct->tuplehash[dir].tuple.dst.u.all),
559                                     TC_H_MAJ(priority) >> 16, TC_H_MIN(priority));
560                                 ct->priority = priority;
561                         } else
562                                 priority = ct->priority;
563                 } else {
564                         assert(xid == 0);
565                 }
566
567                 /* Set class */
568                 (*pskb)->priority = priority;
569
570                 break;
571
572         default:
573                 /* Huh? */
574                 assert(hook == NF_IP_LOCAL_IN || hook == NF_IP_LOCAL_OUT);
575                 break;
576         }
577
578         /* Mark packet */
579         set_skb_xid(*pskb, xid);
580
581 #ifdef VNET_DEBUG
582         if (vnet_verbose >= 3) {
583                 if (ct)
584                         print_conntrack(ct, ctinfo, hook);
585                 if (vnet_verbose >= 4)
586                         print_packet(*pskb);
587         }
588 #endif
589
590  get_verdict:
591         verdict = ipt_do_table(pskb, hook, in, out, &vnet_table, NULL);
592
593         /* Pass to network taps */
594         if (verdict == NF_ACCEPT)
595                 verdict = packet_hook(*pskb, hook);
596
597  done:
598         return verdict;
599 }
600
601 static struct nf_hook_ops vnet_ops[] = {
602         {
603                 .hook           = vnet_hook,
604 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,0)
605                 .owner          = THIS_MODULE,
606 #endif
607                 .pf             = PF_INET,
608                 .hooknum        = NF_IP_LOCAL_IN,
609                 .priority       = NF_IP_PRI_LAST,
610         },
611         {
612                 .hook           = vnet_hook,
613 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,0)
614                 .owner          = THIS_MODULE,
615 #endif
616                 .pf             = PF_INET,
617                 .hooknum        = NF_IP_LOCAL_OUT,
618                 .priority       = NF_IP_PRI_LAST,
619         },
620 };
621
622 /* Exported by net/ipv4/af_inet.c */
623 extern struct net_proto_family inet_family_ops;
624 extern struct proto_ops inet_stream_ops;
625 extern struct proto_ops inet_dgram_ops;
626 extern struct proto_ops inet_sockraw_ops;
627 extern int inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len);
628 extern int inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
629                                int addr_len, int flags);
630 extern int inet_listen(struct socket *sock, int backlog);
631 extern int inet_dgram_connect(struct socket *sock, struct sockaddr * uaddr,
632                               int addr_len, int flags);
633 extern int inet_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg,
634                         size_t size);
635 extern int inet_release(struct socket *sock);
636
637 /* Exported by net/ipv4/tcp_ipv4.c */
638 extern struct proto tcp_prot;
639 extern int tcp_port_rover;
640 extern int sysctl_local_port_range[2];
641
642 /* Exported by net/ipv4/udp.c */
643 extern struct proto udp_prot;
644 extern int udp_port_rover;
645
646 /* Functions that are not exported */
647 static int (*inet_create)(struct socket *sock, int protocol);
648 static ssize_t (*inet_sendpage)(struct socket *sock, struct page *page, int offset, size_t size, int flags);
649 static void (*tcp_v4_hash)(struct sock *sk);
650 static void (*tcp_v4_unhash)(struct sock *sk);
651 static void (*udp_v4_hash)(struct sock *sk);
652 static void (*udp_v4_unhash)(struct sock *sk);
653
654 static int
655 vnet_inet_create(struct socket *sock, int protocol)
656 {
657         int ret;
658
659         if (sock->type == SOCK_RAW) {
660                 /* Temporarily give CAP_NET_RAW to root VServer accounts */
661                 if (current->euid)
662                         return -EPERM;
663                 cap_raise(current->cap_effective, CAP_NET_RAW);
664         }
665         ret = inet_create(sock, protocol);
666         if (sock->type == SOCK_RAW)
667                 cap_lower(current->cap_effective, CAP_NET_RAW);
668         if (ret)
669                 return ret;
670
671         if (sock->type == SOCK_RAW) {
672                 struct sock *sk = sock->sk;
673                 struct inet_opt *inet = inet_sk(sk);
674                 /* Usually redundant and unused */
675                 assert(inet->sport == htons(inet->num));
676                 /* So we can track double raw binds */
677                 inet->sport = 0;
678         }
679
680         return ret;
681 }
682
683 /* Make sure our bind table gets updated whenever the stack decides to
684  * unhash or rehash a socket.
685  */
686 static void
687 vnet_inet_unhash(struct sock *sk)
688 {
689         struct inet_opt *inet = inet_sk(sk);
690         struct bind_key *key;
691
692         key = bind_get(sk->sk_protocol, inet->saddr, inet->sport, sk);
693         if (key) {
694                 dbg("vnet_inet_unhash:%d: released %s port %u.%u.%u.%u:%u\n", get_sk_xid(sk),
695                     print_protocol(sk->sk_protocol), NIPQUAD(inet->saddr), ntohs(inet->sport));
696                 bind_del(key);
697                 bind_put(key);
698         }
699
700         if (sk->sk_protocol == IPPROTO_TCP)
701                 tcp_v4_unhash(sk);
702         else if (sk->sk_protocol == IPPROTO_UDP)
703                 udp_v4_unhash(sk);
704 }
705
706 static void
707 vnet_inet_hash(struct sock *sk)
708 {
709         struct inet_opt *inet = inet_sk(sk);
710
711         if (bind_add(sk->sk_protocol, inet->saddr, inet->sport, sk) == 0) {
712                 dbg("vnet_inet_hash:%d: bound %s port %u.%u.%u.%u:%u\n", get_sk_xid(sk),
713                     print_protocol(sk->sk_protocol), NIPQUAD(inet->saddr), ntohs(inet->sport));
714         }
715
716         if (sk->sk_protocol == IPPROTO_TCP)
717                 tcp_v4_hash(sk);
718         else if (sk->sk_protocol == IPPROTO_UDP)
719                 udp_v4_hash(sk);
720 }
721
722 /* Port reservation */
723 static int
724 vnet_inet_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
725 {
726         struct sock *sk = sock->sk;
727         struct inet_opt *inet = inet_sk(sk);
728         struct sockaddr_in *sin = (struct sockaddr_in *) uaddr;
729         struct bind_key *key;
730         int ret;
731
732         /* Bind socket */
733         if ((ret = inet_bind(sock, uaddr, addr_len)))
734                 return ret;
735
736         lock_sock(sk);
737
738         /* Backward compatibility with safe raw sockets */
739         if (sock->type == SOCK_RAW) {
740                 /* Runt sockaddr */
741                 if (addr_len < sizeof(struct sockaddr_in))
742                         ret = -EINVAL;
743                 /* Non-local bind */
744                 else if (sin->sin_addr.s_addr != __constant_htonl(INADDR_ANY) && inet_addr_type(sin->sin_addr.s_addr) != RTN_LOCAL)
745                         ret = -EINVAL;
746                 /* Unspecified port */
747                 else if (!sin->sin_port)
748                         ret = -EINVAL;
749                 /* Reserved port */
750                 else if ((sk->sk_protocol == IPPROTO_TCP || sk->sk_protocol == IPPROTO_UDP) &&
751                          ntohs(sin->sin_port) < PROT_SOCK && !capable(CAP_NET_BIND_SERVICE))
752                         ret = -EACCES;
753                 /* Double bind */
754                 else if (inet->sport)
755                         ret = -EINVAL;
756                 if (ret)
757                         goto done;
758                 inet->saddr = sin->sin_addr.s_addr;
759                 inet->sport = sin->sin_port;
760         }
761
762         key = bind_get(sk->sk_protocol, inet->saddr, inet->sport, NULL);
763         if (key) {
764                 /*
765                  * If we are root or own the already bound socket, and
766                  * SO_REUSEADDR has been set on both.
767                  */
768                 if ((get_sk_xid(sk) == 0 || get_sk_xid(sk) == get_sk_xid(key->sk)) &&
769                     key->sk->sk_reuse && sk->sk_reuse) {
770                         if (key->ip == __constant_htonl(INADDR_ANY)) {
771                                 /* Keep the current bind key */
772                                 bind_put(key);
773                                 goto done;
774                         } else if (inet->saddr == __constant_htonl(INADDR_ANY)) {
775                                 /* Consider the port to be bound to this socket now */
776                                 bind_del(key);
777                         }
778                 }
779                 bind_put(key);
780         }
781
782         if ((ret = bind_add(sk->sk_protocol, inet->saddr, inet->sport, sk)) == 0) {
783                 dbg("vnet_inet_bind:%d: bound %s port %u.%u.%u.%u:%u\n", get_sk_xid(sk),
784                     print_protocol(sk->sk_protocol), NIPQUAD(inet->saddr), ntohs(inet->sport));
785         }
786
787  done:
788         release_sock(sk);
789         return ret;
790 }
791
792 /* Override TCP and UDP port rovers since they do not know about raw
793  * socket binds.
794  */
795 static int
796 vnet_autobind(struct sock *sk)
797 {
798         int (*get_port)(struct sock *, unsigned short);
799         int low = sysctl_local_port_range[0];
800         int high = sysctl_local_port_range[1];
801         int remaining = (high - low) + 1;
802         int port;
803         struct inet_opt *inet = inet_sk(sk);
804         struct bind_key *key;
805
806         /* Must be locked */
807         assert(sock_owned_by_user(sk));
808
809         /* Already bound to a port */
810         if (inet->num)
811                 return 0;
812
813         if (sk->sk_protocol == IPPROTO_TCP) {
814                 get_port = tcp_prot.get_port;
815 #if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,14)
816                 /* Approximate the tcp_v4_get_port() strategy */
817                 port = tcp_port_rover + 1;
818 #else
819                 /* Approximate the inet_csk_get_port() strategy */
820                 port = net_random() % (high - low) + low;
821 #endif
822         } else if (sk->sk_protocol == IPPROTO_UDP) {
823                 get_port = udp_prot.get_port;
824                 port = udp_port_rover;
825         } else if (sk->sk_prot->get_port) {
826                 err("vnet_get_port:%d: %s unhandled\n", get_sk_xid(sk),
827                     print_protocol(sk->sk_protocol));
828                 if (sk->sk_prot->get_port(sk, 0))
829                         return -EAGAIN;
830                 inet->sport = htons(inet->num);
831                 return 0;
832         } else {
833                 return 0;
834         }
835
836         dbg("vnet_autobind:%d: roving %s port range %u.%u.%u.%u:%u-%u\n", get_sk_xid(sk),
837             print_protocol(sk->sk_protocol), NIPQUAD(inet->saddr), low, high);
838
839         /* Find a free port by linear search. Note that the standard
840          * udp_v4_get_port() function attempts to pick a port that
841          * keeps its hash tables balanced. If the UDP hash table keeps
842          * getting bombed, we should try implementing this strategy
843          * here.
844          */
845         do {
846                 if (port < low || port > high)
847                         port = low;
848
849                 /* XXX We could probably try something more clever
850                  * like checking to see if the bound socket is a
851                  * regular TCP socket owned by the same context (or we
852                  * are root) and, if so, letting tcp_v4_get_port()
853                  * apply its fast reuse logic to determine if the port
854                  * can be reused.
855                  */
856                 if (bind_add(sk->sk_protocol, inet->saddr, htons(port), sk)) {
857                         dbg("vnet_get_port:%d: %s port %u.%u.%u.%u:%u already bound\n", get_sk_xid(sk),
858                             print_protocol(sk->sk_protocol), NIPQUAD(inet->saddr), port);
859                         goto next;
860                 }
861
862                 if (get_port(sk, port)) {
863                         /* Can happen if we are unloaded when there are active sockets */
864                         dbg("vnet_get_port:%d: failed to hash unbound %s port %u.%u.%u.%u:%u\n", get_sk_xid(sk),
865                             print_protocol(sk->sk_protocol), NIPQUAD(inet->saddr), port);
866                         key = bind_get(sk->sk_protocol, inet->saddr, htons(port), sk);
867                         assert(key);
868                         bind_del(key);
869                         bind_put(key);
870                 } else {
871                         assert(port == inet->num);
872                         inet->sport = htons(inet->num);
873                         break;
874                 }
875         next:
876                 port++;
877         } while (--remaining > 0);
878
879         if (sk->sk_protocol == IPPROTO_UDP)
880                 udp_port_rover = port;
881 #if LINUX_VERSION_CODE < KERNEL_VERSION(2,6,14)
882         else if (sk->sk_protocol == IPPROTO_TCP)
883                 tcp_port_rover = port;
884 #endif
885
886         if (remaining <= 0) {
887                 err("vnet_get_port:%d: exhausted local %s port range %u.%u.%u.%u:%u-%u\n", get_sk_xid(sk),
888                     print_protocol(sk->sk_protocol), NIPQUAD(inet->saddr), low, high);
889                 return -EAGAIN;
890         } else {
891                 dbg("vnet_get_port:%d: autobound %s port %u.%u.%u.%u:%u\n", get_sk_xid(sk),
892                     print_protocol(sk->sk_protocol), NIPQUAD(inet->saddr), port);
893                 return 0;
894         }
895 }
896
897 static int
898 vnet_inet_stream_connect(struct socket *sock, struct sockaddr *uaddr,
899                          int addr_len, int flags)
900 {
901         struct sock *sk = sock->sk;
902
903         lock_sock(sk);
904
905         /* Duplicates checks in inet_stream_connect() */
906         if (uaddr->sa_family != AF_UNSPEC &&
907             sock->state == SS_UNCONNECTED &&
908             sk->sk_state == TCP_CLOSE) {
909                 /* We may need to bind the socket. */
910                 if (!inet_sk(sk)->num && vnet_autobind(sk)) {
911                         release_sock(sk);
912                         return -EAGAIN;
913                 }
914         }
915
916         release_sock(sk);
917
918         return inet_stream_connect(sock, uaddr, addr_len, flags);
919 }
920
921 static int 
922 vnet_inet_listen(struct socket *sock, int backlog)
923 {
924         struct sock *sk = sock->sk;
925
926         lock_sock(sk);
927
928         /* Duplicates checks in inet_listen() */
929         if (sock->type == SOCK_STREAM &&
930             sock->state == SS_UNCONNECTED &&
931             sk->sk_state == TCP_CLOSE) {
932                 /* We may need to bind the socket. */
933                 if (!inet_sk(sk)->num && vnet_autobind(sk)) {
934                         release_sock(sk);
935                         return -EAGAIN;
936                 }
937         }
938
939         release_sock(sk);
940
941         return inet_listen(sock, backlog);
942 }
943
944 static int
945 vnet_inet_dgram_connect(struct socket *sock, struct sockaddr * uaddr,
946                         int addr_len, int flags)
947 {
948         struct sock *sk = sock->sk;
949
950         lock_sock(sk);
951
952         /* Duplicates checks in inet_dgram_connect() */
953         if (uaddr->sa_family != AF_UNSPEC) {
954                 /* We may need to bind the socket. */
955                 if (!inet_sk(sk)->num && vnet_autobind(sk)) {
956                         release_sock(sk);
957                         return -EAGAIN;
958                 }
959         }
960
961         release_sock(sk);
962
963         return inet_dgram_connect(sock, uaddr, addr_len, flags);
964 }
965
966 static int
967 vnet_inet_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *msg,
968                   size_t size)
969 {
970         struct sock *sk = sock->sk;
971
972         lock_sock(sk);
973
974         /* We may need to bind the socket. */
975         if (!inet_sk(sk)->num && vnet_autobind(sk)) {
976                 release_sock(sk);
977                 return -EAGAIN;
978         }
979
980         release_sock(sk);
981
982         return inet_sendmsg(iocb, sock, msg, size);
983 }
984
985 static ssize_t
986 vnet_inet_sendpage(struct socket *sock, struct page *page, int offset, size_t size, int flags)
987 {
988         struct sock *sk = sock->sk;
989
990         lock_sock(sk);
991
992         /* We may need to bind the socket. */
993         if (!inet_sk(sk)->num && vnet_autobind(sk)) {
994                 release_sock(sk);
995                 return -EAGAIN;
996         }
997
998         release_sock(sk);
999
1000         return inet_sendpage(sock, page, offset, size, flags);
1001 }
1002
1003 static int
1004 vnet_inet_release(struct socket *sock)
1005 {
1006         struct sock *sk = sock->sk;
1007         struct inet_opt *inet = inet_sk(sk);
1008         struct bind_key *key;
1009
1010         /* Partial socket created by accept() */
1011         if (!sk)
1012                 goto done;
1013
1014         lock_sock(sk);
1015
1016         key = bind_get(sk->sk_protocol, inet->saddr, inet->sport, sk);
1017         if (key) {
1018                 dbg("vnet_inet_release:%d: released %s port %u.%u.%u.%u:%u\n", get_sk_xid(sk),
1019                     print_protocol(sk->sk_protocol), NIPQUAD(inet->saddr), ntohs(inet->sport));
1020                 bind_del(key);
1021                 bind_put(key);
1022         }
1023
1024         release_sock(sk);
1025
1026  done:
1027         return inet_release(sock);
1028 }
1029
1030 /* Sanity check */
1031 #define override_op(op, from, to) do { assert((op) == (from)); (op) = (to); } while (0)
1032
1033 static int __init
1034 vnet_init(void)
1035 {
1036         int ret;
1037
1038         /* Initialize bind table */
1039         ret = bind_init();
1040         if (ret < 0)
1041                 return ret;
1042
1043         /* Register /proc entries */
1044         ret = proc_init();
1045         if (ret < 0)
1046                 goto cleanup_bind;
1047
1048         /* Register dummy netdevice */
1049         ret = packet_init();
1050         if (ret < 0)
1051                 goto cleanup_proc;
1052
1053         /* Register tap netdevice */
1054         ret = tun_init();
1055         if (ret < 0)
1056                 goto cleanup_packet;
1057
1058         /* Get pointers to unexported functions */
1059         inet_create = inet_family_ops.create;
1060         inet_sendpage = inet_dgram_ops.sendpage;
1061         tcp_v4_hash = tcp_prot.hash;
1062         tcp_v4_unhash = tcp_prot.unhash;
1063         udp_v4_hash = udp_prot.hash;
1064         udp_v4_unhash = udp_prot.unhash;
1065
1066         /* Override PF_INET socket operations */
1067         override_op(inet_family_ops.create, inet_create, vnet_inet_create);
1068         override_op(inet_stream_ops.bind, inet_bind, vnet_inet_bind);
1069         override_op(inet_stream_ops.connect, inet_stream_connect, vnet_inet_stream_connect);
1070         override_op(inet_stream_ops.listen, inet_listen, vnet_inet_listen);
1071         override_op(inet_stream_ops.sendmsg, inet_sendmsg, vnet_inet_sendmsg);
1072         override_op(inet_stream_ops.release, inet_release, vnet_inet_release);
1073         override_op(inet_dgram_ops.bind, inet_bind, vnet_inet_bind);
1074         override_op(inet_dgram_ops.connect, inet_dgram_connect, vnet_inet_dgram_connect);
1075         override_op(inet_dgram_ops.sendmsg, inet_sendmsg, vnet_inet_sendmsg); 
1076         override_op(inet_dgram_ops.sendpage, inet_sendpage, vnet_inet_sendpage);
1077         override_op(inet_dgram_ops.release, inet_release, vnet_inet_release);
1078 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,10)
1079         override_op(inet_sockraw_ops.bind, inet_bind, vnet_inet_bind);
1080         override_op(inet_sockraw_ops.connect, inet_dgram_connect, vnet_inet_dgram_connect);
1081         override_op(inet_sockraw_ops.sendmsg, inet_sendmsg, vnet_inet_sendmsg);
1082         override_op(inet_sockraw_ops.sendpage, inet_sendpage, vnet_inet_sendpage); 
1083         override_op(inet_sockraw_ops.release, inet_release, vnet_inet_release);
1084 #endif
1085         override_op(tcp_prot.hash, tcp_v4_hash, vnet_inet_hash);
1086         override_op(tcp_prot.unhash, tcp_v4_unhash, vnet_inet_unhash);
1087         override_op(udp_prot.hash, udp_v4_hash, vnet_inet_hash);
1088         override_op(udp_prot.unhash, udp_v4_unhash, vnet_inet_unhash);
1089
1090         /* Register table */
1091 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,11)
1092         ret = ipt_register_table(&vnet_table, &initial_table.repl);
1093 #else
1094         ret = ipt_register_table(&vnet_table);
1095 #endif
1096         if (ret < 0)
1097                 goto cleanup_override;
1098
1099         /* Register hooks */
1100         ret = nf_register_hook(&vnet_ops[0]);
1101         if (ret < 0)
1102                 goto cleanup_table;
1103
1104         ret = nf_register_hook(&vnet_ops[1]);
1105         if (ret < 0)
1106                 goto cleanup_hook0;
1107
1108         /* Enables any runtime kernel support for VNET */
1109         vnet_active = 1;
1110
1111         /* Print banner */
1112         printk("VNET: version " VNET_VERSION " compiled on " __DATE__ " at " __TIME__ "\n");
1113
1114         return ret;
1115
1116  cleanup_hook0:
1117         nf_unregister_hook(&vnet_ops[0]);
1118  cleanup_table:
1119         ipt_unregister_table(&vnet_table);
1120  cleanup_override:
1121         inet_family_ops.create = inet_create;
1122         inet_stream_ops.bind = inet_bind;
1123         inet_stream_ops.connect = inet_stream_connect;
1124         inet_stream_ops.listen = inet_listen;
1125         inet_stream_ops.sendmsg = inet_sendmsg;
1126         inet_stream_ops.release = inet_release;
1127         inet_dgram_ops.bind = inet_bind;
1128         inet_dgram_ops.connect = inet_dgram_connect;
1129         inet_dgram_ops.sendmsg = inet_sendmsg;
1130         inet_dgram_ops.sendpage = inet_sendpage;
1131         inet_dgram_ops.release = inet_release;
1132         tun_cleanup();
1133  cleanup_packet:
1134         packet_cleanup();       
1135  cleanup_proc:
1136         proc_cleanup();
1137  cleanup_bind:
1138         bind_cleanup();
1139
1140         return ret;
1141 }
1142
1143 static void __exit
1144 vnet_exit(void)
1145 {
1146         unsigned int i;
1147
1148         /* Print banner */
1149         printk("VNET: exiting\n");
1150
1151         /* Disables any runtime kernel support for VNET */
1152         vnet_active = 0;
1153
1154         /* Stop handling packets first */
1155         for (i = 0; i < sizeof(vnet_ops)/sizeof(struct nf_hook_ops); i++)
1156                 nf_unregister_hook(&vnet_ops[i]);
1157
1158         ipt_unregister_table(&vnet_table);
1159
1160         /* Stop handling PF_INET socket operations */
1161         override_op(inet_family_ops.create, vnet_inet_create, inet_create);
1162         override_op(inet_stream_ops.bind, vnet_inet_bind, inet_bind);
1163         override_op(inet_stream_ops.connect, vnet_inet_stream_connect, inet_stream_connect);
1164         override_op(inet_stream_ops.listen, vnet_inet_listen, inet_listen);
1165         override_op(inet_stream_ops.sendmsg, vnet_inet_sendmsg, inet_sendmsg);
1166         override_op(inet_stream_ops.release, vnet_inet_release, inet_release);
1167         override_op(inet_dgram_ops.bind, vnet_inet_bind, inet_bind);
1168         override_op(inet_dgram_ops.connect, vnet_inet_dgram_connect, inet_dgram_connect);
1169         override_op(inet_dgram_ops.sendmsg, vnet_inet_sendmsg, inet_sendmsg); 
1170         override_op(inet_dgram_ops.sendpage, vnet_inet_sendpage, inet_sendpage);
1171         override_op(inet_dgram_ops.release, vnet_inet_release, inet_release);
1172 #if LINUX_VERSION_CODE >= KERNEL_VERSION(2,6,10)
1173         override_op(inet_sockraw_ops.bind, vnet_inet_bind, inet_bind);
1174         override_op(inet_sockraw_ops.connect, vnet_inet_dgram_connect, inet_dgram_connect);
1175         override_op(inet_sockraw_ops.sendmsg, vnet_inet_sendmsg, inet_sendmsg);
1176         override_op(inet_sockraw_ops.sendpage, vnet_inet_sendpage, inet_sendpage); 
1177         override_op(inet_sockraw_ops.release, vnet_inet_release, inet_release);
1178 #endif
1179         override_op(tcp_prot.hash, vnet_inet_hash, tcp_v4_hash);
1180         override_op(tcp_prot.unhash, vnet_inet_unhash, tcp_v4_unhash);
1181         override_op(udp_prot.hash, vnet_inet_hash, udp_v4_hash);
1182         override_op(udp_prot.unhash, vnet_inet_unhash, udp_v4_unhash);
1183
1184         /* Disable tap netdevice */
1185         tun_cleanup();
1186
1187         /* Disable vnet netdevice and stop handling PF_PACKET sockets */
1188         packet_cleanup();
1189
1190         /* Unregister /proc handlers */
1191         proc_cleanup();
1192
1193         /* Cleanup bind table (must be after nf_unregister_hook()) */
1194         bind_cleanup();
1195 }
1196
1197 module_init(vnet_init);
1198 module_exit(vnet_exit);
1199
1200 MODULE_LICENSE("GPL");
1201 MODULE_AUTHOR("Mark Huang <mlhuang@cs.princeton.edu>");
1202 MODULE_DESCRIPTION("VServer IP isolation");