vserver 1.9.3
[linux-2.6.git] / net / sunrpc / auth_gss / svcauth_gss.c
index dae18e9..ee2e9ce 100644 (file)
@@ -333,6 +333,7 @@ rsc_init(struct rsc *new, struct rsc *tmp)
        new->handle.data = tmp->handle.data;
        tmp->handle.data = NULL;
        new->mechctx = NULL;
+       new->cred.cr_group_info = NULL;
 }
 
 static inline void
@@ -453,8 +454,11 @@ gss_svc_searchbyctx(struct xdr_netobj *handle)
        struct rsc rsci;
        struct rsc *found;
 
-       rsci.handle = *handle;
+       memset(&rsci, 0, sizeof(rsci));
+       if (dup_to_netobj(&rsci.handle, handle->data, handle->len))
+               return NULL;
        found = rsc_lookup(&rsci, 0);
+       rsc_free(&rsci);
        if (!found)
                return NULL;
        if (cache_check(&rsc_cache, &found->h, NULL))
@@ -499,7 +503,7 @@ static inline u32 round_up_to_quad(u32 i)
 }
 
 static inline int
-svc_safe_getnetobj(struct iovec *argv, struct xdr_netobj *o)
+svc_safe_getnetobj(struct kvec *argv, struct xdr_netobj *o)
 {
        int l;
 
@@ -516,7 +520,7 @@ svc_safe_getnetobj(struct iovec *argv, struct xdr_netobj *o)
 }
 
 static inline int
-svc_safe_putnetobj(struct iovec *resv, struct xdr_netobj *o)
+svc_safe_putnetobj(struct kvec *resv, struct xdr_netobj *o)
 {
        u32 *p;
 
@@ -544,8 +548,8 @@ gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci,
        struct xdr_buf          rpchdr;
        struct xdr_netobj       checksum;
        u32                     flavor = 0;
-       struct iovec            *argv = &rqstp->rq_arg.head[0];
-       struct iovec            iov;
+       struct kvec             *argv = &rqstp->rq_arg.head[0];
+       struct kvec             iov;
 
        /* data to compute the checksum over: */
        iov.iov_base = rpcstart;
@@ -570,14 +574,14 @@ gss_verify_header(struct svc_rqst *rqstp, struct rsc *rsci,
        }
 
        if (gc->gc_seq > MAXSEQ) {
-               dprintk("svcauth_gss: discarding request with large"
-                       " sequence number %d\n", gc->gc_seq);
+               dprintk("RPC:      svcauth_gss: discarding request with large sequence number %d\n",
+                               gc->gc_seq);
                *authp = rpcsec_gsserr_ctxproblem;
                return SVC_DENIED;
        }
        if (!gss_check_seq_num(rsci, gc->gc_seq)) {
-               dprintk("svcauth_gss: discarding request with old"
-                               " sequence number %d\n", gc->gc_seq);
+               dprintk("RPC:      svcauth_gss: discarding request with old sequence number %d\n",
+                               gc->gc_seq);
                return SVC_DROP;
        }
        return SVC_OK;
@@ -591,7 +595,7 @@ gss_write_verf(struct svc_rqst *rqstp, struct gss_ctx *ctx_id, u32 seq)
        struct xdr_buf          verf_data;
        struct xdr_netobj       mic;
        u32                     *p;
-       struct iovec            iov;
+       struct kvec             iov;
 
        svc_putu32(rqstp->rq_res.head, htonl(RPC_AUTH_GSS));
        xdr_seq = htonl(seq);
@@ -617,19 +621,15 @@ struct gss_domain {
        u32                     pseudoflavor;
 };
 
-/* XXX this should be done in gss_pseudoflavors, and shouldn't be hardcoded: */
 static struct auth_domain *
 find_gss_auth_domain(struct gss_ctx *ctx, u32 svc)
 {
-       switch(gss_get_pseudoflavor(ctx, 0, svc)) {
-               case RPC_AUTH_GSS_KRB5:
-                       return auth_domain_find("gss/krb5");
-               case RPC_AUTH_GSS_KRB5I:
-                       return auth_domain_find("gss/krb5i");
-               case RPC_AUTH_GSS_KRB5P:
-                       return auth_domain_find("gss/krb5p");
-       }
-       return NULL;
+       char *name;
+
+       name = gss_service_to_auth_domain_name(ctx->mech_type, svc);
+       if (!name)
+               return NULL;
+       return auth_domain_find(name);
 }
 
 int
@@ -637,23 +637,19 @@ svcauth_gss_register_pseudoflavor(u32 pseudoflavor, char * name)
 {
        struct gss_domain       *new;
        struct auth_domain      *test;
-       static char             *prefix = "gss/";
-       int                     stat = -1;
+       int                     stat = -ENOMEM;
 
        new = kmalloc(sizeof(*new), GFP_KERNEL);
        if (!new)
                goto out;
        cache_init(&new->h.h);
-       atomic_inc(&new->h.h.refcnt);
-       new->h.name = kmalloc(strlen(name) + strlen(prefix) + 1, GFP_KERNEL);
+       new->h.name = kmalloc(strlen(name) + 1, GFP_KERNEL);
        if (!new->h.name)
                goto out_free_dom;
-       strcpy(new->h.name, prefix);
-       strcat(new->h.name, name);
+       strcpy(new->h.name, name);
        new->h.flavour = RPC_AUTH_GSS;
        new->pseudoflavor = pseudoflavor;
        new->h.h.expiry_time = NEVER;
-       new->h.h.flags = 0;
 
        test = auth_domain_lookup(&new->h, 1);
        if (test == &new->h) {
@@ -670,6 +666,8 @@ out:
        return stat;
 }
 
+EXPORT_SYMBOL(svcauth_gss_register_pseudoflavor);
+
 static inline int
 read_u32_from_xdr_buf(struct xdr_buf *buf, int base, u32 *obj)
 {
@@ -743,8 +741,8 @@ struct gss_svc_data {
 static int
 svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
 {
-       struct iovec    *argv = &rqstp->rq_arg.head[0];
-       struct iovec    *resv = &rqstp->rq_res.head[0];
+       struct kvec     *argv = &rqstp->rq_arg.head[0];
+       struct kvec     *resv = &rqstp->rq_res.head[0];
        u32             crlen;
        struct xdr_netobj tmpobj;
        struct gss_svc_data *svcdata = rqstp->rq_auth_data;
@@ -755,7 +753,7 @@ svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
        u32             *reject_stat = resv->iov_base + resv->iov_len;
        int             ret;
 
-       dprintk("RPC: svcauth_gss: argv->iov_len = %zd\n",argv->iov_len);
+       dprintk("RPC:      svcauth_gss: argv->iov_len = %zd\n",argv->iov_len);
 
        *authp = rpc_autherr_badcred;
        if (!svcdata)
@@ -763,7 +761,7 @@ svcauth_gss_accept(struct svc_rqst *rqstp, u32 *authp)
        if (!svcdata)
                goto auth_err;
        rqstp->rq_auth_data = svcdata;
-       svcdata->body_start = 0;
+       svcdata->body_start = NULL;
        svcdata->rsci = NULL;
        gc = &svcdata->clcred;
 
@@ -952,7 +950,7 @@ svcauth_gss_release(struct svc_rqst *rqstp)
        struct xdr_buf *resbuf = &rqstp->rq_res;
        struct xdr_buf integ_buf;
        struct xdr_netobj mic;
-       struct iovec *resv;
+       struct kvec *resv;
        u32 *p;
        int integ_offset, integ_len;
        int stat = -EINVAL;
@@ -970,7 +968,7 @@ svcauth_gss_release(struct svc_rqst *rqstp)
                break;
        case RPC_GSS_SVC_INTEGRITY:
                p = gsd->body_start;
-               gsd->body_start = 0;
+               gsd->body_start = NULL;
                /* move accept_stat to right place: */
                memcpy(p, p + 2, 4);
                /* don't wrap in failure case: */
@@ -1049,6 +1047,7 @@ svcauth_gss_domain_release(struct auth_domain *dom)
 
 struct auth_ops svcauthops_gss = {
        .name           = "rpcsec_gss",
+       .owner          = THIS_MODULE,
        .flavour        = RPC_AUTH_GSS,
        .accept         = svcauth_gss_accept,
        .release        = svcauth_gss_release,
@@ -1058,10 +1057,12 @@ struct auth_ops svcauthops_gss = {
 int
 gss_svc_init(void)
 {
-       cache_register(&rsc_cache);
-       cache_register(&rsi_cache);
-       svc_auth_register(RPC_AUTH_GSS, &svcauthops_gss);
-       return 0;
+       int rv = svc_auth_register(RPC_AUTH_GSS, &svcauthops_gss);
+       if (rv == 0) {
+               cache_register(&rsc_cache);
+               cache_register(&rsi_cache);
+       }
+       return rv;
 }
 
 void
@@ -1069,4 +1070,5 @@ gss_svc_shutdown(void)
 {
        cache_unregister(&rsc_cache);
        cache_unregister(&rsi_cache);
+       svc_auth_unregister(RPC_AUTH_GSS);
 }