Merge to Fedora kernel-2.6.18-1.2260_FC5 patched with stable patch-2.6.18.5-vs2.0...
[linux-2.6.git] / net / ipv4 / netfilter / ip_set_ipporthash.c
1 /* Copyright (C) 2003-2004 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
2  *
3  * This program is free software; you can redistribute it and/or modify
4  * it under the terms of the GNU General Public License version 2 as
5  * published by the Free Software Foundation.  
6  */
7
8 /* Kernel module implementing an ip+port hash set */
9
10 #include <linux/module.h>
11 #include <linux/ip.h>
12 #include <linux/tcp.h>
13 #include <linux/udp.h>
14 #include <linux/skbuff.h>
15 #include <linux/netfilter_ipv4/ip_tables.h>
16 #include <linux/netfilter_ipv4/ip_set.h>
17 #include <linux/errno.h>
18 #include <asm/uaccess.h>
19 #include <asm/bitops.h>
20 #include <linux/spinlock.h>
21 #include <linux/vmalloc.h>
22 #include <linux/random.h>
23
24 #include <net/ip.h>
25
26 #include <linux/netfilter_ipv4/ip_set_malloc.h>
27 #include <linux/netfilter_ipv4/ip_set_ipporthash.h>
28 #include <linux/netfilter_ipv4/ip_set_jhash.h>
29
30 static int limit = MAX_RANGE;
31
32 /* We must handle non-linear skbs */
33 static inline ip_set_ip_t
34 get_port(const struct sk_buff *skb, u_int32_t flags)
35 {
36         struct iphdr *iph = skb->nh.iph;
37         u_int16_t offset = ntohs(iph->frag_off) & IP_OFFSET;
38
39         switch (iph->protocol) {
40         case IPPROTO_TCP: {
41                 struct tcphdr tcph;
42                 
43                 /* See comments at tcp_match in ip_tables.c */
44                 if (offset)
45                         return INVALID_PORT;
46
47                 if (skb_copy_bits(skb, skb->nh.iph->ihl*4, &tcph, sizeof(tcph)) < 0)
48                         /* No choice either */
49                         return INVALID_PORT;
50                 
51                 return ntohs(flags & IPSET_SRC ?
52                              tcph.source : tcph.dest);
53             }
54         case IPPROTO_UDP: {
55                 struct udphdr udph;
56
57                 if (offset)
58                         return INVALID_PORT;
59
60                 if (skb_copy_bits(skb, skb->nh.iph->ihl*4, &udph, sizeof(udph)) < 0)
61                         /* No choice either */
62                         return INVALID_PORT;
63                 
64                 return ntohs(flags & IPSET_SRC ?
65                              udph.source : udph.dest);
66             }
67         default:
68                 return INVALID_PORT;
69         }
70 }
71
72 static inline __u32
73 jhash_ip(const struct ip_set_ipporthash *map, uint16_t i, ip_set_ip_t ip)
74 {
75         return jhash_1word(ip, *(((uint32_t *) map->initval) + i));
76 }
77
78 #define HASH_IP(map, ip, port) (port + ((ip - ((map)->first_ip)) << 16))
79
80 static inline __u32
81 hash_id(struct ip_set *set, ip_set_ip_t ip, ip_set_ip_t port,
82         ip_set_ip_t *hash_ip)
83 {
84         struct ip_set_ipporthash *map = 
85                 (struct ip_set_ipporthash *) set->data;
86         __u32 id;
87         u_int16_t i;
88         ip_set_ip_t *elem;
89
90         *hash_ip = HASH_IP(map, ip, port);
91         DP("set: %s, ipport:%u.%u.%u.%u:%u, %u.%u.%u.%u",
92            set->name, HIPQUAD(ip), port, HIPQUAD(*hash_ip));
93         
94         for (i = 0; i < map->probes; i++) {
95                 id = jhash_ip(map, i, *hash_ip) % map->hashsize;
96                 DP("hash key: %u", id);
97                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, id);
98                 if (*elem == *hash_ip)
99                         return id;
100                 /* No shortcut at testing - there can be deleted
101                  * entries. */
102         }
103         return UINT_MAX;
104 }
105
106 static inline int
107 __testip(struct ip_set *set, ip_set_ip_t ip, ip_set_ip_t port,
108          ip_set_ip_t *hash_ip)
109 {
110         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
111         
112         if (ip < map->first_ip || ip > map->last_ip)
113                 return -ERANGE;
114
115         return (hash_id(set, ip, port, hash_ip) != UINT_MAX);
116 }
117
118 static int
119 testip(struct ip_set *set, const void *data, size_t size,
120        ip_set_ip_t *hash_ip)
121 {
122         struct ip_set_req_ipporthash *req = 
123             (struct ip_set_req_ipporthash *) data;
124
125         if (size != sizeof(struct ip_set_req_ipporthash)) {
126                 ip_set_printk("data length wrong (want %zu, have %zu)",
127                               sizeof(struct ip_set_req_ipporthash),
128                               size);
129                 return -EINVAL;
130         }
131         return __testip(set, req->ip, req->port, hash_ip);
132 }
133
134 static int
135 testip_kernel(struct ip_set *set, 
136               const struct sk_buff *skb,
137               ip_set_ip_t *hash_ip,
138               const u_int32_t *flags,
139               unsigned char index)
140 {
141         ip_set_ip_t port;
142
143         if (flags[index+1] == 0)
144                 return -EINVAL;
145                 
146         port = get_port(skb, flags[index+1]);
147
148         DP("flag: %s src: %u.%u.%u.%u dst: %u.%u.%u.%u",
149            flags[index] & IPSET_SRC ? "SRC" : "DST",
150            NIPQUAD(skb->nh.iph->saddr),
151            NIPQUAD(skb->nh.iph->daddr));
152         DP("flag %s port %u",
153            flags[index+1] & IPSET_SRC ? "SRC" : "DST", 
154            port);       
155         if (port == INVALID_PORT)
156                 return 0;       
157
158         return __testip(set,
159                         ntohl(flags[index] & IPSET_SRC 
160                                         ? skb->nh.iph->saddr 
161                                         : skb->nh.iph->daddr),
162                         port,
163                         hash_ip);
164 }
165
166 static inline int
167 __add_haship(struct ip_set_ipporthash *map, ip_set_ip_t hash_ip)
168 {
169         __u32 probe;
170         u_int16_t i;
171         ip_set_ip_t *elem;
172
173         for (i = 0; i < map->probes; i++) {
174                 probe = jhash_ip(map, i, hash_ip) % map->hashsize;
175                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, probe);
176                 if (*elem == hash_ip)
177                         return -EEXIST;
178                 if (!*elem) {
179                         *elem = hash_ip;
180                         map->elements++;
181                         return 0;
182                 }
183         }
184         /* Trigger rehashing */
185         return -EAGAIN;
186 }
187
188 static inline int
189 __addip(struct ip_set_ipporthash *map, ip_set_ip_t ip, ip_set_ip_t port,
190         ip_set_ip_t *hash_ip)
191 {
192         if (map->elements > limit)
193                 return -ERANGE;
194         if (ip < map->first_ip || ip > map->last_ip)
195                 return -ERANGE;
196
197         *hash_ip = HASH_IP(map, ip, port);
198         
199         return __add_haship(map, *hash_ip);
200 }
201
202 static int
203 addip(struct ip_set *set, const void *data, size_t size,
204         ip_set_ip_t *hash_ip)
205 {
206         struct ip_set_req_ipporthash *req = 
207             (struct ip_set_req_ipporthash *) data;
208
209         if (size != sizeof(struct ip_set_req_ipporthash)) {
210                 ip_set_printk("data length wrong (want %zu, have %zu)",
211                               sizeof(struct ip_set_req_ipporthash),
212                               size);
213                 return -EINVAL;
214         }
215         return __addip((struct ip_set_ipporthash *) set->data, 
216                         req->ip, req->port, hash_ip);
217 }
218
219 static int
220 addip_kernel(struct ip_set *set, 
221              const struct sk_buff *skb,
222              ip_set_ip_t *hash_ip,
223              const u_int32_t *flags,
224              unsigned char index)
225 {
226         ip_set_ip_t port;
227
228         if (flags[index+1] == 0)
229                 return -EINVAL;
230                 
231         port = get_port(skb, flags[index+1]);
232
233         DP("flag: %s src: %u.%u.%u.%u dst: %u.%u.%u.%u",
234            flags[index] & IPSET_SRC ? "SRC" : "DST",
235            NIPQUAD(skb->nh.iph->saddr),
236            NIPQUAD(skb->nh.iph->daddr));
237         DP("flag %s port %u", 
238            flags[index+1] & IPSET_SRC ? "SRC" : "DST", 
239            port);       
240         if (port == INVALID_PORT)
241                 return -EINVAL; 
242
243         return __addip((struct ip_set_ipporthash *) set->data,
244                        ntohl(flags[index] & IPSET_SRC 
245                                 ? skb->nh.iph->saddr 
246                                 : skb->nh.iph->daddr),
247                        port,
248                        hash_ip);
249 }
250
251 static int retry(struct ip_set *set)
252 {
253         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
254         ip_set_ip_t *elem;
255         void *members;
256         u_int32_t i, hashsize = map->hashsize;
257         int res;
258         struct ip_set_ipporthash *tmp;
259         
260         if (map->resize == 0)
261                 return -ERANGE;
262
263     again:
264         res = 0;
265         
266         /* Calculate new hash size */
267         hashsize += (hashsize * map->resize)/100;
268         if (hashsize == map->hashsize)
269                 hashsize++;
270         
271         ip_set_printk("rehashing of set %s triggered: "
272                       "hashsize grows from %u to %u",
273                       set->name, map->hashsize, hashsize);
274
275         tmp = kmalloc(sizeof(struct ip_set_ipporthash) 
276                       + map->probes * sizeof(uint32_t), GFP_ATOMIC);
277         if (!tmp) {
278                 DP("out of memory for %d bytes",
279                    sizeof(struct ip_set_ipporthash)
280                    + map->probes * sizeof(uint32_t));
281                 return -ENOMEM;
282         }
283         tmp->members = harray_malloc(hashsize, sizeof(ip_set_ip_t), GFP_ATOMIC);
284         if (!tmp->members) {
285                 DP("out of memory for %d bytes", hashsize * sizeof(ip_set_ip_t));
286                 kfree(tmp);
287                 return -ENOMEM;
288         }
289         tmp->hashsize = hashsize;
290         tmp->elements = 0;
291         tmp->probes = map->probes;
292         tmp->resize = map->resize;
293         tmp->first_ip = map->first_ip;
294         tmp->last_ip = map->last_ip;
295         memcpy(tmp->initval, map->initval, map->probes * sizeof(uint32_t));
296         
297         write_lock_bh(&set->lock);
298         map = (struct ip_set_ipporthash *) set->data; /* Play safe */
299         for (i = 0; i < map->hashsize && res == 0; i++) {
300                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, i);     
301                 if (*elem)
302                         res = __add_haship(tmp, *elem);
303         }
304         if (res) {
305                 /* Failure, try again */
306                 write_unlock_bh(&set->lock);
307                 harray_free(tmp->members);
308                 kfree(tmp);
309                 goto again;
310         }
311         
312         /* Success at resizing! */
313         members = map->members;
314
315         map->hashsize = tmp->hashsize;
316         map->members = tmp->members;
317         write_unlock_bh(&set->lock);
318
319         harray_free(members);
320         kfree(tmp);
321
322         return 0;
323 }
324
325 static inline int
326 __delip(struct ip_set *set, ip_set_ip_t ip, ip_set_ip_t port,
327         ip_set_ip_t *hash_ip)
328 {
329         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
330         ip_set_ip_t id;
331         ip_set_ip_t *elem;
332
333         if (ip < map->first_ip || ip > map->last_ip)
334                 return -ERANGE;
335
336         id = hash_id(set, ip, port, hash_ip);
337
338         if (id == UINT_MAX)
339                 return -EEXIST;
340                 
341         elem = HARRAY_ELEM(map->members, ip_set_ip_t *, id);
342         *elem = 0;
343         map->elements--;
344
345         return 0;
346 }
347
348 static int
349 delip(struct ip_set *set, const void *data, size_t size,
350         ip_set_ip_t *hash_ip)
351 {
352         struct ip_set_req_ipporthash *req =
353             (struct ip_set_req_ipporthash *) data;
354
355         if (size != sizeof(struct ip_set_req_ipporthash)) {
356                 ip_set_printk("data length wrong (want %zu, have %zu)",
357                               sizeof(struct ip_set_req_ipporthash),
358                               size);
359                 return -EINVAL;
360         }
361         return __delip(set, req->ip, req->port, hash_ip);
362 }
363
364 static int
365 delip_kernel(struct ip_set *set, 
366              const struct sk_buff *skb,
367              ip_set_ip_t *hash_ip,
368              const u_int32_t *flags,
369              unsigned char index)
370 {
371         ip_set_ip_t port;
372
373         if (flags[index+1] == 0)
374                 return -EINVAL;
375                 
376         port = get_port(skb, flags[index+1]);
377
378         DP("flag: %s src: %u.%u.%u.%u dst: %u.%u.%u.%u",
379            flags[index] & IPSET_SRC ? "SRC" : "DST",
380            NIPQUAD(skb->nh.iph->saddr),
381            NIPQUAD(skb->nh.iph->daddr));
382         DP("flag %s port %u",
383            flags[index+1] & IPSET_SRC ? "SRC" : "DST", 
384            port);       
385         if (port == INVALID_PORT)
386                 return -EINVAL; 
387
388         return __delip(set,
389                        ntohl(flags[index] & IPSET_SRC 
390                                 ? skb->nh.iph->saddr 
391                                 : skb->nh.iph->daddr),
392                        port,
393                        hash_ip);
394 }
395
396 static int create(struct ip_set *set, const void *data, size_t size)
397 {
398         struct ip_set_req_ipporthash_create *req =
399             (struct ip_set_req_ipporthash_create *) data;
400         struct ip_set_ipporthash *map;
401         uint16_t i;
402
403         if (size != sizeof(struct ip_set_req_ipporthash_create)) {
404                 ip_set_printk("data length wrong (want %zu, have %zu)",
405                                sizeof(struct ip_set_req_ipporthash_create),
406                                size);
407                 return -EINVAL;
408         }
409
410         if (req->hashsize < 1) {
411                 ip_set_printk("hashsize too small");
412                 return -ENOEXEC;
413         }
414
415         if (req->probes < 1) {
416                 ip_set_printk("probes too small");
417                 return -ENOEXEC;
418         }
419
420         map = kmalloc(sizeof(struct ip_set_ipporthash) 
421                       + req->probes * sizeof(uint32_t), GFP_KERNEL);
422         if (!map) {
423                 DP("out of memory for %d bytes",
424                    sizeof(struct ip_set_ipporthash)
425                    + req->probes * sizeof(uint32_t));
426                 return -ENOMEM;
427         }
428         for (i = 0; i < req->probes; i++)
429                 get_random_bytes(((uint32_t *) map->initval)+i, 4);
430         map->elements = 0;
431         map->hashsize = req->hashsize;
432         map->probes = req->probes;
433         map->resize = req->resize;
434         map->first_ip = req->from;
435         map->last_ip = req->to;
436         map->members = harray_malloc(map->hashsize, sizeof(ip_set_ip_t), GFP_KERNEL);
437         if (!map->members) {
438                 DP("out of memory for %d bytes", map->hashsize * sizeof(ip_set_ip_t));
439                 kfree(map);
440                 return -ENOMEM;
441         }
442
443         set->data = map;
444         return 0;
445 }
446
447 static void destroy(struct ip_set *set)
448 {
449         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
450
451         harray_free(map->members);
452         kfree(map);
453
454         set->data = NULL;
455 }
456
457 static void flush(struct ip_set *set)
458 {
459         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
460         harray_flush(map->members, map->hashsize, sizeof(ip_set_ip_t));
461         map->elements = 0;
462 }
463
464 static void list_header(const struct ip_set *set, void *data)
465 {
466         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
467         struct ip_set_req_ipporthash_create *header =
468             (struct ip_set_req_ipporthash_create *) data;
469
470         header->hashsize = map->hashsize;
471         header->probes = map->probes;
472         header->resize = map->resize;
473         header->from = map->first_ip;
474         header->to = map->last_ip;
475 }
476
477 static int list_members_size(const struct ip_set *set)
478 {
479         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
480
481         return (map->hashsize * sizeof(ip_set_ip_t));
482 }
483
484 static void list_members(const struct ip_set *set, void *data)
485 {
486         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
487         ip_set_ip_t i, *elem;
488
489         for (i = 0; i < map->hashsize; i++) {
490                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, i);     
491                 ((ip_set_ip_t *)data)[i] = *elem;
492         }
493 }
494
495 static struct ip_set_type ip_set_ipporthash = {
496         .typename               = SETTYPE_NAME,
497         .features               = IPSET_TYPE_IP | IPSET_TYPE_PORT | IPSET_DATA_DOUBLE,
498         .protocol_version       = IP_SET_PROTOCOL_VERSION,
499         .create                 = &create,
500         .destroy                = &destroy,
501         .flush                  = &flush,
502         .reqsize                = sizeof(struct ip_set_req_ipporthash),
503         .addip                  = &addip,
504         .addip_kernel           = &addip_kernel,
505         .retry                  = &retry,
506         .delip                  = &delip,
507         .delip_kernel           = &delip_kernel,
508         .testip                 = &testip,
509         .testip_kernel          = &testip_kernel,
510         .header_size            = sizeof(struct ip_set_req_ipporthash_create),
511         .list_header            = &list_header,
512         .list_members_size      = &list_members_size,
513         .list_members           = &list_members,
514         .me                     = THIS_MODULE,
515 };
516
517 MODULE_LICENSE("GPL");
518 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
519 MODULE_DESCRIPTION("ipporthash type of IP sets");
520 module_param(limit, int, 0600);
521 MODULE_PARM_DESC(limit, "maximal number of elements stored in the sets");
522
523 static int __init init(void)
524 {
525         return ip_set_register_set_type(&ip_set_ipporthash);
526 }
527
528 static void __exit fini(void)
529 {
530         /* FIXME: possible race with ip_set_create() */
531         ip_set_unregister_set_type(&ip_set_ipporthash);
532 }
533
534 module_init(init);
535 module_exit(fini);