X-Git-Url: http://git.onelab.eu/?a=blobdiff_plain;f=net%2Fsunrpc%2Fauth.c;h=f4b8344046f84e7f8e0a97d40c1c2472ec491f38;hb=97bf2856c6014879bd04983a3e9dfcdac1e7fe85;hp=71bd7a1ea51b5917b4e46d8c84e2b1e8f1746f69;hpb=f7f1b0f1e2fbadeab12d24236000e778aa9b1ead;p=linux-2.6.git diff --git a/net/sunrpc/auth.c b/net/sunrpc/auth.c index 71bd7a1ea..f4b834404 100644 --- a/net/sunrpc/auth.c +++ b/net/sunrpc/auth.c @@ -11,10 +11,9 @@ #include #include #include -#include #include #include -#include +#include #ifdef RPC_DEBUG # define RPCDBG_FACILITY RPCDBG_AUTH @@ -66,14 +65,26 @@ rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt) struct rpc_authops *ops; u32 flavor = pseudoflavor_to_flavor(pseudoflavor); - if (flavor >= RPC_AUTH_MAXFLAVOR || !(ops = auth_flavors[flavor])) - return NULL; + auth = ERR_PTR(-EINVAL); + if (flavor >= RPC_AUTH_MAXFLAVOR) + goto out; + + /* FIXME - auth_flavors[] really needs an rw lock, + * and module refcounting. */ +#ifdef CONFIG_KMOD + if ((ops = auth_flavors[flavor]) == NULL) + request_module("rpc-auth-%u", flavor); +#endif + if ((ops = auth_flavors[flavor]) == NULL) + goto out; auth = ops->create(clnt, pseudoflavor); - if (!auth) - return NULL; + if (IS_ERR(auth)) + return auth; if (clnt->cl_auth) rpcauth_destroy(clnt->cl_auth); clnt->cl_auth = auth; + +out: return auth; } @@ -96,7 +107,7 @@ rpcauth_init_credcache(struct rpc_auth *auth, unsigned long expire) struct rpc_cred_cache *new; int i; - new = (struct rpc_cred_cache *)kmalloc(sizeof(*new), GFP_KERNEL); + new = kmalloc(sizeof(*new), GFP_KERNEL); if (!new) return -ENOMEM; for (i = 0; i < RPC_CREDCACHE_NR; i++) @@ -186,7 +197,7 @@ rpcauth_gc_credcache(struct rpc_auth *auth, struct hlist_head *free) */ struct rpc_cred * rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, - int taskflags) + int flags) { struct rpc_cred_cache *cache = auth->au_credcache; HLIST_HEAD(free); @@ -195,7 +206,7 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred, *cred = NULL; int nr = 0; - if (!(taskflags & RPC_TASK_ROOTCREDS)) + if (!(flags & RPCAUTH_LOOKUP_ROOTCREDS)) nr = acred->uid & RPC_CREDCACHE_MASK; retry: spin_lock(&rpc_credcache_lock); @@ -204,7 +215,7 @@ retry: hlist_for_each_safe(pos, next, &cache->hashtable[nr]) { struct rpc_cred *entry; entry = hlist_entry(pos, struct rpc_cred, cr_hash); - if (entry->cr_ops->crmatch(acred, entry, taskflags)) { + if (entry->cr_ops->crmatch(acred, entry, flags)) { hlist_del(&entry->cr_hash); cred = entry; break; @@ -226,7 +237,7 @@ retry: rpcauth_destroy_credlist(&free); if (!cred) { - new = auth->au_ops->crcreate(auth, acred, taskflags); + new = auth->au_ops->crcreate(auth, acred, flags); if (!IS_ERR(new)) { #ifdef RPC_DEBUG new->cr_magic = RPCAUTH_CRED_MAGIC; @@ -234,18 +245,26 @@ retry: goto retry; } else cred = new; + } else if ((cred->cr_flags & RPCAUTH_CRED_NEW) + && cred->cr_ops->cr_init != NULL + && !(flags & RPCAUTH_LOOKUP_NEW)) { + int res = cred->cr_ops->cr_init(auth, cred); + if (res < 0) { + put_rpccred(cred); + cred = ERR_PTR(res); + } } return (struct rpc_cred *) cred; } struct rpc_cred * -rpcauth_lookupcred(struct rpc_auth *auth, int taskflags) +rpcauth_lookupcred(struct rpc_auth *auth, int flags) { struct auth_cred acred = { .uid = current->fsuid, .gid = current->fsgid, - .xid = vx_current_xid(), + .tag = dx_current_tag(), .group_info = current->group_info, }; struct rpc_cred *ret; @@ -253,7 +272,7 @@ rpcauth_lookupcred(struct rpc_auth *auth, int taskflags) dprintk("RPC: looking up %s cred\n", auth->au_ops->au_name); get_group_info(acred.group_info); - ret = auth->au_ops->lookup_cred(auth, &acred, taskflags); + ret = auth->au_ops->lookup_cred(auth, &acred, flags); put_group_info(acred.group_info); return ret; } @@ -265,15 +284,18 @@ rpcauth_bindcred(struct rpc_task *task) struct auth_cred acred = { .uid = current->fsuid, .gid = current->fsgid, - .xid = vx_current_xid(), + .tag = dx_current_tag(), .group_info = current->group_info, }; struct rpc_cred *ret; + int flags = 0; dprintk("RPC: %4d looking up %s cred\n", task->tk_pid, task->tk_auth->au_ops->au_name); get_group_info(acred.group_info); - ret = auth->au_ops->lookup_cred(auth, &acred, task->tk_flags); + if (task->tk_flags & RPC_TASK_ROOTCREDS) + flags |= RPCAUTH_LOOKUP_ROOTCREDS; + ret = auth->au_ops->lookup_cred(auth, &acred, flags); if (!IS_ERR(ret)) task->tk_msg.rpc_cred = ret; else @@ -303,41 +325,40 @@ put_rpccred(struct rpc_cred *cred) void rpcauth_unbindcred(struct rpc_task *task) { - struct rpc_auth *auth = task->tk_auth; struct rpc_cred *cred = task->tk_msg.rpc_cred; dprintk("RPC: %4d releasing %s cred %p\n", - task->tk_pid, auth->au_ops->au_name, cred); + task->tk_pid, task->tk_auth->au_ops->au_name, cred); put_rpccred(cred); task->tk_msg.rpc_cred = NULL; } -u32 * -rpcauth_marshcred(struct rpc_task *task, u32 *p) +__be32 * +rpcauth_marshcred(struct rpc_task *task, __be32 *p) { - struct rpc_auth *auth = task->tk_auth; struct rpc_cred *cred = task->tk_msg.rpc_cred; dprintk("RPC: %4d marshaling %s cred %p\n", - task->tk_pid, auth->au_ops->au_name, cred); + task->tk_pid, task->tk_auth->au_ops->au_name, cred); + return cred->cr_ops->crmarshal(task, p); } -u32 * -rpcauth_checkverf(struct rpc_task *task, u32 *p) +__be32 * +rpcauth_checkverf(struct rpc_task *task, __be32 *p) { - struct rpc_auth *auth = task->tk_auth; struct rpc_cred *cred = task->tk_msg.rpc_cred; dprintk("RPC: %4d validating %s cred %p\n", - task->tk_pid, auth->au_ops->au_name, cred); + task->tk_pid, task->tk_auth->au_ops->au_name, cred); + return cred->cr_ops->crvalidate(task, p); } int rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp, - u32 *data, void *obj) + __be32 *data, void *obj) { struct rpc_cred *cred = task->tk_msg.rpc_cred; @@ -351,7 +372,7 @@ rpcauth_wrap_req(struct rpc_task *task, kxdrproc_t encode, void *rqstp, int rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp, - u32 *data, void *obj) + __be32 *data, void *obj) { struct rpc_cred *cred = task->tk_msg.rpc_cred; @@ -367,12 +388,12 @@ rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp, int rpcauth_refreshcred(struct rpc_task *task) { - struct rpc_auth *auth = task->tk_auth; struct rpc_cred *cred = task->tk_msg.rpc_cred; int err; dprintk("RPC: %4d refreshing %s cred %p\n", - task->tk_pid, auth->au_ops->au_name, cred); + task->tk_pid, task->tk_auth->au_ops->au_name, cred); + err = cred->cr_ops->crrefresh(task); if (err < 0) task->tk_status = err;