In this post I’m going to propose a new abstract operation on Spark RDDs – multiplexing – that makes some categories of operations on RDDs both easier to program and in many cases much faster.
My main working example will be the operation of splitting a collection of data elements into N randomly-selected subsamples. This operation is quite common in machine learning, for the purpose of dividing data into a training and testing set, or the related task of creating folds for cross-validation).
Consider the current standard RDD method for accomplishing this task, randomSplit(). This method takes a collection of N weights, and returns N output RDDs, each of which contains a randomly-sampled subset of the input, proportional to the corresponding weight. The randomSplit() method generates the jth output by running a random number generator (RNG) for each input data element and accepting all elements which are in the corresponding jth (normalized) weight range. As a diagram, the process looks like this at each RDD partition:
The observation I want to draw attention to is that to produce the N output RDDs, it has to run a random sampling over every element in the input for each output. So if you are splitting into 10 outputs (e.g. for a 10-fold cross-validation), you are re-sampling your input 10 times, the only difference being that each output is created using a different acceptance range for the RNG output.
To see what this looks like in code, consider a simplified version of random splitting that just takes an integer n and always produces (n) equally-weighted outputs:
(Note that for this method to operate correctly, the RNG seed must be set to the same value each time, or the data will not be correctly partitioned)
While this approach to random splitting works fine, resampling the same data N times is somewhat wasteful. However, it is possible to re-organize the computation so that the input data is sampled only once. The idea is to run the RNG once per data element, and save the element into a randomly-chosen collection. To make this work in the RDD compute model, all N output collections reside in a single row of an intermediate RDD – a “manifold” RDD. Each output RDD then takes its data from the corresponding collection in the manifold RDD, as in this diagram:
If you abstract the diagram above into a generalized operation, you end up with methods that might like the following:
Here, the operation of sampling is generalized to any user-supplied function that maps RDD partition data into a sequence of objects that are computed in a single pass, and then multiplexed to the final user-visible outputs. Note that these functions take a StorageLevel argument that can be used to control the caching level of the internal “manifold” RDD. This typically defaults to MEMORY_ONLY, so that the computation can be saved and re-used for efficiency.
An efficient split-sampling method based on multiplexing, as described above, might be written using flatMuxPartitions as follows:
To test whether multiplexed RDDs actually improve compute efficiency, I collected run-time data at various split values of n (from 1 to 10), for both the non-multiplexing logic (equivalent to the standard randomSplit) and the multiplexed version:
As the timing data above show, the computation required to run a non-multiplexed version grows linearly with n, just as predicted. The multiplexed version, by computing the (n) outputs in a single pass, takes a nearly constant amount of time regardless of how many samples the input is split into.
There are other potential applications for multiplexed RDDs. Consider the following tuple-based versions of multiplexing:
Suppose you wanted to run an input-validation filter on some data, sending the data that pass validation into one RDD, and data that failed into a second RDD, paired with information about the error that occurred. Data validation is a potentially expensive operation. With multiplexing, you can easily write the filter to operate in a single efficient pass to obtain both the valid stream and the stream of error-data:
I want to make an argument that the Algebird Aggregator design, in particular its use of the prepare operation in a map-reduce context, has substantial inefficiencies, compared to an equivalent formulation that is more directly suited to taking advantage of Scala’s aggregate method on collections method.
Consider the definition of aggregation in the Aggregator class:
You can see that it is a standard map/reduce operation, where reduce is defined as a monoidal (or semigroup – more on this later) operation. Under the hood, it boils down to an invocation of Scala’s reduceLeft method. The key thing to notice is that the role of prepare is to map a collection of data elements into the required monoids, which are then aggregated using that monoid’s plus operation. In other words, prepare converts data elements into “singleton” monoids each representing a data element.
Now, if the monoid in question is simple, say some numeric type, this conversion is free, or nearly so. For example, the conversion of an integer into the “integer monoid” is a no-op. However, there are other kinds of “non-trivial” monoids, for which the conversion of a data element into its corresponding monoid may be costly. In this post, I will be using the monoid defined by Scala Set[Int], where the monoid plus operation is set union, and of course the zero element is the empty set.
Consider the process of defining an Algebird aggregator for the task of generating the set of unique elements in a data set. The corresponding prepare operation is: prepare(e: Int) = Set(e). A monoid trait that encodes this idea might look like the following. (the code I used in this post can be found here)
1234567891011121314151617181920
// an algebird-like monoid with the 'prepare' operationtraitPreparedMonoid[M, E]{valzero:Mdefplus(m1:M,m2:M):Mdefprepare(e:E):M}// a PreparedMonoid for a set of integers. monoid operator is set union.objectintSetPreparedextendsPreparedMonoid[Set[Int], Int]{valzero=Set.empty[Int]defplus(m1:Set[Int],m2:Set[Int])=m1++m2defprepare(e:Int)=Set(e)}implicitclassSeqWithMapReduce[E](seq:Seq[E]){// algebird map/reduce Aggregator modeldefmrPrepared[M](mon:PreparedMonoid[M, E]):M={seq.map(mon.prepare).reduceLeft(mon.plus)}}
If we unpack the above code, as applied to intSetPrepared, we are instantiating a new Set object, containing a single value, for every single input data element.
But there is a potentially better model of aggregation, exemplified by the Scala aggregate method. This method does not use a prepare operation. It uses a zero value and a monoidal operator, which the Scala docs refer to as combop, but it also uses an “update” operation, that defines how to update the monoid object, directly, with a single element, referred to as seqop in Scala’s documentation. This idea can also be encoded as a flavor of monoid, enhanced with an update method:
1234567891011121314151617181920
// an algebird-like monoid with 'update' operationtraitUpdatedMonoid[M, E]{valzero:Mdefplus(m1:M,m2:M):Mdefupdate(m:M,e:E):M}// an equivalent UpdatedMonoid for a set of integersobjectintSetUpdatedextendsUpdatedMonoid[Set[Int], Int]{valzero=Set.empty[Int]defplus(m1:Set[Int],m2:Set[Int])=m1++m2defupdate(m:Set[Int],e:Int)=m+e}implicitclassSeqWithMapReduceUpdated[E](seq:Seq[E]){// map/reduce logic, taking advantage of scala 'aggregate'defmrUpdatedAggregate[M](mon:UpdatedMonoid[M, E]):M={seq.aggregate(mon.zero)(mon.update,mon.plus)}}
This arrangement promises more efficiency when aggregating w.r.t. nontrivial monoids, by avoiding the construction of “singleton” monoids for each data element. The following demo confirms that for the Set-based monoid, it is over 10 times faster:
1234567891011121314151617181920212223242526
scala>:load/home/eje/scala/prepare.scalaLoading/home/eje/scala/prepare.scala...definedmodulepreparescala>importprepare._importprepare._scala>valdata=Vector.fill(1000000){scala.util.Random.nextInt(10)}data:scala.collection.immutable.Vector[Int]=Vector(7,9,4,2,7,...// Verify that output is the same for both implementations:scala>data.mrPrepared(intSetPrepared)res0:Set[Int]=Set(0,5,1,6,9,2,7,3,8,4)// results are the samescala>data.mrUpdatedAggregate(intSetUpdated)res1:Set[Int]=Set(0,5,1,6,9,2,7,3,8,4)// Compare timings of prepare-based versus update-based aggregation// (benchmark values are returned in seconds)scala>benchmark(10){data.mrPrepared(intSetPrepared)}res2:Double=0.2957673056// update-based aggregation is 10 times fasterscala>benchmark(10){data.mrUpdatedAggregate(intSetUpdated)}res3:Double=0.027041249300000004
It is also possible to apply Scala’s aggregate to a monoid enhanced with prepare:
123456
implicitclassSeqWithMapReducePrepared[E](seq:Seq[E]){// using 'aggregate' with prepared opdefmrPreparedAggregate[M](mon:PreparedMonoid[M, E]):M={seq.aggregate(mon.zero)((m,e)=>mon.plus(m,mon.prepare(e)),mon.plus)}}
Although this turns out to be measurably faster than the literal map-reduce implementation, it is still not nearly as fast as the variation using update:
Readers familiar with Algebird may be wondering about my use of monoids above, when the Aggregator interface is actually based on semigroups. This is important, since building on Scala’s aggregate function requires a zero element that semigroups do not have. Although I believe it might be worth considering changing Aggregator to use monoids, another sensible option is to change the internal logic for the subclass AggregatorMonoid, which does require a monoid, or possibly just define a new AggregatorMonoidUpdated subclass.
A final note on compatability: note that any monoid enhanced with prepare can be converted into an equivalent monoid enhanced with update, as demonstrated by this factory function:
12345678
objectUpdatedMonoid{// create an UpdatedMonoid from a PreparedMonoiddefapply[M, E](mon:PreparedMonoid[M, E])=newUpdatedMonoid[M, E]{valzero=mon.zerodefplus(m1:M,m2:M)=mon.plus(m1,m2)defupdate(m:M,e:E)=mon.plus(m,mon.prepare(e))}}
In this post I will demonstrate how to do reservoir sampling orders of magnitude faster than the traditional “naive” reservoir sampling algorithm, using a fast high-fidelity approximation to the reservoir sampling-gap distribution.
The code I used to collect the data for this post can be viewed here. I generated the plots using the quantifind WISP project.
Update (April 4, 2016): my colleague RJ Nowling ran across a paper by J.S. Vitter that shows Vitter developed the trick of accelerating sampling with a sampling-gap distribution in 1987 – I re-invented Vitter’s wheel 30 years after the fact! I’m surprised it never caught on, as it is not much harder to implement than the naive version.
In a previous post, I showed that random Bernoulli and Poisson sampling could be made much faster by modeling the sampling gap distribution for the corresponding sampling distributions. More recently, I also began exploring whether reservoir sampling might also be optimized using the gap sampling technique, by deriving the reservoir sampling gap distribution. For a sampling reservoir of size (R), starting at data element (j), the probability distribution of the sampling gap is:
Modeling a sampling gap distribution is a powerful tool for optimizing a sampling algorithm, but it presupposes that you can actually draw values from that distribution substantially faster than just applying a random process to drawing each data element. I was unable to come up with a “direct” algorithm for drawing samples from P(k) above (I suspect none exists), however I also know the CDF F(k), so it is possible to apply inversion sampling, which runs in logarithmic time w.r.t the desired accuracy. Although its logarithmic cost effectively guarantees that it will be a net efficiency win for sufficiently large (j), it still involves a substantial number of computations to yield its samples, and it seems unlikely to be competitive with straight “naive” reservoir sampling over many real-world data sizes, where (j) may never grow very large.
Well, if exact computations are too expensive, we can always look for a fast approximation. Consider the original “first principles” formula for the sampling gap P(k):
As the figure above alludes to, if (j) is relatively large compared to (k), then values (j+1),(j+2)…(j+k) are all going to be effectively “close” to (j), and so we can replace them all with (j) as an approximation. Note that the resulting approximation is just the PMF of the geometric distribution, with probability of success p=(R/j), and we already saw how to efficiently draw values from a geometric distribution from our experience with Bernoulli sampling.
Do we have any reason to hope that this approximation will be useful? For reasons that are similar to those for Bernoulli gap sampling, it will only be efficient to employ gap sampling when the probability (R/j) becomes small enough. From our experiences with Bernoulli sampling that is at least j>=2R. So, we have some assurance that (j) itself will be never be very small. What about (k)? Note that a geometric distribution “favors” smaller values of (k) – that is, small values of (k) have the highest probabilities. In fact, the smaller that (j) is, the larger the probability (R/j) is, and so the more likely that (k) values that are small relative to (j) will be the frequent ones. It is also promising that the true distribution for P(k) also favors smaller values of (k) (in fact it favors them even a bit more strongly than the approximation).
Although it is encouraging, it is also clear that my argument above is limited to heuristic hand-waving. What does this approximation really look like, compared to the true distribution? Fortunately, it is easy to plot both distributions numerically, since we now know the formulas for both:
The plot above shows that, in fact, the geometric approximation is a surprisingly good approximation to the true distribution! Furthermore, the approximation remains good as both (j) and (k) grow larger.
Our numeric eye-balling looks quite promising. Is there an effective way to measure how good this approximation is? One useful measure is the Kolmogorov-Smirnov D statistic, which is just the maximum absolute error between two cumulative distributions. Here is a plot of the D statistic for reservoir size R=10, as (j) varies across several magnitudes:
This plot is also good news: we can see that deviation, as measured by D, remains bounded at a small value (less than 0.0262). As this is for the specific value R=10, we also want to know how things change as reservoir size changes:
The news is still good! As reservoir size grows, the approximation only gets better: the D values get smaller as R increases, and remain asymptotically bounded as (j) increases.
Now we have some numeric assurance that the geometric approximation is a good one, and stays good as reservoir size grows and sampling runs get longer. However, we should also verify that an actual implementation of the approximation works as expected.
Here is pseudocode for an implementation of reservoir sampling using the fast geometric approximation:
// data is array to sample from
// R is the reservoir size
function reservoirFast(data: Array, R: Int) {
n = data.length
// Initialize reservoir with first R elements of data:
res = data[0 until R]
// Until this threshold, use traditional sampling. This value may
// depend on performance characteristics of random number generation and/or
// numeric libraries:
t = 4 * R
j = 1 + R
while (j < n && j <= t) {
k = randomInt(j) // random integer >= 0 and < j
if (k < R) res[k] = data[j]
j = j + 1
}
// Once gaps become significant, it pays to do gap sampling
while (j < n) {
// draw gap size (g) from geometric distribution with probability p = R/j
p = R / j
u = randomFloat() // random float > 0 and <= 1
g = floor(log(u) / log(1-p))
j = j + g
if (j < n) {
k = randomInt(R)
res[k] = data[j]
}
j = j + 1
}
// return the reservoir
return res
}
Following is a plot that shows two-sample D statistics, comparing the distribution in sample gaps between runs of the exact “naive” reservoir sampling with the fast geometric approximation:
As expected, the measured difference in sampling characteristics between naive and fast approximation are small, confirming the numeric predictions.
Since the point of this exercise was to achieve faster random sampling, it remains to measure what kind of speed improvements the fast approximation provides. As a point of reference, here is a plot of run times for reservoir sampling over 10^{8} integers:
As expected, sample time remains constant at around 1.5 seconds, regardless of reservoir size, since the naive algorithm always samples from its RNG per each sample.
Compare this to the corresponding plot for the fast geometric approximation:
Firstly, we see that the sampling times are much faster, as originally anticipated in my previous post – in the neighborhood of 3 orders of magnitude faster. Secondly, we see that the sampling times do increase as a linear function of reservoir size. Based on our experience with Bernoulli gap sampling, this is expected; the sampling probabilities are given by (R/j), and therefore the amount of sampling is proportional to R.
Another property anticipated in my previous post was that the efficiency of gap sampling should continue to increase as the amount of data sampled grows; the sampling probability being (R/j), the probability of sampling decreases as j gets larger, and so the corresponding gap sizes grow. The following plot verifies this property, holding reservoir size R constant, and increasing the data size:
The sampling time (per million elements) decreases as the sample size grows, as predicted by the formula.
In conclusion, I have demonstrated that a geometric distribution can be used as a high quality approximation to the true sampling gap distribution for reservoir sampling, which allows reservoir sampling to be performed much faster than the naive algorithm while still retaining sampling quality.
In this post I am going to describe some work I’ve done recently on a system of Scala traits that support tree-based collection algorithms prefix-sum, nearest key query and value increment in a mixable format, all backed by Red-Black balanced tree logic, which is also a fully inheritable trait.
(update) Since I wrote this post, the code has evolved into a library on the isarn project. The original source files, containing the exact code fragments discussed in the remainder of this post, are preserved for posterity here.
This post eventually became a bit more sprawling and “tl/dr” than I was expecting, so by way of apology, here is a table of contents with links:
The skeptical programmer may be wondering what the point of Yet Another Map Collection really is, much less an entire class hierarchy. The use case that inspired this work was my project of implementing the t-digest algorithm. Discussion of t-digest is beyond the scope of this post, but suffice it to say that constructing a t-digest requires the maintenance of a collection of “cluster” objects, that needs to satisfy the following several properties:
an entry contains one or more cluster objects at a given numeric location
entries are maintained in a numeric key order
entries will be frequently inserted and deleted, in arbitrary order
given a numeric key value, be able to find the entry nearest to that value
all of the above should be bounded by logarithmic time complexity
Propreties 2,3 and 6 are commonly satisfied by a map structure backed by some variety of balanced tree representation, of which the best-known is the Red-Black tree.
Properties 1, 4 and 5 are more interesting. Property 1 – representing a collection of multiple objects at each entry – can be accomplished in a generalizable way by noting that a collection is representable as a monoid, and so supporting values that can be incremented with respect to a user-supplied monoid relation can satisfy property-1, but also can support many other kinds of update, including but not limited to classical numeric incrementing operations.
Properties 4 and 5 – nearest-entry queries and prefix-sum queries – are also both supportable in logarithmic time using a tree data structure, provided that tree is balanced. Again, the details of the algorithms are out of the current scope, however they are not extremely complex, and their implementations are available in the code.
A reader with their software engineering hat on will notice that these properties are orthogonal. A programmer might be interested in a data structure supporting any one of them, or in some mixed combination. This kind of situation fairly shouts “Scala traits” (or, alternatively, interfaces in Java, etc). With that idea in mind, I designed a system of Scala collection traits that support all of the above properties, in a pure trait form that is fully “mixable” by the programmer, so that one can use exactly the properties needed, but not pay for anything else.
Library Overview
The library consists broadly of 3 kinds of traits:
tree node traits – implement core tree support for some functionality
collection traits – provide additional collection API methods the user
collections – instantiate a usable incarnation of a collection
For the programmer who wishes to either create a trait mixture, or add new mixable traits, the collections also function as reference implementations.
The three tables that follow summarize the currently available traits of each kind listed above. They are (at the time of this posting) all under the package namespace com.redhat.et.silex.maps:
Tree Node Traits
trait
sub-package
description
Node[K]
redblack.tree
Fundamental Red-Black tree functionality
NodeMap[K,V]
ordered.tree
Support a mapping from keys to values
NodeNear[K]
nearest.tree
Nearest-entry query (key-only)
NodeNearMap[K,V]
nearest.tree
Nearest-entry query for key/value maps
NodeInc[K,V]
increment.tree
Increment values w.r.t. a monoid
NodePS[K,V,P]
prefixsum.tree
Prefix sum queries by key (w.r.t. a monoid)
Collection Traits
trait
sub-package
description
OrderedSetLike[K,IN,M]
ordered
ordered set of keys
OrderedMapLike[K,V,IN,M]
ordered
ordered key/value map
NearestSetLike[K,IN,M]
nearest
nearest entry query on keys
NearestMapLike[K,V,IN,M]
nearest
nearest entry query on key/value map
IncrementMapLike[K,V,IN,M]
increment
increment values w.r.t a monoid
PrefixSumMapLike[K,V,P,IN,M]
prefixsum
prefix sum queries w.r.t. a monoid
Concrete Collections
trait
sub-package
description
OrderedSet[K]
ordered
ordered set
OrderedMap[K,V]
ordered
ordered key/value map
NearestSet[K]
nearest
ordered set with nearest-entry query
NearestMap[K,V]
nearest
ordred map with nearest-entry query
IncrementMap[K,V]
increment
ordered map with value increment w.r.t. a monoid
PrefixSumMap[K,V,P]
prefixsum
ordered map with prefix sum query w.r.t. a monoid
The following diagram summarizes the organization and inheritance relationships of the classes.
A Red/Black Tree Base Class
The most fundamental trait in this hierarchy is the trait that embodies Red-Black balancing; a “red-black-ness” trait, as it were. This trait supplies the axiomatic tree operations of insertion, deletion and key lookup, where the Red-Black balancing operations are encapsulated for insertion (due to Chris Okasaki) and deletion (due to Stefan Kahrs) Note that Red-Black trees do not assume a separate value, as in a map, but require only keys (thus implementing an ordered set over the key type):
objecttree{/** The color (red or black) of a node in a Red/Black tree */sealedtraitColorcaseobjectRextendsColorcaseobjectBextendsColor/** Defines the data payload of a tree node */traitData[K]{/** The axiomatic unit of data for R/B trees is a key */valkey:K}/** Base class of a Red/Black tree node * @tparam K The key type */traitNode[K]{/** The ordering that is applied to key values */valkeyOrdering:Ordering[K]/** Instantiate an internal node. */protecteddefiNode(color:Color,d:Data[K],lsub:Node[K],rsub:Node[K]):INode[K]// ... declarations for insertion, deletion and key lookup ...// ... red-black balancing rules ...}/** Represents a leaf node in the Red Black tree system */traitLNode[K]extendsNode[K]{// ... basis case insertion, deletion, lookup ...}/** Represents an internal node (Red or Black) in the Red Black tree system */traitINode[K]extendsNode[K]{/** The Red/Black color of this node */valcolor:Color/** Including, but not limited to, the key */valdata:Data[K]/** The left sub-tree */vallsub:Node[K]/** The right sub-tree */valrsub:Node[K]// ... implementations for insertion, deletion, lookup ...}}
I will assume most readers are familiar with basic binary tree operations, and the Red-Black rules are described elsewhere (I adapted them from the Scala red-black implementation). For the purposes of this discussion, the most interesting feature is that this is a pure Scala trait. All val declarations are abstract. This trait, by itself, cannot function without a subclass to eventually perform dependency injection. However, this abstraction allows the trait to be inherited freely – any programmer can inherit from this trait and get a basic Red-Black balanced tree for (nearly) free, as long as a few basic principles are adhered to for proper dependency injection.
Another detail to call out is the abstraction of the usual key with a Data element. This element represents any node payload that is moved around as a unit during tree structure manipulations, such as balancing pivots. In the case of a map-like subclass, Data is extended to include a value field as well as a key field.
The other noteworthy detail is the abstract definition def iNode(color: Color, d: Data[K], lsub: Node[K], rsub: Node[K]): INode[K] - this is the function called to create any new tree node. In fact, this function, when eventually instantiated, is what performs dependency injection of other tree node fields.
Node Inheritance Example: NodeMap[K,V]
A relatively simple example of node inheritance is hopefully instructive. Here is the definition for tree nodes supporting a key/value map:
123456789101112131415161718
objecttree{/** Trees that back a map-like object have a value as well as a key */traitDataMap[K, V]extendsData[K]{valvalue:V}/** Base class of ordered K/V tree node * @tparam K The key type * @tparam V The value type */traitNodeMap[K, V]extendsNode[K]traitLNodeMap[K, V]extendsNodeMap[K, V]withLNode[K]traitINodeMap[K, V]extendsNodeMap[K, V]withINode[K]{valdata:DataMap[K, V]}}
Note that in this case very little is added to the red/black functionality already provided by Node[K]. A DataMap[K,V] trait is defined to add a value field in addition to the key, and the internal node INodeMap[K,V] refines the type of its data field to be DataMap[K,V]. The semantics is little more than “tree nodes now carry a value in addition to a key.”
A tree node trait inherits from its own parent class and the corresponding traits for any mixed-in functionality. So for example INodeMap[K,V] inherits from NodeMap[K,V] but also INode[K].
Continuing with the ordered map example, here is the definition of the collection trait for an ordered map:
12345678910111213141516171819202122
traitOrderedMapLike[K, V, IN<:INodeMap[K, V], M<:OrderedMapLike[K, V, IN, M]]extendsNodeMap[K, V]withOrderedLike[K, IN, M]{/** Obtain a new map with a (key, val) pair inserted */def+(kv:(K,V))=this.insert(newDataMap[K, V]{valkey=kv._1valvalue=kv._2}).asInstanceOf[M]/** Get the value stored at a key, or None if key is not present */defget(k:K)=this.getNode(k).map(_.data.value)/** Iterator over (key,val) pairs, in key order */defiterator=nodesIterator.map(n=>((n.data.key,n.data.value)))/** Container of values, in key order */defvalues=valuesIterator.toIterable/** Iterator over values, in key order */defvaluesIterator=nodesIterator.map(_.data.value)}
You can see that this trait supplies collection API methods that a Scala programmer will recognize as being standard for any map-like collection. Note that this trait also inherits other standard methods from OrderedLike[K,IN,M] (common to both sets and maps) and also inherits from NodeMap[K,V]: In other words, a collection is effectively yet another kind of tree node, with additional collection API methods mixed in. Note also the use of “self types” (the type parameter M), which allows the collection to return objects of its own kind. This is crucial for allowing operations like data insertion to return an object that also supports node insertion, and to maintain consistency of type across operations. The collection type is properly “closed” with respect to its own operations.
Collection Example: OrderedMap[K,V]
To conclude the ordered map example, consider the task of defining a concrete instantiation of an ordered map:
You can see that (aside from a convenience override of toString) the trait OrderedMap[K,V] is nothing more than a vehicle for instantiating a particular concrete OrderedMapLike[K,V,IN,M] subtype, with particular concrete types for internal node (INodeMap[K,V]) and its own self-type.
Things become a little more interesting inside the companion object OrderedMap:
Note that the object returned by the factory method is upcast to OrderedMap[K,V], but in fact has the more complicated type: InjectMap[K,V] with LNodeMap[K,V] with OrderedMap[K,V]. There are a couple things going on here. The trait LNodeMap[K,V] ensures that the new object is in particular a leaf node, which embodies a new empty tree in the Red-Black tree system.
The type InjectMap[K,V] has an even more interesting purpose. Here is its definition:
Firstly, note that it is a bona fide class, as opposed to a trait. This class is where, finally, all things abstract are made real – “dependency injection” in the parlance of Scala idioms. You can see that it defines the implementation of abstract method iNode, and that it does this by returning yet anotherInjectMap[K,V] object, mixed with both INodeMap[K,V] and OrderedMap[K,V], thus maintaining closure with respect to all three slices of functionality: dependency injection, the proper type of internal node, and map collection methods.
The various abstract val fields color, data, lsub and rsub are all given concrete values inside of iNode. Here is where the value of concrete “reference” implementations manifests. Any fields in the relevant internal-node type must be instantiated here, and the logic of instantiation cannot be inherited while still preserving the ability to mix abstract traits. Therefore, any programmer wishing to create a new concrete sub-class must replicate the logic for instantiating all inherited in an internal node.
Another example makes the implications more clear. Here is the definition of injection for a collection that mixes in all three traits for incrementable values, nearest-key queries, and prefix-sum queries:
Here you can see that all logic for both “basic” internal nodes and also for maintaining prefix sums, and key min/max information for nearest-entry queries, must be supplied. If there is a singularity in this design here is where it is. The saving grace is that it is localized into a single well defined place, and any logic can be transcribed from a proper reference implementation of whatever traits are being mixed.
Finale: Trait Mixing
I will conclude by showing the code for mixing tree node traits and collection traits, which is elegant. Here are type definitions for tree nodes and collection traits that inherit from incrementable values, nearest-key queries, and prefix-sum queries, and there is almost no code except the proper inheritances:
In this post I want to discuss several advantages of defining lightweight non-negative numeric types in Scala, whose primary benefit is that they allow improved type signatures for Scala functions and methods. I’ll first describe the simple class definition, and then demonstrate how it can be used in function signatures and the benefits of doing so.
If the following ideas interest you at all, I highly recommend looking at the ‘refined’ project authored by Frank S. Thomas, which generalizes on the ideas below and supports additional static checking functionalities via macros.
A Non-Negative Integer Type
As a working example, I’ll discuss a non-negative integer type NonNegInt. My proposed definition is sufficiently lightweight to view as a single code block:
The notable properties and features of NonNegInt are:
NonNegInt is a value class around an Int, and so invokes no actual object construction or allocation
Its constructor is private, and so is safe from directly constructing around a negative integer
It supplies factory method NonNegInt(v) to construct a non negative integer value
It supplies implicit conversion from Int values to NonNegInt
Both factory method and implicit conversion check for negative values. There is no way to construct a NonNegInt that contains a negative integer value.
It also supplies implicit conversion from NonNegInt back to Int. Moving back and forth between Int and NonNegInt is effectively transparent.
The above properties work to make NonNegInt very lightweight with respect to size and runtime properties, and semantically safe in the sense that it is impossible to construct one with a negative value inside it.
Application of NonNegInt
I primarily envision NonNegInt as an easy and informative way to declare function parameters that are only well defined for non-negative values, without the need to write any explicit checking code, and yet allowing the programmer to call the function with normal Int values, due to the implicit conversions:
12345678
objectexample{importnonneg._defelement[T](seq:Seq[T],j:NonNegInt)=seq(j)// call element function with a regular Int indexvale=element(Vector(1,2,3),1)// e is set to 2}
This short example demonstrates some appealing properties of NonNegInt. Firstly, the constraint that index j >= 0 is enforced via the type definition, and so the programmer does not have to write the usual require(j >= 0, ...) check (or worry about forgetting it). Secondly, the implicit conversion from Int to NonNegInt means the programmer can just provide a regular integer value for parameter j, instead of having to explicitly say NonNegInt(1). Third, the implicit conversion from NonNegInt to Int means that j can easily be used anywhere a regular Int is used. Last, and very definitely not least, the fact that function element requires a non-negative integer is obvious right in the function signature. There is no need for a programmer to guess whether j can be negative, and no need for the author of element to document that j cannot be negative. Its type makes that completely clear.
Conclusions
In this post I’ve laid out some advantages of defining lightweight non-negative numeric types, in particular using NonNegInt as a working example. Clearly, if you want to apply this idea, you’d want to also define NonNegLong, NonNegDouble, NonNegFloat and for that matter PosInt, PosLong, etc. Happy computing!
In a previous post, I showed that random Bernoulli and Poisson sampling could be made much faster by modeling the sampling gap distribution - that is, directly drawing random samples from the distribution of how many elements would be skipped over between actual samples taken.
Another popular sampling algorithm is Reservoir Sampling. Its sampling logic is a bit more complicated than Bernoulli or Poisson sampling, in the sense that the probability of sampling any given (jth) element changes. For a sampling reservoir of size R, and all j>R, the probability of choosing element (j) is R/j. You can see that the potential payoff for gap-sampling is big, particularly as data size becomes large; as (j) approaches infinity, the probability R/j goes to zero, and the corresponding gaps between samples grow without bound.
Modeling a sampling gap distribution is a powerful tool for optimizing a sampling algorithm, but it requires that (1) you actually know the sampling distribution, and (2) that you can effectively draw values from that distribution faster than just applying a random process to drawing each data element.
With that goal in mind, I derived the probability mass function (pmf) and cumulative distribution function (cdf) for the sampling gap distribution of reservoir sampling. In this post I will show the derivations.
The Sampling Gap Distribution
In the interest of making it easy to get at the actual answers, here are the pmf and cdf for the Reservoir Sampling Gap Distribution. For a sampling reservoir of size (R), starting at data element (j), the probability distribution of the sampling gap is:
Conventions
In the derivations that follow, I will keep to some conventions:
R = the sampling reservoir size. R > 0.
j = the index of a data element being considered for sampling. j > R.
k = the size of a gap between samples. k >= 0.
P(k) is the probability that the gap between one sample and the next is of size k. The support for P(k) is over all k>=0. I will generally assume that j>R, as the first R samples are always loaded into the reservoir and the actual random sampling logic starts at j=R+1. The constraint j>R will also be relevant to many binomial coefficient expressions, where it ensures the coefficient is well defined.
Deriving the Probability Mass Function, P(k)
Suppose we just chose (randomly) to sample data element (j-1). Now we are interested in the probability distribution of the next sampling gap. That is, the probability P(k) that we will not sample the next (k) elements {j,j+1,…j+k-1}, and sample element (j+k):
By arranging the product terms in descending order as above, you can see that they can be written as factorial quotients:
Now we apply Lemma A. The 2nd case (a<=b) of the Lemma applies, since (j-1-R)<=j, so we have:
And so we have now derived a compact, closed-form expression for P(k).
Deriving the Cumulative Distribution Function, F(k)
Now that we have a derivation for the pmf P(k), we can tackle a derivation for the cdf. First I will make note of this useful identity that I scraped off of Wikipedia (I substituted (x) => (a) and (k) => (b)):
The cumulative distribution function for the sampling gap, F(k), is of course just the sum over P(t), for (t) from 0 up to (k):
This is a closed-form solution, but we can apply a bit more simplification:
Conclusions
We have derived closed-form expressions for the pmf and cdf of the Reservoir Sampling gap distribution:
In order to apply these results to a practical gap-sampling implementation of Reservoir Sampling, we would next need a way to efficiently sample from P(k), to obtain gap sizes to skip over. How to accomplish this is an open question, but knowing a formula for P(k) and F(k) is a start.
Acknowledgements
Many thanks to RJ Nowling and Will Benton for proof reading and moral support! Any remaining errors are my own fault.
Recently I have been applying Kendall’s Tau as an evaluation metric to assess how well a regression model ranks input samples, with respect to a known correct ranking.
The process of implementing the Kendall’s Tau statistic, with my software engineer’s hat on, caused me to reflect a bit on how it could be generalized beyond the traditional application of ranking numeric pairs. In this post I’ll discuss the generalization of Kendall’s Tau to non-numeric data, and also generalizing from totally ordered data to partial orderings.
A Review of Kendall’s Tau
I’ll start with a brief review of Kendall’s Tau. For more depth, a good place to start is the Wikipedia article at the link above.
Consider a sequence of (n) observations where each observation is a pair (x,y), where we wish to measure how well a ranking by x-values agrees with a ranking by the y-values. Informally, Kendall’s Tau (aka the Kendall Rank Correlation Coefficient) is the difference between number of observation-pairs (pairs of pairs, if you will) whose ordering agrees (“concordant” pairs) and the number of such pairs whose ordering disagrees (“discordant” pairs). This difference is divided by the total number of observation pairs.
The commonly-used formulation of Kendall’s Tau is the “Tau-B” statistic, which accounts for observed pairs having tied values in either x or y as being neither concordant nor discordant:
Figure 1: Kendall’s Tau-B
The formulation above has quadratic complexity, with respect to data size (n). It is possible to rearrange this computation in a way that can be computed in (n)log(n) time[1]:
Figure 2: An (n)log(n) formulation of Kendall’s Tau-B
The details of performing this computation can be found at [1] or on the Wikipedia entry. For my purposes, I’ll note that it requires two (n)log(n) sorts of the data, which becomes relevant below.
Generalizing to Non-Numeric Values
Generalizing Kendall’s Tau to non-numeric values is mostly just making the observation that the definition of “concordant” and “discordant” pairs is purely based on comparing x-values and y-values (and, in the (n)log(n) formulation, performing sorts on the data). From the software engineer’s perspective this means that the computations are well defined on any data type with an ordering relation, which includes numeric types but also chars, strings, sequences of any element supporting an ordering, etc. Significantly, most programming languages support the concept of defining ordering relations on arbitrary data types, which means that Kendall’s Tau can, in principle, be computed on literally any kind of data structure, provided you supply it with a well defined ordering. Furthermore, an examination of the algorithms shows that values of x and y need not even be of the same type, nor do they require the same ordering.
Generalizing to Partial Orderings
When I brought this observation up, my colleague Will Benton asked the very interesting question of whether it’s also possible to compute Kendall’s Tau on objects that have only a partial ordering. It turns out that you can define Kendall’s Tau on partially ordered data, by defining the case of two non-comparable x-values, or y-values, as another kind of tie.
The big caveat with this definition is that the (n)log(n) optimization does not apply. Firstly, the optimized algorithm relies heavily on (n)log(n) sorting, and there is no unique full sorting of elements that are only partially ordered. Secondly, the formula’s definition of the quantities n1, n2 and n3 is founded on the assumption that element equality is transitive; this is why you can count a number of tied values, t, and use t(t-1)/2 as the corresponding number of tied pairs. But in a partial ordering, this assumption is violated. Consider the case where (a) < (b), but (a) is non-comparable to (c) and (b) is also non-comparable to (c). By our definition, (a) is tied with (c), and (c) is tied with (b), but transitivity is violated, as (a) < (b).
So how can we compute Tau in this case? Consider (n1) and (n2), in Figure-1. These values represent the number of pairs that were tied wrt (x) and (y), respectively. We can’t use the shortcut formulas for (n1) and (n2), but we can count them directly, pair by pair, simply by conducting the traditional quadratic iteration over pairs, and incrementing (n1) whenever two x-values are noncomparable, and incrementing (n2) whenever two y-values are non-comparable, just as we increment (nc) and (nd) to count concordant and discordant pairs. With this modification, we can apply the formula in Figure-1 as-is.
Conclusions
I made these observations without any particular application in mind. However, my instincts as a software engineer tell me that making generalizations in this way often paves the way for new ideas, once the generalized concept is made available. With luck, it will inspire either me or somebody else to apply Kendall’s Tau in interesting new ways.
References
[1] Knight, W. (1966). “A Computer Method for Calculating Kendall’s Tau with Ungrouped Data”. Journal of the American Statistical Association 61 (314): 436–439. doi:10.2307/2282833. JSTOR 2282833.
Scala supplies a parallel collections library that was designed to make it easy for a programmer to add parallel computing over the elements in a collection. In this post, I will describe a case study of applying Scala’s parallel collections to cleanly implement multithreading support for training a K-Medoids clustering model.
Motivation
K-Medoids clustering is a relative of K-Means clustering that does not require an algebra over input data elements. That is, K-Medoids requires only a distance metric defined on elements in the data space, and can cluster objects which do not have a well-defined concept of addition or division that is necessary for computing the centroids required by K-Means. For example, K-Medoids can cluster character strings, which have a notion of distance, but no notion of summation that could be used to compute a geometric centroid.
This additional generality comes at a cost. The medoid of a collection of elements is the member of the collection that minimizes some function F of the distances from that element to all the other elements in the collection. For example, F might be the sum of distances from one element to all the elements, or perhaps the maximum distance, etc. It is not hard to see that the cost of computing a medoid of (n) elements is quadratic in (n): Evaluating F is linear in (n) and F in turn must be evaluated with respect to each element. Furthermore, unlike centroid-based computations used in K-Means, computing a medoid does not naturally lend itself to common scale-out computing formalisms such as Spark RDDs, due to the full-cross-product nature of the computation.
With this in mind, a more traditional multithreading approach is a good candidate to achieve some practical parallelism on modern multi-core hardware. I’ll demonstrate that this is easy to implement in Scala with parallel sequences.
Non-Parallel Code
Consider a baseline non-parallel implementation of K-Medoids, as in the following example skeleton code. (A working version of this code, under review at the time of this post, can be viewed here)
classKMedoids[T](k:Int,metric:(T,T)=>Double){// Train a K-Medoids cluster on some input datadeftrain[T](data:Seq[T]){varcurrent=// randomly select k data elements as initial clustervarmodel_converged=falsewhile(!model_converged){// assign each element to its closest medoidvalclusters=data.groupBy(medoidIdx(_,current)).map(_._2)// recompute the medoid from the latest cluster elementsvalnext=benchmark("medoids"){clusters.map(medoid)}model_converged=// test for model convergencecurrent=next}}// Return the medoid of some collection of elementsdefmedoid(data:Seq[T])={benchmark(s"medoid: n= ${data.length}"){data.minBy(medoidCost(_,data)}}// The sum of an element's distance to all the elements in its clusterdefmedoidCost(e:T,data:Seq[T])=data.iterator.map(metric(e,_)).sum// Index of the closest medoid to an elementdefmedoidIdx(e:T,mv:Seq[T])=mv.iterator.map(metric(e,_)).zipWithIndex.min._2// Output a benchmark timing of some expressiondefbenchmark[T](label:String)(blk:=>T)={valt0=System.nanoTimevalt=blkvalsec=(System.nanoTime-t0)/1e9println(f"Run time for $label = $sec%.1f");System.out.flusht}}
If we run the code above (de-skeletonized), then we might see something like this output from our benchmarking, where I clustered a dataset of 40,000 randomly-generated (x,y,z) points by Gaussian sampling around 5 chosen centers. (This data is numeric, but I provide only a distance metric on the points. K-Medoids has no knowledge of the data except that it can run the given metric function on it):
One iteration of a clustering run (k = 5)
123456
Run time for medoid: n= 8299 = 7.7
Run time for medoid: n= 3428 = 1.2
Run time for medoid: n= 12581 = 17.0
Run time for medoid: n= 5731 = 3.3
Run time for medoid: n= 9961 = 10.2
Run time for medoids = 39.8
Observe that cluster sizes are generally not the same, and we can see the time per cluster varying quadratically with respect to cluster size.
A First Take On Parallel K-Medoids
Studying our non-parallel code above, we can see that the computation of each new medoid is independent, which makes it a likely place to inject some parallelism. A Scala sequence can be transformed into a corresponding parallel sequence using the par method, and so parallelizing our code is literally this simple:
Parallelizing a collection with .par
1234
// recompute the medoid from the latest cluster elementsvalnext=benchmark("medoids"){clusters.par.map(medoid).seq}
In this block, I also apply .seq at the end, which is not always necessary but can avoid type mismatches between Seq[T] and ParSeq[T] under some circumstances.
In my case I also wish to exercise some control over the threading used by the parallelism, and so I explicitly assign a ForkJoinPool thread pool to the sequence:
Set the threading used by a Scala ParSeq
1234567891011
// establish a thread pool for use by K-MedoidsvalthreadPool=newForkJoinPool(numThreads)// ...// recompute the medoid from the latest cluster elementsvalnext=benchmark("medoids"){valpseq=clusters.parpseq.tasksupport=newForkJoinTaskSupport(threadPool)pseq.map(medoid).seq}
Minor grievance: it would be nice if Scala supported some ‘in-line’ methods, like seq.par(n)... and seq.par(threadPool)..., instead of requiring the programmer to break the flow of the code to invoke tasksupport =, which returns Unit.
Now that we’ve parallelized our K-Medoids training, we should see how well it responds to additional threads. I ran the above parallelized version using {1, 2, 4, 8, 16, 32} threads, on a machine with 40 cores, so that my benchmarking would not be impacted by attempting to run more threads than there are cores to support them. I also ran two versions of test data. The first I generated with clusters of equal size (5 clusters of ~8000 elements), and the second with one cluster being twice as large (1 cluster of ~13300 and 4 clusters of ~6700). Following is a plot of throughput (iterations / second) versus threads:
In the best of all possible worlds, our throughput would increase linearly with the number of threads; double the threads, double our iterations per second. Instead, our throughput starts to increase nicely as we add threads, but hits a hard ceiling at 8 threads. It is not hard to see why: our parallelism is limited by the number of elements in our collection of clusters. In our case that is k = 5, and so we reach our ceiling at 8 threads, the first thread number >= 5. Furthermore, we see that when the size of clusters is unequal, the throughput suffers even more. The time required to complete the clustering is dominated by the most expensive element. In our case, the cluster that is twice the size of other clusters:
Run time is dominated by largest cluster
123456
Run time for medoid: n= 6695 = 5.1
Run time for medoid: n= 6686 = 5.2
Run time for medoid: n= 6776 = 5.3
Run time for medoid: n= 6682 = 5.4
Run time for medoid: n= 13161 = 19.9
Run time for medoids = 19.9
Take 2: Improving The Use Of Threads
Fortunately it is not hard to improve on this situation. If parallelizing by cluster is too coarse, we can try pushing our parallelism down one level of granularity. In our case, that means parallelizing the outer loop of our medoid function, and it is just as easy as before:
Parallelize the outer loop of medoid computation
12345678
// Return the medoid of some collection of elementsdefmedoid(data:Seq[T])={benchmark(s"medoid: n= ${data.length}"){valpseq=data.parpseq.tasksupport=newForkJoinTaskSupport(threadPool)pseq.minBy(medoidCost(_,data)}}
Note that I retained the previous parallelism at the cluster level, otherwise the algorithm would execute parallel medoids, but one cluster at a time. Also observe that we are applying the same thread pool we supplied to the ParSeq at the cluster level. Scala’s parallel logic can utilize the same thread pool at multiple granularities without blocking. This makes it very clean to control the total number of threads used by some computation, by simply re-using the same threadpool across all points of parallelism.
Now, when we re-run our experiment, we see that our throughput continues to increase as we add threads. The following plot illustrates the throughput increasing in comparison to the previous ceiling, and also that throughput is less sensitive to the cluster size, as threads can be allocated flexibly across clusters as they are available:
I hope this short case study has demonstrated how easy it is to add multithreading to computations with Scala parallel sequences, and some considerations for making the best use of available threads. Happy Parallel Programming!
In most use cases of Scala closures, what you see is what you get, but there are exceptions where looks can be deceiving and this can have a big impact on closure serialization. Closure serialization is of more than academic interest. Tools like Apache Spark cannot operate without serializing functions over the network. In this post I’ll describe some scenarios where closures include more than what is evident in the code, and then a technique for preventing unwanted inclusions.
To establish a bit of context, consider this simple example that obtains a function and serializes it to disk, and which does behave as expected:
object Demo extends App {
def write[A](obj: A, fname: String) {
import java.io._
new ObjectOutputStream(new FileOutputStream(fname)).writeObject(obj)
}
object foo {
val v = 42
// The returned function includes 'v' in its closure
def f() = (x: Int) => v * x
}
// The function 'f' will serialize as expected
val f = foo.f
write(f, "/tmp/demo.f")
}
When this app is compiled and run, it will serialize f to “/tmp/demo.f1”, which of course includes the value of v as part of the closure for f.
Now, imagine you wanted to make a straightforward change, where object foo becomes class foo:
object Demo extends App {
def write[A](obj: A, fname: String) {
import java.io._
new ObjectOutputStream(new FileOutputStream(fname)).writeObject(obj)
}
// foo is a class instead of an object
class foo() {
val v = 42
// The returned function includes 'v' in its closure, but also a secret surprise
def f() = (x: Int) => v * x
}
// This will throw an exception!
val f = new foo().f
write(f, "/tmp/demo.f")
}
It would be reasonable to expect that this minor variation behaves exactly as the previous one, but instead it throws an exception!
If we look at the exception message, we see that it’s complaining about not knowing how to serialize objects of class foo. But we weren’t including any values of foo in the closure for f, only a particular member ‘v’! What gives? Scala is not very helpful with diagnosing this problem, but when a class member value shows up in a closure that is defined inside the class body, the entire instance, including any and all other member values, is included in the closure. Presumably this is because a class may have any number of instances, and the compiler is including the entire instance in the closure to properly resolve the correct member value.
One straightforward way to fix this is to simply make class foo serializable:
class foo() extends Serializable {
// ...
}
If you make this change to the above code, the example with class foo now works correctly, but it is working by serializing the entire foo instance, not just the value of v.
In many cases, this is not a problem and will work fine. Serializing a few additional members may be inexpensive. In other cases, however, it can be an impractical or impossible option. For example, foo might include other very large members, which will be expensive or outright impossible to serialize:
class foo() extends Serializable {
val v = 42 // easy to serialize
val w = 4.5 // easy to serialize
val data = (1 to 1000000000).toList // serialization landmine hiding in your closure
// The returned function includes all of 'foo' instance in its closure
def f() = (x: Int) => v * x
}
A variation on the above problem is class members that are small or moderate in size, but serialized many times. In this case, the serialization cost can become intractable via repetition of unwanted inclusions.
Another potential problem is class members that are not serializable, and perhaps not under your control:
class foo() extends Serializable {
import some.class.NotSerializable
val v = 42 // easy to serialize
val x = new NotSerializable // I'll hide in your closure and fail to serialize
// The returned function includes all of 'foo' instance in its closure
def f() = (x: Int) => v * x
}
There is a relatively painless way to decouple values from their parent instance, so that only desired values are included in a closure. Passing desired values as parameters to a shim function whose job is to assemble the closure will prevent the parent instance from being pulled into the closure. In the following example, a shim function named closureFunction is defined for this purpose:
object Demo extends App {
def write[A](obj: A, fname: String) {
import java.io._
new ObjectOutputStream(new FileOutputStream(fname)).writeObject(obj)
}
// apply a generator to create a function with safe decoupled closures
def closureFunction[E,D,R](enclosed: E)(gen: E => (D => R)) = gen(enclosed)
class NotSerializable {}
class foo() {
val v1 = 42
val v2 = 73
val n = new NotSerializable
// use shim function to enclose *only* the values of 'v1' and 'v2'
def f() = closureFunction((v1, v2)) { enclosed =>
val (v1, v2) = enclosed
(x: Int) => (v1 + v2) * x // Desired function, with 'v1' and 'v2' enclosed
}
}
// This will work!
val f = new foo().f
write(f, "/tmp/demo.f")
}
Being aware of the scenarios where parent instances are pulled into closures, and how to keep your closures clean, can save some frustration and wasted time. Happy programming!
Author’s note 0: I have come up with better, more correct designs for monadic objects that implement
break and continue in Scala for-comprehensions. I’m leaving this blog post up for posterity, but I
recommend using the ‘breakable’ project
if you are interested in break and continue in a Scala framework.
Author’s note: I’ve since received some excellent feedback from the Scala community, which I included in some end notes.
Author’s note the 2nd: I later realized I could apply an implicit conversion and mediator class to preserve the traditional ordering: the code has been updated with that approach.
Author’s note the 3rd: This concept has been submitted to the Scala project as JIRA SI-9120 (PR #4275)
Scala sequence comprehensions are an excellent functional programming idiom for looping in Scala. However, sequence comprehensions encompass much more than just looping – they represent a powerful syntax for manipulating all monadic structures[1].
The break and continue looping constructs are a popular framework for cleanly representing multiple loop halting and continuation conditions at differing stages in the execution flow. Although there is no native support for break or continue in Scala control constructs, it is possible to implement them in a clean and idiomatic way for sequence comprehensions.
In this post I will describe a lightweight and easy-to-use implementation of break and continue for use in Scala sequence comprehensions (aka for statements). The entire implementation is as follows:
object BreakableGenerators {
import scala.language.implicitConversions
type Generator[+A] = Iterator[A]
type BreakableGenerator[+A] = BreakableIterator[A]
// Generates a new breakable generator from any traversable object.
def breakable[A](t1: TraversableOnce[A]): Generator[BreakableGenerator[A]] =
List(new BreakableIterator(t1.toIterator)).iterator
// Mediates boolean expression with 'break' and 'continue' invocations
case class BreakableGuardCondition(cond: Boolean) {
// Break the looping over one or more breakable generators, if 'cond'
// evaluates to true.
def break(b: BreakableGenerator[_], bRest: BreakableGenerator[_]*): Boolean = {
if (cond) {
b.break
for (x <- bRest) { x.break }
}
!cond
}
// Continue to next iteration of enclosing generator if 'cond'
// evaluates to true.
def continue: Boolean = !cond
}
// implicit conversion of boolean values to breakable guard condition mediary
implicit def toBreakableGuardCondition(cond: Boolean) =
BreakableGuardCondition(cond)
// An iterator that can be halted via its 'break' method. Not invoked directly
class BreakableIterator[+A](itr: Iterator[A]) extends Iterator[A] {
private var broken = false
private[BreakableGenerators] def break { broken = true }
def hasNext = !broken && itr.hasNext
def next = itr.next
}
}
The approach is based on a simple subclass of Iterator – BreakableIterator – that can be halted by ‘breaking’ it. The function breakable(<traversable-object>) returns an Iterator over a single BreakableIterator object. Iterators are monad-like structures in that they implement map and flatMap, and so its output can be used with <- at the start of a for construct in the usual way. Note that this means the result of the for statement will also be an Iterator.
Whenever the boolean expression for an if guard is followed by either break or continue, it is implicitly converted to a “breakable guard condition” that supports those methods. The function break accepts one or more instances of BreakableIterator. If it evaluates to true, the loops embodied by the given iterators are immediately halted via the associated if guard, and the iterators are halted via their break method. The continue function is mostly syntactic sugar for a standard if guard, simply with the condition inverted.
Here is a simple example of break and continue in use:
object Main {
import BreakableGenerators._
def main(args: Array[String]) {
val r = for (
// generate a breakable sequence from some sequential input
loop <- breakable(1 to 1000);
// iterate over the breakable sequence
j <- loop;
// print out at each iteration
_ = { println(s"iteration j= $j") };
// continue to next iteration when 'j' is even
if { j % 2 == 0 } continue;
// break out of the loop when 'j' exceeds 5
if { j > 5 } break(loop)
) yield {
j
}
println(s"result= ${r.toList}")
}
}
We can see from the resulting output that break and continue function in the usual way. The continue clause ignores all subsequent code when j is even. The break clause halts the loop when it sees its first value > 5, which is 7. Only odd values <= 5 are output from the yield statement:
Breakable iterators can be nested in the way one would expect. The following example shows an inner breakable loop nested inside an outer one:
object Main {
import BreakableGenerators._
def main(args: Array[String]) {
val r = for (
outer <- breakable(1 to 7);
j <- outer;
_ = { println(s"outer j= $j") };
if { j % 2 == 0 } continue;
inner <- breakable(List("a", "b", "c", "d", "e"));
k <- inner;
_ = { println(s" inner j= $j k= $k") };
if { k == "d" } break(inner);
if { j == 5 && k == "c" } break(inner, outer)
) yield {
(j, k)
}
println(s"result= ${r.toList}")
}
}
The output demonstrates that the inner loop breaks whenever k=="d", and so "e" is never present in the yield result. When j==5 and k=="c", both the inner and outer loops are broken, and so we see that there is no (5,"c") pair in the result, nor does the outer loop ever iterate over 6 or 7:
$ scalac -d /home/eje/class monadic_break.scala
$ scala -classpath /home/eje/class Main
outer j= 1
inner j= 1 k= a
inner j= 1 k= b
inner j= 1 k= c
inner j= 1 k= d
outer j= 2
outer j= 3
inner j= 3 k= a
inner j= 3 k= b
inner j= 3 k= c
inner j= 3 k= d
outer j= 4
outer j= 5
inner j= 5 k= a
inner j= 5 k= b
inner j= 5 k= c
result= List((1,a), (1,b), (1,c), (3,a), (3,b), (3,c), (5,a), (5,b))
Using break and continue with BreakableIterator for sequence comprehensions is that easy. Enjoy!
Notes
The helpful community on freenode #scala made some excellent observations:
1: Iterators in Scala are not strictly monadic – it would be more accurate to say they’re “things with a flatMap and map method, also they can use filter or withFilter sometimes.” However, I personally still prefer to think of them as “monadic in spirit if not law.”
2: The break function, as described in this post, is not truly functional in the sense of referential transparency, as the invocation if break(loop) { condition } involves a side-effect on the variable loop. I would say that it does maintain “scoped functionality.” That is, the break in non-referential transparency is scoped by the variables in question. The for statement containing them is referentially transparent with respect to its inputs (provided no other code is breaking referential transparency, of course).