15.2 动态规划-矩阵链乘法

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: 15.2 动态规划-矩阵链乘法

package cn.com.liuxiaofei;
 
public class Chapter15_2 {
 
	/**
	 * @param args
	 */
	public static void main(String[] args) {
		testMatrixChainOrder();
 
	}
 
	public static void testMatrixChainOrder() {
		//A1      A2     A3   A4   A5     A6
		//900x350 350x25 25x5 5x10 10x400 400x950
		long[][] A1 = generateMatrix(900, 350);
		long[][] A2 = generateMatrix(350, 25);
		long[][] A3 = generateMatrix(25, 5);
		long[][] A4 = generateMatrix(5, 10);
		long[][] A5 = generateMatrix(10, 400);
		long[][] A6 = generateMatrix(400, 950);
		int[] p = new int[] { A1.length, A2.length, A3.length, A4.length, A5.length, A6.length, A6[0].length };
		Object[] mands = matrixChainOrder(p);
		int[][] m = (int[][]) mands[0];
		int[][] s = (int[][]) mands[1];
		System.out.print(" \t");
		for (int i = 1; i < m.length; i++) {
			System.out.print(i + "\t");
		}
		System.out.println();
		for (int j = m.length - 1; j >= 1; j--) {
			System.out.print(j + "\t");
			for (int i = 1; i < m.length && j >= i; i++) {
				System.out.print(m[i][j] + "\t");
			}
			System.out.println();
		}
		System.out.print(" \t");
		for (int i = 1; i < s.length; i++) {
			System.out.print(i + "\t");
		}
		System.out.println();
		for (int j = s.length - 1; j >= 1; j--) {
			System.out.print(j + "\t");
			for (int i = 1; i < s.length && j >= i; i++) {
				System.out.print(s[i][j] + "\t");
			}
			System.out.println();
		}
 
		printOptimalParens(s, 1, 6);
		System.out.println();
 
		System.out.println("Not use matrix chain order:");
		long time = System.currentTimeMillis();
		long[][] B1 = multipyMatrix(multipyMatrix(multipyMatrix(multipyMatrix(multipyMatrix(A1, A2), A3), A4), A5), A6);
		System.out.println("Time elapse:" + (System.currentTimeMillis() - time));
 
		System.out.println("Use matrix chain order:");
		time = System.currentTimeMillis();
		long[][] B2 = multipyMatrixByOptimalParens(s, 1, 6, new Object[] { A1, A2, A3, A4, A5, A6 });
		System.out.println("Time elapse:" + (System.currentTimeMillis() - time));
		System.out.println(compaireMatrix(B1, B2));
 
		System.out.println("--Use memoized matrix chain--");
		mands = matrixChainOrder(p);
		m = (int[][]) mands[0];
		s = (int[][]) mands[1];
		System.out.print(" \t");
		for (int i = 1; i < m.length; i++) {
			System.out.print(i + "\t");
		}
		System.out.println();
		for (int j = m.length - 1; j >= 1; j--) {
			System.out.print(j + "\t");
			for (int i = 1; i < m.length && j >= i; i++) {
				System.out.print(m[i][j] + "\t");
			}
			System.out.println();
		}
		System.out.print(" \t");
		for (int i = 1; i < s.length; i++) {
			System.out.print(i + "\t");
		}
		System.out.println();
		for (int j = s.length - 1; j >= 1; j--) {
			System.out.print(j + "\t");
			for (int i = 1; i < s.length && j >= i; i++) {
				System.out.print(s[i][j] + "\t");
			}
			System.out.println();
		}
 
	}
 
	public static void printOptimalParens(int[][] s, int i, int j) {
		if (i == j) {
			System.out.print("A" + i);
		} else {
			System.out.print("(");
			printOptimalParens(s, i, s[i][j]);
			printOptimalParens(s, s[i][j] + 1, j);
			System.out.print(")");
		}
	}
 
	public static Object[] matrixChainOrder(int[] p) {
		int n = p.length - 1;
		int[][] m = new int[n + 1][n + 1];
		int[][] s = new int[n + 1][n + 1];
		for (int i = 1; i <= n; i++) {
			m[i][i] = 0;
		}
		//len is the chain length
		for (int len = 2; len <= n; len++) {
			//start from i, calculate the multiply time of chain
			for (int i = 1; i <= n - len + 1; i++) {
				int j = i + len - 1;
				m[i][j] = Integer.MAX_VALUE;
				for (int k = i; k <= j - 1; k++) {
					int q = m[i][k] + m[k + 1][j] + p[i - 1] * p[k] * p[j];
					if (q < m[i][j]) {
						m[i][j] = q;
						s[i][j] = k;
					}
				}
			}
		}
		return new Object[] { m, s };
	}
 
	public static long[][] generateMatrix(int row, int col) {
		long[][] matrix = new long[row][col];
		for (int i = 0; i < row; i++) {
			for (int j = 0; j < col; j++) {
				matrix[i][j] = Double.valueOf(Math.random() * 10).intValue();
			}
		}
		return matrix;
	}
 
	public static boolean compaireMatrix(long[][] A, long[][] B) {
		int row = A.length;
		int col = A[0].length;
		for (int i = 0; i < row; i++) {
			for (int j = 0; j < col; j++) {
				if (A[i][j] != B[i][j]) {
					return false;
				}
			}
		}
		return true;
	}
 
	public static long[][] multipyMatrix(long[][] A, long[][] B) {
		if (A.length == 0 || B.length == 0) {
			return null;
		}
		if (A[0].length != B.length) {
			return null;
		}
 
		long[][] C = new long[A.length][B[0].length];
		for (int i = 0; i < A.length; i++) {
			for (int j = 0; j < B[0].length; j++) {
				C[i][j] = 0;
				for (int k = 0; k < A[0].length; k++) {
					C[i][j] = C[i][j] + A[i][k] * B[k][j];
				}
			}
		}
		return C;
	}
 
	public static long[][] multipyMatrixByOptimalParens(int[][] s, int i, int j, Object[] matrixs) {
		if (i == j) {
			return (long[][]) matrixs[i - 1];
		} else {
			long[][] B1 = multipyMatrixByOptimalParens(s, i, s[i][j], matrixs);
			long[][] B2 = multipyMatrixByOptimalParens(s, s[i][j] + 1, j, matrixs);
			long[][] B = multipyMatrix(B1, B2);
			return B;
		}
	}
 
	public static Object[] memoMatrixChain(int[] p) {
		int n = p.length - 1;
		int[][] m = new int[n + 1][n + 1];
		int[][] s = new int[n + 1][n + 1];
		for (int i = 1; i <= n; i++) {
			for (int j = 1; j <= n; j++) {
				m[i][j] = Integer.MAX_VALUE;
			}
		}
		lookupChain(m, s, p, 1, n);
 
		return new Object[] { m, s };
	}
 
	public static int lookupChain(int[][] m, int[][] s, int[] p, int i, int j) {
		if (m[i][j] < Integer.MAX_VALUE) {
			return m[i][j];
		}
		if (i == j) {
			m[i][j] = 0;
		} else {
			for (int k = i; k < j; k++) {
				int q = lookupChain(m, s, p, i, k) + lookupChain(m, s, p, k + 1, j) + p[i - 1] * p[k] * p[j];
				if (q < m[i][j]) {
					m[i][j] = q;
				}
			}
		}
		return m[i][j];
	}
 
}
 
输出结果
 	1	2	3	4	5	6	
6	7813750	3626250	2038750	1920000	3800000	0	
5	3438750	763750	70000	20000	0	
4	1663750	61250	1250	0	
3	1618750	43750	0	
2	7875000	0	
1	0	
 	1	2	3	4	5	6	
6	3	3	3	5	5	0	
5	3	3	3	4	0	
4	3	3	3	0	
3	1	2	0	
2	1	0	
1	0	
((A1(A2A3))((A4A5)A6))
Not use matrix chain order:
Time elapse:3698
Use matrix chain order:
Time elapse:61
true
--Use memoized matrix chain--
 	1	2	3	4	5	6	
6	7813750	3626250	2038750	1920000	3800000	0	
5	3438750	763750	70000	20000	0	
4	1663750	61250	1250	0	
3	1618750	43750	0	
2	7875000	0	
1	0	
 	1	2	3	4	5	6	
6	3	3	3	5	5	0	
5	3	3	3	4	0	
4	3	3	3	0	
3	1	2	0	
2	1	0	
1	0

本作品采用知识共享署名 4.0 国际许可协议进行许可。

15.1 动态规划-钢条切割

原创文章,转载请注明: 转载自慢慢的回味

本文链接地址: 15.1 动态规划-钢条切割

package cn.com.liuxiaofei;
 
public class Chapter15_1 {
	/**
	 * @param args
	 */
	public static void main(String[] args) {
		testCutRod(30);
 
	}
 
	public static void testCutRod(int n) {
		int[] p = new int[] { 0, 1, 5, 8, 9, 10, 17, 17, 20, 24, 30, 31, 35, 37, 40, 41, 42, 45, 49, 52, 55, 55, 56, 58, 59, 62, 62, 65, 68, 66, 64, 64, 62,
				64, 66, 68, 66, 61, 69, 70, 72 };
		if (n > p.length - 1) {
			System.out.println("The max value of n should be " + (p.length - 1));
			return;
		}
		System.out.println("Using cutRod");
		long time = System.currentTimeMillis();
		int income = cutRod(p, n);
		System.out.println("Max Income:" + income);
		System.out.println("Time elapse:" + (System.currentTimeMillis() - time));
 
		int[] r = new int[n + 1];
		System.out.println("\nUsing cutRodWithMemo");
		time = System.currentTimeMillis();
		income = cutRodWithMemo(p, n, r);
		System.out.println("Max Income:" + income);
		System.out.println("Time elapse:" + (System.currentTimeMillis() - time));
 
		System.out.println("\nUsing cutRodWithBottomUpMemo");
		time = System.currentTimeMillis();
		income = bottomUpCutRodWithMemo(p, n);
		System.out.println("Max Income:" + income);
		System.out.println("Time elapse:" + (System.currentTimeMillis() - time));
 
		System.out.println("\nUsing extendBottomUpCutRodWithMemo");
		time = System.currentTimeMillis();
		int[][] rands = extendBottomUpCutRodWithMemo(p, n);
		System.out.println("Time elapse:" + (System.currentTimeMillis() - time));
		System.out.print("i\t");
		for (int i = 1; i <= n; i++) {
			System.out.print(i + "\t");
		}
		System.out.println();
		System.out.print("s[i]\t");
		for (int i = 1; i <= n; i++) {
			System.out.print(rands[1][i] + "\t");
		}
		System.out.println();
		System.out.print("r[i]\t");
		for (int i = 1; i <= n; i++) {
			System.out.print(rands[0][i] + "\t");
		}
		System.out.println();
		System.out.print("Take n = 17 for example:\nThe solution should be:");
		int testN = 17;
		while (testN > 0) {
			int s = rands[1][testN];
			System.out.print(s + " ");
			testN -= s;
		}
		System.out.println();
 
	}
 
	/**
	 * @param p The price for each inch
	 * @param n The total length of the rod
	 * @return The max income
	 */
	public static int cutRod(int[] p, int n) {
		if (n == 0) {
			return 0;
		}
		int q = Integer.MIN_VALUE;
 
		for (int i = 1; i <= n; i++) {
			int income = p[i] + cutRod(p, n - i);
			if (q < income) {
				q = income;
			}
		}
		return q;
	}
 
	/**
	 * @param p The price for each inch
	 * @param n The total length of the rod
	 * @param r The income record array
	 * @return The max income
	 */
	public static int cutRodWithMemo(int[] p, int n, int[] r) {
		if (n == 0) {
			return 0;
		}
		if (r[n] > 0) {
			return r[n];
		}
		int q = Integer.MIN_VALUE;
 
		for (int i = 1; i <= n; i++) {
			int income = p[i] + cutRodWithMemo(p, n - i, r);
			if (q < income) {
				q = income;
			}
		}
		r[n] = q;
		return q;
	}
 
	/**
	 * @param p The price for each inch
	 * @param n The total length of the rod
	 * @return The max income
	 */
	public static int bottomUpCutRodWithMemo(int[] p, int n) {
		if (n == 0) {
			return 0;
		}
		int[] r = new int[n + 1];
		r[0] = 0;
		int q = Integer.MIN_VALUE;
 
		for (int j = 1; j <= n; j++) {
			for (int i = 1; i <= j; i++) {
				int income = p[i] + r[j - i];
				if (q < income) {
					q = income;
				}
			}
			r[j] = q;
		}
		return q;
	}
 
	/**
	 * @param p The price for each inch
	 * @param n The total length of the rod
	 * @return The max income
	 */
	public static int[][] extendBottomUpCutRodWithMemo(int[] p, int n) {
		int[] r = new int[n + 1];
		int[] s = new int[n + 1];
		if (n == 0) {
			return new int[][] { r, s };
		}
		r[0] = 0;
		int q = Integer.MIN_VALUE;
 
		for (int j = 1; j <= n; j++) {
			for (int i = 1; i <= j; i++) {
				int income = p[i] + r[j - i];
				if (q < income) {
					q = income;
					s[j] = i;
				}
			}
			r[j] = q;
		}
		return new int[][] { r, s };
	}
 
}
 
 
输出结果
Using cutRod
Max Income:90
Time elapse:4105
 
Using cutRodWithMemo
Max Income:90
Time elapse:0
 
Using cutRodWithBottomUpMemo
Max Income:90
Time elapse:0
 
Using extendBottomUpCutRodWithMemo
Time elapse:0
i	1	2	3	4	5	6	7	8	9	10	11	12	13	14	15	16	17	18	19	20	21	22	23	24	25	26	27	28	29	30	
s[i]	1	2	3	2	2	6	1	2	3	10	1	2	3	2	2	6	1	2	3	10	1	2	3	2	2	6	1	2	3	10	
r[i]	1	5	8	10	13	17	18	22	25	30	31	35	38	40	43	47	48	52	55	60	61	65	68	70	73	77	78	82	85	90	
Take n = 17 for example:
The solution should be:1 6 10

本作品采用知识共享署名 4.0 国际许可协议进行许可。