The Algorithms logo
The Algorithms
AboutDonate

Matrix Chain Multiply

J
// matrix_chain_multiply finds the minimum number of multiplications to perform a chain of matrix
// multiplications. The input matrices represents the dimensions of matrices. For example [1,2,3,4]
// represents matrices of dimension (1x2), (2x3), and (3x4)
//
// Lets say we are given [4, 3, 2, 1]. If we naively multiply left to right, we get:
//
// (4*3*2) + (4*2*1) = 20
//
// We can reduce the multiplications by reordering the matrix multiplications:
//
// (3*2*1) + (4*3*1) = 18
//
// We solve this problem with dynamic programming and tabulation. table[i][j] holds the optimal
// number of multiplications in range matrices[i..j] (inclusive). Note this means that table[i][i]
// and table[i][i+1] are always zero, since those represent a single vector/matrix and do not
// require any multiplications.
//
// For any i, j, and k = i+1, i+2, ..., j-1:
//
// table[i][j] = min(table[i][k] + table[k][j] + matrices[i] * matrices[k] * matrices[j])
//
// table[i][k] holds the optimal solution to matrices[i..k]
//
// table[k][j] holds the optimal solution to matrices[k..j]
//
// matrices[i] * matrices[k] * matrices[j] computes the number of multiplications to join the two
// matrices together.
//
// Runs in O(n^3) time and O(n^2) space.

pub fn matrix_chain_multiply(matrices: Vec<u32>) -> u32 {
    let n = matrices.len();
    if n <= 2 {
        // No multiplications required.
        return 0;
    }
    let mut table = vec![vec![0; n]; n];

    for length in 2..n {
        for i in 0..n - length {
            let j = i + length;
            table[i][j] = u32::MAX;
            for k in i + 1..j {
                let multiplications =
                    table[i][k] + table[k][j] + matrices[i] * matrices[k] * matrices[j];
                if multiplications < table[i][j] {
                    table[i][j] = multiplications;
                }
            }
        }
    }

    table[0][n - 1]
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn basic() {
        assert_eq!(matrix_chain_multiply(vec![1, 2, 3, 4]), 18);
        assert_eq!(matrix_chain_multiply(vec![4, 3, 2, 1]), 18);
        assert_eq!(matrix_chain_multiply(vec![40, 20, 30, 10, 30]), 26000);
        assert_eq!(matrix_chain_multiply(vec![1, 2, 3, 4, 3]), 30);
        assert_eq!(matrix_chain_multiply(vec![1, 2, 3, 4, 3]), 30);
        assert_eq!(matrix_chain_multiply(vec![4, 10, 3, 12, 20, 7]), 1344);
    }

    #[test]
    fn zero() {
        assert_eq!(matrix_chain_multiply(vec![]), 0);
        assert_eq!(matrix_chain_multiply(vec![10]), 0);
        assert_eq!(matrix_chain_multiply(vec![10, 20]), 0);
    }
}