Make tests compatible with Python 3.10

This commit is contained in:
Lumir Balhar 2021-05-18 15:05:38 +02:00 committed by Ram Rachum
parent c539cbc520
commit 1c94b1af52
2 changed files with 60 additions and 13 deletions

View file

@ -140,6 +140,7 @@ def test_relative_time():
# In with in recursive call
LineEntry('qux()'),
LineEntry(source_regex="with snoop:", min_python_version=(3, 10)),
ElapsedTimeEntry(0.4),
# Call to bar3 from after with
@ -168,6 +169,7 @@ def test_relative_time():
# In with in first call
LineEntry('qux()'),
LineEntry(source_regex="with snoop:", min_python_version=(3, 10)),
ElapsedTimeEntry(0.7),
# Call to bar3 from after with
@ -1086,6 +1088,7 @@ def test_with_block(normalize):
# In with in recursive call
LineEntry('qux()'),
LineEntry(source_regex="with snoop:", min_python_version=(3, 10)),
ElapsedTimeEntry(),
# Call to bar3 from after with
@ -1114,6 +1117,7 @@ def test_with_block(normalize):
# In with in first call
LineEntry('qux()'),
LineEntry(source_regex="with snoop:", min_python_version=(3, 10)),
ElapsedTimeEntry(),
# Call to bar3 from after with
@ -1183,6 +1187,8 @@ def test_with_block_depth(normalize):
LineEntry(),
ReturnEntry(),
ReturnValueEntry('20'),
VariableEntry(min_python_version=(3, 10)),
LineEntry(source_regex="with pysnooper.snoop.*", min_python_version=(3, 10)),
ElapsedTimeEntry(),
),
normalize=normalize,
@ -1250,6 +1256,8 @@ def test_cellvars(normalize):
ReturnValueEntry(),
ReturnEntry(),
ReturnValueEntry(),
VariableEntry(min_python_version=(3, 10)),
LineEntry(source_regex="with pysnooper.snoop.*", min_python_version=(3, 10)),
ElapsedTimeEntry(),
),
normalize=normalize,
@ -1298,6 +1306,8 @@ def test_var_order(normalize):
VariableEntry("seven", "7"),
ReturnEntry(),
ReturnValueEntry(),
VariableEntry("result", "None", min_python_version=(3, 10)),
LineEntry(source_regex="with pysnooper.snoop.*", min_python_version=(3, 10)),
ElapsedTimeEntry(),
),
normalize=normalize,

View file

@ -4,6 +4,7 @@ import os
import re
import abc
import inspect
import sys
from pysnooper.utils import DEFAULT_REPR_RE
@ -30,13 +31,24 @@ def get_function_arguments(function, exclude=()):
class _BaseEntry(pysnooper.pycompat.ABC):
def __init__(self, prefix=''):
def __init__(self, prefix='', min_python_version=None, max_python_version=None):
self.prefix = prefix
self.min_python_version = min_python_version
self.max_python_version = max_python_version
@abc.abstractmethod
def check(self, s):
pass
def is_compatible_with_current_python_version(self):
compatible = True
if self.min_python_version and self.min_python_version > sys.version_info:
compatible = False
if self.max_python_version and self.max_python_version < sys.version_info:
compatible = False
return compatible
def __repr__(self):
init_arguments = get_function_arguments(self.__init__,
exclude=('self',))
@ -53,8 +65,11 @@ class _BaseEntry(pysnooper.pycompat.ABC):
class _BaseValueEntry(_BaseEntry):
def __init__(self, prefix=''):
_BaseEntry.__init__(self, prefix=prefix)
def __init__(self, prefix='', min_python_version=None,
max_python_version=None):
_BaseEntry.__init__(self, prefix=prefix,
min_python_version=min_python_version,
max_python_version=max_python_version)
self.line_pattern = re.compile(
r"""^%s(?P<indent>(?: {4})*)(?P<preamble>[^:]*):"""
r"""\.{2,7} (?P<content>.*)$""" % (re.escape(self.prefix),)
@ -78,8 +93,11 @@ class _BaseValueEntry(_BaseEntry):
class ElapsedTimeEntry(_BaseEntry):
def __init__(self, elapsed_time_value=None, tolerance=0.2, prefix=''):
_BaseEntry.__init__(self, prefix=prefix)
def __init__(self, elapsed_time_value=None, tolerance=0.2, prefix='',
min_python_version=None, max_python_version=None):
_BaseEntry.__init__(self, prefix=prefix,
min_python_version=min_python_version,
max_python_version=max_python_version)
self.line_pattern = re.compile(
r"""^%s(?P<indent>(?: {4})*)Elapsed time: (?P<time>.*)""" % (
re.escape(self.prefix),
@ -116,8 +134,11 @@ class CallEndedByExceptionEntry(_BaseEntry):
class VariableEntry(_BaseValueEntry):
def __init__(self, name=None, value=None, stage=None, prefix='',
name_regex=None, value_regex=None):
_BaseValueEntry.__init__(self, prefix=prefix)
name_regex=None, value_regex=None, min_python_version=None,
max_python_version=None):
_BaseValueEntry.__init__(self, prefix=prefix,
min_python_version=min_python_version,
max_python_version=max_python_version)
if name is not None:
assert name_regex is None
if value is not None:
@ -179,8 +200,11 @@ class VariableEntry(_BaseValueEntry):
class _BaseSimpleValueEntry(_BaseValueEntry):
def __init__(self, value=None, value_regex=None, prefix=''):
_BaseValueEntry.__init__(self, prefix=prefix)
def __init__(self, value=None, value_regex=None, prefix='',
min_python_version=None, max_python_version=None):
_BaseValueEntry.__init__(self, prefix=prefix,
min_python_version=min_python_version,
max_python_version=max_python_version)
if value is not None:
assert value_regex is None
@ -240,8 +264,11 @@ class SourcePathEntry(_BaseValueEntry):
class _BaseEventEntry(_BaseEntry):
def __init__(self, source=None, source_regex=None, thread_info=None,
thread_info_regex=None, prefix=''):
_BaseEntry.__init__(self, prefix=prefix)
thread_info_regex=None, prefix='', min_python_version=None,
max_python_version=None):
_BaseEntry.__init__(self, prefix=prefix,
min_python_version=min_python_version,
max_python_version=max_python_version)
if type(self) is _BaseEventEntry:
raise TypeError
if source is not None:
@ -345,16 +372,26 @@ def assert_output(output, expected_entries, prefix=None, normalize=False):
if normalize:
verify_normalize(lines, prefix)
# Filter only entries compatible with the current Python
filtered_expected_entries = []
for expected_entry in expected_entries:
if isinstance(expected_entry, _BaseEntry):
if expected_entry.is_compatible_with_current_python_version():
filtered_expected_entries.append(expected_entry)
else:
filtered_expected_entries.append(expected_entry)
expected_entries_count = len(filtered_expected_entries)
any_mismatch = False
result = ''
template = u'\n{line!s:%s} {expected_entry} {arrow}' % max(map(len, lines))
for expected_entry, line in zip_longest(expected_entries, lines, fillvalue=""):
for expected_entry, line in zip_longest(filtered_expected_entries, lines, fillvalue=""):
mismatch = not (expected_entry and expected_entry.check(line))
any_mismatch |= mismatch
arrow = '<===' * mismatch
result += template.format(**locals())
if len(lines) != len(expected_entries):
if len(lines) != expected_entries_count:
result += '\nOutput has {} lines, while we expect {} lines.'.format(
len(lines), len(expected_entries))