Java - Brackets in Matrix Chain Multiplication in Java

Java Code:

import java.util.Arrays;

public class MatrixChainMultiplication {

    // Function to print the optimal parenthesization

    public static void printOptimalParenthesis(int[][] split, int i, int j) {

        if (i == j) {

            System.out.print("A" + (i + 1));

        } else {

            System.out.print("(");

            printOptimalParenthesis(split, i, split[i][j]);

            printOptimalParenthesis(split, split[i][j] + 1, j);

            System.out.print(")");

        }

    }

    // Matrix Chain Multiplication function with optimal parenthesis

    public static int matrixChainMultiplication(int[] dims) {

        int n = dims.length; // Number of matrices + 1        

        // dp[i][j] stores the minimum number of multiplications needed to multiply matrices i to j

        int[][] dp = new int[n][n];       

        // split[i][j] stores the index of the matrix where the optimal split occurs

        int[][] split = new int[n][n];     

        // Chain length

        for (int length = 2; length < n; length++) { // length is the chain length of matrices

            for (int i = 0; i < n - length; i++) {

                int j = i + length - 1;

                dp[i][j] = Integer.MAX_VALUE; // Set to a large number                

                // Try all possible places to split the chain

                for (int k = i; k < j; k++) {

                    int q = dp[i][k] + dp[k + 1][j] + dims[i] * dims[k + 1] * dims[j + 1];

                    if (q < dp[i][j]) {

                        dp[i][j] = q;

                        split[i][j] = k;

                    }

                }

            }

        }        

        // Print the optimal parenthesization

        System.out.print("Optimal Parenthesization: ");

        printOptimalParenthesis(split, 0, n - 2);

        System.out.println();      

        return dp[0][n - 2];

    }

    public static void main(String[] args) {

        // Dimensions of the matrices (A1, A2, ..., An)

        // For example, A1 is of dimension 10x20, A2 is of dimension 20x30, and so on

        int[] dims = {10, 20, 30, 40, 30};  // 4 matrices: A1 (10x20), A2 (20x30), A3 (30x40), A4 (40x30)

        int minMultiplications = matrixChainMultiplication(dims);

        System.out.println("Minimum number of multiplications: " + minMultiplications);

    }

}

Explanation:

dp[i][j] Table: The dp[i][j] table stores the minimum number of scalar multiplications required to multiply matrices from matrix i to matrix j.

split[i][j] Table: This table stores the index k where the optimal split happens. This index divides the matrix chain into two parts: matrices i to k and matrices k+1 to j.

Matrix Dimensions: The input dims is an array that contains the dimensions of the matrices. For example, if dims = {10, 20, 30, 40, 30}, it means there are 4 matrices:

A1 with dimensions 10x20

A2 with dimensions 20x30

A3 with dimensions 30x40

A4 with dimensions 40x30

Reconstruction of Optimal Parenthesization: After filling the DP table, the function printOptimalParenthesis recursively reconstructs and prints the optimal parenthesization using the split table.

Example Output:

For the input dims = {10, 20, 30, 40, 30}, the output will be:

Optimal Parenthesization: ((A1A2)A3)A4

Minimum number of multiplications: 26000