This commit was manufactured by cvs2svn to create branch 'vserver'.
[linux-2.6.git] / net / ipv4 / netfilter / ip_set_portmap.c
1 /* Copyright (C) 2003-2004 Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>
2  *
3  * This program is free software; you can redistribute it and/or modify
4  * it under the terms of the GNU General Public License version 2 as
5  * published by the Free Software Foundation.  
6  */
7
8 /* Kernel module implementing a port set type as a bitmap */
9
10 #include <linux/module.h>
11 #include <linux/ip.h>
12 #include <linux/tcp.h>
13 #include <linux/udp.h>
14 #include <linux/skbuff.h>
15 #include <linux/netfilter_ipv4/ip_tables.h>
16 #include <linux/netfilter_ipv4/ip_set.h>
17 #include <linux/errno.h>
18 #include <asm/uaccess.h>
19 #include <asm/bitops.h>
20 #include <linux/spinlock.h>
21
22 #include <net/ip.h>
23
24 #include <linux/netfilter_ipv4/ip_set_portmap.h>
25
26 /* We must handle non-linear skbs */
27 static inline ip_set_ip_t
28 get_port(const struct sk_buff *skb, u_int32_t flags)
29 {
30         struct iphdr *iph = skb->nh.iph;
31         u_int16_t offset = ntohs(iph->frag_off) & IP_OFFSET;
32
33         switch (iph->protocol) {
34         case IPPROTO_TCP: {
35                 struct tcphdr tcph;
36                 
37                 /* See comments at tcp_match in ip_tables.c */
38                 if (offset)
39                         return INVALID_PORT;
40
41                 if (skb_copy_bits(skb, skb->nh.iph->ihl*4, &tcph, sizeof(tcph)) < 0)
42                         /* No choice either */
43                         return INVALID_PORT;
44                 
45                 return ntohs(flags & IPSET_SRC ?
46                              tcph.source : tcph.dest);
47             }
48         case IPPROTO_UDP: {
49                 struct udphdr udph;
50
51                 if (offset)
52                         return INVALID_PORT;
53
54                 if (skb_copy_bits(skb, skb->nh.iph->ihl*4, &udph, sizeof(udph)) < 0)
55                         /* No choice either */
56                         return INVALID_PORT;
57                 
58                 return ntohs(flags & IPSET_SRC ?
59                              udph.source : udph.dest);
60             }
61         default:
62                 return INVALID_PORT;
63         }
64 }
65
66 static inline int
67 __testport(struct ip_set *set, ip_set_ip_t port, ip_set_ip_t *hash_port)
68 {
69         struct ip_set_portmap *map = (struct ip_set_portmap *) set->data;
70
71         if (port < map->first_port || port > map->last_port)
72                 return -ERANGE;
73                 
74         *hash_port = port;
75         DP("set: %s, port:%u, %u", set->name, port, *hash_port);
76         return !!test_bit(port - map->first_port, map->members);
77 }
78
79 static int
80 testport(struct ip_set *set, const void *data, size_t size,
81          ip_set_ip_t *hash_port)
82 {
83         struct ip_set_req_portmap *req = 
84             (struct ip_set_req_portmap *) data;
85
86         if (size != sizeof(struct ip_set_req_portmap)) {
87                 ip_set_printk("data length wrong (want %zu, have %zu)",
88                               sizeof(struct ip_set_req_portmap),
89                               size);
90                 return -EINVAL;
91         }
92         return __testport(set, req->port, hash_port);
93 }
94
95 static int
96 testport_kernel(struct ip_set *set, 
97                 const struct sk_buff *skb,
98                 ip_set_ip_t *hash_port,
99                 const u_int32_t *flags,
100                 unsigned char index)
101 {
102         int res;
103         ip_set_ip_t port = get_port(skb, flags[index]);
104
105         DP("flag %s port %u", flags[index] & IPSET_SRC ? "SRC" : "DST", port);  
106         if (port == INVALID_PORT)
107                 return 0;       
108
109         res =  __testport(set, port, hash_port);
110         
111         return (res < 0 ? 0 : res);
112 }
113
114 static inline int
115 __addport(struct ip_set *set, ip_set_ip_t port, ip_set_ip_t *hash_port)
116 {
117         struct ip_set_portmap *map = (struct ip_set_portmap *) set->data;
118
119         if (port < map->first_port || port > map->last_port)
120                 return -ERANGE;
121         if (test_and_set_bit(port - map->first_port, map->members))
122                 return -EEXIST;
123                 
124         *hash_port = port;
125         DP("port %u", port);
126         return 0;
127 }
128
129 static int
130 addport(struct ip_set *set, const void *data, size_t size,
131         ip_set_ip_t *hash_port)
132 {
133         struct ip_set_req_portmap *req = 
134             (struct ip_set_req_portmap *) data;
135
136         if (size != sizeof(struct ip_set_req_portmap)) {
137                 ip_set_printk("data length wrong (want %zu, have %zu)",
138                               sizeof(struct ip_set_req_portmap),
139                               size);
140                 return -EINVAL;
141         }
142         return __addport(set, req->port, hash_port);
143 }
144
145 static int
146 addport_kernel(struct ip_set *set, 
147                const struct sk_buff *skb,
148                ip_set_ip_t *hash_port,
149                const u_int32_t *flags,
150                unsigned char index)
151 {
152         ip_set_ip_t port = get_port(skb, flags[index]);
153         
154         if (port == INVALID_PORT)
155                 return -EINVAL;
156
157         return __addport(set, port, hash_port);
158 }
159
160 static inline int
161 __delport(struct ip_set *set, ip_set_ip_t port, ip_set_ip_t *hash_port)
162 {
163         struct ip_set_portmap *map = (struct ip_set_portmap *) set->data;
164
165         if (port < map->first_port || port > map->last_port)
166                 return -ERANGE;
167         if (!test_and_clear_bit(port - map->first_port, map->members))
168                 return -EEXIST;
169                 
170         *hash_port = port;
171         DP("port %u", port);
172         return 0;
173 }
174
175 static int
176 delport(struct ip_set *set, const void *data, size_t size,
177         ip_set_ip_t *hash_port)
178 {
179         struct ip_set_req_portmap *req =
180             (struct ip_set_req_portmap *) data;
181
182         if (size != sizeof(struct ip_set_req_portmap)) {
183                 ip_set_printk("data length wrong (want %zu, have %zu)",
184                               sizeof(struct ip_set_req_portmap),
185                               size);
186                 return -EINVAL;
187         }
188         return __delport(set, req->port, hash_port);
189 }
190
191 static int
192 delport_kernel(struct ip_set *set, 
193                const struct sk_buff *skb,
194                ip_set_ip_t *hash_port,
195                const u_int32_t *flags,
196                unsigned char index)
197 {
198         ip_set_ip_t port = get_port(skb, flags[index]);
199         
200         if (port == INVALID_PORT)
201                 return -EINVAL;
202
203         return __delport(set, port, hash_port);
204 }
205
206 static int create(struct ip_set *set, const void *data, size_t size)
207 {
208         int newbytes;
209         struct ip_set_req_portmap_create *req =
210             (struct ip_set_req_portmap_create *) data;
211         struct ip_set_portmap *map;
212
213         if (size != sizeof(struct ip_set_req_portmap_create)) {
214                 ip_set_printk("data length wrong (want %zu, have %zu)",
215                                sizeof(struct ip_set_req_portmap_create),
216                                size);
217                 return -EINVAL;
218         }
219
220         DP("from %u to %u", req->from, req->to);
221
222         if (req->from > req->to) {
223                 DP("bad port range");
224                 return -ENOEXEC;
225         }
226
227         if (req->to - req->from > MAX_RANGE) {
228                 ip_set_printk("range too big (max %d ports)",
229                                MAX_RANGE+1);
230                 return -ENOEXEC;
231         }
232
233         map = kmalloc(sizeof(struct ip_set_portmap), GFP_KERNEL);
234         if (!map) {
235                 DP("out of memory for %d bytes",
236                    sizeof(struct ip_set_portmap));
237                 return -ENOMEM;
238         }
239         map->first_port = req->from;
240         map->last_port = req->to;
241         newbytes = bitmap_bytes(req->from, req->to);
242         map->members = kmalloc(newbytes, GFP_KERNEL);
243         if (!map->members) {
244                 DP("out of memory for %d bytes", newbytes);
245                 kfree(map);
246                 return -ENOMEM;
247         }
248         memset(map->members, 0, newbytes);
249
250         set->data = map;
251         return 0;
252 }
253
254 static void destroy(struct ip_set *set)
255 {
256         struct ip_set_portmap *map = (struct ip_set_portmap *) set->data;
257
258         kfree(map->members);
259         kfree(map);
260
261         set->data = NULL;
262 }
263
264 static void flush(struct ip_set *set)
265 {
266         struct ip_set_portmap *map = (struct ip_set_portmap *) set->data;
267         memset(map->members, 0, bitmap_bytes(map->first_port, map->last_port));
268 }
269
270 static void list_header(const struct ip_set *set, void *data)
271 {
272         struct ip_set_portmap *map = (struct ip_set_portmap *) set->data;
273         struct ip_set_req_portmap_create *header =
274             (struct ip_set_req_portmap_create *) data;
275
276         DP("list_header %u %u", map->first_port, map->last_port);
277
278         header->from = map->first_port;
279         header->to = map->last_port;
280 }
281
282 static int list_members_size(const struct ip_set *set)
283 {
284         struct ip_set_portmap *map = (struct ip_set_portmap *) set->data;
285
286         return bitmap_bytes(map->first_port, map->last_port);
287 }
288
289 static void list_members(const struct ip_set *set, void *data)
290 {
291         struct ip_set_portmap *map = (struct ip_set_portmap *) set->data;
292         int bytes = bitmap_bytes(map->first_port, map->last_port);
293
294         memcpy(data, map->members, bytes);
295 }
296
297 static struct ip_set_type ip_set_portmap = {
298         .typename               = SETTYPE_NAME,
299         .features               = IPSET_TYPE_PORT | IPSET_DATA_SINGLE,
300         .protocol_version       = IP_SET_PROTOCOL_VERSION,
301         .create                 = &create,
302         .destroy                = &destroy,
303         .flush                  = &flush,
304         .reqsize                = sizeof(struct ip_set_req_portmap),
305         .addip                  = &addport,
306         .addip_kernel           = &addport_kernel,
307         .delip                  = &delport,
308         .delip_kernel           = &delport_kernel,
309         .testip                 = &testport,
310         .testip_kernel          = &testport_kernel,
311         .header_size            = sizeof(struct ip_set_req_portmap_create),
312         .list_header            = &list_header,
313         .list_members_size      = &list_members_size,
314         .list_members           = &list_members,
315         .me                     = THIS_MODULE,
316 };
317
318 MODULE_LICENSE("GPL");
319 MODULE_AUTHOR("Jozsef Kadlecsik <kadlec@blackhole.kfki.hu>");
320 MODULE_DESCRIPTION("portmap type of IP sets");
321
322 static int __init init(void)
323 {
324         return ip_set_register_set_type(&ip_set_portmap);
325 }
326
327 static void __exit fini(void)
328 {
329         /* FIXME: possible race with ip_set_create() */
330         ip_set_unregister_set_type(&ip_set_portmap);
331 }
332
333 module_init(init);
334 module_exit(fini);