// Computes the high-order half of the 64-bit product, unsigned. // Max line length is 57, to fit in hacker.book. (But not used there.) // Derived from Knuth's Algorithm M. // Subscript 0 denotes the least significant half (little endian). #include #include // To define "exit", req'd by XLC. // The program below takes 16 ops, 4 of which are multiplies, // which are of the type unsigned 16 x 16 ==> 32. // The statement "low = (w1 << 16) + (w0 & 0xFFFF);" placed just before // the return statement, computes the low-order part in 3 more ops. unsigned mulhu(unsigned u, unsigned v) { unsigned u0, u1, v0, v1, w0, w1, w2, t; u0 = u & 0xFFFF; u1 = u >> 16; v0 = v & 0xFFFF; v1 = v >> 16; w0 = u0*v0; t = u1*v0 + (w0 >> 16); w1 = t & 0xFFFF; w2 = t >> 16; w1 = u0*v1 + w1; return u1*v1 + w2 + (w1 >> 16); } /* The next version does it using only three multiplications. It is based on: Let u = a*2**16 + b, v = c*2**16 + d. Then calculate p = ac, q = bd, r = (-a + b)(c - d) Then uv = p*2**32 + (r + p + q)*2**16 + q. There is a difficulty in computing r, because it doesn't quite fit in a 32-bit word. But because 0 <= a, b, c, d < 2**16, it is easy to see that -2**32 < r < 2**32. Thus it can be represented as a 64-bit quantity with the high-order 32 bits being either 0 or all 1's. The low-order 32 bits, rlow, can be calculated directly from r = (-a + b)*(c - d), using 32-bit instructions. The high-order 32 bits will be all 1's if the product is negative, and 0 if it is nonnegative. The product is negative if (-a + b) and (c - d) have opposite signs. Thus, basically, rhigh = ((-a + b) xor (c - d)) >>s 31. However, if either a = b or c = d, we must ensure that rhigh = 0. It suffices to test rlow, i.e., follow the above assignment to rhigh with: if (rlow == 0) rhigh = 0. This is because if rlow = 0, it must be the case that either a = b or c = d, because the product cannot be >= 2**32. This leads to the function below. */ unsigned mulhu1(unsigned u, unsigned v) { unsigned a, b, c, d, p, q, rlow, rhigh; a = u >> 16; b = u & 0xFFFF; c = v >> 16; d = v & 0xFFFF; p = a*c; q = b*d; rlow = (-a + b)*(c - d); rhigh = (int)((-a + b)^(c - d)) >> 31; if (rlow == 0) rhigh = 0; // Correction. q = q + (q >> 16); // Overflow cannot occur here. rlow = rlow + p; if (rlow < p) rhigh = rhigh + 1; rlow = rlow + q; if (rlow < q) rhigh = rhigh + 1; return p + (rlow >> 16) + (rhigh << 16); } /* Branch-free version: */ unsigned mulhu2(unsigned u, unsigned v) { unsigned a, b, c, d, p, q, x, y, rlow, rhigh, t; a = u >> 16; b = u & 0xFFFF; c = v >> 16; d = v & 0xFFFF; p = a*c; q = b*d; x = -a + b; y = c - d; rlow = x*y; rhigh = (x ^ y) & (rlow | -rlow); rhigh = (int)rhigh >> 31; q = q + (q >> 16); // Overflow cannot occur here. t = (rlow & 0xFFFF) + (p & 0xFFFF) + (q & 0xFFFF); p += (t >> 16) + (rlow >> 16) + (p >> 16) + (q >> 16); p += (rhigh << 16); return p; } int errors; void error(unsigned u, unsigned v, unsigned rr, unsigned r) { errors = errors + 1; printf("Error for u = %08x, v = %08x, shd be %08x, got %08x\n", u, v, rr, r); } int main() { int i, j, n; unsigned r; unsigned long long rr; static unsigned test[] = {0, 1, 2, 3, 4, 0xE000, 0xF000, 0xF001, 0xFFFE, 0xFFFF, 0x10000, 0x10001, 0x10002, 0x20000, 0xE0000000, 0xE0001000, 0xF0000000, 0xF0001000, 0xF0001001, 0xF0010001, 0xF0010002, 0xF0020001, 0xF0020002, 0xFFFE0000, 0xFFFE0001, 0xFFFE0002, 0xFFFF0000, 0xFFFF0001, 0xFFFF0002, 0xFFFFFFFE, 0xFFFFFFFF}; n = sizeof(test)/4; printf("mulhu:\n"); for (i = 0; i < n; i += 1) { for (j = 0; j < n; j += 1) { r = mulhu(test[i], test[j]); rr = (unsigned long long)test[i]*(unsigned long long)test[j]; rr = rr >> 32; if (r != rr) error(test[i], test[j], (unsigned)rr, r); } } printf("mulhu1:\n"); for (i = 0; i < n; i += 1) { for (j = 0; j < n; j += 1) { r = mulhu1(test[i], test[j]); rr = (unsigned long long)test[i]*(unsigned long long)test[j]; rr = rr >> 32; if (r != rr) error(test[i], test[j], (unsigned)rr, r); } } printf("mulhu2:\n"); for (i = 0; i < n; i += 1) { for (j = 0; j < n; j += 1) { r = mulhu2(test[i], test[j]); rr = (unsigned long long)test[i]*(unsigned long long)test[j]; rr = rr >> 32; if (r != rr) error(test[i], test[j], (unsigned)rr, r); } } if (errors == 0) printf("Passed all %d cases.\n", n*n); return 0; }