This commit was manufactured by cvs2svn to create branch 'vserver'.
[linux-2.6.git] / net / ipv4 / netfilter / ip_set_nethash.c
1 /* Copyright (C) 2003-2004 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
2  *
3  * This program is free software; you can redistribute it and/or modify
4  * it under the terms of the GNU General Public License version 2 as
5  * published by the Free Software Foundation.  
6  */
7
8 /* Kernel module implementing a cidr nethash set */
9
10 #include <linux/module.h>
11 #include <linux/ip.h>
12 #include <linux/skbuff.h>
13 #include <linux/netfilter_ipv4/ip_tables.h>
14 #include <linux/netfilter_ipv4/ip_set.h>
15 #include <linux/errno.h>
16 #include <asm/uaccess.h>
17 #include <asm/bitops.h>
18 #include <linux/spinlock.h>
19 #include <linux/vmalloc.h>
20 #include <linux/random.h>
21
22 #include <net/ip.h>
23
24 #include <linux/netfilter_ipv4/ip_set_malloc.h>
25 #include <linux/netfilter_ipv4/ip_set_nethash.h>
26 #include <linux/netfilter_ipv4/ip_set_jhash.h>
27
28 static inline __u32
29 jhash_ip(const struct ip_set_nethash *map, uint16_t i, ip_set_ip_t ip)
30 {
31         return jhash_1word(ip, *(((uint32_t *) map->initval) + i));
32 }
33
34 static inline __u32
35 hash_id_cidr(struct ip_set_nethash *map,
36              ip_set_ip_t ip,
37              unsigned char cidr,
38              ip_set_ip_t *hash_ip)
39 {
40         __u32 id;
41         u_int16_t i;
42         ip_set_ip_t *elem;
43
44         *hash_ip = pack(ip, cidr);
45         
46         for (i = 0; i < map->probes; i++) {
47                 id = jhash_ip(map, i, *hash_ip) % map->hashsize;
48                 DP("hash key: %u", id);
49                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, id);
50                 if (*elem == *hash_ip)
51                         return id;
52         }
53         return UINT_MAX;
54 }
55
56 static inline __u32
57 hash_id(struct ip_set *set, ip_set_ip_t ip, ip_set_ip_t *hash_ip)
58 {
59         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
60         __u32 id = UINT_MAX;
61         int i;
62
63         for (i = 0; i < 30 && map->cidr[i]; i++) {
64                 id = hash_id_cidr(map, ip, map->cidr[i], hash_ip);
65                 if (id != UINT_MAX)
66                         break;
67         }
68         return id;
69 }
70
71 static inline int
72 __testip_cidr(struct ip_set *set, ip_set_ip_t ip, unsigned char cidr,
73               ip_set_ip_t *hash_ip)
74 {
75         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
76
77         return (hash_id_cidr(map, ip, cidr, hash_ip) != UINT_MAX);
78 }
79
80 static inline int
81 __testip(struct ip_set *set, ip_set_ip_t ip, ip_set_ip_t *hash_ip)
82 {
83         return (hash_id(set, ip, hash_ip) != UINT_MAX);
84 }
85
86 static int
87 testip(struct ip_set *set, const void *data, size_t size,
88        ip_set_ip_t *hash_ip)
89 {
90         struct ip_set_req_nethash *req = 
91             (struct ip_set_req_nethash *) data;
92
93         if (size != sizeof(struct ip_set_req_nethash)) {
94                 ip_set_printk("data length wrong (want %zu, have %zu)",
95                               sizeof(struct ip_set_req_nethash),
96                               size);
97                 return -EINVAL;
98         }
99         return (req->cidr == 32 ? __testip(set, req->ip, hash_ip)
100                 : __testip_cidr(set, req->ip, req->cidr, hash_ip));
101 }
102
103 static int
104 testip_kernel(struct ip_set *set, 
105               const struct sk_buff *skb,
106               ip_set_ip_t *hash_ip,
107               const u_int32_t *flags,
108               unsigned char index)
109 {
110         return __testip(set,
111                         ntohl(flags[index] & IPSET_SRC 
112                                 ? skb->nh.iph->saddr 
113                                 : skb->nh.iph->daddr),
114                         hash_ip);
115 }
116
117 static inline int
118 __addip_base(struct ip_set_nethash *map, ip_set_ip_t ip)
119 {
120         __u32 probe;
121         u_int16_t i;
122         ip_set_ip_t *elem;
123         
124         for (i = 0; i < map->probes; i++) {
125                 probe = jhash_ip(map, i, ip) % map->hashsize;
126                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, probe);
127                 if (*elem == ip)
128                         return -EEXIST;
129                 if (!*elem) {
130                         *elem = ip;
131                         return 0;
132                 }
133         }
134         /* Trigger rehashing */
135         return -EAGAIN;
136 }
137
138 static inline int
139 __addip(struct ip_set_nethash *map, ip_set_ip_t ip, unsigned char cidr,
140         ip_set_ip_t *hash_ip)
141 {
142         *hash_ip = pack(ip, cidr);
143         DP("%u.%u.%u.%u/%u, %u.%u.%u.%u", HIPQUAD(ip), cidr, HIPQUAD(*hash_ip));
144         
145         return __addip_base(map, *hash_ip);
146 }
147
148 static void
149 update_cidr_sizes(struct ip_set_nethash *map, unsigned char cidr)
150 {
151         unsigned char next;
152         int i;
153         
154         for (i = 0; i < 30 && map->cidr[i]; i++) {
155                 if (map->cidr[i] == cidr) {
156                         return;
157                 } else if (map->cidr[i] < cidr) {
158                         next = map->cidr[i];
159                         map->cidr[i] = cidr;
160                         cidr = next;
161                 }
162         }
163         if (i < 30)
164                 map->cidr[i] = cidr;
165 }
166
167 static int
168 addip(struct ip_set *set, const void *data, size_t size,
169         ip_set_ip_t *hash_ip)
170 {
171         struct ip_set_req_nethash *req = 
172             (struct ip_set_req_nethash *) data;
173         int ret;
174
175         if (size != sizeof(struct ip_set_req_nethash)) {
176                 ip_set_printk("data length wrong (want %zu, have %zu)",
177                               sizeof(struct ip_set_req_nethash),
178                               size);
179                 return -EINVAL;
180         }
181         ret = __addip((struct ip_set_nethash *) set->data, 
182                       req->ip, req->cidr, hash_ip);
183         
184         if (ret == 0)
185                 update_cidr_sizes((struct ip_set_nethash *) set->data,
186                                   req->cidr);
187         
188         return ret;
189 }
190
191 static int
192 addip_kernel(struct ip_set *set, 
193              const struct sk_buff *skb,
194              ip_set_ip_t *hash_ip,
195              const u_int32_t *flags,
196              unsigned char index)
197 {
198         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
199         int ret = -ERANGE;
200         ip_set_ip_t ip = ntohl(flags[index] & IPSET_SRC 
201                                         ? skb->nh.iph->saddr
202                                         : skb->nh.iph->daddr);
203         
204         if (map->cidr[0])
205                 ret = __addip(map, ip, map->cidr[0], hash_ip);
206                 
207         return ret;
208 }
209
210 static int retry(struct ip_set *set)
211 {
212         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
213         ip_set_ip_t *elem;
214         void *members;
215         u_int32_t i, hashsize = map->hashsize;
216         int res;
217         struct ip_set_nethash *tmp;
218         
219         if (map->resize == 0)
220                 return -ERANGE;
221
222     again:
223         res = 0;
224         
225         /* Calculate new parameters */
226         hashsize += (hashsize * map->resize)/100;
227         if (hashsize == map->hashsize)
228                 hashsize++;
229         
230         ip_set_printk("rehashing of set %s triggered: "
231                       "hashsize grows from %u to %u",
232                       set->name, map->hashsize, hashsize);
233
234         tmp = kmalloc(sizeof(struct ip_set_nethash) 
235                       + map->probes * sizeof(uint32_t), GFP_ATOMIC);
236         if (!tmp) {
237                 DP("out of memory for %d bytes",
238                    sizeof(struct ip_set_nethash)
239                    + map->probes * sizeof(uint32_t));
240                 return -ENOMEM;
241         }
242         tmp->members = harray_malloc(hashsize, sizeof(ip_set_ip_t), GFP_ATOMIC);
243         if (!tmp->members) {
244                 DP("out of memory for %d bytes", hashsize * sizeof(ip_set_ip_t));
245                 kfree(tmp);
246                 return -ENOMEM;
247         }
248         tmp->hashsize = hashsize;
249         tmp->probes = map->probes;
250         tmp->resize = map->resize;
251         memcpy(tmp->initval, map->initval, map->probes * sizeof(uint32_t));
252         memcpy(tmp->cidr, map->cidr, 30 * sizeof(unsigned char));
253         
254         write_lock_bh(&set->lock);
255         map = (struct ip_set_nethash *) set->data; /* Play safe */
256         for (i = 0; i < map->hashsize && res == 0; i++) {
257                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, i);     
258                 if (*elem)
259                         res = __addip_base(tmp, *elem);
260         }
261         if (res) {
262                 /* Failure, try again */
263                 write_unlock_bh(&set->lock);
264                 harray_free(tmp->members);
265                 kfree(tmp);
266                 goto again;
267         }
268         
269         /* Success at resizing! */
270         members = map->members;
271         
272         map->hashsize = tmp->hashsize;
273         map->members = tmp->members;
274         write_unlock_bh(&set->lock);
275
276         harray_free(members);
277         kfree(tmp);
278
279         return 0;
280 }
281
282 static inline int
283 __delip(struct ip_set_nethash *map, ip_set_ip_t ip, unsigned char cidr,
284         ip_set_ip_t *hash_ip)
285 {
286         ip_set_ip_t id = hash_id_cidr(map, ip, cidr, hash_ip);
287         ip_set_ip_t *elem;
288
289         if (id == UINT_MAX)
290                 return -EEXIST;
291                 
292         elem = HARRAY_ELEM(map->members, ip_set_ip_t *, id);
293         *elem = 0;
294         return 0;
295 }
296
297 static int
298 delip(struct ip_set *set, const void *data, size_t size,
299         ip_set_ip_t *hash_ip)
300 {
301         struct ip_set_req_nethash *req =
302             (struct ip_set_req_nethash *) data;
303
304         if (size != sizeof(struct ip_set_req_nethash)) {
305                 ip_set_printk("data length wrong (want %zu, have %zu)",
306                               sizeof(struct ip_set_req_nethash),
307                               size);
308                 return -EINVAL;
309         }
310         /* TODO: no garbage collection in map->cidr */          
311         return __delip((struct ip_set_nethash *) set->data, 
312                        req->ip, req->cidr, hash_ip);
313 }
314
315 static int
316 delip_kernel(struct ip_set *set, 
317              const struct sk_buff *skb,
318              ip_set_ip_t *hash_ip,
319              const u_int32_t *flags,
320              unsigned char index)
321 {
322         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
323         int ret = -ERANGE;
324         ip_set_ip_t ip = ntohl(flags[index] & IPSET_SRC 
325                                         ? skb->nh.iph->saddr
326                                         : skb->nh.iph->daddr);
327         
328         if (map->cidr[0])
329                 ret = __delip(map, ip, map->cidr[0], hash_ip);
330         
331         return ret;
332 }
333
334 static int create(struct ip_set *set, const void *data, size_t size)
335 {
336         struct ip_set_req_nethash_create *req =
337             (struct ip_set_req_nethash_create *) data;
338         struct ip_set_nethash *map;
339         uint16_t i;
340
341         if (size != sizeof(struct ip_set_req_nethash_create)) {
342                 ip_set_printk("data length wrong (want %zu, have %zu)",
343                                sizeof(struct ip_set_req_nethash_create),
344                                size);
345                 return -EINVAL;
346         }
347
348         if (req->hashsize < 1) {
349                 ip_set_printk("hashsize too small");
350                 return -ENOEXEC;
351         }
352         if (req->probes < 1) {
353                 ip_set_printk("probes too small");
354                 return -ENOEXEC;
355         }
356
357         map = kmalloc(sizeof(struct ip_set_nethash)
358                       + req->probes * sizeof(uint32_t), GFP_KERNEL);
359         if (!map) {
360                 DP("out of memory for %d bytes",
361                    sizeof(struct ip_set_nethash)
362                    + req->probes * sizeof(uint32_t));
363                 return -ENOMEM;
364         }
365         for (i = 0; i < req->probes; i++)
366                 get_random_bytes(((uint32_t *) map->initval)+i, 4);
367         map->hashsize = req->hashsize;
368         map->probes = req->probes;
369         map->resize = req->resize;
370         memset(map->cidr, 0, 30 * sizeof(unsigned char));
371         map->members = harray_malloc(map->hashsize, sizeof(ip_set_ip_t), GFP_KERNEL);
372         if (!map->members) {
373                 DP("out of memory for %d bytes", map->hashsize * sizeof(ip_set_ip_t));
374                 kfree(map);
375                 return -ENOMEM;
376         }
377         
378         set->data = map;
379         return 0;
380 }
381
382 static void destroy(struct ip_set *set)
383 {
384         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
385
386         harray_free(map->members);
387         kfree(map);
388
389         set->data = NULL;
390 }
391
392 static void flush(struct ip_set *set)
393 {
394         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
395         harray_flush(map->members, map->hashsize, sizeof(ip_set_ip_t));
396         memset(map->cidr, 0, 30 * sizeof(unsigned char));
397 }
398
399 static void list_header(const struct ip_set *set, void *data)
400 {
401         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
402         struct ip_set_req_nethash_create *header =
403             (struct ip_set_req_nethash_create *) data;
404
405         header->hashsize = map->hashsize;
406         header->probes = map->probes;
407         header->resize = map->resize;
408 }
409
410 static int list_members_size(const struct ip_set *set)
411 {
412         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
413
414         return (map->hashsize * sizeof(ip_set_ip_t));
415 }
416
417 static void list_members(const struct ip_set *set, void *data)
418 {
419         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
420         ip_set_ip_t i, *elem;
421
422         for (i = 0; i < map->hashsize; i++) {
423                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, i);     
424                 ((ip_set_ip_t *)data)[i] = *elem;
425         }
426 }
427
428 static struct ip_set_type ip_set_nethash = {
429         .typename               = SETTYPE_NAME,
430         .features               = IPSET_TYPE_IP | IPSET_DATA_SINGLE,
431         .protocol_version       = IP_SET_PROTOCOL_VERSION,
432         .create                 = &create,
433         .destroy                = &destroy,
434         .flush                  = &flush,
435         .reqsize                = sizeof(struct ip_set_req_nethash),
436         .addip                  = &addip,
437         .addip_kernel           = &addip_kernel,
438         .retry                  = &retry,
439         .delip                  = &delip,
440         .delip_kernel           = &delip_kernel,
441         .testip                 = &testip,
442         .testip_kernel          = &testip_kernel,
443         .header_size            = sizeof(struct ip_set_req_nethash_create),
444         .list_header            = &list_header,
445         .list_members_size      = &list_members_size,
446         .list_members           = &list_members,
447         .me                     = THIS_MODULE,
448 };
449
450 MODULE_LICENSE("GPL");
451 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
452 MODULE_DESCRIPTION("nethash type of IP sets");
453
454 static int __init init(void)
455 {
456         return ip_set_register_set_type(&ip_set_nethash);
457 }
458
459 static void __exit fini(void)
460 {
461         /* FIXME: possible race with ip_set_create() */
462         ip_set_unregister_set_type(&ip_set_nethash);
463 }
464
465 module_init(init);
466 module_exit(fini);