In this post I’m going to propose a new abstract operation on Spark RDDs – multiplexing – that makes some categories of operations on RDDs both easier to program and in many cases much faster.
My main working example will be the operation of splitting a collection of data elements into N randomlyselected subsamples. This operation is quite common in machine learning, for the purpose of dividing data into a training and testing set, or the related task of creating folds for crossvalidation).
Consider the current standard RDD method for accomplishing this task, randomSplit()
. This method takes a collection of N weights, and returns N output RDDs, each of which contains a randomlysampled subset of the input, proportional to the corresponding weight. The randomSplit()
method generates the jth output by running a random number generator (RNG) for each input data element and accepting all elements which are in the corresponding jth (normalized) weight range. As a diagram, the process looks like this at each RDD partition:
The observation I want to draw attention to is that to produce the N output RDDs, it has to run a random sampling over every element in the input for each output. So if you are splitting into 10 outputs (e.g. for a 10fold crossvalidation), you are resampling your input 10 times, the only difference being that each output is created using a different acceptance range for the RNG output.
To see what this looks like in code, consider a simplified version of random splitting that just takes an integer n
and always produces (n) equallyweighted outputs:
1 2 3 4 5 6 7 8 

(Note that for this method to operate correctly, the RNG seed must be set to the same value each time, or the data will not be correctly partitioned)
While this approach to random splitting works fine, resampling the same data N times is somewhat wasteful. However, it is possible to reorganize the computation so that the input data is sampled only once. The idea is to run the RNG once per data element, and save the element into a randomlychosen collection. To make this work in the RDD compute model, all N output collections reside in a single row of an intermediate RDD – a “manifold” RDD. Each output RDD then takes its data from the corresponding collection in the manifold RDD, as in this diagram:
If you abstract the diagram above into a generalized operation, you end up with methods that might like the following:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 

Here, the operation of sampling is generalized to any usersupplied function that maps RDD partition data into a sequence of objects that are computed in a single pass, and then multiplexed to the final uservisible outputs. Note that these functions take a StorageLevel
argument that can be used to control the caching level of the internal “manifold” RDD. This typically defaults to MEMORY_ONLY
, so that the computation can be saved and reused for efficiency.
An efficient splitsampling method based on multiplexing, as described above, might be written using flatMuxPartitions
as follows:
1 2 3 4 5 6 7 8 9 

To test whether multiplexed RDDs actually improve compute efficiency, I collected runtime data at various split values of n
(from 1 to 10), for both the nonmultiplexing logic (equivalent to the standard randomSplit
) and the multiplexed version:
As the timing data above show, the computation required to run a nonmultiplexed version grows linearly with n
, just as predicted. The multiplexed version, by computing the (n) outputs in a single pass, takes a nearly constant amount of time regardless of how many samples the input is split into.
There are other potential applications for multiplexed RDDs. Consider the following tuplebased versions of multiplexing:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 

Suppose you wanted to run an inputvalidation filter on some data, sending the data that pass validation into one RDD, and data that failed into a second RDD, paired with information about the error that occurred. Data validation is a potentially expensive operation. With multiplexing, you can easily write the filter to operate in a single efficient pass to obtain both the valid stream and the stream of errordata:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 

RDD multiplexing is currently a PR against the silex project. The code I used to run the timing experiments above is saved for posterity here.
Happy multiplexing!