Merge to Fedora kernel-2.6.18-1.2260_FC5 patched with stable patch-2.6.18.5-vs2.0...
[linux-2.6.git] / net / ipv4 / netfilter / ip_set_nethash.c
1 /* Copyright (C) 2003-2004 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
2  *
3  * This program is free software; you can redistribute it and/or modify
4  * it under the terms of the GNU General Public License version 2 as
5  * published by the Free Software Foundation.  
6  */
7
8 /* Kernel module implementing a cidr nethash set */
9
10 #include <linux/module.h>
11 #include <linux/ip.h>
12 #include <linux/skbuff.h>
13 #include <linux/netfilter_ipv4/ip_tables.h>
14 #include <linux/netfilter_ipv4/ip_set.h>
15 #include <linux/errno.h>
16 #include <asm/uaccess.h>
17 #include <asm/bitops.h>
18 #include <linux/spinlock.h>
19 #include <linux/vmalloc.h>
20 #include <linux/random.h>
21
22 #include <net/ip.h>
23
24 #include <linux/netfilter_ipv4/ip_set_malloc.h>
25 #include <linux/netfilter_ipv4/ip_set_nethash.h>
26 #include <linux/netfilter_ipv4/ip_set_jhash.h>
27
28 static int limit = MAX_RANGE;
29
30 static inline __u32
31 jhash_ip(const struct ip_set_nethash *map, uint16_t i, ip_set_ip_t ip)
32 {
33         return jhash_1word(ip, *(((uint32_t *) map->initval) + i));
34 }
35
36 static inline __u32
37 hash_id_cidr(struct ip_set_nethash *map,
38              ip_set_ip_t ip,
39              unsigned char cidr,
40              ip_set_ip_t *hash_ip)
41 {
42         __u32 id;
43         u_int16_t i;
44         ip_set_ip_t *elem;
45
46         *hash_ip = pack(ip, cidr);
47         
48         for (i = 0; i < map->probes; i++) {
49                 id = jhash_ip(map, i, *hash_ip) % map->hashsize;
50                 DP("hash key: %u", id);
51                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, id);
52                 if (*elem == *hash_ip)
53                         return id;
54         }
55         return UINT_MAX;
56 }
57
58 static inline __u32
59 hash_id(struct ip_set *set, ip_set_ip_t ip, ip_set_ip_t *hash_ip)
60 {
61         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
62         __u32 id = UINT_MAX;
63         int i;
64
65         for (i = 0; i < 30 && map->cidr[i]; i++) {
66                 id = hash_id_cidr(map, ip, map->cidr[i], hash_ip);
67                 if (id != UINT_MAX)
68                         break;
69         }
70         return id;
71 }
72
73 static inline int
74 __testip_cidr(struct ip_set *set, ip_set_ip_t ip, unsigned char cidr,
75               ip_set_ip_t *hash_ip)
76 {
77         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
78
79         return (ip && hash_id_cidr(map, ip, cidr, hash_ip) != UINT_MAX);
80 }
81
82 static inline int
83 __testip(struct ip_set *set, ip_set_ip_t ip, ip_set_ip_t *hash_ip)
84 {
85         return (ip && hash_id(set, ip, hash_ip) != UINT_MAX);
86 }
87
88 static int
89 testip(struct ip_set *set, const void *data, size_t size,
90        ip_set_ip_t *hash_ip)
91 {
92         struct ip_set_req_nethash *req = 
93             (struct ip_set_req_nethash *) data;
94
95         if (size != sizeof(struct ip_set_req_nethash)) {
96                 ip_set_printk("data length wrong (want %zu, have %zu)",
97                               sizeof(struct ip_set_req_nethash),
98                               size);
99                 return -EINVAL;
100         }
101         return (req->cidr == 32 ? __testip(set, req->ip, hash_ip)
102                 : __testip_cidr(set, req->ip, req->cidr, hash_ip));
103 }
104
105 static int
106 testip_kernel(struct ip_set *set, 
107               const struct sk_buff *skb,
108               ip_set_ip_t *hash_ip,
109               const u_int32_t *flags,
110               unsigned char index)
111 {
112         return __testip(set,
113                         ntohl(flags[index] & IPSET_SRC 
114                                 ? skb->nh.iph->saddr 
115                                 : skb->nh.iph->daddr),
116                         hash_ip);
117 }
118
119 static inline int
120 __addip_base(struct ip_set_nethash *map, ip_set_ip_t ip)
121 {
122         __u32 probe;
123         u_int16_t i;
124         ip_set_ip_t *elem;
125         
126         for (i = 0; i < map->probes; i++) {
127                 probe = jhash_ip(map, i, ip) % map->hashsize;
128                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, probe);
129                 if (*elem == ip)
130                         return -EEXIST;
131                 if (!*elem) {
132                         *elem = ip;
133                         map->elements++;
134                         return 0;
135                 }
136         }
137         /* Trigger rehashing */
138         return -EAGAIN;
139 }
140
141 static inline int
142 __addip(struct ip_set_nethash *map, ip_set_ip_t ip, unsigned char cidr,
143         ip_set_ip_t *hash_ip)
144 {
145         if (!ip || map->elements > limit)
146                 return -ERANGE;
147         
148         *hash_ip = pack(ip, cidr);
149         DP("%u.%u.%u.%u/%u, %u.%u.%u.%u", HIPQUAD(ip), cidr, HIPQUAD(*hash_ip));
150         
151         return __addip_base(map, *hash_ip);
152 }
153
154 static void
155 update_cidr_sizes(struct ip_set_nethash *map, unsigned char cidr)
156 {
157         unsigned char next;
158         int i;
159         
160         for (i = 0; i < 30 && map->cidr[i]; i++) {
161                 if (map->cidr[i] == cidr) {
162                         return;
163                 } else if (map->cidr[i] < cidr) {
164                         next = map->cidr[i];
165                         map->cidr[i] = cidr;
166                         cidr = next;
167                 }
168         }
169         if (i < 30)
170                 map->cidr[i] = cidr;
171 }
172
173 static int
174 addip(struct ip_set *set, const void *data, size_t size,
175         ip_set_ip_t *hash_ip)
176 {
177         struct ip_set_req_nethash *req = 
178             (struct ip_set_req_nethash *) data;
179         int ret;
180
181         if (size != sizeof(struct ip_set_req_nethash)) {
182                 ip_set_printk("data length wrong (want %zu, have %zu)",
183                               sizeof(struct ip_set_req_nethash),
184                               size);
185                 return -EINVAL;
186         }
187         ret = __addip((struct ip_set_nethash *) set->data, 
188                       req->ip, req->cidr, hash_ip);
189         
190         if (ret == 0)
191                 update_cidr_sizes((struct ip_set_nethash *) set->data,
192                                   req->cidr);
193         
194         return ret;
195 }
196
197 static int
198 addip_kernel(struct ip_set *set, 
199              const struct sk_buff *skb,
200              ip_set_ip_t *hash_ip,
201              const u_int32_t *flags,
202              unsigned char index)
203 {
204         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
205         int ret = -ERANGE;
206         ip_set_ip_t ip = ntohl(flags[index] & IPSET_SRC 
207                                         ? skb->nh.iph->saddr
208                                         : skb->nh.iph->daddr);
209         
210         if (map->cidr[0])
211                 ret = __addip(map, ip, map->cidr[0], hash_ip);
212                 
213         return ret;
214 }
215
216 static int retry(struct ip_set *set)
217 {
218         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
219         ip_set_ip_t *elem;
220         void *members;
221         u_int32_t i, hashsize = map->hashsize;
222         int res;
223         struct ip_set_nethash *tmp;
224         
225         if (map->resize == 0)
226                 return -ERANGE;
227
228     again:
229         res = 0;
230         
231         /* Calculate new parameters */
232         hashsize += (hashsize * map->resize)/100;
233         if (hashsize == map->hashsize)
234                 hashsize++;
235         
236         ip_set_printk("rehashing of set %s triggered: "
237                       "hashsize grows from %u to %u",
238                       set->name, map->hashsize, hashsize);
239
240         tmp = kmalloc(sizeof(struct ip_set_nethash) 
241                       + map->probes * sizeof(uint32_t), GFP_ATOMIC);
242         if (!tmp) {
243                 DP("out of memory for %d bytes",
244                    sizeof(struct ip_set_nethash)
245                    + map->probes * sizeof(uint32_t));
246                 return -ENOMEM;
247         }
248         tmp->members = harray_malloc(hashsize, sizeof(ip_set_ip_t), GFP_ATOMIC);
249         if (!tmp->members) {
250                 DP("out of memory for %d bytes", hashsize * sizeof(ip_set_ip_t));
251                 kfree(tmp);
252                 return -ENOMEM;
253         }
254         tmp->hashsize = hashsize;
255         tmp->elements = 0;
256         tmp->probes = map->probes;
257         tmp->resize = map->resize;
258         memcpy(tmp->initval, map->initval, map->probes * sizeof(uint32_t));
259         memcpy(tmp->cidr, map->cidr, 30 * sizeof(unsigned char));
260         
261         write_lock_bh(&set->lock);
262         map = (struct ip_set_nethash *) set->data; /* Play safe */
263         for (i = 0; i < map->hashsize && res == 0; i++) {
264                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, i);     
265                 if (*elem)
266                         res = __addip_base(tmp, *elem);
267         }
268         if (res) {
269                 /* Failure, try again */
270                 write_unlock_bh(&set->lock);
271                 harray_free(tmp->members);
272                 kfree(tmp);
273                 goto again;
274         }
275         
276         /* Success at resizing! */
277         members = map->members;
278         
279         map->hashsize = tmp->hashsize;
280         map->members = tmp->members;
281         write_unlock_bh(&set->lock);
282
283         harray_free(members);
284         kfree(tmp);
285
286         return 0;
287 }
288
289 static inline int
290 __delip(struct ip_set_nethash *map, ip_set_ip_t ip, unsigned char cidr,
291         ip_set_ip_t *hash_ip)
292 {
293         ip_set_ip_t id, *elem;
294
295         if (!ip)
296                 return -ERANGE;
297         
298         id = hash_id_cidr(map, ip, cidr, hash_ip);
299         if (id == UINT_MAX)
300                 return -EEXIST;
301                 
302         elem = HARRAY_ELEM(map->members, ip_set_ip_t *, id);
303         *elem = 0;
304         map->elements--;
305         return 0;
306 }
307
308 static int
309 delip(struct ip_set *set, const void *data, size_t size,
310         ip_set_ip_t *hash_ip)
311 {
312         struct ip_set_req_nethash *req =
313             (struct ip_set_req_nethash *) data;
314
315         if (size != sizeof(struct ip_set_req_nethash)) {
316                 ip_set_printk("data length wrong (want %zu, have %zu)",
317                               sizeof(struct ip_set_req_nethash),
318                               size);
319                 return -EINVAL;
320         }
321         /* TODO: no garbage collection in map->cidr */          
322         return __delip((struct ip_set_nethash *) set->data, 
323                        req->ip, req->cidr, hash_ip);
324 }
325
326 static int
327 delip_kernel(struct ip_set *set, 
328              const struct sk_buff *skb,
329              ip_set_ip_t *hash_ip,
330              const u_int32_t *flags,
331              unsigned char index)
332 {
333         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
334         int ret = -ERANGE;
335         ip_set_ip_t ip = ntohl(flags[index] & IPSET_SRC 
336                                         ? skb->nh.iph->saddr
337                                         : skb->nh.iph->daddr);
338         
339         if (map->cidr[0])
340                 ret = __delip(map, ip, map->cidr[0], hash_ip);
341         
342         return ret;
343 }
344
345 static int create(struct ip_set *set, const void *data, size_t size)
346 {
347         struct ip_set_req_nethash_create *req =
348             (struct ip_set_req_nethash_create *) data;
349         struct ip_set_nethash *map;
350         uint16_t i;
351
352         if (size != sizeof(struct ip_set_req_nethash_create)) {
353                 ip_set_printk("data length wrong (want %zu, have %zu)",
354                                sizeof(struct ip_set_req_nethash_create),
355                                size);
356                 return -EINVAL;
357         }
358
359         if (req->hashsize < 1) {
360                 ip_set_printk("hashsize too small");
361                 return -ENOEXEC;
362         }
363         if (req->probes < 1) {
364                 ip_set_printk("probes too small");
365                 return -ENOEXEC;
366         }
367
368         map = kmalloc(sizeof(struct ip_set_nethash)
369                       + req->probes * sizeof(uint32_t), GFP_KERNEL);
370         if (!map) {
371                 DP("out of memory for %d bytes",
372                    sizeof(struct ip_set_nethash)
373                    + req->probes * sizeof(uint32_t));
374                 return -ENOMEM;
375         }
376         for (i = 0; i < req->probes; i++)
377                 get_random_bytes(((uint32_t *) map->initval)+i, 4);
378         map->elements = 0;
379         map->hashsize = req->hashsize;
380         map->probes = req->probes;
381         map->resize = req->resize;
382         memset(map->cidr, 0, 30 * sizeof(unsigned char));
383         map->members = harray_malloc(map->hashsize, sizeof(ip_set_ip_t), GFP_KERNEL);
384         if (!map->members) {
385                 DP("out of memory for %d bytes", map->hashsize * sizeof(ip_set_ip_t));
386                 kfree(map);
387                 return -ENOMEM;
388         }
389         
390         set->data = map;
391         return 0;
392 }
393
394 static void destroy(struct ip_set *set)
395 {
396         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
397
398         harray_free(map->members);
399         kfree(map);
400
401         set->data = NULL;
402 }
403
404 static void flush(struct ip_set *set)
405 {
406         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
407         harray_flush(map->members, map->hashsize, sizeof(ip_set_ip_t));
408         memset(map->cidr, 0, 30 * sizeof(unsigned char));
409         map->elements = 0;
410 }
411
412 static void list_header(const struct ip_set *set, void *data)
413 {
414         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
415         struct ip_set_req_nethash_create *header =
416             (struct ip_set_req_nethash_create *) data;
417
418         header->hashsize = map->hashsize;
419         header->probes = map->probes;
420         header->resize = map->resize;
421 }
422
423 static int list_members_size(const struct ip_set *set)
424 {
425         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
426
427         return (map->hashsize * sizeof(ip_set_ip_t));
428 }
429
430 static void list_members(const struct ip_set *set, void *data)
431 {
432         struct ip_set_nethash *map = (struct ip_set_nethash *) set->data;
433         ip_set_ip_t i, *elem;
434
435         for (i = 0; i < map->hashsize; i++) {
436                 elem = HARRAY_ELEM(map->members, ip_set_ip_t *, i);     
437                 ((ip_set_ip_t *)data)[i] = *elem;
438         }
439 }
440
441 static struct ip_set_type ip_set_nethash = {
442         .typename               = SETTYPE_NAME,
443         .features               = IPSET_TYPE_IP | IPSET_DATA_SINGLE,
444         .protocol_version       = IP_SET_PROTOCOL_VERSION,
445         .create                 = &create,
446         .destroy                = &destroy,
447         .flush                  = &flush,
448         .reqsize                = sizeof(struct ip_set_req_nethash),
449         .addip                  = &addip,
450         .addip_kernel           = &addip_kernel,
451         .retry                  = &retry,
452         .delip                  = &delip,
453         .delip_kernel           = &delip_kernel,
454         .testip                 = &testip,
455         .testip_kernel          = &testip_kernel,
456         .header_size            = sizeof(struct ip_set_req_nethash_create),
457         .list_header            = &list_header,
458         .list_members_size      = &list_members_size,
459         .list_members           = &list_members,
460         .me                     = THIS_MODULE,
461 };
462
463 MODULE_LICENSE("GPL");
464 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
465 MODULE_DESCRIPTION("nethash type of IP sets");
466 module_param(limit, int, 0600);
467 MODULE_PARM_DESC(limit, "maximal number of elements stored in the sets");
468
469 static int __init init(void)
470 {
471         return ip_set_register_set_type(&ip_set_nethash);
472 }
473
474 static void __exit fini(void)
475 {
476         /* FIXME: possible race with ip_set_create() */
477         ip_set_unregister_set_type(&ip_set_nethash);
478 }
479
480 module_init(init);
481 module_exit(fini);