patch-2_6_7-vs1_9_1_12
[linux-2.6.git] / net / sunrpc / auth_gss / auth_gss.c
index 9711241..60bca99 100644 (file)
@@ -132,6 +132,8 @@ print_hexl(u32 *p, u_int length, u_int offset)
        }
 }
 
+EXPORT_SYMBOL(print_hexl);
+
 static inline struct gss_cl_ctx *
 gss_get_ctx(struct gss_cl_ctx *ctx)
 {
@@ -280,7 +282,7 @@ err_free_ctx:
        kfree(ctx);
 err:
        *gc = NULL;
-       dprintk("RPC: gss_parse_init_downcall returning %d\n", err);
+       dprintk("RPC:      gss_parse_init_downcall returning %d\n", err);
        return err;
 }
 
@@ -311,8 +313,10 @@ __gss_find_upcall(struct gss_auth *gss_auth, uid_t uid)
                if (pos->uid != uid)
                        continue;
                atomic_inc(&pos->count);
+               dprintk("RPC:      gss_find_upcall found msg %p\n", pos);
                return pos;
        }
+       dprintk("RPC:      gss_find_upcall found nothing\n");
        return NULL;
 }
 
@@ -350,6 +354,8 @@ gss_upcall(struct rpc_clnt *clnt, struct rpc_task *task, struct rpc_cred *cred)
        uid_t uid = cred->cr_uid;
        int res = 0;
 
+       dprintk("RPC: %4u gss_upcall for uid %u\n", task->tk_pid, uid);
+
 retry:
        spin_lock(&gss_auth->lock);
        gss_msg = __gss_find_upcall(gss_auth, uid);
@@ -358,8 +364,10 @@ retry:
        if (gss_new == NULL) {
                spin_unlock(&gss_auth->lock);
                gss_new = kmalloc(sizeof(*gss_new), GFP_KERNEL);
-               if (!gss_new)
+               if (!gss_new) {
+                       dprintk("RPC: %4u gss_upcall -ENOMEM\n", task->tk_pid);
                        return -ENOMEM;
+               }
                goto retry;
        }
        gss_msg = gss_new;
@@ -389,12 +397,14 @@ retry:
                spin_unlock(&gss_auth->lock);
        }
        gss_release_msg(gss_msg);
+       dprintk("RPC: %4u gss_upcall for uid %u result %d", task->tk_pid,
+                       uid, res);
        return res;
 out_sleep:
-       /* Sleep forever */
        task->tk_timeout = 0;
        rpc_sleep_on(&gss_msg->waitq, task, NULL, NULL);
        spin_unlock(&gss_auth->lock);
+       dprintk("RPC: %4u gss_upcall  sleeping\n", task->tk_pid);
        if (gss_new)
                kfree(gss_new);
        /* Note: we drop the reference here: we are automatically removed
@@ -426,13 +436,13 @@ gss_pipe_upcall(struct file *filp, struct rpc_pipe_msg *msg,
        return mlen;
 }
 
+#define MSG_BUF_MAXSIZE 1024
+
 static ssize_t
 gss_pipe_downcall(struct file *filp, const char *src, size_t mlen)
 {
-       char buf[1024];
        struct xdr_netobj obj = {
                .len    = mlen,
-               .data   = buf,
        };
        struct inode *inode = filp->f_dentry->d_inode;
        struct rpc_inode *rpci = RPC_I(inode);
@@ -448,11 +458,16 @@ gss_pipe_downcall(struct file *filp, const char *src, size_t mlen)
        int err;
        int gss_err;
 
-       if (mlen > sizeof(buf))
-               return -ENOSPC;
-       left = copy_from_user(buf, src, mlen);
-       if (left)
-               return -EFAULT;
+       if (mlen > MSG_BUF_MAXSIZE)
+               return -EFBIG;
+       obj.data = kmalloc(mlen, GFP_KERNEL);
+       if (!obj.data)
+               return -ENOMEM;
+       left = copy_from_user(obj.data, src, mlen);
+       if (left) {
+               err = -EFAULT;
+               goto out;
+       }
        clnt = rpci->private;
        atomic_inc(&clnt->cl_users);
        auth = clnt->cl_auth;
@@ -477,12 +492,16 @@ gss_pipe_downcall(struct file *filp, const char *src, size_t mlen)
        } else
                spin_unlock(&gss_auth->lock);
        rpc_release_client(clnt);
+       kfree(obj.data);
+       dprintk("RPC:      gss_pipe_downcall returning length %Zu\n", mlen);
        return mlen;
 err:
        if (ctx)
                gss_destroy_ctx(ctx);
        rpc_release_client(clnt);
-       dprintk("RPC: gss_pipe_downcall returning %d\n", err);
+out:
+       kfree(obj.data);
+       dprintk("RPC:      gss_pipe_downcall returning %d\n", err);
        return err;
 }
 
@@ -520,6 +539,8 @@ gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
        static unsigned long ratelimit;
 
        if (msg->errno < 0) {
+               dprintk("RPC:      gss_pipe_destroy_msg releasing msg %p\n",
+                               gss_msg);
                atomic_inc(&gss_msg->count);
                gss_unhash_msg(gss_msg);
                if (msg->errno == -ETIMEDOUT || msg->errno == -EPIPE) {
@@ -544,10 +565,11 @@ gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
        struct gss_auth *gss_auth;
        struct rpc_auth * auth;
 
-       dprintk("RPC: creating GSS authenticator for client %p\n",clnt);
+       dprintk("RPC:      creating GSS authenticator for client %p\n",clnt);
+
        if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
                goto out_dec;
-       gss_auth->mech = gss_pseudoflavor_to_mech(flavor);
+       gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
        if (!gss_auth->mech) {
                printk(KERN_WARNING "%s: Pseudoflavor %d not found!",
                                __FUNCTION__, flavor);
@@ -566,7 +588,7 @@ gss_create(struct rpc_clnt *clnt, rpc_authflavor_t flavor)
 
        snprintf(gss_auth->path, sizeof(gss_auth->path), "%s/%s",
                        clnt->cl_pathname,
-                       gss_auth->mech->gm_ops->name);
+                       gss_auth->mech->gm_name);
        gss_auth->dentry = rpc_mkpipe(gss_auth->path, clnt, &gss_upcall_ops, RPC_PIPE_WAIT_FOR_OPEN);
        if (IS_ERR(gss_auth->dentry))
                goto err_free;
@@ -582,7 +604,8 @@ static void
 gss_destroy(struct rpc_auth *auth)
 {
        struct gss_auth *gss_auth;
-       dprintk("RPC: destroying GSS authenticator %p flavor %d\n",
+
+       dprintk("RPC:      destroying GSS authenticator %p flavor %d\n",
                auth, auth->au_flavor);
 
        gss_auth = container_of(auth, struct gss_auth, rpc_auth);
@@ -597,8 +620,7 @@ gss_destroy(struct rpc_auth *auth)
 static void
 gss_destroy_ctx(struct gss_cl_ctx *ctx)
 {
-
-       dprintk("RPC: gss_destroy_ctx\n");
+       dprintk("RPC:      gss_destroy_ctx\n");
 
        if (ctx->gc_gss_ctx)
                gss_delete_sec_context(&ctx->gc_gss_ctx);
@@ -617,7 +639,7 @@ gss_destroy_cred(struct rpc_cred *rc)
 {
        struct gss_cred *cred = (struct gss_cred *)rc;
 
-       dprintk("RPC: gss_destroy_cred \n");
+       dprintk("RPC:      gss_destroy_cred \n");
 
        if (cred->gc_ctx)
                gss_put_ctx(cred->gc_ctx);
@@ -629,7 +651,7 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int taskflags)
 {
        struct gss_cred *cred = NULL;
 
-       dprintk("RPC: gss_create_cred for uid %d, flavor %d\n",
+       dprintk("RPC:      gss_create_cred for uid %d, flavor %d\n",
                acred->uid, auth->au_flavor);
 
        if (!(cred = kmalloc(sizeof(*cred), GFP_KERNEL)))
@@ -649,7 +671,7 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int taskflags)
        return (struct rpc_cred *) cred;
 
 out_err:
-       dprintk("RPC: gss_create_cred failed\n");
+       dprintk("RPC:      gss_create_cred failed\n");
        if (cred) gss_destroy_cred((struct rpc_cred *)cred);
        return NULL;
 }
@@ -679,15 +701,16 @@ gss_marshal(struct rpc_task *task, u32 *p, int ruid)
        struct xdr_buf  verf_buf;
        u32             service;
 
-       dprintk("RPC: gss_marshal\n");
+       dprintk("RPC: %4u gss_marshal\n", task->tk_pid);
 
        *p++ = htonl(RPC_AUTH_GSS);
        cred_len = p++;
 
-       service = gss_pseudoflavor_to_service(gss_cred->gc_flavor);
+       service = gss_pseudoflavor_to_service(ctx->gc_gss_ctx->mech_type,
+                                               gss_cred->gc_flavor);
        if (service == 0) {
-               dprintk("Bad pseudoflavor %d in gss_marshal\n",
-                       gss_cred->gc_flavor);
+               dprintk("RPC: %4u Bad pseudoflavor %d in gss_marshal\n",
+                       task->tk_pid, gss_cred->gc_flavor);
                goto out_put_ctx;
        }
        spin_lock(&ctx->gc_seq_lock);
@@ -736,10 +759,8 @@ static int
 gss_refresh(struct rpc_task *task)
 {
        struct rpc_clnt *clnt = task->tk_client;
-       struct rpc_xprt *xprt = task->tk_xprt;
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
 
-       task->tk_timeout = xprt->timeout.to_current;
        if (!gss_cred_is_uptodate_ctx(cred))
                return gss_upcall(clnt, task, cred);
        return 0;
@@ -759,7 +780,7 @@ gss_validate(struct rpc_task *task, u32 *p)
        u32             flav,len;
        u32             service;
 
-       dprintk("RPC: gss_validate\n");
+       dprintk("RPC: %4u gss_validate\n", task->tk_pid);
 
        flav = ntohl(*p++);
        if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE)
@@ -775,7 +796,8 @@ gss_validate(struct rpc_task *task, u32 *p)
 
        if (gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic, &qop_state))
                goto out_bad;
-       service = gss_pseudoflavor_to_service(gss_cred->gc_flavor);
+       service = gss_pseudoflavor_to_service(ctx->gc_gss_ctx->mech_type,
+                                       gss_cred->gc_flavor);
        switch (service) {
        case RPC_GSS_SVC_NONE:
               /* verifier data, flavor, length: */
@@ -789,33 +811,75 @@ gss_validate(struct rpc_task *task, u32 *p)
               goto out_bad;
        }
        gss_put_ctx(ctx);
+       dprintk("RPC: %4u GSS gss_validate: gss_verify_mic succeeded.\n",
+                       task->tk_pid);
        return p + XDR_QUADLEN(len);
 out_bad:
        gss_put_ctx(ctx);
+       dprintk("RPC: %4u gss_validate failed.\n", task->tk_pid);
        return NULL;
 }
 
+static inline int
+gss_wrap_req_integ(struct gss_cl_ctx *ctx,
+                       kxdrproc_t encode, void *rqstp, u32 *p, void *obj)
+{
+       struct rpc_rqst *req = (struct rpc_rqst *)rqstp;
+       struct xdr_buf  *snd_buf = &req->rq_snd_buf;
+       struct xdr_buf  integ_buf;
+       u32             *integ_len = NULL;
+       struct xdr_netobj mic;
+       u32             offset, *q;
+       struct iovec    *iov;
+       u32             maj_stat = 0;
+       int             status = -EIO;
+
+       integ_len = p++;
+       offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
+       *p++ = htonl(req->rq_seqno);
+
+       status = encode(rqstp, p, obj);
+       if (status)
+               return status;
+
+       if (xdr_buf_subsegment(snd_buf, &integ_buf,
+                               offset, snd_buf->len - offset))
+               return status;
+       *integ_len = htonl(integ_buf.len);
+
+       /* guess whether we're in the head or the tail: */
+       if (snd_buf->page_len || snd_buf->tail[0].iov_len) 
+               iov = snd_buf->tail;
+       else
+               iov = snd_buf->head;
+       p = iov->iov_base + iov->iov_len;
+       mic.data = (u8 *)(p + 1);
+
+       maj_stat = gss_get_mic(ctx->gc_gss_ctx,
+                       GSS_C_QOP_DEFAULT, &integ_buf, &mic);
+       status = -EIO; /* XXX? */
+       if (maj_stat)
+               return status;
+       q = xdr_encode_opaque(p, NULL, mic.len);
+
+       offset = (u8 *)q - (u8 *)p;
+       iov->iov_len += offset;
+       snd_buf->len += offset;
+       return 0;
+}
+
 static int
 gss_wrap_req(struct rpc_task *task,
             kxdrproc_t encode, void *rqstp, u32 *p, void *obj)
 {
-       struct rpc_rqst *req = (struct rpc_rqst *)rqstp;
-       struct xdr_buf  *snd_buf = &req->rq_snd_buf;
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
        struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
                        gc_base);
        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
-       u32             *integ_len = NULL;
        int             status = -EIO;
-       u32             maj_stat = 0;
-       struct xdr_buf  integ_buf;
-       struct xdr_netobj mic;
        u32             service;
-       u32             offset, *q;
-       struct iovec    *iov;
 
-       dprintk("RPC: gss_wrap_body\n");
-       BUG_ON(!ctx);
+       dprintk("RPC: %4u gss_wrap_req\n", task->tk_pid);
        if (ctx->gc_proc != RPC_GSS_PROC_DATA) {
                /* The spec seems a little ambiguous here, but I think that not
                 * wrapping context destruction requests makes the most sense.
@@ -823,103 +887,84 @@ gss_wrap_req(struct rpc_task *task,
                status = encode(rqstp, p, obj);
                goto out;
        }
-       service = gss_pseudoflavor_to_service(gss_cred->gc_flavor);
+       service = gss_pseudoflavor_to_service(ctx->gc_gss_ctx->mech_type,
+                                               gss_cred->gc_flavor);
        switch (service) {
                case RPC_GSS_SVC_NONE:
                        status = encode(rqstp, p, obj);
                        goto out;
                case RPC_GSS_SVC_INTEGRITY:
-
-                       integ_len = p++;
-                       offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
-                       *p++ = htonl(req->rq_seqno);
-
-                       status = encode(rqstp, p, obj);
-                       if (status)
-                               goto out;
-
-                       if (xdr_buf_subsegment(snd_buf, &integ_buf,
-                                               offset, snd_buf->len - offset))
-                               goto out;
-                       *integ_len = htonl(integ_buf.len);
-
-                       /* guess whether we're in the head or the tail: */
-                       if (snd_buf->page_len || snd_buf->tail[0].iov_len) 
-                               iov = snd_buf->tail;
-                       else
-                               iov = snd_buf->head;
-                       p = iov->iov_base + iov->iov_len;
-                       mic.data = (u8 *)(p + 1);
-
-                       maj_stat = gss_get_mic(ctx->gc_gss_ctx,
-                                       GSS_C_QOP_DEFAULT, &integ_buf, &mic);
-                       status = -EIO; /* XXX? */
-                       if (maj_stat)
-                               goto out;
-                       q = xdr_encode_opaque(p, NULL, mic.len);
-
-                       offset = (u8 *)q - (u8 *)p;
-                       iov->iov_len += offset;
-                       snd_buf->len += offset;
-                       break;
+                       status = gss_wrap_req_integ(ctx, encode, rqstp, p, obj);
+                       goto out;
                case RPC_GSS_SVC_PRIVACY:
                default:
                        goto out;
        }
-       status = 0;
 out:
        gss_put_ctx(ctx);
-       dprintk("RPC: gss_wrap_req returning %d\n", status);
+       dprintk("RPC: %4u gss_wrap_req returning %d\n", task->tk_pid, status);
        return status;
 }
 
+static inline int
+gss_unwrap_resp_integ(struct gss_cl_ctx *ctx,
+               kxdrproc_t decode, void *rqstp, u32 **p, void *obj)
+{
+       struct rpc_rqst *req = (struct rpc_rqst *)rqstp;
+       struct xdr_buf  *rcv_buf = &req->rq_rcv_buf;
+       struct xdr_buf integ_buf;
+       struct xdr_netobj mic;
+       u32 data_offset, mic_offset;
+       u32 integ_len;
+       u32 maj_stat;
+       int status = -EIO;
+
+       integ_len = ntohl(*(*p)++);
+       if (integ_len & 3)
+               return status;
+       data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
+       mic_offset = integ_len + data_offset;
+       if (mic_offset > rcv_buf->len)
+               return status;
+       if (ntohl(*(*p)++) != req->rq_seqno)
+               return status;
+
+       if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset,
+                               mic_offset - data_offset))
+               return status;
+
+       if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset))
+               return status;
+
+       maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf,
+                       &mic, NULL);
+       if (maj_stat != GSS_S_COMPLETE)
+               return status;
+       return 0;
+}
+
 static int
 gss_unwrap_resp(struct rpc_task *task,
                kxdrproc_t decode, void *rqstp, u32 *p, void *obj)
 {
-       struct rpc_rqst *req = (struct rpc_rqst *)rqstp;
-       struct xdr_buf  *rcv_buf = &req->rq_rcv_buf;
        struct rpc_cred *cred = task->tk_msg.rpc_cred;
        struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
                        gc_base);
        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
-       struct xdr_buf  integ_buf;
-       struct xdr_netobj mic;
        int             status = -EIO;
-       u32             maj_stat = 0;
        u32             service;
-       u32             data_offset, mic_offset;
-       u32             integ_len;
-
-       BUG_ON(!ctx);
 
        if (ctx->gc_proc != RPC_GSS_PROC_DATA)
                goto out_decode;
-       service = gss_pseudoflavor_to_service(gss_cred->gc_flavor);
+       service = gss_pseudoflavor_to_service(ctx->gc_gss_ctx->mech_type,
+                                               gss_cred->gc_flavor);
        switch (service) {
                case RPC_GSS_SVC_NONE:
                        goto out_decode;
                case RPC_GSS_SVC_INTEGRITY:
-                       integ_len = ntohl(*p++);
-                       if (integ_len & 3)
-                               goto out;
-                       data_offset = (u8 *)p - (u8 *)rcv_buf->head[0].iov_base;
-                       mic_offset = integ_len + data_offset;
-                       if (mic_offset > rcv_buf->len)
-                               goto out;
-                       if (ntohl(*p++) != req->rq_seqno)
-                               goto out;
-
-                       if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset,
-                                               mic_offset - data_offset))
-                               goto out;
-
-                       if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset))
-                               goto out;
-
-                       maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf,
-                                       &mic, NULL);
-                       if (maj_stat != GSS_S_COMPLETE)
+                       status = gss_unwrap_resp_integ(ctx, decode, 
+                                                       rqstp, &p, obj);
+                       if (status)
                                goto out;
                        break;
                case RPC_GSS_SVC_PRIVACY:
@@ -930,7 +975,8 @@ out_decode:
        status = decode(rqstp, p, obj);
 out:
        gss_put_ctx(ctx);
-       dprintk("RPC: gss_unwrap_resp returning %d\n", status);
+       dprintk("RPC: %4u gss_unwrap_resp returning %d\n", task->tk_pid,
+                       status);
        return status;
 }
   
@@ -985,7 +1031,6 @@ out:
 static void __exit exit_rpcsec_gss(void)
 {
        gss_svc_shutdown();
-       gss_mech_unregister_all();
        rpcauth_unregister(&authgss_ops);
 }