vserver 2.0 rc7
[linux-2.6.git] / net / sunrpc / auth_gss / gss_spkm3_mech.c
index fd213dc..dad0599 100644 (file)
 # define RPCDBG_FACILITY       RPCDBG_AUTH
 #endif
 
-static inline int
-get_bytes(char **ptr, const char *end, void *res, int len)
+static const void *
+simple_get_bytes(const void *p, const void *end, void *res, int len)
 {
-       char *p, *q;
-       p = *ptr;
-       q = p + len;
-       if (q > end || q < p)
-               return -1;
+       const void *q = (const void *)((const char *)p + len);
+       if (unlikely(q > end || q < p))
+               return ERR_PTR(-EFAULT);
        memcpy(res, p, len);
-       *ptr = q;
-       return 0;
+       return q;
 }
 
-static inline int
-get_netobj(char **ptr, const char *end, struct xdr_netobj *res)
+static const void *
+simple_get_netobj(const void *p, const void *end, struct xdr_netobj *res)
 {
-       char *p, *q;
-       p = *ptr;
-       if (get_bytes(&p, end, &res->len, sizeof(res->len)))
-               return -1;
-       q = p + res->len;
-       if(res->len == 0)
-               goto out_nocopy;
-       if (q > end || q < p)
-               return -1;
-       if (!(res->data = kmalloc(res->len, GFP_KERNEL)))
-               return -1;
-       memcpy(res->data, p, res->len);
-out_nocopy:
-       *ptr = q;
-       return 0;
+       const void *q;
+       unsigned int len;
+       p = simple_get_bytes(p, end, &len, sizeof(len));
+       if (IS_ERR(p))
+               return p;
+       res->len = len;
+       if (len == 0) {
+               res->data = NULL;
+               return p;
+       }
+       q = (const void *)((const char *)p + len);
+       if (unlikely(q > end || q < p))
+               return ERR_PTR(-EFAULT);
+       res->data = kmalloc(len, GFP_KERNEL);
+       if (unlikely(res->data == NULL))
+               return ERR_PTR(-ENOMEM);
+       memcpy(res->data, p, len);
+       return q;
 }
 
-static inline int
-get_key(char **p, char *end, struct crypto_tfm **res, int *resalg)
+static inline const void *
+get_key(const void *p, const void *end, struct crypto_tfm **res, int *resalg)
 {
-       struct xdr_netobj       key = {
-               .len = 0,
-               .data = NULL,
-       };
+       struct xdr_netobj       key = { 0 };
        int                     alg_mode,setkey = 0;
        char                    *alg_name;
 
-       if (get_bytes(p, end, resalg, sizeof(int)))
+       p = simple_get_bytes(p, end, resalg, sizeof(*resalg));
+       if (IS_ERR(p))
                goto out_err;
-       if ((get_netobj(p, end, &key)))
+       p = simple_get_netobj(p, end, &key);
+       if (IS_ERR(p))
                goto out_err;
 
        switch (*resalg) {
@@ -111,10 +110,6 @@ get_key(char **p, char *end, struct crypto_tfm **res, int *resalg)
                        alg_mode = 0;
                        setkey = 0;
                        break;
-               case NID_cast5_cbc:
-                       dprintk("RPC: SPKM3 get_key: case cast5_cbc, UNSUPPORTED \n");
-                       goto out_err;
-                       break;
                default:
                        dprintk("RPC: SPKM3 get_key: unsupported algorithm %d", *resalg);
                        goto out_err_free_key;
@@ -128,69 +123,81 @@ get_key(char **p, char *end, struct crypto_tfm **res, int *resalg)
 
        if(key.len > 0)
                kfree(key.data);
-       return 0;
+       return p;
 
 out_err_free_tfm:
        crypto_free_tfm(*res);
 out_err_free_key:
        if(key.len > 0)
                kfree(key.data);
+       p = ERR_PTR(-EINVAL);
 out_err:
-       return -1;
+       return p;
 }
 
-static u32
-gss_import_sec_context_spkm3(struct xdr_netobj *inbuf,
+static int
+gss_import_sec_context_spkm3(const void *p, size_t len,
                                struct gss_ctx *ctx_id)
 {
-       char    *p = inbuf->data;
-       char    *end = inbuf->data + inbuf->len;
+       const void *end = (const void *)((const char *)p + len);
        struct  spkm3_ctx *ctx;
 
        if (!(ctx = kmalloc(sizeof(*ctx), GFP_KERNEL)))
                goto out_err;
        memset(ctx, 0, sizeof(*ctx));
 
-       if (get_netobj(&p, end, &ctx->ctx_id))
+       p = simple_get_netobj(p, end, &ctx->ctx_id);
+       if (IS_ERR(p))
                goto out_err_free_ctx;
 
-       if (get_bytes(&p, end, &ctx->qop, sizeof(ctx->qop)))
+       p = simple_get_bytes(p, end, &ctx->qop, sizeof(ctx->qop));
+       if (IS_ERR(p))
                goto out_err_free_ctx_id;
 
-       if (get_netobj(&p, end, &ctx->mech_used))
+       p = simple_get_netobj(p, end, &ctx->mech_used);
+       if (IS_ERR(p))
                goto out_err_free_mech;
 
-       if (get_bytes(&p, end, &ctx->ret_flags, sizeof(ctx->ret_flags)))
+       p = simple_get_bytes(p, end, &ctx->ret_flags, sizeof(ctx->ret_flags));
+       if (IS_ERR(p))
                goto out_err_free_mech;
 
-       if (get_bytes(&p, end, &ctx->req_flags, sizeof(ctx->req_flags)))
+       p = simple_get_bytes(p, end, &ctx->req_flags, sizeof(ctx->req_flags));
+       if (IS_ERR(p))
                goto out_err_free_mech;
 
-       if (get_netobj(&p, end, &ctx->share_key))
+       p = simple_get_netobj(p, end, &ctx->share_key);
+       if (IS_ERR(p))
                goto out_err_free_s_key;
 
-       if (get_key(&p, end, &ctx->derived_conf_key, &ctx->conf_alg)) {
-               dprintk("RPC: SPKM3 confidentiality key will be NULL\n");
-       }
+       p = get_key(p, end, &ctx->derived_conf_key, &ctx->conf_alg);
+       if (IS_ERR(p))
+               goto out_err_free_s_key;
 
-       if (get_key(&p, end, &ctx->derived_integ_key, &ctx->intg_alg)) {
-               dprintk("RPC: SPKM3 integrity key will be NULL\n");
-       }
+       p = get_key(p, end, &ctx->derived_integ_key, &ctx->intg_alg);
+       if (IS_ERR(p))
+               goto out_err_free_key1;
 
-       if (get_bytes(&p, end, &ctx->owf_alg, sizeof(ctx->owf_alg)))
-               goto out_err_free_s_key;
+       p = simple_get_bytes(p, end, &ctx->keyestb_alg, sizeof(ctx->keyestb_alg));
+       if (IS_ERR(p))
+               goto out_err_free_key2;
 
-       if (get_bytes(&p, end, &ctx->owf_alg, sizeof(ctx->owf_alg)))
-               goto out_err_free_s_key;
+       p = simple_get_bytes(p, end, &ctx->owf_alg, sizeof(ctx->owf_alg));
+       if (IS_ERR(p))
+               goto out_err_free_key2;
 
        if (p != end)
-               goto out_err_free_s_key;
+               goto out_err_free_key2;
 
        ctx_id->internal_ctx_id = ctx;
 
        dprintk("Succesfully imported new spkm context.\n");
        return 0;
 
+out_err_free_key2:
+       crypto_free_tfm(ctx->derived_integ_key);
+out_err_free_key1:
+       crypto_free_tfm(ctx->derived_conf_key);
 out_err_free_s_key:
        kfree(ctx->share_key.data);
 out_err_free_mech:
@@ -200,7 +207,7 @@ out_err_free_ctx_id:
 out_err_free_ctx:
        kfree(ctx);
 out_err:
-       return GSS_S_FAILURE;
+       return PTR_ERR(p);
 }
 
 static void