Naive Matrix Multiplication in Clojure

I recently started reading Algorithms, 4th Edition by Robert Sedgewick and Kevin Wayne. The main goal is to refresh myself on common algorithms and data structures. I came across an interesting problem at the end of the 1st section.

The goal of the exercise is to implement the following API for dealing with matrices:

static double dot(double[] x, double[] y)

static double[][] transpose(double[][] a)

static double[][] mult(double[][] a, double[][] b)

static double[] mult(double[][] a, double[] x)

static double[] mult(double[] y, double[][] a)

Implementing this API in an imperative language like Java would be easy enough. And there are a plethora of examples on the web on how to do just that. I wanted to challenge myself and take a functional approach to this problem. My language of choice was Clojure. I’ve dabbled with it before and have always enjoyed it.

Warning: I don’t claim/pretend to be a Clojure expert. I’m sure there are more elegant ways to solve these problems.

Let’s tackle these methods one by one.


Dot Product

The first method dot is for computing the dot product.

(defn dot                                                                        
  "Vector dot product"                                                           
  [x y]                                                                          
  (reduce + (map * x y)))

We use the map function to multiply the vectors element by element. This will produce a sequence of products. We then use reduce to sum up that sequence.


Transpose

transpose flips a matrix over its diagonal.

(defn transpose                                                                  
  "Transposes matrix"                                                                    
  [a]                                                                            
  (apply map vector a)) 

We use apply to apply (map vector) to the rows of a.

Explanation

Let’s say we had a 2 X 3 matrix:

[[1 2 3] [3 4 5]]

Let’s pull out those rows into two separate vectors and call (map vector) on them

(map vector [1 2 3] [3 4 5])

The result would be:

([1 3] [2 4] [3 5])

We now have 3 rows and 2 columns.

Wait What?

vector is just a function that creates a new vector with the arguments passed to it.

We call map to map the vector function to the two vectors. The map function consumes the items of the vectors in parallel.

Essentially we’re doing the following when we call (map vector) on the two vectors:

(vector 1 3) (vector 2 4) (vector 3 5)

The issue is that our matrix isn’t represented by two standalone separate vectors. This is why we need apply.

(def a [[1 2 3] [3 4 5]])

(apply map vector a)

Result: [[1 3] [2 4] [3 5]]


Mult (Matrix-Matrix)

mat-mult is for multiplying two matrices.

In order to multiply two matrices, the number of columns of the first matrix must be equal to the number of rows of the second matrix. We check for this with a :pre condition.

(defn mat-mult                                                                   
  "Matrix-Matrix product"                                                        
  [a b]                                                                          
  {:pre [(= (count (nth a 0)) (count b))]}                                       
  (vec                                                                           
   (->> (for [x a                                                                
              y (transpose b)]                                                   
          (dot x y))                                                             
        (partition (count (transpose b)))                                        
        (map vec))))  

Explanation

Let’s ignore the thread-last macro (->>) for the moment.

The for allows us to create a list comprehension. You may be familiar with this concept if you’ve programmed with Python.

In this case we’re creating a nested list comprehension. The resulting list will essentially be the pairings we need to perform dot product on from matrix a and the transposed matrix of b.

i.e.

(dot *The first row of A* *the first column of b*)

The resulting list will be a single flat sequence. We need to use partition to separate out the rows. This is based on the number of rows in the second matrix b.

Finally we call vec on each of these rows to convert them from lists to vectors.

Going back to the macro. This allows us to thread an expression through multiple forms. In this instance we’re threading the flat sequence we get back from the list comprehension through partition and map.

The outermost vec call just converts the list of vectors into a vector of vectors.


Mult (Matrix-Vector)

mat-vec-mult is for getting the product of a matrix times a vector.

(defn mat-vec-mult                                                               
  "Matrix-Vector product"                                                        
  [m v]                                                                          
  (mat-mult m (vector v))) 

Mult (Vector-Matrix)

vec-mat-mult is for getting the product of a vector times a matrix.

(defn vec-mat-mult                                                               
  "Vector-Matrix product"                                                        
  [v m]                                                                          
  (mat-mult (vector v) m))