This commit was manufactured by cvs2svn to create branch 'vserver'.
[linux-2.6.git] / net / ipv4 / netfilter / ip_set_ipporthash.c
1 /* Copyright (C) 2003-2004 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
2  *
3  * This program is free software; you can redistribute it and/or modify
4  * it under the terms of the GNU General Public License version 2 as
5  * published by the Free Software Foundation.  
6  */
7
8 /* Kernel module implementing an ip+port hash set */
9
10 #include <linux/module.h>
11 #include <linux/ip.h>
12 #include <linux/tcp.h>
13 #include <linux/udp.h>
14 #include <linux/skbuff.h>
15 #include <linux/netfilter_ipv4/ip_tables.h>
16 #include <linux/netfilter_ipv4/ip_set.h>
17 #include <linux/errno.h>
18 #include <asm/uaccess.h>
19 #include <asm/bitops.h>
20 #include <linux/spinlock.h>
21 #include <linux/vmalloc.h>
22 #include <linux/random.h>
23
24 #include <net/ip.h>
25
26 #include <linux/netfilter_ipv4/ip_set_malloc.h>
27 #include <linux/netfilter_ipv4/ip_set_ipporthash.h>
28 #include <linux/netfilter_ipv4/ip_set_jhash.h>
29
30 /* We must handle non-linear skbs */
31 static inline ip_set_ip_t
32 get_port(const struct sk_buff *skb, u_int32_t flags)
33 {
34         struct iphdr *iph = skb->nh.iph;
35         u_int16_t offset = ntohs(iph->frag_off) & IP_OFFSET;
36
37         switch (iph->protocol) {
38         case IPPROTO_TCP: {
39                 struct tcphdr tcph;
40                 
41                 /* See comments at tcp_match in ip_tables.c */
42                 if (offset)
43                         return INVALID_PORT;
44
45                 if (skb_copy_bits(skb, skb->nh.iph->ihl*4, &tcph, sizeof(tcph)) < 0)
46                         /* No choice either */
47                         return INVALID_PORT;
48                 
49                 return ntohs(flags & IPSET_SRC ?
50                              tcph.source : tcph.dest);
51             }
52         case IPPROTO_UDP: {
53                 struct udphdr udph;
54
55                 if (offset)
56                         return INVALID_PORT;
57
58                 if (skb_copy_bits(skb, skb->nh.iph->ihl*4, &udph, sizeof(udph)) < 0)
59                         /* No choice either */
60                         return INVALID_PORT;
61                 
62                 return ntohs(flags & IPSET_SRC ?
63                              udph.source : udph.dest);
64             }
65         default:
66                 return INVALID_PORT;
67         }
68 }
69
70 static inline __u32
71 jhash_ip(const struct ip_set_ipporthash *map, uint16_t i, ip_set_ip_t ip)
72 {
73         return jhash_1word(ip, *(((uint32_t *) map->initval) + i));
74 }
75
76 #define HASH_IP(map, ip, port) (port + ((ip - ((map)->first_ip)) << 16))
77
78 static inline __u32
79 hash_id(struct ip_set *set, ip_set_ip_t ip, ip_set_ip_t port,
80         ip_set_ip_t *hash_ip)
81 {
82         struct ip_set_ipporthash *map = 
83                 (struct ip_set_ipporthash *) set->data;
84         __u32 id;
85         u_int16_t i;
86         ip_set_ip_t *elem;
87
88         *hash_ip = HASH_IP(map, ip, port);
89         DP("set: %s, ipport:%u.%u.%u.%u:%u, %u.%u.%u.%u",
90            set->name, HIPQUAD(ip), port, HIPQUAD(*hash_ip));
91         
92         for (i = 0; i < map->probes; i++) {
93                 id = jhash_ip(map, i, *hash_ip) % map->hashsize;
94                 DP("hash key: %u", id);
95                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, id);
96                 if (*elem == *hash_ip)
97                         return id;
98                 /* No shortcut at testing - there can be deleted
99                  * entries. */
100         }
101         return UINT_MAX;
102 }
103
104 static inline int
105 __testip(struct ip_set *set, ip_set_ip_t ip, ip_set_ip_t port,
106          ip_set_ip_t *hash_ip)
107 {
108         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
109         
110         if (ip < map->first_ip || ip > map->last_ip)
111                 return -ERANGE;
112
113         return (hash_id(set, ip, port, hash_ip) != UINT_MAX);
114 }
115
116 static int
117 testip(struct ip_set *set, const void *data, size_t size,
118        ip_set_ip_t *hash_ip)
119 {
120         struct ip_set_req_ipporthash *req = 
121             (struct ip_set_req_ipporthash *) data;
122
123         if (size != sizeof(struct ip_set_req_ipporthash)) {
124                 ip_set_printk("data length wrong (want %zu, have %zu)",
125                               sizeof(struct ip_set_req_ipporthash),
126                               size);
127                 return -EINVAL;
128         }
129         return __testip(set, req->ip, req->port, hash_ip);
130 }
131
132 static int
133 testip_kernel(struct ip_set *set, 
134               const struct sk_buff *skb,
135               ip_set_ip_t *hash_ip,
136               const u_int32_t *flags,
137               unsigned char index)
138 {
139         ip_set_ip_t port;
140
141         if (flags[index+1] == 0)
142                 return -EINVAL;
143                 
144         port = get_port(skb, flags[index+1]);
145
146         DP("flag: %s src: %u.%u.%u.%u dst: %u.%u.%u.%u",
147            flags[index] & IPSET_SRC ? "SRC" : "DST",
148            NIPQUAD(skb->nh.iph->saddr),
149            NIPQUAD(skb->nh.iph->daddr));
150         DP("flag %s port %u",
151            flags[index+1] & IPSET_SRC ? "SRC" : "DST", 
152            port);       
153         if (port == INVALID_PORT)
154                 return 0;       
155
156         return __testip(set,
157                         ntohl(flags[index] & IPSET_SRC 
158                                         ? skb->nh.iph->saddr 
159                                         : skb->nh.iph->daddr),
160                         port,
161                         hash_ip);
162 }
163
164 static inline int
165 __add_haship(struct ip_set_ipporthash *map, ip_set_ip_t hash_ip)
166 {
167         __u32 probe;
168         u_int16_t i;
169         ip_set_ip_t *elem;
170
171         for (i = 0; i < map->probes; i++) {
172                 probe = jhash_ip(map, i, hash_ip) % map->hashsize;
173                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, probe);
174                 if (*elem == hash_ip)
175                         return -EEXIST;
176                 if (!*elem) {
177                         *elem = hash_ip;
178                         return 0;
179                 }
180         }
181         /* Trigger rehashing */
182         return -EAGAIN;
183 }
184
185 static inline int
186 __addip(struct ip_set_ipporthash *map, ip_set_ip_t ip, ip_set_ip_t port,
187         ip_set_ip_t *hash_ip)
188 {
189         if (ip < map->first_ip || ip > map->last_ip)
190                 return -ERANGE;
191
192         *hash_ip = HASH_IP(map, ip, port);
193         
194         return __add_haship(map, *hash_ip);
195 }
196
197 static int
198 addip(struct ip_set *set, const void *data, size_t size,
199         ip_set_ip_t *hash_ip)
200 {
201         struct ip_set_req_ipporthash *req = 
202             (struct ip_set_req_ipporthash *) data;
203
204         if (size != sizeof(struct ip_set_req_ipporthash)) {
205                 ip_set_printk("data length wrong (want %zu, have %zu)",
206                               sizeof(struct ip_set_req_ipporthash),
207                               size);
208                 return -EINVAL;
209         }
210         return __addip((struct ip_set_ipporthash *) set->data, 
211                         req->ip, req->port, hash_ip);
212 }
213
214 static int
215 addip_kernel(struct ip_set *set, 
216              const struct sk_buff *skb,
217              ip_set_ip_t *hash_ip,
218              const u_int32_t *flags,
219              unsigned char index)
220 {
221         ip_set_ip_t port;
222
223         if (flags[index+1] == 0)
224                 return -EINVAL;
225                 
226         port = get_port(skb, flags[index+1]);
227
228         DP("flag: %s src: %u.%u.%u.%u dst: %u.%u.%u.%u",
229            flags[index] & IPSET_SRC ? "SRC" : "DST",
230            NIPQUAD(skb->nh.iph->saddr),
231            NIPQUAD(skb->nh.iph->daddr));
232         DP("flag %s port %u", 
233            flags[index+1] & IPSET_SRC ? "SRC" : "DST", 
234            port);       
235         if (port == INVALID_PORT)
236                 return -EINVAL; 
237
238         return __addip((struct ip_set_ipporthash *) set->data,
239                        ntohl(flags[index] & IPSET_SRC 
240                                 ? skb->nh.iph->saddr 
241                                 : skb->nh.iph->daddr),
242                        port,
243                        hash_ip);
244 }
245
246 static int retry(struct ip_set *set)
247 {
248         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
249         ip_set_ip_t *elem;
250         void *members;
251         u_int32_t i, hashsize = map->hashsize;
252         int res;
253         struct ip_set_ipporthash *tmp;
254         
255         if (map->resize == 0)
256                 return -ERANGE;
257
258     again:
259         res = 0;
260         
261         /* Calculate new hash size */
262         hashsize += (hashsize * map->resize)/100;
263         if (hashsize == map->hashsize)
264                 hashsize++;
265         
266         ip_set_printk("rehashing of set %s triggered: "
267                       "hashsize grows from %u to %u",
268                       set->name, map->hashsize, hashsize);
269
270         tmp = kmalloc(sizeof(struct ip_set_ipporthash) 
271                       + map->probes * sizeof(uint32_t), GFP_ATOMIC);
272         if (!tmp) {
273                 DP("out of memory for %d bytes",
274                    sizeof(struct ip_set_ipporthash)
275                    + map->probes * sizeof(uint32_t));
276                 return -ENOMEM;
277         }
278         tmp->members = harray_malloc(hashsize, sizeof(ip_set_ip_t), GFP_ATOMIC);
279         if (!tmp->members) {
280                 DP("out of memory for %d bytes", hashsize * sizeof(ip_set_ip_t));
281                 kfree(tmp);
282                 return -ENOMEM;
283         }
284         tmp->hashsize = hashsize;
285         tmp->probes = map->probes;
286         tmp->resize = map->resize;
287         tmp->first_ip = map->first_ip;
288         tmp->last_ip = map->last_ip;
289         memcpy(tmp->initval, map->initval, map->probes * sizeof(uint32_t));
290         
291         write_lock_bh(&set->lock);
292         map = (struct ip_set_ipporthash *) set->data; /* Play safe */
293         for (i = 0; i < map->hashsize && res == 0; i++) {
294                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, i);     
295                 if (*elem)
296                         res = __add_haship(tmp, *elem);
297         }
298         if (res) {
299                 /* Failure, try again */
300                 write_unlock_bh(&set->lock);
301                 harray_free(tmp->members);
302                 kfree(tmp);
303                 goto again;
304         }
305         
306         /* Success at resizing! */
307         members = map->members;
308
309         map->hashsize = tmp->hashsize;
310         map->members = tmp->members;
311         write_unlock_bh(&set->lock);
312
313         harray_free(members);
314         kfree(tmp);
315
316         return 0;
317 }
318
319 static inline int
320 __delip(struct ip_set *set, ip_set_ip_t ip, ip_set_ip_t port,
321         ip_set_ip_t *hash_ip)
322 {
323         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
324         ip_set_ip_t id;
325         ip_set_ip_t *elem;
326
327         if (ip < map->first_ip || ip > map->last_ip)
328                 return -ERANGE;
329
330         id = hash_id(set, ip, port, hash_ip);
331
332         if (id == UINT_MAX)
333                 return -EEXIST;
334                 
335         elem = HARRAY_ELEM(map->members, ip_set_ip_t *, id);
336         *elem = 0;
337
338         return 0;
339 }
340
341 static int
342 delip(struct ip_set *set, const void *data, size_t size,
343         ip_set_ip_t *hash_ip)
344 {
345         struct ip_set_req_ipporthash *req =
346             (struct ip_set_req_ipporthash *) data;
347
348         if (size != sizeof(struct ip_set_req_ipporthash)) {
349                 ip_set_printk("data length wrong (want %zu, have %zu)",
350                               sizeof(struct ip_set_req_ipporthash),
351                               size);
352                 return -EINVAL;
353         }
354         return __delip(set, req->ip, req->port, hash_ip);
355 }
356
357 static int
358 delip_kernel(struct ip_set *set, 
359              const struct sk_buff *skb,
360              ip_set_ip_t *hash_ip,
361              const u_int32_t *flags,
362              unsigned char index)
363 {
364         ip_set_ip_t port;
365
366         if (flags[index+1] == 0)
367                 return -EINVAL;
368                 
369         port = get_port(skb, flags[index+1]);
370
371         DP("flag: %s src: %u.%u.%u.%u dst: %u.%u.%u.%u",
372            flags[index] & IPSET_SRC ? "SRC" : "DST",
373            NIPQUAD(skb->nh.iph->saddr),
374            NIPQUAD(skb->nh.iph->daddr));
375         DP("flag %s port %u",
376            flags[index+1] & IPSET_SRC ? "SRC" : "DST", 
377            port);       
378         if (port == INVALID_PORT)
379                 return -EINVAL; 
380
381         return __delip(set,
382                        ntohl(flags[index] & IPSET_SRC 
383                                 ? skb->nh.iph->saddr 
384                                 : skb->nh.iph->daddr),
385                        port,
386                        hash_ip);
387 }
388
389 static int create(struct ip_set *set, const void *data, size_t size)
390 {
391         struct ip_set_req_ipporthash_create *req =
392             (struct ip_set_req_ipporthash_create *) data;
393         struct ip_set_ipporthash *map;
394         uint16_t i;
395
396         if (size != sizeof(struct ip_set_req_ipporthash_create)) {
397                 ip_set_printk("data length wrong (want %zu, have %zu)",
398                                sizeof(struct ip_set_req_ipporthash_create),
399                                size);
400                 return -EINVAL;
401         }
402
403         if (req->hashsize < 1) {
404                 ip_set_printk("hashsize too small");
405                 return -ENOEXEC;
406         }
407
408         if (req->probes < 1) {
409                 ip_set_printk("probes too small");
410                 return -ENOEXEC;
411         }
412
413         map = kmalloc(sizeof(struct ip_set_ipporthash) 
414                       + req->probes * sizeof(uint32_t), GFP_KERNEL);
415         if (!map) {
416                 DP("out of memory for %d bytes",
417                    sizeof(struct ip_set_ipporthash)
418                    + req->probes * sizeof(uint32_t));
419                 return -ENOMEM;
420         }
421         for (i = 0; i < req->probes; i++)
422                 get_random_bytes(((uint32_t *) map->initval)+i, 4);
423         map->hashsize = req->hashsize;
424         map->probes = req->probes;
425         map->resize = req->resize;
426         map->first_ip = req->from;
427         map->last_ip = req->to;
428         map->members = harray_malloc(map->hashsize, sizeof(ip_set_ip_t), GFP_KERNEL);
429         if (!map->members) {
430                 DP("out of memory for %d bytes", map->hashsize * sizeof(ip_set_ip_t));
431                 kfree(map);
432                 return -ENOMEM;
433         }
434
435         set->data = map;
436         return 0;
437 }
438
439 static void destroy(struct ip_set *set)
440 {
441         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
442
443         harray_free(map->members);
444         kfree(map);
445
446         set->data = NULL;
447 }
448
449 static void flush(struct ip_set *set)
450 {
451         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
452         harray_flush(map->members, map->hashsize, sizeof(ip_set_ip_t));
453 }
454
455 static void list_header(const struct ip_set *set, void *data)
456 {
457         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
458         struct ip_set_req_ipporthash_create *header =
459             (struct ip_set_req_ipporthash_create *) data;
460
461         header->hashsize = map->hashsize;
462         header->probes = map->probes;
463         header->resize = map->resize;
464         header->from = map->first_ip;
465         header->to = map->last_ip;
466 }
467
468 static int list_members_size(const struct ip_set *set)
469 {
470         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
471
472         return (map->hashsize * sizeof(ip_set_ip_t));
473 }
474
475 static void list_members(const struct ip_set *set, void *data)
476 {
477         struct ip_set_ipporthash *map = (struct ip_set_ipporthash *) set->data;
478         ip_set_ip_t i, *elem;
479
480         for (i = 0; i < map->hashsize; i++) {
481                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, i);     
482                 ((ip_set_ip_t *)data)[i] = *elem;
483         }
484 }
485
486 static struct ip_set_type ip_set_ipporthash = {
487         .typename               = SETTYPE_NAME,
488         .features               = IPSET_TYPE_IP | IPSET_TYPE_PORT | IPSET_DATA_DOUBLE,
489         .protocol_version       = IP_SET_PROTOCOL_VERSION,
490         .create                 = &create,
491         .destroy                = &destroy,
492         .flush                  = &flush,
493         .reqsize                = sizeof(struct ip_set_req_ipporthash),
494         .addip                  = &addip,
495         .addip_kernel           = &addip_kernel,
496         .retry                  = &retry,
497         .delip                  = &delip,
498         .delip_kernel           = &delip_kernel,
499         .testip                 = &testip,
500         .testip_kernel          = &testip_kernel,
501         .header_size            = sizeof(struct ip_set_req_ipporthash_create),
502         .list_header            = &list_header,
503         .list_members_size      = &list_members_size,
504         .list_members           = &list_members,
505         .me                     = THIS_MODULE,
506 };
507
508 MODULE_LICENSE("GPL");
509 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
510 MODULE_DESCRIPTION("ipporthash type of IP sets");
511
512 static int __init init(void)
513 {
514         return ip_set_register_set_type(&ip_set_ipporthash);
515 }
516
517 static void __exit fini(void)
518 {
519         /* FIXME: possible race with ip_set_create() */
520         ip_set_unregister_set_type(&ip_set_ipporthash);
521 }
522
523 module_init(init);
524 module_exit(fini);