Tuesday, September 9, 2014

Type Classes, Implicit Parameters and Instance Equality

In Scala the concept of implicit parameters can be used to, among other things, emulate Haskell type classes. One of the key differences between the two approaches is that the Haskell compiler statically guarantees that there will always be at most one instance of a type class for a specific type, while in Scala there can be any number of instances which are selected based on some quite complex scope searching rules. This guarantee is sometimes brought up as an advantage for Haskell type classes as you can for example easily write a union function for two sets with the same element types and be certain that both sets use the same ordering instance. In Scala you would typically store the ordering instance in the set type, but since two sets can refer to different instances which have the same static type, the compiler can't statically check if it's safe to call union on the sets. One option is of course add a runtime equality check of the ordering instances, but it would obviously be better to have a static check. In this post I'll describe a solution to get the same static type safety in Scala as in Haskell.

Implicit Instance Construction

The first thing to notice is that there are two basic ways to create an implicit instance in Scala, either as an immutable value or as a function which can take other implicit instances as arguments, for example:

  // Very incomplete Ord type class as example
  case class Ord[T](v: String)

  implicit val intOrd = Ord[Int]("Int")
  
  implicit def listOrd[T](implicit ord: Ord[T]) = Ord[List[T]]("List[" + ord.v + "]")

In this example when searching for an implicit Ord[Int] the Scala compiler will simply use the intOrd instance directly. However when searching for an implicit of type Ord[List[Int]] things gets a bit more complicated. In this case the Scala compiler figures out using the static types that it can create an instance by calling listOrd(intOrd), so it generates this call in the compiled code where the instance is requested. Further recursive calls can be generated as needed, so when searching for an implicit Ord[List[List[Int]] it will generate a call listOrd(listOrd(intOrd)), and so on. 

Note that these function calls are performed every time an implicit instance is needed at runtime, there is no memoization or caching done. So, while for Ord[Int] the same instance will always be used (as it's a constant value), there will be multiple instances for List[Int] created at runtime.

Furthermore, implicit values and functions of the exact same types (but possibly with different implementations as in the example below) can be defined in other modules (or objects as they are called in Scala):

  object o1 {
    implicit val intOrd = Ord[Int]("o1.Int")

    implicit def listOrd[T](implicit ord: Ord[T]) = Ord[List[T]]("o1.List[" + ord.v + "]")
  }

  object o2 {
    implicit val intOrd = Ord[Int]("o2.Int")

    implicit def listOrd[T](implicit ord: Ord[T]) = Ord[List[T]]("o2.List[" + ord.v + "]")
  }

Considering all this it might seem a bit complicated to statically check that two implicit instances are indeed behaviorally equal and interchangeable. But the key insight here is that as long as no side effects are performed inside the implicit functions (and it's my strong recommendation not to perform any!), two separate instances created through the same implicit call chain generated by the Scala compiler will behave identically (except for instance identity checks of course). So, to statically check that two instances are equal (but not necessarily identical) all we need to do is to track the implicit call chain in the type of the implicit instance. So, let's get to it...


Phantom Types to the Rescue

Let's start by adding a phantom type, P to the Ord type:

  case class Ord[P, T](v: String)

This type is only used for static equality checks, e.g. if a: Ord[A, B] and b: Ord[A, B] then a and b are equal and can be used interchangeably (note that with the previous definition, Ord[T], this was not the case). Note that the P type can also be written as a type member in Scala, but doing that gave me problems with the type inference so I won't explore that road in this article.

Now we can easily define a unique type for the implicit values in each module object:

  object o1 {
    implicit val intOrd = Ord[this.type, Int]("o1.Int")
  }

  object o2 {
    implicit val intOrd = Ord[this.type, Int]("o2.Int")
  }

The this.type expression used for the P type parameter is evaluated to the singleton type of the module object containing the expression (i.e. o1.type respectively o2.type in this case). This means that o1.intOrd and o2.intOrd will not have the same static type anymore (Ord[o1.type, Int] vs Ord[o2.type, Int]), which is exactly what we wanted. Note that the use of this.type only works as long as there are just one implicit instance with the same type T in Ord (in this case Int). This is usually not a problem, and otherwise this.type can be replaced with arbitrary unique type defined in the module object.

Things get a bit trickier for implicit functions which have implicit parameters. Here we must construct a phantom type that is a contains both a unique module type and the phantom types of the implicit arguments. We can use a tuple type to accomplish this:

  object o1 {
    implicit val intOrd = Ord[this.type, Int]("o1.Int")

    implicit def ordList[P, T](implicit ord: Ord[P, T]) = Ord[(this.type, P), List[T]]("o1.List[" + ord + "]")
  }

  object o2 {
    implicit val intOrd = Ord[this.type, Int]("o2.Int")

    implicit def ordList[P, T](implicit ord: Ord[P, T]) = Ord[(this.type, P), List[T]]("o2.List[" + ord + "]")
  }

In this example the instance o1.ordList(o1.intOrd) would have type Ord[(o1.type, o1.type), List[Int]], o1.ordList(o1.ordList(o1.intOrd)) would have type Ord[(o1.type, (o1.type, o1.type)), List[List[Int]]] and so on. A combination of implicits from both modules o1.ordList(o2.intOrd) would have type Ord[(o1.type, o2.type), List[Int]].

Set Union Implementation

So, now that we have our phantom types we can quite easily write the type safe set type and the union function. However, there are two possible implementations, either we can get the Ord instance from an implicit parameter to each set function (like union):

  object set1 {
    class Set[P, T]

    def union[P, T](s1: Set[P, T], s2: Set[P, T])(implicit ord: Ord[P, T]) = new Set[P, T]

    // Dummy constructor
    def set[P, T](v: T)(implicit ord: Ord[P, T]) = new Set[P, T]
  }

or we can store it inside the Set object:

  object set2 {
    class Set[P, T](val ord: Ord[P, T])

    def union[P, T](s1: Set[P, T], s2: Set[P, T]) = new Set[P, T](s1.ord)

    // Dummy constructor
    def set[P, T](v: T)(implicit ord: Ord[P, T]) = new Set[P, T](ord)
  }

Either way we can't call union on sets with different P types. The first set implementation saves some memory and works quite well as the Scala compiler is pretty smart about where to look implicit arguments based on the function type arguments.

A Small Test

Here's a small test case to verify that we get a type error when trying to unify two sets with different Ord instances:

  import set1._

  object so1 {
    import o1._
    def listListIntSet() = set(List(List(1)))
  }

  object so2 {
    import o2._
    def listListIntSet() = set(List(List(1)))
  }

  val a1 = so1.listListIntSet()
  val b1 = so1.listListIntSet()
  val a2 = so2.listListIntSet()
  val b2 = so2.listListIntSet()
  union(a1, b1)
  union(a2, b2)
  // Compiler error: union(a1, a2)

Final Words

In this article I've shown that by using Scala's powerful type and module system it's possible to get the same guarantees for instance equality in Scala as you get in Haskell, but with the extra flexibility of being able to create multiple type class instances. IMHO implicit parameters are a strictly more powerful solution to ad hoc polymorphism than type classes, and they can also be used for more purposes. But feel free to prove me wrong. :-)

Interestingly enough using a similar technique of adding a phantom type parameter to for example the Ord a type class in Haskell it should make it possible to create multiple Ord instances for the same type a and still get the same guarantees for the set union function for example. Maybe this idea has been explored already?