diff --git a/pysnooper/tracer.py b/pysnooper/tracer.py index 0914820..5929c87 100644 --- a/pysnooper/tracer.py +++ b/pysnooper/tracer.py @@ -99,6 +99,34 @@ def get_source_from_frame(frame): return source +def get_write_and_truncate_functions(output): + truncate = None + if output is None: + def write(s): + stderr = sys.stderr + try: + stderr.write(s) + except UnicodeEncodeError: + # God damn Python 2 + stderr.write(utils.shitcode(s)) + elif isinstance(output, (pycompat.PathLike, str)): + def write(s): + with open(six.text_type(output), 'a') as output_file: + output_file.write(s) + + def truncate(): + with open(six.text_type(output), 'w'): + pass + elif callable(output): + write = output + else: + assert isinstance(output, utils.WritableStream) + + def write(s): + output.write(s) + return write, truncate + + class Tracer: def __init__( self, @@ -142,38 +170,12 @@ class Tracer: @pysnooper.snoop(prefix='ZZZ ') ''' - self.truncate = None - if output is None: - def write(s): - stderr = sys.stderr - try: - stderr.write(s) - except UnicodeEncodeError: - # God damn Python 2 - stderr.write(utils.shitcode(s)) - elif isinstance(output, (pycompat.PathLike, str)): - def write(s): - with open(six.text_type(output), 'a') as output_file: - output_file.write(s) - - def truncate(): - with open(six.text_type(output), 'w'): - pass - - self.truncate = truncate - elif callable(output): - write = output - else: - assert isinstance(output, utils.WritableStream) - - def write(s): - output.write(s) + self._write, self.truncate = get_write_and_truncate_functions(output) if self.truncate is None and overwrite: raise Exception("`overwrite=True` can only be used when writing " "content to file.") - self._write = write self.watch = [ v if isinstance(v, BaseVariable) else CommonVariable(v) for v in utils.ensure_tuple(watch)