diff --git a/pysnooper/tracer.py b/pysnooper/tracer.py index b05b0f9..3245e8a 100644 --- a/pysnooper/tracer.py +++ b/pysnooper/tracer.py @@ -100,8 +100,11 @@ def get_source_from_frame(frame): return source -def get_write_and_truncate_functions(output): - truncate = None +def get_write_function(output, overwrite): + is_path = isinstance(output, (pycompat.PathLike, str)) + if overwrite and not is_path: + raise Exception('`overwrite=True` can only be used when writing ' + 'content to file.') if output is None: def write(s): stderr = sys.stderr @@ -110,14 +113,8 @@ def get_write_and_truncate_functions(output): except UnicodeEncodeError: # God damn Python 2 stderr.write(utils.shitcode(s)) - elif isinstance(output, (pycompat.PathLike, str)): - def write(s): - with open(six.text_type(output), 'a') as output_file: - output_file.write(s) - - def truncate(): - with open(six.text_type(output), 'w'): - pass + elif is_path: + return FileWriter(output, overwrite).write elif callable(output): write = output else: @@ -125,7 +122,18 @@ def get_write_and_truncate_functions(output): def write(s): output.write(s) - return write, truncate + return write + + +class FileWriter(object): + def __init__(self, path, overwrite): + self.path = six.text_type(path) + self.overwrite = overwrite + + def write(self, s): + with open(self.path, 'w' if self.overwrite else 'a') as output_file: + output_file.write(s) + self.overwrite = False thread_global = threading.local() @@ -180,11 +188,7 @@ class Tracer: overwrite=False, thread_info=False, ): - self._write, self.truncate = get_write_and_truncate_functions(output) - - if self.truncate is None and overwrite: - raise Exception("`overwrite=True` can only be used when writing " - "content to file.") + self._write = get_write_function(output, overwrite) self.watch = [ v if isinstance(v, BaseVariable) else CommonVariable(v) @@ -196,8 +200,6 @@ class Tracer: self.frame_to_local_reprs = {} self.depth = depth self.prefix = prefix - self.overwrite = overwrite - self._did_overwrite = False self.thread_info = thread_info self.thread_info_padding = 0 assert self.depth >= 1 @@ -237,9 +239,6 @@ class Tracer: return simple_wrapper def write(self, s): - if self.overwrite and not self._did_overwrite: - self.truncate() - self._did_overwrite = True s = u'{self.prefix}{s}\n'.format(**locals()) self._write(s)