linux 2.6.16.38 w/ vs2.0.3-rc1
[linux-2.6.git] / net / sunrpc / auth_gss / svcauth_gss.c
index 55b2fd1..23632d8 100644 (file)
@@ -250,6 +250,7 @@ out:
 }
 
 static struct cache_detail rsi_cache = {
+       .owner          = THIS_MODULE,
        .hash_size      = RSI_HASHMAX,
        .hash_table     = rsi_table,
        .name           = "auth.rpcsec.init",
@@ -381,7 +382,6 @@ static int rsc_parse(struct cache_detail *cd,
        else {
                int N, i;
                struct gss_api_mech *gm;
-               struct xdr_netobj tmp_buf;
 
                /* gid */
                if (get_int(&mesg, &rsci.cred.cr_gid))
@@ -420,9 +420,8 @@ static int rsc_parse(struct cache_detail *cd,
                        gss_mech_put(gm);
                        goto out;
                }
-               tmp_buf.len = len;
-               tmp_buf.data = buf;
-               if (gss_import_sec_context(&tmp_buf, gm, &rsci.mechctx)) {
+               status = gss_import_sec_context(buf, len, gm, &rsci.mechctx);
+               if (status) {
                        gss_mech_put(gm);
                        goto out;
                }
@@ -439,6 +438,7 @@ out:
 }
 
 static struct cache_detail rsc_cache = {
+       .owner          = THIS_MODULE,
        .hash_size      = RSC_HASHMAX,
        .hash_table     = rsc_table,
        .name           = "auth.rpcsec.context",
@@ -448,7 +448,7 @@ static struct cache_detail rsc_cache = {
 
 static DefineSimpleCacheLookup(rsc, 0);
 
-struct rsc *
+static struct rsc *
 gss_svc_searchbyctx(struct xdr_netobj *handle)
 {
        struct rsc rsci;
@@ -503,7 +503,7 @@ static inline u32 round_up_to_quad(u32 i)
 }
 
 static inline int
-svc_safe_getnetobj(struct iovec *argv, struct xdr_netobj *o)
+svc_safe_getnetobj(struct kvec *argv, struct xdr_netobj *o)
 {
        int l;
 
@@ -520,7 +520,7 @@ svc_safe_getnetobj(struct iovec *argv, struct xdr_netobj *o)
 }
 
 static inline int
-svc_safe_putnetobj(struct iovec *resv, struct xdr_netobj *o)
+svc_safe_putnetobj(struct kvec *resv, struct xdr_netobj *o)
 {
        u32 *p;
 
@@ -548,8 +548,8 @@ gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci,
        struct xdr_buf          rpchdr;
        struct xdr_netobj       checksum;
        u32                     flavor = 0;
-       struct iovec            *argv = &rqstp->rq_arg.head[0];
-       struct iovec            iov;
+       struct kvec             *argv = &rqstp->rq_arg.head[0];
+       struct kvec             iov;
 
        /* data to compute the checksum over: */
        iov.iov_base = rpcstart;
@@ -567,8 +567,7 @@ gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci,
 
        if (rqstp->rq_deferred) /* skip verification of revisited request */
                return SVC_OK;
-       if (gss_verify_mic(ctx_id, &rpchdr, &checksum, NULL)
-                                                       != GSS_S_COMPLETE) {
+       if (gss_verify_mic(ctx_id, &rpchdr, &checksum) != GSS_S_COMPLETE) {
                *authp = rpcsec_gsserr_credproblem;
                return SVC_DENIED;
        }
@@ -587,6 +586,20 @@ gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci,
        return SVC_OK;
 }
 
+static int
+gss_write_null_verf(struct svc_rqst *rqstp)
+{
+       u32     *p;
+
+       svc_putu32(rqstp->rq_res.head, htonl(RPC_AUTH_NULL));
+       p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len;
+       /* don't really need to check if head->iov_len > PAGE_SIZE ... */
+       *p++ = 0;
+       if (!xdr_ressize_check(rqstp, p))
+               return -1;
+       return 0;
+}
+
 static int
 gss_write_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq)
 {
@@ -595,7 +608,7 @@ gss_write_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq)
        struct xdr_buf          verf_data;
        struct xdr_netobj       mic;
        u32                     *p;
-       struct iovec            iov;
+       struct kvec             iov;
 
        svc_putu32(rqstp->rq_res.head, htonl(RPC_AUTH_GSS));
        xdr_seq = htonl(seq);
@@ -605,7 +618,7 @@ gss_write_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq)
        xdr_buf_from_iov(&iov, &verf_data);
        p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len;
        mic.data = (u8 *)(p + 1);
-       maj_stat = gss_get_mic(ctx_id, 0, &verf_data, &mic);
+       maj_stat = gss_get_mic(ctx_id, &verf_data, &mic);
        if (maj_stat != GSS_S_COMPLETE)
                return -1;
        *p++ = htonl(mic.len);
@@ -643,7 +656,6 @@ svcauth_gss_register_pseudoflavor(u32 pseudoflavor, char * name)
        if (!new)
                goto out;
        cache_init(&new->h.h);
-       atomic_inc(&new->h.h.refcnt);
        new->h.name = kmalloc(strlen(name) + 1, GFP_KERNEL);
        if (!new->h.name)
                goto out_free_dom;
@@ -651,7 +663,6 @@ svcauth_gss_register_pseudoflavor(u32 pseudoflavor, char * name)
        new->h.flavour = RPC_AUTH_GSS;
        new->pseudoflavor = pseudoflavor;
        new->h.h.expiry_time = NEVER;
-       new->h.h.flags = 0;
 
        test = auth_domain_lookup(&new->h, 1);
        if (test == &new->h) {
@@ -713,7 +724,7 @@ unwrap_integ_data(struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx)
                goto out;
        if (read_bytes_from_xdr_buf(buf, integ_len + 4, mic.data, mic.len))
                goto out;
-       maj_stat = gss_verify_mic(ctx, &integ_buf, &mic, NULL);
+       maj_stat = gss_verify_mic(ctx, &integ_buf, &mic);
        if (maj_stat != GSS_S_COMPLETE)
                goto out;
        if (ntohl(svc_getu32(&buf->head[0])) != seq)
@@ -732,6 +743,34 @@ struct gss_svc_data {
        struct rsc                      *rsci;
 };
 
+static int
+svcauth_gss_set_client(struct svc_rqst *rqstp)
+{
+       struct gss_svc_data *svcdata = rqstp->rq_auth_data;
+       struct rsc *rsci = svcdata->rsci;
+       struct rpc_gss_wire_cred *gc = &svcdata->clcred;
+
+       rqstp->rq_client = find_gss_auth_domain(rsci->mechctx, gc->gc_svc);
+       if (rqstp->rq_client == NULL)
+               return SVC_DENIED;
+       return SVC_OK;
+}
+
+static inline int
+gss_write_init_verf(struct svc_rqst *rqstp, struct rsi *rsip)
+{
+       struct rsc *rsci;
+
+       if (rsip->major_status != GSS_S_COMPLETE)
+               return gss_write_null_verf(rqstp);
+       rsci = gss_svc_searchbyctx(&rsip->out_handle);
+       if (rsci == NULL) {
+               rsip->major_status = GSS_S_NO_CONTEXT;
+               return gss_write_null_verf(rqstp);
+       }
+       return gss_write_verf(rqstp, rsci->mechctx, GSS_SEQ_WIN);
+}
+
 /*
  * Accept an rpcsec packet.
  * If context establishment, punt to user space
@@ -743,8 +782,8 @@ struct gss_svc_data {
 static int
 svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
 {
-       struct iovec    *argv = &rqstp->rq_arg.head[0];
-       struct iovec    *resv = &rqstp->rq_res.head[0];
+       struct kvec     *argv = &rqstp->rq_arg.head[0];
+       struct kvec     *resv = &rqstp->rq_res.head[0];
        u32             crlen;
        struct xdr_netobj tmpobj;
        struct gss_svc_data *svcdata = rqstp->rq_auth_data;
@@ -763,7 +802,7 @@ svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
        if (!svcdata)
                goto auth_err;
        rqstp->rq_auth_data = svcdata;
-       svcdata->body_start = 0;
+       svcdata->body_start = NULL;
        svcdata->rsci = NULL;
        gc = &svcdata->clcred;
 
@@ -867,11 +906,7 @@ svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
                case -ENOENT:
                        goto drop;
                case 0:
-                       rsci = gss_svc_searchbyctx(&rsip->out_handle);
-                       if (!rsci) {
-                               goto drop;
-                       }
-                       if (gss_write_verf(rqstp, rsci->mechctx, GSS_SEQ_WIN))
+                       if (gss_write_init_verf(rqstp, rsip))
                                goto drop;
                        if (resv->iov_len + 4 > PAGE_SIZE)
                                goto drop;
@@ -895,11 +930,6 @@ svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
                svc_putu32(resv, rpc_success);
                goto complete;
        case RPC_GSS_PROC_DATA:
-               *authp = rpc_autherr_badcred;
-               rqstp->rq_client =
-                       find_gss_auth_domain(rsci->mechctx, gc->gc_svc);
-               if (rqstp->rq_client == NULL)
-                       goto auth_err;
                *authp = rpcsec_gsserr_ctxproblem;
                if (gss_write_verf(rqstp, rsci->mechctx, gc->gc_seq))
                        goto auth_err;
@@ -913,8 +943,6 @@ svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
                        if (unwrap_integ_data(&rqstp->rq_arg,
                                        gc->gc_seq, rsci->mechctx))
                                goto auth_err;
-                       svcdata->rsci = rsci;
-                       cache_get(&rsci->h);
                        /* placeholders for length and seq. number: */
                        svcdata->body_start = resv->iov_base + resv->iov_len;
                        svc_putu32(resv, 0);
@@ -925,6 +953,8 @@ svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
                default:
                        goto auth_err;
                }
+               svcdata->rsci = rsci;
+               cache_get(&rsci->h);
                ret = SVC_OK;
                goto out;
        }
@@ -952,7 +982,7 @@ svcauth_gss_release(struct svc_rqst *rqstp)
        struct xdr_buf *resbuf = &rqstp->rq_res;
        struct xdr_buf integ_buf;
        struct xdr_netobj mic;
-       struct iovec *resv;
+       struct kvec *resv;
        u32 *p;
        int integ_offset, integ_len;
        int stat = -EINVAL;
@@ -960,7 +990,7 @@ svcauth_gss_release(struct svc_rqst *rqstp)
        if (gc->gc_proc != RPC_GSS_PROC_DATA)
                goto out;
        /* Release can be called twice, but we only wrap once. */
-       if (gsd->body_start == 0)
+       if (gsd->body_start == NULL)
                goto out;
        /* normally not set till svc_send, but we need it here: */
        resbuf->len = resbuf->head[0].iov_len
@@ -970,7 +1000,7 @@ svcauth_gss_release(struct svc_rqst *rqstp)
                break;
        case RPC_GSS_SVC_INTEGRITY:
                p = gsd->body_start;
-               gsd->body_start = 0;
+               gsd->body_start = NULL;
                /* move accept_stat to right place: */
                memcpy(p, p + 2, 4);
                /* don't wrap in failure case: */
@@ -1007,7 +1037,7 @@ svcauth_gss_release(struct svc_rqst *rqstp)
                        resv = &resbuf->tail[0];
                }
                mic.data = (u8 *)resv->iov_base + resv->iov_len + 4;
-               if (gss_get_mic(gsd->rsci->mechctx, 0, &integ_buf, &mic))
+               if (gss_get_mic(gsd->rsci->mechctx, &integ_buf, &mic))
                        goto out_err;
                svc_putu32(resv, htonl(mic.len));
                memset(mic.data + mic.len, 0,
@@ -1047,13 +1077,14 @@ svcauth_gss_domain_release(struct auth_domain *dom)
        kfree(gd);
 }
 
-struct auth_ops svcauthops_gss = {
+static struct auth_ops svcauthops_gss = {
        .name           = "rpcsec_gss",
        .owner          = THIS_MODULE,
        .flavour        = RPC_AUTH_GSS,
        .accept         = svcauth_gss_accept,
        .release        = svcauth_gss_release,
        .domain_release = svcauth_gss_domain_release,
+       .set_client     = svcauth_gss_set_client,
 };
 
 int
@@ -1070,7 +1101,9 @@ gss_svc_init(void)
 void
 gss_svc_shutdown(void)
 {
-       cache_unregister(&rsc_cache);
-       cache_unregister(&rsi_cache);
+       if (cache_unregister(&rsc_cache))
+               printk(KERN_ERR "auth_rpcgss: failed to unregister rsc cache\n");
+       if (cache_unregister(&rsi_cache))
+               printk(KERN_ERR "auth_rpcgss: failed to unregister rsi cache\n");
        svc_auth_unregister(RPC_AUTH_GSS);
 }