| from __future__ import annotations |
| |
| import argparse |
| import sys |
| from typing import IO |
| from typing import NamedTuple |
| from typing import Sequence |
| |
| DEFAULT_PRAGMA = b'# -*- coding: utf-8 -*-' |
| |
| |
| def has_coding(line: bytes) -> bool: |
| if not line.strip(): |
| return False |
| return ( |
| line.lstrip()[:1] == b'#' and ( |
| b'unicode' in line or |
| b'encoding' in line or |
| b'coding:' in line or |
| b'coding=' in line |
| ) |
| ) |
| |
| |
| class ExpectedContents(NamedTuple): |
| shebang: bytes |
| rest: bytes |
| # True: has exactly the coding pragma expected |
| # False: missing coding pragma entirely |
| # None: has a coding pragma, but it does not match |
| pragma_status: bool | None |
| ending: bytes |
| |
| @property |
| def has_any_pragma(self) -> bool: |
| return self.pragma_status is not False |
| |
| def is_expected_pragma(self, remove: bool) -> bool: |
| expected_pragma_status = not remove |
| return self.pragma_status is expected_pragma_status |
| |
| |
| def _get_expected_contents( |
| first_line: bytes, |
| second_line: bytes, |
| rest: bytes, |
| expected_pragma: bytes, |
| ) -> ExpectedContents: |
| ending = b'\r\n' if first_line.endswith(b'\r\n') else b'\n' |
| |
| if first_line.startswith(b'#!'): |
| shebang = first_line |
| potential_coding = second_line |
| else: |
| shebang = b'' |
| potential_coding = first_line |
| rest = second_line + rest |
| |
| if potential_coding.rstrip(b'\r\n') == expected_pragma: |
| pragma_status: bool | None = True |
| elif has_coding(potential_coding): |
| pragma_status = None |
| else: |
| pragma_status = False |
| rest = potential_coding + rest |
| |
| return ExpectedContents( |
| shebang=shebang, rest=rest, pragma_status=pragma_status, ending=ending, |
| ) |
| |
| |
| def fix_encoding_pragma( |
| f: IO[bytes], |
| remove: bool = False, |
| expected_pragma: bytes = DEFAULT_PRAGMA, |
| ) -> int: |
| expected = _get_expected_contents( |
| f.readline(), f.readline(), f.read(), expected_pragma, |
| ) |
| |
| # Special cases for empty files |
| if not expected.rest.strip(): |
| # If a file only has a shebang or a coding pragma, remove it |
| if expected.has_any_pragma or expected.shebang: |
| f.seek(0) |
| f.truncate() |
| f.write(b'') |
| return 1 |
| else: |
| return 0 |
| |
| if expected.is_expected_pragma(remove): |
| return 0 |
| |
| # Otherwise, write out the new file |
| f.seek(0) |
| f.truncate() |
| f.write(expected.shebang) |
| if not remove: |
| f.write(expected_pragma + expected.ending) |
| f.write(expected.rest) |
| |
| return 1 |
| |
| |
| def _normalize_pragma(pragma: str) -> bytes: |
| return pragma.encode().rstrip() |
| |
| |
| def main(argv: Sequence[str] | None = None) -> int: |
| print( |
| 'warning: this hook is deprecated and will be removed in a future ' |
| 'release because py2 is EOL. instead, use ' |
| 'https://github.com/asottile/pyupgrade', |
| file=sys.stderr, |
| ) |
| |
| parser = argparse.ArgumentParser( |
| 'Fixes the encoding pragma of python files', |
| ) |
| parser.add_argument('filenames', nargs='*', help='Filenames to fix') |
| parser.add_argument( |
| '--pragma', default=DEFAULT_PRAGMA, type=_normalize_pragma, |
| help=( |
| f'The encoding pragma to use. ' |
| f'Default: {DEFAULT_PRAGMA.decode()}' |
| ), |
| ) |
| parser.add_argument( |
| '--remove', action='store_true', |
| help='Remove the encoding pragma (Useful in a python3-only codebase)', |
| ) |
| args = parser.parse_args(argv) |
| |
| retv = 0 |
| |
| if args.remove: |
| fmt = 'Removed encoding pragma from {filename}' |
| else: |
| fmt = 'Added `{pragma}` to {filename}' |
| |
| for filename in args.filenames: |
| with open(filename, 'r+b') as f: |
| file_ret = fix_encoding_pragma( |
| f, remove=args.remove, expected_pragma=args.pragma, |
| ) |
| retv |= file_ret |
| if file_ret: |
| print( |
| fmt.format(pragma=args.pragma.decode(), filename=filename), |
| ) |
| |
| return retv |
| |
| |
| if __name__ == '__main__': |
| raise SystemExit(main()) |