Tuesday, April 21, 2009

Equality, Mutability and Products

The thoughts in this post was sparked by a discussion on the scala-user mailing list. The discussion was about the Object.equals() method and what its semantics should be. There are many different types of equality, but IMO the most usable one is if the expression a.equals(b) returns the same value regardless of when it's evaluated (provided the references a and b are constant), i.e. it's time invariant. If this was not the case, it would be quite impossible to implement Object.hashCode() in a useful way and you wouldn't be able to use objects as keys in a hash map or as members of a set.

The Java standard library breaks this practice in a number of places, for example java.util.ArrayList implements equals() as element equality which of course is not time invariant as the list can change over time (unless it's read only). I think the rationale behind this might be that there are no immutable collection classes in the Java libraries, and thus ArrayList is used for both mutable and immutable lists.

Equality for Mutable Case Classes

When using Scala case classes you automatically get an equals() and hashCode() implementation which calls the corresponding class fields methods. The implementation is the same regardless of whether the fields are mutable or not, so if you use vars the implementations will not be time invariant. For example:

scala> case class A(var i : Int)
defined class A

scala> val a1 = A(1)
a1: A = A(1)

scala> val a2 = A(1)
a2: A = A(1)

scala> a1 == a2
res57: Boolean = true

scala> a1.hashCode
res58: Int = -1085252538

scala> a1.i = 2

scala> a1 == a2
res59: Boolean = false

scala> a1.hashCode
res60: Int = -1085252537

How can we fix this? The Scala compiler only generates equals() and hashCode() implementations if we have not already implemented them in a super type (excluding java.lang.Object). So we can simply define a super trait for all mutable case classes:

trait Mutable {
override def equals(v : Any) = super.equals(v)
override def hashCode = super.hashCode
}

Now we can easily define mutable case class with the default, time invariant equals() and hashCode() methods:

scala> case class A(var i : Int) extends Mutable
defined class A

scala> val a1 = A(1)
a1: A = A(1)

scala> val a2 = A(1)
a2: A = A(1)

scala> a1 == a2
res61: Boolean = false

scala> a1.hashCode
res62: Int = 22213838

scala> a1.i = 2

scala> a1 == a2
res63: Boolean = false

scala> a1.hashCode
res64: Int = 22213838

A more elegant solution is to define a Var trait and use this instead of vars in case classes, but this is not always a practical solution and it has some performance drawbacks.

Time Variant Equalily for Products

So, what if I want to compare my case classes for element equality at a given moment in time? Turns out it's quite easy to implement such a function in Scala as all case classes implements the Product trait. Here's an example implementation (not thoroughly tested):

def equalElements(v1 : Any, v2 : Any) : Boolean = (v1, v2) match {
case (p1 : Product, p2 : Product) =>
p1.getClass == p2.getClass &&
(0 until p1.productArity).forall(i => equalElements(p1.productElement(i), p2.productElement(i)))
case _ => v1 == v2
}

This function has some limitations and won't work correctly with nested product and non-product objects. In this case a more type specialized solution is required.

Example usage:

scala> val a1 = A(1)
a1: A = A(1)

scala> val a2 = A(1)
a2: A = A(1)

scala> equalElements(a1, a2)
res68: Boolean = true

scala> a1.i = 2

scala> equalElements(a1, a2)
res69: Boolean = false