diff --git a/README.md b/README.md index 97bfba66..94c62ceb 100644 --- a/README.md +++ b/README.md @@ -112,6 +112,7 @@ Checks for the existence of private keys. #### `double-quote-string-fixer` This hook replaces double quoted strings with single quoted strings. + - `--replace-single-quotes` - replaces single quoted strings with double quoted strings. #### `end-of-file-fixer` Makes sure files end in a newline and only a newline. diff --git a/pre_commit_hooks/string_fixer.py b/pre_commit_hooks/string_fixer.py index d1b1c4ae..3b683662 100644 --- a/pre_commit_hooks/string_fixer.py +++ b/pre_commit_hooks/string_fixer.py @@ -13,10 +13,10 @@ else: # pragma: <3.12 cover FSTRING_START = FSTRING_END = -1 -START_QUOTE_RE = re.compile('^[a-zA-Z]*"') +START_QUOTE_RE = re.compile("^[a-zA-Z]*['\"]") -def handle_match(token_text: str) -> str: +def handle_match(token_text: str, replace_single_quotes: bool = False) -> str: if '"""' in token_text or "'''" in token_text: return token_text @@ -25,6 +25,8 @@ def handle_match(token_text: str) -> str: meat = token_text[match.end():-1] if '"' in meat or "'" in meat: return token_text + elif replace_single_quotes: + return match.group().replace("'", '"') + meat + '"' else: return match.group().replace('"', "'") + meat + "'" else: @@ -39,7 +41,7 @@ def get_line_offsets_by_line_no(src: str) -> list[int]: return offsets -def fix_strings(filename: str) -> int: +def fix_strings(filename: str, replace_single_quotes: bool = False) -> int: with open(filename, encoding='UTF-8', newline='') as f: contents = f.read() line_offsets = get_line_offsets_by_line_no(contents) @@ -58,7 +60,9 @@ def fix_strings(filename: str) -> int: elif token_type == FSTRING_END: # pragma: >=3.12 cover fstring_depth -= 1 elif fstring_depth == 0 and token_type == tokenize.STRING: - new_text = handle_match(token_text) + new_text = handle_match( + token_text, replace_single_quotes=replace_single_quotes + ) splitcontents[ line_offsets[srow] + scol: line_offsets[erow] + ecol @@ -76,12 +80,20 @@ def fix_strings(filename: str) -> int: def main(argv: Sequence[str] | None = None) -> int: parser = argparse.ArgumentParser() parser.add_argument('filenames', nargs='*', help='Filenames to fix') + parser.add_argument( + '--replace-single-quotes', + action='store_true', + default=False, + help='Replace single quotes into double quotes', + ) args = parser.parse_args(argv) retv = 0 for filename in args.filenames: - return_value = fix_strings(filename) + return_value = fix_strings( + filename, replace_single_quotes=args.replace_single_quotes + ) if return_value != 0: print(f'Fixing strings in {filename}') retv |= return_value diff --git a/tests/string_fixer_test.py b/tests/string_fixer_test.py index 8eb164c5..ce5ce884 100644 --- a/tests/string_fixer_test.py +++ b/tests/string_fixer_test.py @@ -8,17 +8,20 @@ TESTS = ( # Base cases - ("''", "''", 0), - ('""', "''", 1), - (r'"\'"', r'"\'"', 0), - (r'"\""', r'"\""', 0), - (r"'\"\"'", r"'\"\"'", 0), + ("''", "''", False, 0), + ("''", '""', True, 1), + ('""', "''", False, 1), + ('""', '""', True, 0), + (r'"\'"', r'"\'"', False, 0), + (r'"\""', r'"\""', False, 0), + (r"'\"\"'", r"'\"\"'", False, 0), # String somewhere in the line - ('x = "foo"', "x = 'foo'", 1), + ('x = "foo"', "x = 'foo'", False, 1), + ("x = 'foo'", 'x = "foo"', True, 1), # Test escaped characters - (r'"\'"', r'"\'"', 0), + (r'"\'"', r'"\'"', False, 0), # Docstring - ('""" Foo """', '""" Foo """', 0), + ('""" Foo """', '""" Foo """', False, 0), ( textwrap.dedent( """ @@ -34,23 +37,49 @@ '\n """, ), + False, 1, ), - ('"foo""bar"', "'foo''bar'", 1), + ( + textwrap.dedent( + """ + x = ' \\ + foo \\ + '\n + """, + ), + textwrap.dedent( + """ + x = " \\ + foo \\ + "\n + """, + ), + True, + 1, + ), + ('"foo""bar"', "'foo''bar'", False, 1), + ("'foo''bar'", '"foo""bar"', True, 1), pytest.param( "f'hello{\"world\"}'", "f'hello{\"world\"}'", + False, 0, id='ignore nested fstrings', ), ) -@pytest.mark.parametrize(('input_s', 'output', 'expected_retval'), TESTS) -def test_rewrite(input_s, output, expected_retval, tmpdir): +@pytest.mark.parametrize(('input_s', 'output', 'reversed_case', 'expected_retval'), TESTS) +def test_rewrite(input_s, output, reversed_case, expected_retval, tmpdir): path = tmpdir.join('file.py') path.write(input_s) - retval = main([str(path)]) + + argv = [str(path)] + if reversed_case: + argv.append("--replace-single-quotes") + retval = main(argv) + assert path.read() == output assert retval == expected_retval