15.2 动态规划-矩阵链乘法

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: 15.2 动态规划-矩阵链乘法

package cn.com.liuxiaofei;
 
public class Chapter15_2 {
 
	/**
	 * @param args
	 */
	public static void main(String[] args) {
		testMatrixChainOrder();
 
	}
 
	public static void testMatrixChainOrder() {
		//A1      A2     A3   A4   A5     A6
		//900x350 350x25 25x5 5x10 10x400 400x950
		long[][] A1 = generateMatrix(900, 350);
		long[][] A2 = generateMatrix(350, 25);
		long[][] A3 = generateMatrix(25, 5);
		long[][] A4 = generateMatrix(5, 10);
		long[][] A5 = generateMatrix(10, 400);
		long[][] A6 = generateMatrix(400, 950);
		int[] p = new int[] { A1.length, A2.length, A3.length, A4.length, A5.length, A6.length, A6[0].length };
		Object[] mands = matrixChainOrder(p);
		int[][] m = (int[][]) mands[0];
		int[][] s = (int[][]) mands[1];
		System.out.print(" \t");
		for (int i = 1; i < m.length; i++) {
			System.out.print(i + "\t");
		}
		System.out.println();
		for (int j = m.length - 1; j >= 1; j--) {
			System.out.print(j + "\t");
			for (int i = 1; i < m.length && j >= i; i++) {
				System.out.print(m[i][j] + "\t");
			}
			System.out.println();
		}
		System.out.print(" \t");
		for (int i = 1; i < s.length; i++) {
			System.out.print(i + "\t");
		}
		System.out.println();
		for (int j = s.length - 1; j >= 1; j--) {
			System.out.print(j + "\t");
			for (int i = 1; i < s.length && j >= i; i++) {
				System.out.print(s[i][j] + "\t");
			}
			System.out.println();
		}
 
		printOptimalParens(s, 1, 6);
		System.out.println();
 
		System.out.println("Not use matrix chain order:");
		long time = System.currentTimeMillis();
		long[][] B1 = multipyMatrix(multipyMatrix(multipyMatrix(multipyMatrix(multipyMatrix(A1, A2), A3), A4), A5), A6);
		System.out.println("Time elapse:" + (System.currentTimeMillis() - time));
 
		System.out.println("Use matrix chain order:");
		time = System.currentTimeMillis();
		long[][] B2 = multipyMatrixByOptimalParens(s, 1, 6, new Object[] { A1, A2, A3, A4, A5, A6 });
		System.out.println("Time elapse:" + (System.currentTimeMillis() - time));
		System.out.println(compaireMatrix(B1, B2));
 
		System.out.println("--Use memoized matrix chain--");
		mands = matrixChainOrder(p);
		m = (int[][]) mands[0];
		s = (int[][]) mands[1];
		System.out.print(" \t");
		for (int i = 1; i < m.length; i++) {
			System.out.print(i + "\t");
		}
		System.out.println();
		for (int j = m.length - 1; j >= 1; j--) {
			System.out.print(j + "\t");
			for (int i = 1; i < m.length && j >= i; i++) {
				System.out.print(m[i][j] + "\t");
			}
			System.out.println();
		}
		System.out.print(" \t");
		for (int i = 1; i < s.length; i++) {
			System.out.print(i + "\t");
		}
		System.out.println();
		for (int j = s.length - 1; j >= 1; j--) {
			System.out.print(j + "\t");
			for (int i = 1; i < s.length && j >= i; i++) {
				System.out.print(s[i][j] + "\t");
			}
			System.out.println();
		}
 
	}
 
	public static void printOptimalParens(int[][] s, int i, int j) {
		if (i == j) {
			System.out.print("A" + i);
		} else {
			System.out.print("(");
			printOptimalParens(s, i, s[i][j]);
			printOptimalParens(s, s[i][j] + 1, j);
			System.out.print(")");
		}
	}
 
	public static Object[] matrixChainOrder(int[] p) {
		int n = p.length - 1;
		int[][] m = new int[n + 1][n + 1];
		int[][] s = new int[n + 1][n + 1];
		for (int i = 1; i <= n; i++) {
			m[i][i] = 0;
		}
		//len is the chain length
		for (int len = 2; len <= n; len++) {
			//start from i, calculate the multiply time of chain
			for (int i = 1; i <= n - len + 1; i++) {
				int j = i + len - 1;
				m[i][j] = Integer.MAX_VALUE;
				for (int k = i; k <= j - 1; k++) {
					int q = m[i][k] + m[k + 1][j] + p[i - 1] * p[k] * p[j];
					if (q < m[i][j]) {
						m[i][j] = q;
						s[i][j] = k;
					}
				}
			}
		}
		return new Object[] { m, s };
	}
 
	public static long[][] generateMatrix(int row, int col) {
		long[][] matrix = new long[row][col];
		for (int i = 0; i < row; i++) {
			for (int j = 0; j < col; j++) {
				matrix[i][j] = Double.valueOf(Math.random() * 10).intValue();
			}
		}
		return matrix;
	}
 
	public static boolean compaireMatrix(long[][] A, long[][] B) {
		int row = A.length;
		int col = A[0].length;
		for (int i = 0; i < row; i++) {
			for (int j = 0; j < col; j++) {
				if (A[i][j] != B[i][j]) {
					return false;
				}
			}
		}
		return true;
	}
 
	public static long[][] multipyMatrix(long[][] A, long[][] B) {
		if (A.length == 0 || B.length == 0) {
			return null;
		}
		if (A[0].length != B.length) {
			return null;
		}
 
		long[][] C = new long[A.length][B[0].length];
		for (int i = 0; i < A.length; i++) {
			for (int j = 0; j < B[0].length; j++) {
				C[i][j] = 0;
				for (int k = 0; k < A[0].length; k++) {
					C[i][j] = C[i][j] + A[i][k] * B[k][j];
				}
			}
		}
		return C;
	}
 
	public static long[][] multipyMatrixByOptimalParens(int[][] s, int i, int j, Object[] matrixs) {
		if (i == j) {
			return (long[][]) matrixs[i - 1];
		} else {
			long[][] B1 = multipyMatrixByOptimalParens(s, i, s[i][j], matrixs);
			long[][] B2 = multipyMatrixByOptimalParens(s, s[i][j] + 1, j, matrixs);
			long[][] B = multipyMatrix(B1, B2);
			return B;
		}
	}
 
	public static Object[] memoMatrixChain(int[] p) {
		int n = p.length - 1;
		int[][] m = new int[n + 1][n + 1];
		int[][] s = new int[n + 1][n + 1];
		for (int i = 1; i <= n; i++) {
			for (int j = 1; j <= n; j++) {
				m[i][j] = Integer.MAX_VALUE;
			}
		}
		lookupChain(m, s, p, 1, n);
 
		return new Object[] { m, s };
	}
 
	public static int lookupChain(int[][] m, int[][] s, int[] p, int i, int j) {
		if (m[i][j] < Integer.MAX_VALUE) {
			return m[i][j];
		}
		if (i == j) {
			m[i][j] = 0;
		} else {
			for (int k = i; k < j; k++) {
				int q = lookupChain(m, s, p, i, k) + lookupChain(m, s, p, k + 1, j) + p[i - 1] * p[k] * p[j];
				if (q < m[i][j]) {
					m[i][j] = q;
				}
			}
		}
		return m[i][j];
	}
 
}
 
输出结果
 	1	2	3	4	5	6	
6	7813750	3626250	2038750	1920000	3800000	0	
5	3438750	763750	70000	20000	0	
4	1663750	61250	1250	0	
3	1618750	43750	0	
2	7875000	0	
1	0	
 	1	2	3	4	5	6	
6	3	3	3	5	5	0	
5	3	3	3	4	0	
4	3	3	3	0	
3	1	2	0	
2	1	0	
1	0	
((A1(A2A3))((A4A5)A6))
Not use matrix chain order:
Time elapse:3698
Use matrix chain order:
Time elapse:61
true
--Use memoized matrix chain--
 	1	2	3	4	5	6	
6	7813750	3626250	2038750	1920000	3800000	0	
5	3438750	763750	70000	20000	0	
4	1663750	61250	1250	0	
3	1618750	43750	0	
2	7875000	0	
1	0	
 	1	2	3	4	5	6	
6	3	3	3	5	5	0	
5	3	3	3	4	0	
4	3	3	3	0	
3	1	2	0	
2	1	0	
1	0

本作品采用知识共享署名 4.0 国际许可协议进行许可。

发表回复