diff --git a/utils/round-trip-syntax-test b/utils/round-trip-syntax-test index fc09165211488..18b374544c5a8 100755 --- a/utils/round-trip-syntax-test +++ b/utils/round-trip-syntax-test @@ -9,6 +9,7 @@ import os import subprocess import sys import tempfile +from functools import reduce logging.basicConfig(format='%(message)s', level=logging.INFO) @@ -17,7 +18,8 @@ class RoundTripTask(object): def __init__(self, input_filename, action, swift_syntax_test, skip_bad_syntax): assert action == '-round-trip-parse' or action == '-round-trip-lex' - assert type(input_filename) == unicode + if sys.version_info[0] < 3: + assert type(input_filename) == unicode assert type(swift_syntax_test) == str assert os.path.isfile(input_filename), \ @@ -51,9 +53,9 @@ class RoundTripTask(object): self.output_file.close() self.stderr_file.close() - with open(self.output_file.name, 'r') as stdout_in: + with open(self.output_file.name, 'rb') as stdout_in: self.stdout = stdout_in.read() - with open(self.stderr_file.name, 'r') as stderr_in: + with open(self.stderr_file.name, 'rb') as stderr_in: self.stderr = stderr_in.read() os.remove(self.output_file.name) @@ -75,7 +77,7 @@ class RoundTripTask(object): raise RuntimeError() contents = ''.join(map(lambda l: l.decode('utf-8', errors='replace'), - open(self.input_filename).readlines())) + open(self.input_filename, 'rb').readlines())) stdout_contents = self.stdout.decode('utf-8', errors='replace') if contents == stdout_contents: @@ -92,7 +94,7 @@ def swift_files_in_dir(d): swift_files = [] for root, dirs, files in os.walk(d): for basename in files: - if not basename.decode('utf-8').endswith('.swift'): + if not basename.endswith('.swift'): continue abs_file = os.path.abspath(os.path.join(root, basename)) swift_files.append(abs_file) @@ -149,7 +151,8 @@ This driver invokes swift-syntax-test using -round-trip-lex and all_input_files = [filename for dir_listing in dir_listings for filename in dir_listing] all_input_files += args.individual_input_files - all_input_files = [f.decode('utf-8') for f in all_input_files] + if sys.version_info[0] < 3: + all_input_files = [f.decode('utf-8') for f in all_input_files] if len(all_input_files) == 0: logging.error('No input files!')