August 2009 Archives

Writing An Effective Merge

Recently I've been implementing simple algorithms from scratch from memory. As an exercise, this has the advantage of being short, well-defined and (for me at least) challenging. For example, I found it surprisingly hard to write an effective merge of the kind needed for the mergesort algorithm.

The spec is void merge(int[] a, int lo, int mid, int hi): two parts of the array are to be merged. It's assumed the two parts are already sorted. The two parts are adjacent; the first part includes the items from lo to mid - 1; the second is the items from mid to hi - 1.

Here's my first attempt, rushed out late at night.

    private void merge(int[] a, int lo, int mid, int hi) {
        assert mid == lo + (hi - lo) / 2;
        List<Integer> result = new ArrayList<Integer>();
        int i = lo;
        int j = mid;

        while (i < mid && j < hi) {
            int m = a[i];
            int n = a[j];
            if (m > n) {
                result.add(n);
                j++;
            } else {
                result.add(m);
                i++;
            }
        }
        while (i < mid) {
            result.add(a[i++]);
        }
        while (j < hi) {
            result.add(a[j++]);
        }
        for (i = lo, j = 0; j < result.size(); i++, j++) {
            a[i] = result.get(j);
        }
    }
}

It did the job, but oh my, how I didn't like it. It's verbose and it creates this temporary list of O(n) size. After some more work and several blind alleys, I came up with the following. It's better, but still, it seems harder than it should be. Oddly enough, I found it easier to implement quicksort from scratch than just the merge portion of mergesort.

    private void merge(int[] a, int lo, int mid, int hi) {
        assert mid == lo + (hi - lo) / 2;
        int i = lo;
        int j = mid;
        while (i < mid) {
            if (a[j] < a[i]) {
                swap(a, i, j);
                j++;
            }
            i++;            
        }
        if (j < hi) {
            while (a[j] < a[i]) {
                swap(a, i, j);
                i++;
            }
        }
    }

    private void swap(int[] a, int i, int j) {
        int t = a[i];
        a[i] = a[j];
        a[j] = t;
    }

Finally, the tests I wrote while working this out:

public void testEmpty() {
    int [] a = new int[0];
    merge(a, 0, 0, 0);
    assertEquals(0, a.length);
}

public void testSingleValue() {
    int [] a = new int[] { 1 };
    merge(a, 0, 0, 1);
    assertArraysEqual(new int[] { 1 }, a);
}

public void testTwoValuesRequiringNoSwap() {
    int [] a = new int[] { 1, 2 };
    merge(a, 0, 1, 2);
    assertArraysEqual(new int[] { 1, 2 }, a);
}

public void testTwoValuesRequiringSwap() {
    int [] a = new int[] { 2, 1 };
    merge(a, 0, 1, 2);
    assertArraysEqual(new int[] { 1, 2 }, a);
}

public void testSimpleInterleavedMerge() {
    int [] a = new int[] { 1, 3, 2, 4 };
    merge(a, 0, 2, 4);
    assertArraysEqual(new int[] { 1, 2, 3, 4 }, a);
}

public void testMergeOfSubset() {
    int [] a = new int[] { 1, 3, 2, 4, 6, 5 };
    merge(a, 0, 2, 4);
    assertArraysEqual(new int[] { 1, 2, 3, 4, 6, 5 }, a);
}

public void testMergeOfAllEqual() {
    int [] a = new int[] { 1, 1, 1, 1 };
    merge(a, 0, 2, 4);
    assertArraysEqual(new int[] { 1, 1, 1, 1 }, a);
}

public void testMergeAllFromRight() {
    int [] a = new int[] { 3, 4, 1, 2 };
    merge(a, 0, 2, 4);
    assertArraysEqual(new int[] { 1, 2, 3, 4 }, a);
}

public void testMergeAllFromLeft() {
    int [] a = new int[] { 1, 2, 3, 4 };
    merge(a, 0, 2, 4);
    assertArraysEqual(new int[] { 1, 2, 3, 4 }, a);
}

public void testMergeUnevenNumber() {
    int [] a = new int[] { 3, 1, 2 };
    merge(a, 0, 1, 3);
    assertArraysEqual(new int[] { 1, 2, 3 }, a);
}

public void testMergeLargerList() {
    int [] a = new int[] { 5, 6, 7, 1, 2, 3, 4 };
    merge(a, 0, 3, 7);
    assertArraysEqual(new int[] { 1, 2, 3, 4, 5, 6, 7 }, a);
}

Counting Unique Digits

Here's a way to count the unique digits in a number using a mostly functional style. There's an API method, count_unique_digits:

def count_unique_digits(n)
  if n < 0
    raise Exception.new("can only handle numbers >= 0")
  end
  return count_unique_digits_iter([], 0, n)
end

This calls the main method, count_unique_digits_iter. This is recursive, passing intermediate results as parameters to itself.

def count_unique_digits_iter(num_array, count, n)
  new_digit = n % 10
  new_digit_unique = ! num_array.include?(new_digit)
  increment = new_digit_unique ? 1 : 0
  if n < 10
    return count + increment
  end
  if new_digit_unique
    num_array.push(new_digit)
  end
  return count_unique_digits_iter(num_array, count + increment, n/10)
end

Given the methods, here's a method to count how often digits repeat in all the integers in a range:

def build_count_map(range)
  count_map = Hash.new(0)
  for i in range
    old_count = count_map[count_unique_digits(i)]
    count_map[count_unique_digits(i)] = old_count.succ
  end
  count_map
end

This counts all the digit repetitions of four digit numbers

count_map = build_count_map(1000 .. 9999)
for count in count_map.keys.sort
  print "#{count}: #{count_map[count]}\n"
end

And here's the output:

1: 9
2: 567
3: 3888
4: 4536

(In the first version of this post, I omitted the unit tests.) Instead of a test framework, I used a homebrew checking method:

def check_method(method, input, expected)
  if self.send(method, input) != expected
    raise Exception.new("#{method}(#{input}):\n" +
      "expected " + expected.inspect +
      " but got " + self.send(method, input).inspect + "\n")
  end
end

Which made it easy to plug in tests for count_unique_digits

def check_digits(input, expected)
  check_method(:count_unique_digits, input, expected)
end

check_digits(0, 1)
check_digits(1, 1)
check_digits(10, 2)
check_digits(11, 1)
check_digits(123, 3)
check_digits(999, 1)
check_digits(1000, 2)

and tests for build_count_map

def check_count_map(input, expected)
  check_method(:build_count_map, input, expected)
end

check_count_map(0 .. 0, {1 => 1})
check_count_map(0 .. 1, {1 => 2})
check_count_map(10 .. 20, {1 => 1, 2 => 10}) # only 11 has no repeated digits
check_count_map(10 .. 99, {1 => 9, 2 => 81}) # 11, 22, 33, ..., 99 are the doubles