Handle overwriting in FileWriter instead of Tracer

This commit is contained in:
Alex Hall 2019-05-11 17:56:41 +02:00 committed by Ram Rachum
parent 57283caf47
commit d200457d63

View file

@ -100,8 +100,11 @@ def get_source_from_frame(frame):
return source
def get_write_and_truncate_functions(output):
truncate = None
def get_write_function(output, overwrite):
is_path = isinstance(output, (pycompat.PathLike, str))
if overwrite and not is_path:
raise Exception('`overwrite=True` can only be used when writing '
'content to file.')
if output is None:
def write(s):
stderr = sys.stderr
@ -110,14 +113,8 @@ def get_write_and_truncate_functions(output):
except UnicodeEncodeError:
# God damn Python 2
stderr.write(utils.shitcode(s))
elif isinstance(output, (pycompat.PathLike, str)):
def write(s):
with open(six.text_type(output), 'a') as output_file:
output_file.write(s)
def truncate():
with open(six.text_type(output), 'w'):
pass
elif is_path:
return FileWriter(output, overwrite).write
elif callable(output):
write = output
else:
@ -125,7 +122,18 @@ def get_write_and_truncate_functions(output):
def write(s):
output.write(s)
return write, truncate
return write
class FileWriter(object):
def __init__(self, path, overwrite):
self.path = six.text_type(path)
self.overwrite = overwrite
def write(self, s):
with open(self.path, 'w' if self.overwrite else 'a') as output_file:
output_file.write(s)
self.overwrite = False
thread_global = threading.local()
@ -180,11 +188,7 @@ class Tracer:
overwrite=False,
thread_info=False,
):
self._write, self.truncate = get_write_and_truncate_functions(output)
if self.truncate is None and overwrite:
raise Exception("`overwrite=True` can only be used when writing "
"content to file.")
self._write = get_write_function(output, overwrite)
self.watch = [
v if isinstance(v, BaseVariable) else CommonVariable(v)
@ -196,8 +200,6 @@ class Tracer:
self.frame_to_local_reprs = {}
self.depth = depth
self.prefix = prefix
self.overwrite = overwrite
self._did_overwrite = False
self.thread_info = thread_info
self.thread_info_padding = 0
assert self.depth >= 1
@ -237,9 +239,6 @@ class Tracer:
return simple_wrapper
def write(self, s):
if self.overwrite and not self._did_overwrite:
self.truncate()
self._did_overwrite = True
s = u'{self.prefix}{s}\n'.format(**locals())
self._write(s)