patch-2_6_7-vs1_9_1_12
[linux-2.6.git] / net / sunrpc / auth.c
1 /*
2  * linux/net/sunrpc/auth.c
3  *
4  * Generic RPC client authentication API.
5  *
6  * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
7  */
8
9 #include <linux/types.h>
10 #include <linux/sched.h>
11 #include <linux/module.h>
12 #include <linux/slab.h>
13 #include <linux/errno.h>
14 #include <linux/socket.h>
15 #include <linux/sunrpc/clnt.h>
16 #include <linux/spinlock.h>
17 #include <linux/vserver/xid.h>
18
19 #ifdef RPC_DEBUG
20 # define RPCDBG_FACILITY        RPCDBG_AUTH
21 #endif
22
23 static struct rpc_authops *     auth_flavors[RPC_AUTH_MAXFLAVOR] = {
24         &authnull_ops,          /* AUTH_NULL */
25         &authunix_ops,          /* AUTH_UNIX */
26         NULL,                   /* others can be loadable modules */
27 };
28
29 u32
30 pseudoflavor_to_flavor(u32 flavor) {
31         if (flavor >= RPC_AUTH_MAXFLAVOR)
32                 return RPC_AUTH_GSS;
33         return flavor;
34 }
35
36 int
37 rpcauth_register(struct rpc_authops *ops)
38 {
39         rpc_authflavor_t flavor;
40
41         if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
42                 return -EINVAL;
43         if (auth_flavors[flavor] != NULL)
44                 return -EPERM;          /* what else? */
45         auth_flavors[flavor] = ops;
46         return 0;
47 }
48
49 int
50 rpcauth_unregister(struct rpc_authops *ops)
51 {
52         rpc_authflavor_t flavor;
53
54         if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
55                 return -EINVAL;
56         if (auth_flavors[flavor] != ops)
57                 return -EPERM;          /* what else? */
58         auth_flavors[flavor] = NULL;
59         return 0;
60 }
61
62 struct rpc_auth *
63 rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt)
64 {
65         struct rpc_auth         *auth;
66         struct rpc_authops      *ops;
67         u32                     flavor = pseudoflavor_to_flavor(pseudoflavor);
68
69         if (flavor >= RPC_AUTH_MAXFLAVOR || !(ops = auth_flavors[flavor]))
70                 return NULL;
71         if (!try_module_get(ops->owner))
72                 return NULL;
73         auth = ops->create(clnt, pseudoflavor);
74         if (!auth)
75                 return NULL;
76         atomic_set(&auth->au_count, 1);
77         if (clnt->cl_auth)
78                 rpcauth_destroy(clnt->cl_auth);
79         clnt->cl_auth = auth;
80         return auth;
81 }
82
83 void
84 rpcauth_destroy(struct rpc_auth *auth)
85 {
86         if (!atomic_dec_and_test(&auth->au_count))
87                 return;
88         auth->au_ops->destroy(auth);
89         module_put(auth->au_ops->owner);
90         kfree(auth);
91 }
92
93 static spinlock_t rpc_credcache_lock = SPIN_LOCK_UNLOCKED;
94
95 /*
96  * Initialize RPC credential cache
97  */
98 void
99 rpcauth_init_credcache(struct rpc_auth *auth)
100 {
101         int i;
102         for (i = 0; i < RPC_CREDCACHE_NR; i++)
103                 INIT_LIST_HEAD(&auth->au_credcache[i]);
104         auth->au_nextgc = jiffies + (auth->au_expire >> 1);
105 }
106
107 /*
108  * Destroy an unreferenced credential
109  */
110 static inline void
111 rpcauth_crdestroy(struct rpc_cred *cred)
112 {
113 #ifdef RPC_DEBUG
114         BUG_ON(cred->cr_magic != RPCAUTH_CRED_MAGIC ||
115                         atomic_read(&cred->cr_count) ||
116                         !list_empty(&cred->cr_hash));
117         cred->cr_magic = 0;
118 #endif
119         cred->cr_ops->crdestroy(cred);
120 }
121
122 /*
123  * Destroy a list of credentials
124  */
125 static inline
126 void rpcauth_destroy_credlist(struct list_head *head)
127 {
128         struct rpc_cred *cred;
129
130         while (!list_empty(head)) {
131                 cred = list_entry(head->next, struct rpc_cred, cr_hash);
132                 list_del_init(&cred->cr_hash);
133                 rpcauth_crdestroy(cred);
134         }
135 }
136
137 /*
138  * Clear the RPC credential cache, and delete those credentials
139  * that are not referenced.
140  */
141 void
142 rpcauth_free_credcache(struct rpc_auth *auth)
143 {
144         LIST_HEAD(free);
145         struct list_head *pos, *next;
146         struct rpc_cred *cred;
147         int             i;
148
149         spin_lock(&rpc_credcache_lock);
150         for (i = 0; i < RPC_CREDCACHE_NR; i++) {
151                 list_for_each_safe(pos, next, &auth->au_credcache[i]) {
152                         cred = list_entry(pos, struct rpc_cred, cr_hash);
153                         cred->cr_auth = NULL;
154                         list_del_init(&cred->cr_hash);
155                         if (atomic_read(&cred->cr_count) == 0)
156                                 list_add(&cred->cr_hash, &free);
157                 }
158         }
159         spin_unlock(&rpc_credcache_lock);
160         rpcauth_destroy_credlist(&free);
161 }
162
163 static inline int
164 rpcauth_prune_expired(struct rpc_cred *cred, struct list_head *free)
165 {
166         if (atomic_read(&cred->cr_count) != 0)
167                return 0;
168         if (time_before(jiffies, cred->cr_expire))
169                 return 0;
170         cred->cr_auth = NULL;
171         list_del(&cred->cr_hash);
172         list_add(&cred->cr_hash, free);
173         return 1;
174 }
175
176 /*
177  * Remove stale credentials. Avoid sleeping inside the loop.
178  */
179 static void
180 rpcauth_gc_credcache(struct rpc_auth *auth, struct list_head *free)
181 {
182         struct list_head *pos, *next;
183         struct rpc_cred *cred;
184         int             i;
185
186         dprintk("RPC: gc'ing RPC credentials for auth %p\n", auth);
187         for (i = 0; i < RPC_CREDCACHE_NR; i++) {
188                 list_for_each_safe(pos, next, &auth->au_credcache[i]) {
189                         cred = list_entry(pos, struct rpc_cred, cr_hash);
190                         rpcauth_prune_expired(cred, free);
191                 }
192         }
193         auth->au_nextgc = jiffies + auth->au_expire;
194 }
195
196 /*
197  * Look up a process' credentials in the authentication cache
198  */
199 struct rpc_cred *
200 rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
201                 int taskflags)
202 {
203         LIST_HEAD(free);
204         struct list_head *pos, *next;
205         struct rpc_cred *new = NULL,
206                         *cred = NULL;
207         int             nr = 0;
208
209         if (!(taskflags & RPC_TASK_ROOTCREDS))
210                 nr = acred->uid & RPC_CREDCACHE_MASK;
211 retry:
212         spin_lock(&rpc_credcache_lock);
213         if (time_before(auth->au_nextgc, jiffies))
214                 rpcauth_gc_credcache(auth, &free);
215         list_for_each_safe(pos, next, &auth->au_credcache[nr]) {
216                 struct rpc_cred *entry;
217                 entry = list_entry(pos, struct rpc_cred, cr_hash);
218                 if (entry->cr_flags & RPCAUTH_CRED_DEAD)
219                         continue;
220                 if (rpcauth_prune_expired(entry, &free))
221                         continue;
222                 if (entry->cr_ops->crmatch(acred, entry, taskflags)) {
223                         list_del(&entry->cr_hash);
224                         cred = entry;
225                         break;
226                 }
227         }
228         if (new) {
229                 if (cred)
230                         list_add(&new->cr_hash, &free);
231                 else
232                         cred = new;
233         }
234         if (cred) {
235                 list_add(&cred->cr_hash, &auth->au_credcache[nr]);
236                 cred->cr_auth = auth;
237                 get_rpccred(cred);
238         }
239         spin_unlock(&rpc_credcache_lock);
240
241         rpcauth_destroy_credlist(&free);
242
243         if (!cred) {
244                 new = auth->au_ops->crcreate(auth, acred, taskflags);
245                 if (new) {
246 #ifdef RPC_DEBUG
247                         new->cr_magic = RPCAUTH_CRED_MAGIC;
248 #endif
249                         goto retry;
250                 }
251         }
252
253         return (struct rpc_cred *) cred;
254 }
255
256 struct rpc_cred *
257 rpcauth_lookupcred(struct rpc_auth *auth, int taskflags)
258 {
259         struct auth_cred acred;
260         struct rpc_cred *ret;
261
262         get_group_info(current->group_info);
263         acred.uid = XIDINO_UID(current->fsuid, current->xid);
264         acred.gid = XIDINO_GID(current->fsgid, current->xid);
265         acred.group_info = current->group_info;
266
267         dprintk("RPC:     looking up %s cred\n",
268                 auth->au_ops->au_name);
269         ret = rpcauth_lookup_credcache(auth, &acred, taskflags);
270         put_group_info(current->group_info);
271         return ret;
272 }
273
274 struct rpc_cred *
275 rpcauth_bindcred(struct rpc_task *task)
276 {
277         struct rpc_auth *auth = task->tk_auth;
278         struct auth_cred acred;
279         struct rpc_cred *ret;
280
281         get_group_info(current->group_info);
282         acred.uid = XIDINO_UID(current->fsuid, current->xid);
283         acred.gid = XIDINO_GID(current->fsgid, current->xid);
284         acred.group_info = current->group_info;
285
286         dprintk("RPC: %4d looking up %s cred\n",
287                 task->tk_pid, task->tk_auth->au_ops->au_name);
288         task->tk_msg.rpc_cred = rpcauth_lookup_credcache(auth, &acred, task->tk_flags);
289         if (task->tk_msg.rpc_cred == 0)
290                 task->tk_status = -ENOMEM;
291         ret = task->tk_msg.rpc_cred;
292         put_group_info(current->group_info);
293         return ret;
294 }
295
296 void
297 rpcauth_holdcred(struct rpc_task *task)
298 {
299         dprintk("RPC: %4d holding %s cred %p\n",
300                 task->tk_pid, task->tk_auth->au_ops->au_name, task->tk_msg.rpc_cred);
301         if (task->tk_msg.rpc_cred)
302                 get_rpccred(task->tk_msg.rpc_cred);
303 }
304
305 void
306 put_rpccred(struct rpc_cred *cred)
307 {
308         if (!atomic_dec_and_lock(&cred->cr_count, &rpc_credcache_lock))
309                 return;
310
311         if ((cred->cr_flags & RPCAUTH_CRED_DEAD) && !list_empty(&cred->cr_hash))
312                 list_del_init(&cred->cr_hash);
313
314         if (list_empty(&cred->cr_hash)) {
315                 spin_unlock(&rpc_credcache_lock);
316                 rpcauth_crdestroy(cred);
317                 return;
318         }
319         cred->cr_expire = jiffies + cred->cr_auth->au_expire;
320         spin_unlock(&rpc_credcache_lock);
321 }
322
323 void
324 rpcauth_unbindcred(struct rpc_task *task)
325 {
326         struct rpc_auth *auth = task->tk_auth;
327         struct rpc_cred *cred = task->tk_msg.rpc_cred;
328
329         dprintk("RPC: %4d releasing %s cred %p\n",
330                 task->tk_pid, auth->au_ops->au_name, cred);
331
332         put_rpccred(cred);
333         task->tk_msg.rpc_cred = NULL;
334 }
335
336 u32 *
337 rpcauth_marshcred(struct rpc_task *task, u32 *p)
338 {
339         struct rpc_auth *auth = task->tk_auth;
340         struct rpc_cred *cred = task->tk_msg.rpc_cred;
341
342         dprintk("RPC: %4d marshaling %s cred %p\n",
343                 task->tk_pid, auth->au_ops->au_name, cred);
344         return cred->cr_ops->crmarshal(task, p,
345                                 task->tk_flags & RPC_CALL_REALUID);
346 }
347
348 u32 *
349 rpcauth_checkverf(struct rpc_task *task, u32 *p)
350 {
351         struct rpc_auth *auth = task->tk_auth;
352         struct rpc_cred *cred = task->tk_msg.rpc_cred;
353
354         dprintk("RPC: %4d validating %s cred %p\n",
355                 task->tk_pid, auth->au_ops->au_name, cred);
356         return cred->cr_ops->crvalidate(task, p);
357 }
358
359 int
360 rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp,
361                 u32 *data, void *obj)
362 {
363         struct rpc_cred *cred = task->tk_msg.rpc_cred;
364
365         dprintk("RPC: %4d using %s cred %p to wrap rpc data\n",
366                         task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
367         if (cred->cr_ops->crwrap_req)
368                 return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
369         /* By default, we encode the arguments normally. */
370         return encode(rqstp, data, obj);
371 }
372
373 int
374 rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp,
375                 u32 *data, void *obj)
376 {
377         struct rpc_cred *cred = task->tk_msg.rpc_cred;
378
379         dprintk("RPC: %4d using %s cred %p to unwrap rpc data\n",
380                         task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
381         if (cred->cr_ops->crunwrap_resp)
382                 return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
383                                                    data, obj);
384         /* By default, we decode the arguments normally. */
385         return decode(rqstp, data, obj);
386 }
387
388 int
389 rpcauth_refreshcred(struct rpc_task *task)
390 {
391         struct rpc_auth *auth = task->tk_auth;
392         struct rpc_cred *cred = task->tk_msg.rpc_cred;
393
394         dprintk("RPC: %4d refreshing %s cred %p\n",
395                 task->tk_pid, auth->au_ops->au_name, cred);
396         task->tk_status = cred->cr_ops->crrefresh(task);
397         return task->tk_status;
398 }
399
400 void
401 rpcauth_invalcred(struct rpc_task *task)
402 {
403         dprintk("RPC: %4d invalidating %s cred %p\n",
404                 task->tk_pid, task->tk_auth->au_ops->au_name, task->tk_msg.rpc_cred);
405         spin_lock(&rpc_credcache_lock);
406         if (task->tk_msg.rpc_cred)
407                 task->tk_msg.rpc_cred->cr_flags &= ~RPCAUTH_CRED_UPTODATE;
408         spin_unlock(&rpc_credcache_lock);
409 }
410
411 int
412 rpcauth_uptodatecred(struct rpc_task *task)
413 {
414         return !(task->tk_msg.rpc_cred) ||
415                 (task->tk_msg.rpc_cred->cr_flags & RPCAUTH_CRED_UPTODATE);
416 }
417
418 int
419 rpcauth_deadcred(struct rpc_task *task)
420 {
421         return !(task->tk_msg.rpc_cred) ||
422                 (task->tk_msg.rpc_cred->cr_flags & RPCAUTH_CRED_DEAD);
423 }