this is much cleaner and might actually get all the coverage with out a bunch of work
diff --git a/pre_commit_hooks/pretty_format_json.py b/pre_commit_hooks/pretty_format_json.py
index 3e0ab32..1058e21 100644
--- a/pre_commit_hooks/pretty_format_json.py
+++ b/pre_commit_hooks/pretty_format_json.py
@@ -6,95 +6,23 @@
 
 import simplejson
 
-class SortableOrderedDict(OrderedDict):
-    """Performs an in-place sort of the keys if you want."""
-    def sort(*args, **kwds):
-        self = args[0]
-        args = args[1:]
-        if 'key' not in kwds:
-            kwds['key'] = lambda x: x[0]
-        if len(args):
-            raise TypeError('expected no positional arguments got {0}'.format(len(args)))
-        sorted_od = sorted([x for x in self.items()], **kwds)
-        self.clear()
-        self.update(sorted_od)
-
-class TrackedSod(SortableOrderedDict):
-    """Tracks instances of the SortableOrderedDict."""
-    _instances = []
-    def __init__(self, *args, **kwds):
-        super(TrackedSod, self).__init__(*args, **kwds)
-        self.__track(self)
-
-    @classmethod
-    def __track(cls, obj):
-        cls._instances.append(obj)
 
 
 def _get_pretty_format(contents, indent, sort_keys=True, top_keys=[]):
-    class KeyToCmp(object):
-        def __init__(self, obj, *args):
-            self.obj = obj[0]
-        def __lt__(self, other):
-            if self.obj in top_keys and other.obj in top_keys:
-                return top_keys.index(self.obj) < top_keys.index(other.obj)
-            elif self.obj in top_keys and other.obj not in top_keys:
-                return True
-            elif self.obj not in top_keys and other.obj in top_keys:
-                return False
-            else: 
-                return self.obj < other.obj
-        def __gt__(self, other):
-            if self.obj in top_keys and other.obj in top_keys:
-                return top_keys.index(self.obj) > top_keys.index(other.obj)
-            elif self.obj in top_keys and other.obj not in top_keys:
-                return False
-            elif self.obj not in top_keys and other.obj in top_keys:
-                return True
-            else: 
-                return self.obj > other.obj
-        def __eq__(self, other):
-            if self.obj in top_keys and other.obj in top_keys:
-                return top_keys.index(self.obj) == top_keys.index(other.obj)
-            elif self.obj in top_keys and other.obj not in top_keys:
-                return False
-            elif self.obj not in top_keys and other.obj in top_keys:
-                return False
-            else: 
-                return self.obj == other.obj
-        def __le__(self, other):
-            if self.obj in top_keys and other.obj in top_keys:
-                return top_keys.index(self.obj) <= top_keys.index(other.obj)
-            elif self.obj in top_keys and other.obj not in top_keys:
-                return True
-            elif self.obj not in top_keys and other.obj in top_keys:
-                return False
-            else: 
-                return self.obj <= other.obj
-        def __ge__(self, other):
-            if self.obj in top_keys and other.obj in top_keys:
-                return top_keys.index(self.obj) >= top_keys.index(other.obj)
-            elif self.obj in top_keys and other.obj not in top_keys:
-                return False
-            elif self.obj not in top_keys and other.obj in top_keys:
-                return True
-            else: 
-                return self.obj >= other.obj
-        def __ne__(self, other):
-            if self.obj in top_keys and other.obj in top_keys:
-                return top_keys.index(self.obj) != top_keys.index(other.obj)
-            elif self.obj in top_keys and other.obj not in top_keys:
-                return False
-            elif self.obj not in top_keys and other.obj in top_keys:
-                return False
-            else: 
-                return self.obj != other.obj
-    py_obj = simplejson.loads(contents, object_pairs_hook=TrackedSod)
-    if sort_keys:
-        for tsod in TrackedSod._instances:
-            tsod.sort(key=KeyToCmp)
-    # dumps don't end with a newline
-    return simplejson.dumps(py_obj, indent=indent) + "\n"
+    def pairs_first(pairs):
+        before = [pair for pair in pairs if pair[0] in top_keys]
+        before = sorted(before, key=lambda x: top_keys.index(x[0]))
+        after = [pair for pair in pairs if pair[0] not in top_keys]
+        if sort_keys:
+            after = sorted(after, key=lambda x: x[0])
+        return OrderedDict(before + after)
+    return simplejson.dumps(
+        simplejson.loads(
+            contents,
+            object_pairs_hook=pairs_first,
+        ),
+        indent=indent
+    ) + "\n"  # dumps don't end with a newline
 
 def _autofix(filename, new_contents):
     print("Fixing file {0}".format(filename))