-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathparametrize.py
132 lines (114 loc) · 4.66 KB
/
parametrize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import ast
import itertools
from typing import Optional
from flake8_plugin_utils import Visitor, check_equivalent_nodes
from ..config import (
Config,
ParametrizeNamesType,
ParametrizeValuesRowType,
ParametrizeValuesType,
)
from ..errors import (
DuplicateParametrizeTestCases,
ParametrizeNamesWrongType,
ParametrizeValuesWrongType,
)
from ..utils import extract_parametrize_call_args, is_parametrize_call
class ParametrizeVisitor(Visitor[Config]):
def _check_parametrize_names(
self, node: ast.Call, names: ast.AST
) -> Optional[bool]:
"""
Handles names in parametrize, checks for PT006.
Returns a flag indicating whether parametrize has multiple names,
or None if we can't tell.
"""
multiple_names: Optional[bool] = None
found_type: Optional[ParametrizeNamesType] = None
if isinstance(names, ast.Str):
if ',' in names.s:
found_type = ParametrizeNamesType.CSV
multiple_names = True
else:
multiple_names = False
elif isinstance(names, (ast.List, ast.Tuple)):
multiple_names = len(names.elts) > 1
if not multiple_names:
self.error_from_node(
ParametrizeNamesWrongType, node, expected_type='string'
)
elif isinstance(names, ast.Tuple):
found_type = ParametrizeNamesType.TUPLE
else:
found_type = ParametrizeNamesType.LIST
if multiple_names and found_type != self.config.parametrize_names_type:
self.error_from_node(
ParametrizeNamesWrongType,
node,
expected_type=self.config.parametrize_names_type.value,
)
return multiple_names
def _get_expected_values_type_str(self, multiple_names: Optional[bool]) -> str:
if multiple_names:
return (
f'{self.config.parametrize_values_type.value}'
f' of {self.config.parametrize_values_row_type.value}s'
)
return self.config.parametrize_values_type.value
def _check_parametrize_values(
self, node: ast.Call, values: Optional[ast.AST], multiple_names: Optional[bool]
) -> None:
"""Checks for PT007."""
expected_type_str = self._get_expected_values_type_str(multiple_names)
if isinstance(values, ast.List):
top_level_type = ParametrizeValuesType.LIST
elif isinstance(values, ast.Tuple):
top_level_type = ParametrizeValuesType.TUPLE
else:
return
if top_level_type != self.config.parametrize_values_type:
self.error_from_node(
ParametrizeValuesWrongType, node, expected_type=expected_type_str
)
return
if multiple_names:
for element in values.elts:
found_row_type: Optional[ParametrizeValuesRowType] = None
if isinstance(element, ast.List):
found_row_type = ParametrizeValuesRowType.LIST
elif isinstance(element, ast.Tuple):
found_row_type = ParametrizeValuesRowType.TUPLE
if (
found_row_type
and found_row_type != self.config.parametrize_values_row_type
):
self.error_from_node(
ParametrizeValuesWrongType,
node,
expected_type=expected_type_str,
)
break
def _check_parametrize_duplicates(
self, node: ast.AST, values: Optional[ast.AST]
) -> None:
"""Checks for PT014."""
if not isinstance(values, (ast.List, ast.Tuple, ast.Set)):
return
for (i, element1), (j, element2) in itertools.combinations(
enumerate(values.elts, start=1), 2
):
if check_equivalent_nodes(element1, element2):
self.error_from_node(
DuplicateParametrizeTestCases, node, indexes=(i, j)
)
def _check_parametrize_call(self, node: ast.Call) -> None:
"""Checks for all violations regarding `pytest.mark.parametrize` calls."""
args = extract_parametrize_call_args(node)
if not args:
return
multiple_names = self._check_parametrize_names(node, args.names)
self._check_parametrize_values(node, args.values, multiple_names)
self._check_parametrize_duplicates(node, args.values)
def visit_Call(self, node: ast.Call) -> None:
if is_parametrize_call(node):
self._check_parametrize_call(node)