linux 2.6.16.38 w/ vs2.0.3-rc1
[linux-2.6.git] / net / sunrpc / auth_gss / svcauth_gss.c
index ee2e9ce..23632d8 100644 (file)
@@ -250,6 +250,7 @@ out:
 }
 
 static struct cache_detail rsi_cache = {
+       .owner          = THIS_MODULE,
        .hash_size      = RSI_HASHMAX,
        .hash_table     = rsi_table,
        .name           = "auth.rpcsec.init",
@@ -381,7 +382,6 @@ static int rsc_parse(struct cache_detail *cd,
        else {
                int N, i;
                struct gss_api_mech *gm;
-               struct xdr_netobj tmp_buf;
 
                /* gid */
                if (get_int(&mesg, &rsci.cred.cr_gid))
@@ -420,9 +420,8 @@ static int rsc_parse(struct cache_detail *cd,
                        gss_mech_put(gm);
                        goto out;
                }
-               tmp_buf.len = len;
-               tmp_buf.data = buf;
-               if (gss_import_sec_context(&tmp_buf, gm, &rsci.mechctx)) {
+               status = gss_import_sec_context(buf, len, gm, &rsci.mechctx);
+               if (status) {
                        gss_mech_put(gm);
                        goto out;
                }
@@ -439,6 +438,7 @@ out:
 }
 
 static struct cache_detail rsc_cache = {
+       .owner          = THIS_MODULE,
        .hash_size      = RSC_HASHMAX,
        .hash_table     = rsc_table,
        .name           = "auth.rpcsec.context",
@@ -448,7 +448,7 @@ static struct cache_detail rsc_cache = {
 
 static DefineSimpleCacheLookup(rsc, 0);
 
-struct rsc *
+static struct rsc *
 gss_svc_searchbyctx(struct xdr_netobj *handle)
 {
        struct rsc rsci;
@@ -567,8 +567,7 @@ gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci,
 
        if (rqstp->rq_deferred) /* skip verification of revisited request */
                return SVC_OK;
-       if (gss_verify_mic(ctx_id, &rpchdr, &checksum, NULL)
-                                                       != GSS_S_COMPLETE) {
+       if (gss_verify_mic(ctx_id, &rpchdr, &checksum) != GSS_S_COMPLETE) {
                *authp = rpcsec_gsserr_credproblem;
                return SVC_DENIED;
        }
@@ -587,6 +586,20 @@ gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci,
        return SVC_OK;
 }
 
+static int
+gss_write_null_verf(struct svc_rqst *rqstp)
+{
+       u32     *p;
+
+       svc_putu32(rqstp->rq_res.head, htonl(RPC_AUTH_NULL));
+       p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len;
+       /* don't really need to check if head->iov_len > PAGE_SIZE ... */
+       *p++ = 0;
+       if (!xdr_ressize_check(rqstp, p))
+               return -1;
+       return 0;
+}
+
 static int
 gss_write_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq)
 {
@@ -605,7 +618,7 @@ gss_write_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq)
        xdr_buf_from_iov(&iov, &verf_data);
        p = rqstp->rq_res.head->iov_base + rqstp->rq_res.head->iov_len;
        mic.data = (u8 *)(p + 1);
-       maj_stat = gss_get_mic(ctx_id, 0, &verf_data, &mic);
+       maj_stat = gss_get_mic(ctx_id, &verf_data, &mic);
        if (maj_stat != GSS_S_COMPLETE)
                return -1;
        *p++ = htonl(mic.len);
@@ -711,7 +724,7 @@ unwrap_integ_data(struct xdr_buf *buf, u32 seq, struct gss_ctx *ctx)
                goto out;
        if (read_bytes_from_xdr_buf(buf, integ_len + 4, mic.data, mic.len))
                goto out;
-       maj_stat = gss_verify_mic(ctx, &integ_buf, &mic, NULL);
+       maj_stat = gss_verify_mic(ctx, &integ_buf, &mic);
        if (maj_stat != GSS_S_COMPLETE)
                goto out;
        if (ntohl(svc_getu32(&buf->head[0])) != seq)
@@ -730,6 +743,34 @@ struct gss_svc_data {
        struct rsc                      *rsci;
 };
 
+static int
+svcauth_gss_set_client(struct svc_rqst *rqstp)
+{
+       struct gss_svc_data *svcdata = rqstp->rq_auth_data;
+       struct rsc *rsci = svcdata->rsci;
+       struct rpc_gss_wire_cred *gc = &svcdata->clcred;
+
+       rqstp->rq_client = find_gss_auth_domain(rsci->mechctx, gc->gc_svc);
+       if (rqstp->rq_client == NULL)
+               return SVC_DENIED;
+       return SVC_OK;
+}
+
+static inline int
+gss_write_init_verf(struct svc_rqst *rqstp, struct rsi *rsip)
+{
+       struct rsc *rsci;
+
+       if (rsip->major_status != GSS_S_COMPLETE)
+               return gss_write_null_verf(rqstp);
+       rsci = gss_svc_searchbyctx(&rsip->out_handle);
+       if (rsci == NULL) {
+               rsip->major_status = GSS_S_NO_CONTEXT;
+               return gss_write_null_verf(rqstp);
+       }
+       return gss_write_verf(rqstp, rsci->mechctx, GSS_SEQ_WIN);
+}
+
 /*
  * Accept an rpcsec packet.
  * If context establishment, punt to user space
@@ -865,11 +906,7 @@ svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
                case -ENOENT:
                        goto drop;
                case 0:
-                       rsci = gss_svc_searchbyctx(&rsip->out_handle);
-                       if (!rsci) {
-                               goto drop;
-                       }
-                       if (gss_write_verf(rqstp, rsci->mechctx, GSS_SEQ_WIN))
+                       if (gss_write_init_verf(rqstp, rsip))
                                goto drop;
                        if (resv->iov_len + 4 > PAGE_SIZE)
                                goto drop;
@@ -893,11 +930,6 @@ svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
                svc_putu32(resv, rpc_success);
                goto complete;
        case RPC_GSS_PROC_DATA:
-               *authp = rpc_autherr_badcred;
-               rqstp->rq_client =
-                       find_gss_auth_domain(rsci->mechctx, gc->gc_svc);
-               if (rqstp->rq_client == NULL)
-                       goto auth_err;
                *authp = rpcsec_gsserr_ctxproblem;
                if (gss_write_verf(rqstp, rsci->mechctx, gc->gc_seq))
                        goto auth_err;
@@ -911,8 +943,6 @@ svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
                        if (unwrap_integ_data(&rqstp->rq_arg,
                                        gc->gc_seq, rsci->mechctx))
                                goto auth_err;
-                       svcdata->rsci = rsci;
-                       cache_get(&rsci->h);
                        /* placeholders for length and seq. number: */
                        svcdata->body_start = resv->iov_base + resv->iov_len;
                        svc_putu32(resv, 0);
@@ -923,6 +953,8 @@ svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
                default:
                        goto auth_err;
                }
+               svcdata->rsci = rsci;
+               cache_get(&rsci->h);
                ret = SVC_OK;
                goto out;
        }
@@ -958,7 +990,7 @@ svcauth_gss_release(struct svc_rqst *rqstp)
        if (gc->gc_proc != RPC_GSS_PROC_DATA)
                goto out;
        /* Release can be called twice, but we only wrap once. */
-       if (gsd->body_start == 0)
+       if (gsd->body_start == NULL)
                goto out;
        /* normally not set till svc_send, but we need it here: */
        resbuf->len = resbuf->head[0].iov_len
@@ -1005,7 +1037,7 @@ svcauth_gss_release(struct svc_rqst *rqstp)
                        resv = &resbuf->tail[0];
                }
                mic.data = (u8 *)resv->iov_base + resv->iov_len + 4;
-               if (gss_get_mic(gsd->rsci->mechctx, 0, &integ_buf, &mic))
+               if (gss_get_mic(gsd->rsci->mechctx, &integ_buf, &mic))
                        goto out_err;
                svc_putu32(resv, htonl(mic.len));
                memset(mic.data + mic.len, 0,
@@ -1045,13 +1077,14 @@ svcauth_gss_domain_release(struct auth_domain *dom)
        kfree(gd);
 }
 
-struct auth_ops svcauthops_gss = {
+static struct auth_ops svcauthops_gss = {
        .name           = "rpcsec_gss",
        .owner          = THIS_MODULE,
        .flavour        = RPC_AUTH_GSS,
        .accept         = svcauth_gss_accept,
        .release        = svcauth_gss_release,
        .domain_release = svcauth_gss_domain_release,
+       .set_client     = svcauth_gss_set_client,
 };
 
 int
@@ -1068,7 +1101,9 @@ gss_svc_init(void)
 void
 gss_svc_shutdown(void)
 {
-       cache_unregister(&rsc_cache);
-       cache_unregister(&rsi_cache);
+       if (cache_unregister(&rsc_cache))
+               printk(KERN_ERR "auth_rpcgss: failed to unregister rsc cache\n");
+       if (cache_unregister(&rsi_cache))
+               printk(KERN_ERR "auth_rpcgss: failed to unregister rsi cache\n");
        svc_auth_unregister(RPC_AUTH_GSS);
 }