diff --git a/pysnooper/pysnooper.py b/pysnooper/pysnooper.py index bc17403..849d923 100644 --- a/pysnooper/pysnooper.py +++ b/pysnooper/pysnooper.py @@ -27,6 +27,9 @@ def get_write_and_truncate_functions(output): def truncate(): with open(output, 'w') as output_file: pass + elif callable(output): + write = output + truncate = None else: assert isinstance(output, utils.WritableStream) def write(s): diff --git a/tests/test_pysnooper.py b/tests/test_pysnooper.py index 03567c5..734296e 100644 --- a/tests/test_pysnooper.py +++ b/tests/test_pysnooper.py @@ -43,6 +43,39 @@ def test_string_io(): ) + +def test_callable(): + string_io = io.StringIO() + + def write(msg): + string_io.write(msg) + + @pysnooper.snoop(write) + def my_function(foo): + x = 7 + y = 8 + return y + x + + result = my_function('baba') + assert result == 15 + output = string_io.getvalue() + assert_output( + output, + ( + VariableEntry('foo', value_regex="u?'baba'"), + CallEntry('def my_function(foo):'), + LineEntry('x = 7'), + VariableEntry('x', '7'), + LineEntry('y = 8'), + VariableEntry('y', '8'), + LineEntry('return y + x'), + ReturnEntry('return y + x'), + ReturnValueEntry('15'), + ) + ) + + + def test_variables(): class Foo(object):