This commit was manufactured by cvs2svn to create branch 'vserver'.
[linux-2.6.git] / crypto / mpi / mpih-mul.c
1 /* mpihelp-mul.c  -  MPI helper functions
2  * Copyright (C) 1994, 1996, 1998, 1999,
3  *               2000 Free Software Foundation, Inc.
4  *
5  * This file is part of GnuPG.
6  *
7  * GnuPG is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License as published by
9  * the Free Software Foundation; either version 2 of the License, or
10  * (at your option) any later version.
11  *
12  * GnuPG is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with this program; if not, write to the Free Software
19  * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA
20  *
21  * Note: This code is heavily based on the GNU MP Library.
22  *       Actually it's the same code with only minor changes in the
23  *       way the data is stored; this is to support the abstraction
24  *       of an optional secure memory allocation which may be used
25  *       to avoid revealing of sensitive data due to paging etc.
26  *       The GNU MP Library itself is published under the LGPL;
27  *       however I decided to publish this code under the plain GPL.
28  */
29
30 #include <linux/string.h>
31 #include "mpi-internal.h"
32 #include "longlong.h"
33
34
35
36 #define MPN_MUL_N_RECURSE(prodp, up, vp, size, tspace) \
37     do {                                                \
38         if( (size) < KARATSUBA_THRESHOLD )              \
39             mul_n_basecase (prodp, up, vp, size);       \
40         else                                            \
41             mul_n (prodp, up, vp, size, tspace);        \
42     } while (0);
43
44 #define MPN_SQR_N_RECURSE(prodp, up, size, tspace) \
45     do {                                            \
46         if ((size) < KARATSUBA_THRESHOLD)           \
47             mpih_sqr_n_basecase (prodp, up, size);       \
48         else                                        \
49             mpih_sqr_n (prodp, up, size, tspace);        \
50     } while (0);
51
52
53
54
55 /* Multiply the natural numbers u (pointed to by UP) and v (pointed to by VP),
56  * both with SIZE limbs, and store the result at PRODP.  2 * SIZE limbs are
57  * always stored.  Return the most significant limb.
58  *
59  * Argument constraints:
60  * 1. PRODP != UP and PRODP != VP, i.e. the destination
61  *    must be distinct from the multiplier and the multiplicand.
62  *
63  *
64  * Handle simple cases with traditional multiplication.
65  *
66  * This is the most critical code of multiplication.  All multiplies rely
67  * on this, both small and huge.  Small ones arrive here immediately.  Huge
68  * ones arrive here as this is the base case for Karatsuba's recursive
69  * algorithm below.
70  */
71
72 static mpi_limb_t
73 mul_n_basecase( mpi_ptr_t prodp, mpi_ptr_t up,
74                                  mpi_ptr_t vp, mpi_size_t size)
75 {
76     mpi_size_t i;
77     mpi_limb_t cy;
78     mpi_limb_t v_limb;
79
80     /* Multiply by the first limb in V separately, as the result can be
81      * stored (not added) to PROD.  We also avoid a loop for zeroing.  */
82     v_limb = vp[0];
83     if( v_limb <= 1 ) {
84         if( v_limb == 1 )
85             MPN_COPY( prodp, up, size );
86         else
87             MPN_ZERO( prodp, size );
88         cy = 0;
89     }
90     else
91         cy = mpihelp_mul_1( prodp, up, size, v_limb );
92
93     prodp[size] = cy;
94     prodp++;
95
96     /* For each iteration in the outer loop, multiply one limb from
97      * U with one limb from V, and add it to PROD.  */
98     for( i = 1; i < size; i++ ) {
99         v_limb = vp[i];
100         if( v_limb <= 1 ) {
101             cy = 0;
102             if( v_limb == 1 )
103                cy = mpihelp_add_n(prodp, prodp, up, size);
104         }
105         else
106             cy = mpihelp_addmul_1(prodp, up, size, v_limb);
107
108         prodp[size] = cy;
109         prodp++;
110     }
111
112     return cy;
113 }
114
115
116 static void
117 mul_n( mpi_ptr_t prodp, mpi_ptr_t up, mpi_ptr_t vp,
118                         mpi_size_t size, mpi_ptr_t tspace )
119 {
120     if( size & 1 ) {
121       /* The size is odd, and the code below doesn't handle that.
122        * Multiply the least significant (size - 1) limbs with a recursive
123        * call, and handle the most significant limb of S1 and S2
124        * separately.
125        * A slightly faster way to do this would be to make the Karatsuba
126        * code below behave as if the size were even, and let it check for
127        * odd size in the end.  I.e., in essence move this code to the end.
128        * Doing so would save us a recursive call, and potentially make the
129        * stack grow a lot less.
130        */
131       mpi_size_t esize = size - 1;       /* even size */
132       mpi_limb_t cy_limb;
133
134       MPN_MUL_N_RECURSE( prodp, up, vp, esize, tspace );
135       cy_limb = mpihelp_addmul_1( prodp + esize, up, esize, vp[esize] );
136       prodp[esize + esize] = cy_limb;
137       cy_limb = mpihelp_addmul_1( prodp + esize, vp, size, up[esize] );
138       prodp[esize + size] = cy_limb;
139     }
140     else {
141         /* Anatolij Alekseevich Karatsuba's divide-and-conquer algorithm.
142          *
143          * Split U in two pieces, U1 and U0, such that
144          * U = U0 + U1*(B**n),
145          * and V in V1 and V0, such that
146          * V = V0 + V1*(B**n).
147          *
148          * UV is then computed recursively using the identity
149          *
150          *        2n   n          n                     n
151          * UV = (B  + B )U V  +  B (U -U )(V -V )  +  (B + 1)U V
152          *                1 1        1  0   0  1              0 0
153          *
154          * Where B = 2**BITS_PER_MP_LIMB.
155          */
156         mpi_size_t hsize = size >> 1;
157         mpi_limb_t cy;
158         int negflg;
159
160         /* Product H.      ________________  ________________
161          *                |_____U1 x V1____||____U0 x V0_____|
162          * Put result in upper part of PROD and pass low part of TSPACE
163          * as new TSPACE.
164          */
165         MPN_MUL_N_RECURSE(prodp + size, up + hsize, vp + hsize, hsize, tspace);
166
167         /* Product M.      ________________
168          *                |_(U1-U0)(V0-V1)_|
169          */
170         if( mpihelp_cmp(up + hsize, up, hsize) >= 0 ) {
171             mpihelp_sub_n(prodp, up + hsize, up, hsize);
172             negflg = 0;
173         }
174         else {
175             mpihelp_sub_n(prodp, up, up + hsize, hsize);
176             negflg = 1;
177         }
178         if( mpihelp_cmp(vp + hsize, vp, hsize) >= 0 ) {
179             mpihelp_sub_n(prodp + hsize, vp + hsize, vp, hsize);
180             negflg ^= 1;
181         }
182         else {
183             mpihelp_sub_n(prodp + hsize, vp, vp + hsize, hsize);
184             /* No change of NEGFLG.  */
185         }
186         /* Read temporary operands from low part of PROD.
187          * Put result in low part of TSPACE using upper part of TSPACE
188          * as new TSPACE.
189          */
190         MPN_MUL_N_RECURSE(tspace, prodp, prodp + hsize, hsize, tspace + size);
191
192         /* Add/copy product H. */
193         MPN_COPY (prodp + hsize, prodp + size, hsize);
194         cy = mpihelp_add_n( prodp + size, prodp + size,
195                             prodp + size + hsize, hsize);
196
197         /* Add product M (if NEGFLG M is a negative number) */
198         if(negflg)
199             cy -= mpihelp_sub_n(prodp + hsize, prodp + hsize, tspace, size);
200         else
201             cy += mpihelp_add_n(prodp + hsize, prodp + hsize, tspace, size);
202
203         /* Product L.      ________________  ________________
204          *                |________________||____U0 x V0_____|
205          * Read temporary operands from low part of PROD.
206          * Put result in low part of TSPACE using upper part of TSPACE
207          * as new TSPACE.
208          */
209         MPN_MUL_N_RECURSE(tspace, up, vp, hsize, tspace + size);
210
211         /* Add/copy Product L (twice) */
212
213         cy += mpihelp_add_n(prodp + hsize, prodp + hsize, tspace, size);
214         if( cy )
215           mpihelp_add_1(prodp + hsize + size, prodp + hsize + size, hsize, cy);
216
217         MPN_COPY(prodp, tspace, hsize);
218         cy = mpihelp_add_n(prodp + hsize, prodp + hsize, tspace + hsize, hsize);
219         if( cy )
220             mpihelp_add_1(prodp + size, prodp + size, size, 1);
221     }
222 }
223
224
225 void
226 mpih_sqr_n_basecase( mpi_ptr_t prodp, mpi_ptr_t up, mpi_size_t size )
227 {
228     mpi_size_t i;
229     mpi_limb_t cy_limb;
230     mpi_limb_t v_limb;
231
232     /* Multiply by the first limb in V separately, as the result can be
233      * stored (not added) to PROD.  We also avoid a loop for zeroing.  */
234     v_limb = up[0];
235     if( v_limb <= 1 ) {
236         if( v_limb == 1 )
237             MPN_COPY( prodp, up, size );
238         else
239             MPN_ZERO(prodp, size);
240         cy_limb = 0;
241     }
242     else
243         cy_limb = mpihelp_mul_1( prodp, up, size, v_limb );
244
245     prodp[size] = cy_limb;
246     prodp++;
247
248     /* For each iteration in the outer loop, multiply one limb from
249      * U with one limb from V, and add it to PROD.  */
250     for( i=1; i < size; i++) {
251         v_limb = up[i];
252         if( v_limb <= 1 ) {
253             cy_limb = 0;
254             if( v_limb == 1 )
255                 cy_limb = mpihelp_add_n(prodp, prodp, up, size);
256         }
257         else
258             cy_limb = mpihelp_addmul_1(prodp, up, size, v_limb);
259
260         prodp[size] = cy_limb;
261         prodp++;
262     }
263 }
264
265
266 void
267 mpih_sqr_n( mpi_ptr_t prodp, mpi_ptr_t up, mpi_size_t size, mpi_ptr_t tspace)
268 {
269     if( size & 1 ) {
270         /* The size is odd, and the code below doesn't handle that.
271          * Multiply the least significant (size - 1) limbs with a recursive
272          * call, and handle the most significant limb of S1 and S2
273          * separately.
274          * A slightly faster way to do this would be to make the Karatsuba
275          * code below behave as if the size were even, and let it check for
276          * odd size in the end.  I.e., in essence move this code to the end.
277          * Doing so would save us a recursive call, and potentially make the
278          * stack grow a lot less.
279          */
280         mpi_size_t esize = size - 1;       /* even size */
281         mpi_limb_t cy_limb;
282
283         MPN_SQR_N_RECURSE( prodp, up, esize, tspace );
284         cy_limb = mpihelp_addmul_1( prodp + esize, up, esize, up[esize] );
285         prodp[esize + esize] = cy_limb;
286         cy_limb = mpihelp_addmul_1( prodp + esize, up, size, up[esize] );
287
288         prodp[esize + size] = cy_limb;
289     }
290     else {
291         mpi_size_t hsize = size >> 1;
292         mpi_limb_t cy;
293
294         /* Product H.      ________________  ________________
295          *                |_____U1 x U1____||____U0 x U0_____|
296          * Put result in upper part of PROD and pass low part of TSPACE
297          * as new TSPACE.
298          */
299         MPN_SQR_N_RECURSE(prodp + size, up + hsize, hsize, tspace);
300
301         /* Product M.      ________________
302          *                |_(U1-U0)(U0-U1)_|
303          */
304         if( mpihelp_cmp( up + hsize, up, hsize) >= 0 )
305             mpihelp_sub_n( prodp, up + hsize, up, hsize);
306         else
307             mpihelp_sub_n (prodp, up, up + hsize, hsize);
308
309         /* Read temporary operands from low part of PROD.
310          * Put result in low part of TSPACE using upper part of TSPACE
311          * as new TSPACE.  */
312         MPN_SQR_N_RECURSE(tspace, prodp, hsize, tspace + size);
313
314         /* Add/copy product H  */
315         MPN_COPY(prodp + hsize, prodp + size, hsize);
316         cy = mpihelp_add_n(prodp + size, prodp + size,
317                            prodp + size + hsize, hsize);
318
319         /* Add product M (if NEGFLG M is a negative number).  */
320         cy -= mpihelp_sub_n (prodp + hsize, prodp + hsize, tspace, size);
321
322         /* Product L.      ________________  ________________
323          *                |________________||____U0 x U0_____|
324          * Read temporary operands from low part of PROD.
325          * Put result in low part of TSPACE using upper part of TSPACE
326          * as new TSPACE.  */
327         MPN_SQR_N_RECURSE (tspace, up, hsize, tspace + size);
328
329         /* Add/copy Product L (twice).  */
330         cy += mpihelp_add_n (prodp + hsize, prodp + hsize, tspace, size);
331         if( cy )
332             mpihelp_add_1(prodp + hsize + size, prodp + hsize + size,
333                                                             hsize, cy);
334
335         MPN_COPY(prodp, tspace, hsize);
336         cy = mpihelp_add_n (prodp + hsize, prodp + hsize, tspace + hsize, hsize);
337         if( cy )
338             mpihelp_add_1 (prodp + size, prodp + size, size, 1);
339     }
340 }
341
342
343 /* This should be made into an inline function in gmp.h.  */
344 int
345 mpihelp_mul_n( mpi_ptr_t prodp, mpi_ptr_t up, mpi_ptr_t vp, mpi_size_t size)
346 {
347     if( up == vp ) {
348         if( size < KARATSUBA_THRESHOLD )
349             mpih_sqr_n_basecase( prodp, up, size );
350         else {
351             mpi_ptr_t tspace;
352             tspace = mpi_alloc_limb_space( 2 * size );
353             if (!tspace)
354                     return -ENOMEM;
355             mpih_sqr_n( prodp, up, size, tspace );
356             mpi_free_limb_space( tspace );
357         }
358     }
359     else {
360         if( size < KARATSUBA_THRESHOLD )
361             mul_n_basecase( prodp, up, vp, size );
362         else {
363             mpi_ptr_t tspace;
364             tspace = mpi_alloc_limb_space( 2 * size );
365             if (!tspace)
366                     return -ENOMEM;
367             mul_n (prodp, up, vp, size, tspace);
368             mpi_free_limb_space( tspace );
369         }
370     }
371
372     return 0;
373 }
374
375
376
377 int
378 mpihelp_mul_karatsuba_case( mpi_ptr_t prodp,
379                             mpi_ptr_t up, mpi_size_t usize,
380                             mpi_ptr_t vp, mpi_size_t vsize,
381                             struct karatsuba_ctx *ctx )
382 {
383     mpi_limb_t cy;
384
385     if( !ctx->tspace || ctx->tspace_size < vsize ) {
386         if( ctx->tspace )
387             mpi_free_limb_space( ctx->tspace );
388         ctx->tspace = mpi_alloc_limb_space( 2 * vsize);
389         if (!ctx->tspace)
390                 return -ENOMEM;
391         ctx->tspace_size = vsize;
392     }
393
394     MPN_MUL_N_RECURSE( prodp, up, vp, vsize, ctx->tspace );
395
396     prodp += vsize;
397     up += vsize;
398     usize -= vsize;
399     if( usize >= vsize ) {
400         if( !ctx->tp || ctx->tp_size < vsize ) {
401             if( ctx->tp )
402                 mpi_free_limb_space( ctx->tp );
403             ctx->tp = mpi_alloc_limb_space( 2 * vsize );
404             if (!ctx->tp) {
405                     if( ctx->tspace )
406                             mpi_free_limb_space( ctx->tspace );
407                     ctx->tspace = NULL;
408                     return -ENOMEM;
409             }
410             ctx->tp_size = vsize;
411         }
412
413         do {
414             MPN_MUL_N_RECURSE( ctx->tp, up, vp, vsize, ctx->tspace );
415             cy = mpihelp_add_n( prodp, prodp, ctx->tp, vsize );
416             mpihelp_add_1( prodp + vsize, ctx->tp + vsize, vsize, cy );
417             prodp += vsize;
418             up += vsize;
419             usize -= vsize;
420         } while( usize >= vsize );
421     }
422
423     if( usize ) {
424         if( usize < KARATSUBA_THRESHOLD ) {
425                 mpi_limb_t tmp;
426                 if (mpihelp_mul( ctx->tspace, vp, vsize, up, usize, &tmp) < 0)
427                         return -ENOMEM;
428         }
429         else {
430             if( !ctx->next ) {
431                 ctx->next = kmalloc( sizeof *ctx, GFP_KERNEL );
432                 if (!ctx->next)
433                         return -ENOMEM;
434                 memset(ctx->next, 0, sizeof(ctx));
435             }
436             if (mpihelp_mul_karatsuba_case( ctx->tspace,
437                                             vp, vsize,
438                                             up, usize,
439                                             ctx->next ) < 0)
440                     return -ENOMEM;
441         }
442
443         cy = mpihelp_add_n( prodp, prodp, ctx->tspace, vsize);
444         mpihelp_add_1( prodp + vsize, ctx->tspace + vsize, usize, cy );
445     }
446
447     return 0;
448 }
449
450
451 void
452 mpihelp_release_karatsuba_ctx( struct karatsuba_ctx *ctx )
453 {
454     struct karatsuba_ctx *ctx2;
455
456     if( ctx->tp )
457         mpi_free_limb_space( ctx->tp );
458     if( ctx->tspace )
459         mpi_free_limb_space( ctx->tspace );
460     for( ctx=ctx->next; ctx; ctx = ctx2 ) {
461         ctx2 = ctx->next;
462         if( ctx->tp )
463             mpi_free_limb_space( ctx->tp );
464         if( ctx->tspace )
465             mpi_free_limb_space( ctx->tspace );
466         kfree( ctx );
467     }
468 }
469
470 /* Multiply the natural numbers u (pointed to by UP, with USIZE limbs)
471  * and v (pointed to by VP, with VSIZE limbs), and store the result at
472  * PRODP.  USIZE + VSIZE limbs are always stored, but if the input
473  * operands are normalized.  Return the most significant limb of the
474  * result.
475  *
476  * NOTE: The space pointed to by PRODP is overwritten before finished
477  * with U and V, so overlap is an error.
478  *
479  * Argument constraints:
480  * 1. USIZE >= VSIZE.
481  * 2. PRODP != UP and PRODP != VP, i.e. the destination
482  *    must be distinct from the multiplier and the multiplicand.
483  */
484
485 int
486 mpihelp_mul( mpi_ptr_t prodp, mpi_ptr_t up, mpi_size_t usize,
487              mpi_ptr_t vp, mpi_size_t vsize,
488              mpi_limb_t *_result)
489 {
490     mpi_ptr_t prod_endp = prodp + usize + vsize - 1;
491     mpi_limb_t cy;
492     struct karatsuba_ctx ctx;
493
494     if( vsize < KARATSUBA_THRESHOLD ) {
495         mpi_size_t i;
496         mpi_limb_t v_limb;
497
498         if( !vsize ) {
499                 *_result = 0;
500                 return 0;
501         }
502
503         /* Multiply by the first limb in V separately, as the result can be
504          * stored (not added) to PROD.  We also avoid a loop for zeroing.  */
505         v_limb = vp[0];
506         if( v_limb <= 1 ) {
507             if( v_limb == 1 )
508                 MPN_COPY( prodp, up, usize );
509             else
510                 MPN_ZERO( prodp, usize );
511             cy = 0;
512         }
513         else
514             cy = mpihelp_mul_1( prodp, up, usize, v_limb );
515
516         prodp[usize] = cy;
517         prodp++;
518
519         /* For each iteration in the outer loop, multiply one limb from
520          * U with one limb from V, and add it to PROD.  */
521         for( i = 1; i < vsize; i++ ) {
522             v_limb = vp[i];
523             if( v_limb <= 1 ) {
524                 cy = 0;
525                 if( v_limb == 1 )
526                    cy = mpihelp_add_n(prodp, prodp, up, usize);
527             }
528             else
529                 cy = mpihelp_addmul_1(prodp, up, usize, v_limb);
530
531             prodp[usize] = cy;
532             prodp++;
533         }
534
535         *_result = cy;
536         return 0;
537     }
538
539     memset( &ctx, 0, sizeof ctx );
540     if (mpihelp_mul_karatsuba_case( prodp, up, usize, vp, vsize, &ctx ) < 0)
541             return -ENOMEM;
542     mpihelp_release_karatsuba_ctx( &ctx );
543     *_result = *prod_endp;
544     return 0;
545 }
546
547