vserver 1.9.5.x5
[linux-2.6.git] / net / sunrpc / auth.c
1 /*
2  * linux/net/sunrpc/auth.c
3  *
4  * Generic RPC client authentication API.
5  *
6  * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
7  */
8
9 #include <linux/types.h>
10 #include <linux/sched.h>
11 #include <linux/module.h>
12 #include <linux/slab.h>
13 #include <linux/errno.h>
14 #include <linux/socket.h>
15 #include <linux/sunrpc/clnt.h>
16 #include <linux/spinlock.h>
17 #include <linux/vserver/xid.h>
18
19 #ifdef RPC_DEBUG
20 # define RPCDBG_FACILITY        RPCDBG_AUTH
21 #endif
22
23 static struct rpc_authops *     auth_flavors[RPC_AUTH_MAXFLAVOR] = {
24         &authnull_ops,          /* AUTH_NULL */
25         &authunix_ops,          /* AUTH_UNIX */
26         NULL,                   /* others can be loadable modules */
27 };
28
29 static u32
30 pseudoflavor_to_flavor(u32 flavor) {
31         if (flavor >= RPC_AUTH_MAXFLAVOR)
32                 return RPC_AUTH_GSS;
33         return flavor;
34 }
35
36 int
37 rpcauth_register(struct rpc_authops *ops)
38 {
39         rpc_authflavor_t flavor;
40
41         if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
42                 return -EINVAL;
43         if (auth_flavors[flavor] != NULL)
44                 return -EPERM;          /* what else? */
45         auth_flavors[flavor] = ops;
46         return 0;
47 }
48
49 int
50 rpcauth_unregister(struct rpc_authops *ops)
51 {
52         rpc_authflavor_t flavor;
53
54         if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
55                 return -EINVAL;
56         if (auth_flavors[flavor] != ops)
57                 return -EPERM;          /* what else? */
58         auth_flavors[flavor] = NULL;
59         return 0;
60 }
61
62 struct rpc_auth *
63 rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt)
64 {
65         struct rpc_auth         *auth;
66         struct rpc_authops      *ops;
67         u32                     flavor = pseudoflavor_to_flavor(pseudoflavor);
68
69         if (flavor >= RPC_AUTH_MAXFLAVOR || !(ops = auth_flavors[flavor]))
70                 return NULL;
71         if (!try_module_get(ops->owner))
72                 return NULL;
73         auth = ops->create(clnt, pseudoflavor);
74         if (!auth)
75                 return NULL;
76         atomic_set(&auth->au_count, 1);
77         if (clnt->cl_auth)
78                 rpcauth_destroy(clnt->cl_auth);
79         clnt->cl_auth = auth;
80         return auth;
81 }
82
83 void
84 rpcauth_destroy(struct rpc_auth *auth)
85 {
86         if (!atomic_dec_and_test(&auth->au_count))
87                 return;
88         auth->au_ops->destroy(auth);
89         module_put(auth->au_ops->owner);
90         kfree(auth);
91 }
92
93 static DEFINE_SPINLOCK(rpc_credcache_lock);
94
95 /*
96  * Initialize RPC credential cache
97  */
98 void
99 rpcauth_init_credcache(struct rpc_auth *auth)
100 {
101         int i;
102         for (i = 0; i < RPC_CREDCACHE_NR; i++)
103                 INIT_LIST_HEAD(&auth->au_credcache[i]);
104         auth->au_nextgc = jiffies + (auth->au_expire >> 1);
105 }
106
107 /*
108  * Destroy an unreferenced credential
109  */
110 static inline void
111 rpcauth_crdestroy(struct rpc_cred *cred)
112 {
113 #ifdef RPC_DEBUG
114         BUG_ON(cred->cr_magic != RPCAUTH_CRED_MAGIC ||
115                         atomic_read(&cred->cr_count) ||
116                         !list_empty(&cred->cr_hash));
117         cred->cr_magic = 0;
118 #endif
119         cred->cr_ops->crdestroy(cred);
120 }
121
122 /*
123  * Destroy a list of credentials
124  */
125 static inline
126 void rpcauth_destroy_credlist(struct list_head *head)
127 {
128         struct rpc_cred *cred;
129
130         while (!list_empty(head)) {
131                 cred = list_entry(head->next, struct rpc_cred, cr_hash);
132                 list_del_init(&cred->cr_hash);
133                 rpcauth_crdestroy(cred);
134         }
135 }
136
137 /*
138  * Clear the RPC credential cache, and delete those credentials
139  * that are not referenced.
140  */
141 void
142 rpcauth_free_credcache(struct rpc_auth *auth)
143 {
144         LIST_HEAD(free);
145         struct list_head *pos, *next;
146         struct rpc_cred *cred;
147         int             i;
148
149         spin_lock(&rpc_credcache_lock);
150         for (i = 0; i < RPC_CREDCACHE_NR; i++) {
151                 list_for_each_safe(pos, next, &auth->au_credcache[i]) {
152                         cred = list_entry(pos, struct rpc_cred, cr_hash);
153                         cred->cr_auth = NULL;
154                         list_del_init(&cred->cr_hash);
155                         if (atomic_read(&cred->cr_count) == 0)
156                                 list_add(&cred->cr_hash, &free);
157                 }
158         }
159         spin_unlock(&rpc_credcache_lock);
160         rpcauth_destroy_credlist(&free);
161 }
162
163 static inline int
164 rpcauth_prune_expired(struct rpc_cred *cred, struct list_head *free)
165 {
166         if (atomic_read(&cred->cr_count) != 0)
167                return 0;
168         if (time_before(jiffies, cred->cr_expire))
169                 return 0;
170         cred->cr_auth = NULL;
171         list_del(&cred->cr_hash);
172         list_add(&cred->cr_hash, free);
173         return 1;
174 }
175
176 /*
177  * Remove stale credentials. Avoid sleeping inside the loop.
178  */
179 static void
180 rpcauth_gc_credcache(struct rpc_auth *auth, struct list_head *free)
181 {
182         struct list_head *pos, *next;
183         struct rpc_cred *cred;
184         int             i;
185
186         dprintk("RPC: gc'ing RPC credentials for auth %p\n", auth);
187         for (i = 0; i < RPC_CREDCACHE_NR; i++) {
188                 list_for_each_safe(pos, next, &auth->au_credcache[i]) {
189                         cred = list_entry(pos, struct rpc_cred, cr_hash);
190                         rpcauth_prune_expired(cred, free);
191                 }
192         }
193         auth->au_nextgc = jiffies + auth->au_expire;
194 }
195
196 /*
197  * Look up a process' credentials in the authentication cache
198  */
199 struct rpc_cred *
200 rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
201                 int taskflags)
202 {
203         LIST_HEAD(free);
204         struct list_head *pos, *next;
205         struct rpc_cred *new = NULL,
206                         *cred = NULL;
207         int             nr = 0;
208
209         if (!(taskflags & RPC_TASK_ROOTCREDS))
210                 nr = acred->uid & RPC_CREDCACHE_MASK;
211 retry:
212         spin_lock(&rpc_credcache_lock);
213         if (time_before(auth->au_nextgc, jiffies))
214                 rpcauth_gc_credcache(auth, &free);
215         list_for_each_safe(pos, next, &auth->au_credcache[nr]) {
216                 struct rpc_cred *entry;
217                 entry = list_entry(pos, struct rpc_cred, cr_hash);
218                 if (rpcauth_prune_expired(entry, &free))
219                         continue;
220                 if (entry->cr_ops->crmatch(acred, entry, taskflags)) {
221                         list_del(&entry->cr_hash);
222                         cred = entry;
223                         break;
224                 }
225         }
226         if (new) {
227                 if (cred)
228                         list_add(&new->cr_hash, &free);
229                 else
230                         cred = new;
231         }
232         if (cred) {
233                 list_add(&cred->cr_hash, &auth->au_credcache[nr]);
234                 cred->cr_auth = auth;
235                 get_rpccred(cred);
236         }
237         spin_unlock(&rpc_credcache_lock);
238
239         rpcauth_destroy_credlist(&free);
240
241         if (!cred) {
242                 new = auth->au_ops->crcreate(auth, acred, taskflags);
243                 if (new) {
244 #ifdef RPC_DEBUG
245                         new->cr_magic = RPCAUTH_CRED_MAGIC;
246 #endif
247                         goto retry;
248                 }
249         }
250
251         return (struct rpc_cred *) cred;
252 }
253
254 struct rpc_cred *
255 rpcauth_lookupcred(struct rpc_auth *auth, int taskflags)
256 {
257         struct auth_cred acred;
258         struct rpc_cred *ret;
259
260         get_group_info(current->group_info);
261         acred.uid = current->fsuid;
262         acred.gid = current->fsgid;
263         acred.xid = vx_current_xid();
264         acred.group_info = current->group_info;
265
266         dprintk("RPC:     looking up %s cred\n",
267                 auth->au_ops->au_name);
268         ret = rpcauth_lookup_credcache(auth, &acred, taskflags);
269         put_group_info(current->group_info);
270         return ret;
271 }
272
273 struct rpc_cred *
274 rpcauth_bindcred(struct rpc_task *task)
275 {
276         struct rpc_auth *auth = task->tk_auth;
277         struct auth_cred acred;
278         struct rpc_cred *ret;
279
280         get_group_info(current->group_info);
281         acred.uid = current->fsuid;
282         acred.gid = current->fsgid;
283         acred.xid = vx_current_xid();
284         acred.group_info = current->group_info;
285
286         dprintk("RPC: %4d looking up %s cred\n",
287                 task->tk_pid, task->tk_auth->au_ops->au_name);
288         task->tk_msg.rpc_cred = rpcauth_lookup_credcache(auth, &acred, task->tk_flags);
289         if (task->tk_msg.rpc_cred == 0)
290                 task->tk_status = -ENOMEM;
291         ret = task->tk_msg.rpc_cred;
292         put_group_info(current->group_info);
293         return ret;
294 }
295
296 void
297 rpcauth_holdcred(struct rpc_task *task)
298 {
299         dprintk("RPC: %4d holding %s cred %p\n",
300                 task->tk_pid, task->tk_auth->au_ops->au_name, task->tk_msg.rpc_cred);
301         if (task->tk_msg.rpc_cred)
302                 get_rpccred(task->tk_msg.rpc_cred);
303 }
304
305 void
306 put_rpccred(struct rpc_cred *cred)
307 {
308         if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
309                 return;
310
311         if (list_empty(&cred->cr_hash)) {
312                 spin_unlock(&rpc_credcache_lock);
313                 rpcauth_crdestroy(cred);
314                 return;
315         }
316         cred->cr_expire = jiffies + cred->cr_auth->au_expire;
317         spin_unlock(&rpc_credcache_lock);
318 }
319
320 void
321 rpcauth_unbindcred(struct rpc_task *task)
322 {
323         struct rpc_auth *auth = task->tk_auth;
324         struct rpc_cred *cred = task->tk_msg.rpc_cred;
325
326         dprintk("RPC: %4d releasing %s cred %p\n",
327                 task->tk_pid, auth->au_ops->au_name, cred);
328
329         put_rpccred(cred);
330         task->tk_msg.rpc_cred = NULL;
331 }
332
333 u32 *
334 rpcauth_marshcred(struct rpc_task *task, u32 *p)
335 {
336         struct rpc_auth *auth = task->tk_auth;
337         struct rpc_cred *cred = task->tk_msg.rpc_cred;
338
339         dprintk("RPC: %4d marshaling %s cred %p\n",
340                 task->tk_pid, auth->au_ops->au_name, cred);
341         return cred->cr_ops->crmarshal(task, p,
342                                 task->tk_flags & RPC_CALL_REALUID);
343 }
344
345 u32 *
346 rpcauth_checkverf(struct rpc_task *task, u32 *p)
347 {
348         struct rpc_auth *auth = task->tk_auth;
349         struct rpc_cred *cred = task->tk_msg.rpc_cred;
350
351         dprintk("RPC: %4d validating %s cred %p\n",
352                 task->tk_pid, auth->au_ops->au_name, cred);
353         return cred->cr_ops->crvalidate(task, p);
354 }
355
356 int
357 rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp,
358                 u32 *data, void *obj)
359 {
360         struct rpc_cred *cred = task->tk_msg.rpc_cred;
361
362         dprintk("RPC: %4d using %s cred %p to wrap rpc data\n",
363                         task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
364         if (cred->cr_ops->crwrap_req)
365                 return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
366         /* By default, we encode the arguments normally. */
367         return encode(rqstp, data, obj);
368 }
369
370 int
371 rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp,
372                 u32 *data, void *obj)
373 {
374         struct rpc_cred *cred = task->tk_msg.rpc_cred;
375
376         dprintk("RPC: %4d using %s cred %p to unwrap rpc data\n",
377                         task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
378         if (cred->cr_ops->crunwrap_resp)
379                 return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
380                                                    data, obj);
381         /* By default, we decode the arguments normally. */
382         return decode(rqstp, data, obj);
383 }
384
385 int
386 rpcauth_refreshcred(struct rpc_task *task)
387 {
388         struct rpc_auth *auth = task->tk_auth;
389         struct rpc_cred *cred = task->tk_msg.rpc_cred;
390
391         dprintk("RPC: %4d refreshing %s cred %p\n",
392                 task->tk_pid, auth->au_ops->au_name, cred);
393         task->tk_status = cred->cr_ops->crrefresh(task);
394         return task->tk_status;
395 }
396
397 void
398 rpcauth_invalcred(struct rpc_task *task)
399 {
400         dprintk("RPC: %4d invalidating %s cred %p\n",
401                 task->tk_pid, task->tk_auth->au_ops->au_name, task->tk_msg.rpc_cred);
402         spin_lock(&rpc_credcache_lock);
403         if (task->tk_msg.rpc_cred)
404                 task->tk_msg.rpc_cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE;
405         spin_unlock(&rpc_credcache_lock);
406 }
407
408 int
409 rpcauth_uptodatecred(struct rpc_task *task)
410 {
411         return !(task->tk_msg.rpc_cred) ||
412                 (task->tk_msg.rpc_cred->cr_flags & RPCAUTH_CRED_UPTODATE);
413 }