Fedora kernel-2.6.17-1.2142_FC4 patched with stable patch-2.6.17.4-vs2.0.2-rc26.diff
[linux-2.6.git] / net / sunrpc / auth.c
index 71bd7a1..a3ea895 100644 (file)
@@ -11,7 +11,6 @@
 #include <linux/module.h>
 #include <linux/slab.h>
 #include <linux/errno.h>
-#include <linux/socket.h>
 #include <linux/sunrpc/clnt.h>
 #include <linux/spinlock.h>
 #include <linux/vserver/xid.h>
@@ -66,14 +65,26 @@ rpcauth_create(rpc_authflavor_t pseudoflavor, struct rpc_clnt *clnt)
        struct rpc_authops      *ops;
        u32                     flavor = pseudoflavor_to_flavor(pseudoflavor);
 
-       if (flavor >= RPC_AUTH_MAXFLAVOR || !(ops = auth_flavors[flavor]))
-               return NULL;
+       auth = ERR_PTR(-EINVAL);
+       if (flavor >= RPC_AUTH_MAXFLAVOR)
+               goto out;
+
+       /* FIXME - auth_flavors[] really needs an rw lock,
+        * and module refcounting. */
+#ifdef CONFIG_KMOD
+       if ((ops = auth_flavors[flavor]) == NULL)
+               request_module("rpc-auth-%u", flavor);
+#endif
+       if ((ops = auth_flavors[flavor]) == NULL)
+               goto out;
        auth = ops->create(clnt, pseudoflavor);
-       if (!auth)
-               return NULL;
+       if (IS_ERR(auth))
+               return auth;
        if (clnt->cl_auth)
                rpcauth_destroy(clnt->cl_auth);
        clnt->cl_auth = auth;
+
+out:
        return auth;
 }
 
@@ -96,7 +107,7 @@ rpcauth_init_credcache(struct rpc_auth *auth, unsigned long expire)
        struct rpc_cred_cache *new;
        int i;
 
-       new = (struct rpc_cred_cache *)kmalloc(sizeof(*new), GFP_KERNEL);
+       new = kmalloc(sizeof(*new), GFP_KERNEL);
        if (!new)
                return -ENOMEM;
        for (i = 0; i < RPC_CREDCACHE_NR; i++)
@@ -186,7 +197,7 @@ rpcauth_gc_credcache(struct rpc_auth *auth, struct hlist_head *free)
  */
 struct rpc_cred *
 rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
-               int taskflags)
+               int flags)
 {
        struct rpc_cred_cache *cache = auth->au_credcache;
        HLIST_HEAD(free);
@@ -195,7 +206,7 @@ rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
                        *cred = NULL;
        int             nr = 0;
 
-       if (!(taskflags & RPC_TASK_ROOTCREDS))
+       if (!(flags & RPCAUTH_LOOKUP_ROOTCREDS))
                nr = acred->uid & RPC_CREDCACHE_MASK;
 retry:
        spin_lock(&rpc_credcache_lock);
@@ -204,7 +215,7 @@ retry:
        hlist_for_each_safe(pos, next, &cache->hashtable[nr]) {
                struct rpc_cred *entry;
                entry = hlist_entry(pos, struct rpc_cred, cr_hash);
-               if (entry->cr_ops->crmatch(acred, entry, taskflags)) {
+               if (entry->cr_ops->crmatch(acred, entry, flags)) {
                        hlist_del(&entry->cr_hash);
                        cred = entry;
                        break;
@@ -226,7 +237,7 @@ retry:
        rpcauth_destroy_credlist(&free);
 
        if (!cred) {
-               new = auth->au_ops->crcreate(auth, acred, taskflags);
+               new = auth->au_ops->crcreate(auth, acred, flags);
                if (!IS_ERR(new)) {
 #ifdef RPC_DEBUG
                        new->cr_magic = RPCAUTH_CRED_MAGIC;
@@ -234,13 +245,21 @@ retry:
                        goto retry;
                } else
                        cred = new;
+       } else if ((cred->cr_flags & RPCAUTH_CRED_NEW)
+                       && cred->cr_ops->cr_init != NULL
+                       && !(flags & RPCAUTH_LOOKUP_NEW)) {
+               int res = cred->cr_ops->cr_init(auth, cred);
+               if (res < 0) {
+                       put_rpccred(cred);
+                       cred = ERR_PTR(res);
+               }
        }
 
        return (struct rpc_cred *) cred;
 }
 
 struct rpc_cred *
-rpcauth_lookupcred(struct rpc_auth *auth, int taskflags)
+rpcauth_lookupcred(struct rpc_auth *auth, int flags)
 {
        struct auth_cred acred = {
                .uid = current->fsuid,
@@ -253,7 +272,7 @@ rpcauth_lookupcred(struct rpc_auth *auth, int taskflags)
        dprintk("RPC:     looking up %s cred\n",
                auth->au_ops->au_name);
        get_group_info(acred.group_info);
-       ret = auth->au_ops->lookup_cred(auth, &acred, taskflags);
+       ret = auth->au_ops->lookup_cred(auth, &acred, flags);
        put_group_info(acred.group_info);
        return ret;
 }
@@ -269,11 +288,14 @@ rpcauth_bindcred(struct rpc_task *task)
                .group_info = current->group_info,
        };
        struct rpc_cred *ret;
+       int flags = 0;
 
        dprintk("RPC: %4d looking up %s cred\n",
                task->tk_pid, task->tk_auth->au_ops->au_name);
        get_group_info(acred.group_info);
-       ret = auth->au_ops->lookup_cred(auth, &acred, task->tk_flags);
+       if (task->tk_flags & RPC_TASK_ROOTCREDS)
+               flags |= RPCAUTH_LOOKUP_ROOTCREDS;
+       ret = auth->au_ops->lookup_cred(auth, &acred, flags);
        if (!IS_ERR(ret))
                task->tk_msg.rpc_cred = ret;
        else
@@ -303,11 +325,10 @@ put_rpccred(struct rpc_cred *cred)
 void
 rpcauth_unbindcred(struct rpc_task *task)
 {
-       struct rpc_auth *auth = task->tk_auth;
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
 
        dprintk("RPC: %4d releasing %s cred %p\n",
-               task->tk_pid, auth->au_ops->au_name, cred);
+               task->tk_pid, task->tk_auth->au_ops->au_name, cred);
 
        put_rpccred(cred);
        task->tk_msg.rpc_cred = NULL;
@@ -316,22 +337,22 @@ rpcauth_unbindcred(struct rpc_task *task)
 u32 *
 rpcauth_marshcred(struct rpc_task *task, u32 *p)
 {
-       struct rpc_auth *auth = task->tk_auth;
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
 
        dprintk("RPC: %4d marshaling %s cred %p\n",
-               task->tk_pid, auth->au_ops->au_name, cred);
+               task->tk_pid, task->tk_auth->au_ops->au_name, cred);
+
        return cred->cr_ops->crmarshal(task, p);
 }
 
 u32 *
 rpcauth_checkverf(struct rpc_task *task, u32 *p)
 {
-       struct rpc_auth *auth = task->tk_auth;
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
 
        dprintk("RPC: %4d validating %s cred %p\n",
-               task->tk_pid, auth->au_ops->au_name, cred);
+               task->tk_pid, task->tk_auth->au_ops->au_name, cred);
+
        return cred->cr_ops->crvalidate(task, p);
 }
 
@@ -367,12 +388,12 @@ rpcauth_unwrap_resp(struct rpc_task *task, kxdrproc_t decode, void *rqstp,
 int
 rpcauth_refreshcred(struct rpc_task *task)
 {
-       struct rpc_auth *auth = task->tk_auth;
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
        int err;
 
        dprintk("RPC: %4d refreshing %s cred %p\n",
-               task->tk_pid, auth->au_ops->au_name, cred);
+               task->tk_pid, task->tk_auth->au_ops->au_name, cred);
+
        err = cred->cr_ops->crrefresh(task);
        if (err < 0)
                task->tk_status = err;