Distributed Matrix Multiplication in Spark


Over the past year, I’ve been using Apache Spark for data piping, exploring large scale machine learning applications, and working in tuning Spark clusters for peak performance for analysts. Spark is a huge draw to the data science community because of it’s familiar API to pandas, low learning curve, and accessibility by several languages.


Many Spark tutorials do not dive into the internals of Spark. The goal of this post is to walk someone through the Spark source. We’ll start at Spark’s high level Python API and arrive last at the compiled libraries LAPACK and BLAS.


I’ll create a post in the future on walking step by step to install an environment, but for now I’ll assume that you have a environment setup, so we can dive into the code.


Below is an example of using Alternating Least Squares as a recommendation engine for movies. It’s not required that you understand the most intricate details of ALS. I’ll be using ALS as a vehicle for exploring the documentaiton and source code of Spark into the underlying linear algebra library calls and nothing more. If you are very interested in this topic, this post is available and translates well. This code can be run interactively line by line with the pyspark shell or in a jupyter notebook.

from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.sql import Row

lines = Spark.read.text("sample_movielens_ratings.txt").rdd
parts = lines.map(lambda row: row.value.split("::"))
ratingsRDD = parts.map(lambda p: Row(userId=int(p[0]), movieId=int(p[1]),
                                     rating=float(p[2]), timestamp=long(p[3])))
ratings = Spark.createDataFrame(ratingsRDD)
(training, test) = ratings.randomSplit([0.8, 0.2])

# Build the recommendation model using ALS on the training data
als = ALS(maxIter=5, regParam=0.01, userCol="userId",
itemCol="movieId", ratingCol="rating")
model = als.fit(training)

# Evaluate the model by computing the RMSE on the test data
predictions = model.transform(test)
evaluator = RegressionEvaluator(metricName="rmse", labelCol="rating",
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = " + str(rmse))


If you are using an interactive environment like a jupyter notebook or the pyspark shell with IPython, we can get some initial information about what the als.fit(training) line of code does. If you are familiar with the estimator style API found in scikit-learn this will look very familiar, but we will looking at it’s actual implementation. In order to get a brief descripiton we can simply type ?als.fit(training) in jupyter or the Spark shell (with IPython) and we will be presented with documentation in our environment.

The doc-string states that given a dataframe, we return a fitted model. This might be helpful for someone who isn’t looking for implementation details, so we press on to the source code.

At this point, I move to the pyspark documentation. Since the docs are built with Sphinx, a really useful tool for building documentation that links to source code, it’s easy to reference. The pyspark documentation is hosted here. I frequently use the quick search functionality in Sphinx, because it is much better than other search tools bundled in documentation. After searching for als.fit in quick search we are directed to the documentation of the fit method and we can click to navigate to the source.

Unfortunately, we can’t navigate to the source of the fit.method, because ALS is actually a subclass of JavaEstimator. Here is where things in the Spark API get interesting. In order to reap the performance benifits of the JVM in spark, Java or Scala code is typically wrapped in Python code. The reason that I choose starting with Python is because I wanted to spend some time investigating this wrapping.

Wrapping Java Objects

The Python class ALS that we used earlier performs Java wrapping in it’s constructor with it’s call to it’s super class’ constructor. After calling the constructor for JavaWrapper, ALS set’s up it’s wrapping of the underlying JVM class org.apache.Spark.ml.recommendation.ALS. This class can be found at in the documentation here. It’s a Scala class that a Scala or Java developer would interact with directly if they were writing Spark code with either of these languages. Initial parameters from the constructor of the ALS Python object are passed to the underlying Scala object. As an example we can pass the parameters regParam to the Python ALS class that will in turn pass down to the Scala ALS implementation algorithm.

At this point we are interacting with the underlying Scala implementation through the wrapper that has been developed in pyspark. When we invoke the ALS.fit on our Python object, Spark will call down to _fit_java in order to run on the JVM. _fit_java will then call down to the Scala objects fit method.

Java Virtual Machine Layer

The Java Virtual Machine, JVM, allows programmers that author Java code to run their code, portably, across different system architectures. The Python code that we initially authored now calls pre-compiled Scala code that will run on the JVM. The ALS.fit method does various sanity checks for the dataset for fitting then, converts that dataset into an RDD of case class instances of type Rating. For those not familiar with case classes or RDD’s, seek answers in the Spark documentation.

ALS Implementation

Now the method of ALS.fit finally calls out to the DeveloperApi that implements ALS factorization. We finally made it to the actual implementation of the algorithm after all of that unwrapping! The implementation also performs various sanity checks and then determines whether or not we can use the NNLSSolver for a non-negative matrix, otherwise the CholeskySolver.


The two solvers are found in the same file as the ALS Scala class.

Both of these classes extend the trait LeastSquaresNESolver, so either can be used with the function computeFactors(think ducktyping if you only come from a Python background).

val solver = if (non negative) new NNLSSolver else new CholeskySolver
computeFactors(..., solver)


The private function computeFactors is where our solvers are actually used. At this line the solvers are invoked to create dstFactors, factors used to make recommendations. The CholeskySolver solves the least squares problem with an L2 regularization using… While the NNLSSolver solves a nonnegative least squares problem with L2 regularization using … subject to x >= 0 CholskeyDecomposition.solve is used for the CholeskySolver and NNLS.solve is used for NNLSSolver.


Like we have covered in our lectures @USF, many linear algebra libraries use LAPACK and BLAS under the hood. LAPACK and BLAS are wrapped in the project netlib which Spark uses for CholskeyDecomposition.solve and NNLS.

This is where we end our journey. We’ve made it quite far in our exploration of Spark’s internals and the next stage would be exploring LAPACK and BLAS. If you continue, feel free to inform me on where your journey took you after that.

You're awesome for taking time out of your day to read this! Please consider sharing below!