原创文章,转载请注明: 转载自慢慢的回味
本文链接地址: 15.2 动态规划-矩阵链乘法
package cn.com.liuxiaofei; public class Chapter15_2 { /** * @param args */ public static void main(String[] args) { testMatrixChainOrder(); } public static void testMatrixChainOrder() { //A1 A2 A3 A4 A5 A6 //900x350 350x25 25x5 5x10 10x400 400x950 long[][] A1 = generateMatrix(900, 350); long[][] A2 = generateMatrix(350, 25); long[][] A3 = generateMatrix(25, 5); long[][] A4 = generateMatrix(5, 10); long[][] A5 = generateMatrix(10, 400); long[][] A6 = generateMatrix(400, 950); int[] p = new int[] { A1.length, A2.length, A3.length, A4.length, A5.length, A6.length, A6[0].length }; Object[] mands = matrixChainOrder(p); int[][] m = (int[][]) mands[0]; int[][] s = (int[][]) mands[1]; System.out.print(" \t"); for (int i = 1; i < m.length; i++) { System.out.print(i + "\t"); } System.out.println(); for (int j = m.length - 1; j >= 1; j--) { System.out.print(j + "\t"); for (int i = 1; i < m.length && j >= i; i++) { System.out.print(m[i][j] + "\t"); } System.out.println(); } System.out.print(" \t"); for (int i = 1; i < s.length; i++) { System.out.print(i + "\t"); } System.out.println(); for (int j = s.length - 1; j >= 1; j--) { System.out.print(j + "\t"); for (int i = 1; i < s.length && j >= i; i++) { System.out.print(s[i][j] + "\t"); } System.out.println(); } printOptimalParens(s, 1, 6); System.out.println(); System.out.println("Not use matrix chain order:"); long time = System.currentTimeMillis(); long[][] B1 = multipyMatrix(multipyMatrix(multipyMatrix(multipyMatrix(multipyMatrix(A1, A2), A3), A4), A5), A6); System.out.println("Time elapse:" + (System.currentTimeMillis() - time)); System.out.println("Use matrix chain order:"); time = System.currentTimeMillis(); long[][] B2 = multipyMatrixByOptimalParens(s, 1, 6, new Object[] { A1, A2, A3, A4, A5, A6 }); System.out.println("Time elapse:" + (System.currentTimeMillis() - time)); System.out.println(compaireMatrix(B1, B2)); System.out.println("--Use memoized matrix chain--"); mands = matrixChainOrder(p); m = (int[][]) mands[0]; s = (int[][]) mands[1]; System.out.print(" \t"); for (int i = 1; i < m.length; i++) { System.out.print(i + "\t"); } System.out.println(); for (int j = m.length - 1; j >= 1; j--) { System.out.print(j + "\t"); for (int i = 1; i < m.length && j >= i; i++) { System.out.print(m[i][j] + "\t"); } System.out.println(); } System.out.print(" \t"); for (int i = 1; i < s.length; i++) { System.out.print(i + "\t"); } System.out.println(); for (int j = s.length - 1; j >= 1; j--) { System.out.print(j + "\t"); for (int i = 1; i < s.length && j >= i; i++) { System.out.print(s[i][j] + "\t"); } System.out.println(); } } public static void printOptimalParens(int[][] s, int i, int j) { if (i == j) { System.out.print("A" + i); } else { System.out.print("("); printOptimalParens(s, i, s[i][j]); printOptimalParens(s, s[i][j] + 1, j); System.out.print(")"); } } public static Object[] matrixChainOrder(int[] p) { int n = p.length - 1; int[][] m = new int[n + 1][n + 1]; int[][] s = new int[n + 1][n + 1]; for (int i = 1; i <= n; i++) { m[i][i] = 0; } //len is the chain length for (int len = 2; len <= n; len++) { //start from i, calculate the multiply time of chain for (int i = 1; i <= n - len + 1; i++) { int j = i + len - 1; m[i][j] = Integer.MAX_VALUE; for (int k = i; k <= j - 1; k++) { int q = m[i][k] + m[k + 1][j] + p[i - 1] * p[k] * p[j]; if (q < m[i][j]) { m[i][j] = q; s[i][j] = k; } } } } return new Object[] { m, s }; } public static long[][] generateMatrix(int row, int col) { long[][] matrix = new long[row][col]; for (int i = 0; i < row; i++) { for (int j = 0; j < col; j++) { matrix[i][j] = Double.valueOf(Math.random() * 10).intValue(); } } return matrix; } public static boolean compaireMatrix(long[][] A, long[][] B) { int row = A.length; int col = A[0].length; for (int i = 0; i < row; i++) { for (int j = 0; j < col; j++) { if (A[i][j] != B[i][j]) { return false; } } } return true; } public static long[][] multipyMatrix(long[][] A, long[][] B) { if (A.length == 0 || B.length == 0) { return null; } if (A[0].length != B.length) { return null; } long[][] C = new long[A.length][B[0].length]; for (int i = 0; i < A.length; i++) { for (int j = 0; j < B[0].length; j++) { C[i][j] = 0; for (int k = 0; k < A[0].length; k++) { C[i][j] = C[i][j] + A[i][k] * B[k][j]; } } } return C; } public static long[][] multipyMatrixByOptimalParens(int[][] s, int i, int j, Object[] matrixs) { if (i == j) { return (long[][]) matrixs[i - 1]; } else { long[][] B1 = multipyMatrixByOptimalParens(s, i, s[i][j], matrixs); long[][] B2 = multipyMatrixByOptimalParens(s, s[i][j] + 1, j, matrixs); long[][] B = multipyMatrix(B1, B2); return B; } } public static Object[] memoMatrixChain(int[] p) { int n = p.length - 1; int[][] m = new int[n + 1][n + 1]; int[][] s = new int[n + 1][n + 1]; for (int i = 1; i <= n; i++) { for (int j = 1; j <= n; j++) { m[i][j] = Integer.MAX_VALUE; } } lookupChain(m, s, p, 1, n); return new Object[] { m, s }; } public static int lookupChain(int[][] m, int[][] s, int[] p, int i, int j) { if (m[i][j] < Integer.MAX_VALUE) { return m[i][j]; } if (i == j) { m[i][j] = 0; } else { for (int k = i; k < j; k++) { int q = lookupChain(m, s, p, i, k) + lookupChain(m, s, p, k + 1, j) + p[i - 1] * p[k] * p[j]; if (q < m[i][j]) { m[i][j] = q; } } } return m[i][j]; } } 输出结果 1 2 3 4 5 6 6 7813750 3626250 2038750 1920000 3800000 0 5 3438750 763750 70000 20000 0 4 1663750 61250 1250 0 3 1618750 43750 0 2 7875000 0 1 0 1 2 3 4 5 6 6 3 3 3 5 5 0 5 3 3 3 4 0 4 3 3 3 0 3 1 2 0 2 1 0 1 0 ((A1(A2A3))((A4A5)A6)) Not use matrix chain order: Time elapse:3698 Use matrix chain order: Time elapse:61 true --Use memoized matrix chain-- 1 2 3 4 5 6 6 7813750 3626250 2038750 1920000 3800000 0 5 3438750 763750 70000 20000 0 4 1663750 61250 1250 0 3 1618750 43750 0 2 7875000 0 1 0 1 2 3 4 5 6 6 3 3 3 5 5 0 5 3 3 3 4 0 4 3 3 3 0 3 1 2 0 2 1 0 1 0 |
本作品采用知识共享署名 4.0 国际许可协议进行许可。