adds top keys list of keys in hashes to put at the top of a hash

This adds custom sorting to preferencially add a list of top keys
at the start of any json hash in the json document
diff --git a/pre_commit_hooks/pretty_format_json.py b/pre_commit_hooks/pretty_format_json.py
index 36a90eb..3e0ab32 100644
--- a/pre_commit_hooks/pretty_format_json.py
+++ b/pre_commit_hooks/pretty_format_json.py
@@ -6,17 +6,95 @@
 
 import simplejson
 
+class SortableOrderedDict(OrderedDict):
+    """Performs an in-place sort of the keys if you want."""
+    def sort(*args, **kwds):
+        self = args[0]
+        args = args[1:]
+        if 'key' not in kwds:
+            kwds['key'] = lambda x: x[0]
+        if len(args):
+            raise TypeError('expected no positional arguments got {0}'.format(len(args)))
+        sorted_od = sorted([x for x in self.items()], **kwds)
+        self.clear()
+        self.update(sorted_od)
 
-def _get_pretty_format(contents, indent, sort_keys=True):
-    return simplejson.dumps(
-        simplejson.loads(
-            contents,
-            object_pairs_hook=None if sort_keys else OrderedDict,
-        ),
-        sort_keys=sort_keys,
-        indent=indent
-    ) + "\n"  # dumps don't end with a newline
+class TrackedSod(SortableOrderedDict):
+    """Tracks instances of the SortableOrderedDict."""
+    _instances = []
+    def __init__(self, *args, **kwds):
+        super(TrackedSod, self).__init__(*args, **kwds)
+        self.__track(self)
 
+    @classmethod
+    def __track(cls, obj):
+        cls._instances.append(obj)
+
+
+def _get_pretty_format(contents, indent, sort_keys=True, top_keys=[]):
+    class KeyToCmp(object):
+        def __init__(self, obj, *args):
+            self.obj = obj[0]
+        def __lt__(self, other):
+            if self.obj in top_keys and other.obj in top_keys:
+                return top_keys.index(self.obj) < top_keys.index(other.obj)
+            elif self.obj in top_keys and other.obj not in top_keys:
+                return True
+            elif self.obj not in top_keys and other.obj in top_keys:
+                return False
+            else: 
+                return self.obj < other.obj
+        def __gt__(self, other):
+            if self.obj in top_keys and other.obj in top_keys:
+                return top_keys.index(self.obj) > top_keys.index(other.obj)
+            elif self.obj in top_keys and other.obj not in top_keys:
+                return False
+            elif self.obj not in top_keys and other.obj in top_keys:
+                return True
+            else: 
+                return self.obj > other.obj
+        def __eq__(self, other):
+            if self.obj in top_keys and other.obj in top_keys:
+                return top_keys.index(self.obj) == top_keys.index(other.obj)
+            elif self.obj in top_keys and other.obj not in top_keys:
+                return False
+            elif self.obj not in top_keys and other.obj in top_keys:
+                return False
+            else: 
+                return self.obj == other.obj
+        def __le__(self, other):
+            if self.obj in top_keys and other.obj in top_keys:
+                return top_keys.index(self.obj) <= top_keys.index(other.obj)
+            elif self.obj in top_keys and other.obj not in top_keys:
+                return True
+            elif self.obj not in top_keys and other.obj in top_keys:
+                return False
+            else: 
+                return self.obj <= other.obj
+        def __ge__(self, other):
+            if self.obj in top_keys and other.obj in top_keys:
+                return top_keys.index(self.obj) >= top_keys.index(other.obj)
+            elif self.obj in top_keys and other.obj not in top_keys:
+                return False
+            elif self.obj not in top_keys and other.obj in top_keys:
+                return True
+            else: 
+                return self.obj >= other.obj
+        def __ne__(self, other):
+            if self.obj in top_keys and other.obj in top_keys:
+                return top_keys.index(self.obj) != top_keys.index(other.obj)
+            elif self.obj in top_keys and other.obj not in top_keys:
+                return False
+            elif self.obj not in top_keys and other.obj in top_keys:
+                return False
+            else: 
+                return self.obj != other.obj
+    py_obj = simplejson.loads(contents, object_pairs_hook=TrackedSod)
+    if sort_keys:
+        for tsod in TrackedSod._instances:
+            tsod.sort(key=KeyToCmp)
+    # dumps don't end with a newline
+    return simplejson.dumps(py_obj, indent=indent) + "\n"
 
 def _autofix(filename, new_contents):
     print("Fixing file {0}".format(filename))
@@ -43,6 +121,9 @@
                 'Negative integer supplied to construct JSON indentation delimiter. ',
             )
 
+def parse_topkeys(s):
+    # type: (str) -> array
+    return s.split(',')
 
 def pretty_format_json(argv=None):
     parser = argparse.ArgumentParser()
@@ -65,6 +146,13 @@
         default=False,
         help='Keep JSON nodes in the same order',
     )
+    parser.add_argument(
+        '--top-keys',
+        type=parse_topkeys,
+        dest='top_keys',
+        default=[],
+        help='Ordered list of keys to keep at the top of JSON hashes',
+    )
 
     parser.add_argument('filenames', nargs='*', help='Filenames to fix')
     args = parser.parse_args(argv)
@@ -78,6 +166,7 @@
         try:
             pretty_contents = _get_pretty_format(
                 contents, args.indent, sort_keys=not args.no_sort_keys,
+                top_keys=args.top_keys
             )
 
             if contents != pretty_contents: