aboutsummaryrefslogtreecommitdiffstats
path: root/3rdparty/pybind11/tests/test_sequences_and_iterators.py
blob: 062e3b3d303a3d96de33a4b18edbd8cbc572e8b1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import pytest
from pytest import approx

from pybind11_tests import ConstructorStats
from pybind11_tests import sequences_and_iterators as m


def test_slice_constructors():
    assert m.make_forward_slice_size_t() == slice(0, -1, 1)
    assert m.make_reversed_slice_object() == slice(None, None, -1)


@pytest.mark.skipif(not m.has_optional, reason="no <optional>")
def test_slice_constructors_explicit_optional():
    assert m.make_reversed_slice_size_t_optional() == slice(None, None, -1)
    assert m.make_reversed_slice_size_t_optional_verbose() == slice(None, None, -1)


def test_generalized_iterators():
    assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero()) == [(1, 2), (3, 4)]
    assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero()) == [(1, 2)]
    assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero()) == []

    assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero_keys()) == [1, 3]
    assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_keys()) == [1]
    assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_keys()) == []

    assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).nonzero_values()) == [2, 4]
    assert list(m.IntPairs([(1, 2), (2, 0), (0, 3), (4, 5)]).nonzero_values()) == [2]
    assert list(m.IntPairs([(0, 3), (1, 2), (3, 4)]).nonzero_values()) == []

    # __next__ must continue to raise StopIteration
    it = m.IntPairs([(0, 0)]).nonzero()
    for _ in range(3):
        with pytest.raises(StopIteration):
            next(it)

    it = m.IntPairs([(0, 0)]).nonzero_keys()
    for _ in range(3):
        with pytest.raises(StopIteration):
            next(it)


def test_nonref_iterators():
    pairs = m.IntPairs([(1, 2), (3, 4), (0, 5)])
    assert list(pairs.nonref()) == [(1, 2), (3, 4), (0, 5)]
    assert list(pairs.nonref_keys()) == [1, 3, 0]
    assert list(pairs.nonref_values()) == [2, 4, 5]


def test_generalized_iterators_simple():
    assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_iterator()) == [
        (1, 2),
        (3, 4),
        (0, 5),
    ]
    assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_keys()) == [1, 3, 0]
    assert list(m.IntPairs([(1, 2), (3, 4), (0, 5)]).simple_values()) == [2, 4, 5]


def test_iterator_referencing():
    """Test that iterators reference rather than copy their referents."""
    vec = m.VectorNonCopyableInt()
    vec.append(3)
    vec.append(5)
    assert [int(x) for x in vec] == [3, 5]
    # Increment everything to make sure the referents can be mutated
    for x in vec:
        x.set(int(x) + 1)
    assert [int(x) for x in vec] == [4, 6]

    vec = m.VectorNonCopyableIntPair()
    vec.append([3, 4])
    vec.append([5, 7])
    assert [int(x) for x in vec.keys()] == [3, 5]
    assert [int(x) for x in vec.values()] == [4, 7]
    for x in vec.keys():
        x.set(int(x) + 1)
    for x in vec.values():
        x.set(int(x) + 10)
    assert [int(x) for x in vec.keys()] == [4, 6]
    assert [int(x) for x in vec.values()] == [14, 17]


def test_sliceable():
    sliceable = m.Sliceable(100)
    assert sliceable[::] == (0, 100, 1)
    assert sliceable[10::] == (10, 100, 1)
    assert sliceable[:10:] == (0, 10, 1)
    assert sliceable[::10] == (0, 100, 10)
    assert sliceable[-10::] == (90, 100, 1)
    assert sliceable[:-10:] == (0, 90, 1)
    assert sliceable[::-10] == (99, -1, -10)
    assert sliceable[50:60:1] == (50, 60, 1)
    assert sliceable[50:60:-1] == (50, 60, -1)


def test_sequence():
    cstats = ConstructorStats.get(m.Sequence)

    s = m.Sequence(5)
    assert cstats.values() == ["of size", "5"]

    assert "Sequence" in repr(s)
    assert len(s) == 5
    assert s[0] == 0 and s[3] == 0
    assert 12.34 not in s
    s[0], s[3] = 12.34, 56.78
    assert 12.34 in s
    assert s[0] == approx(12.34, rel=1e-05)
    assert s[3] == approx(56.78, rel=1e-05)

    rev = reversed(s)
    assert cstats.values() == ["of size", "5"]

    rev2 = s[::-1]
    assert cstats.values() == ["of size", "5"]

    it = iter(m.Sequence(0))
    for _ in range(3):  # __next__ must continue to raise StopIteration
        with pytest.raises(StopIteration):
            next(it)
    assert cstats.values() == ["of size", "0"]

    expected = [0, 56.78, 0, 0, 12.34]
    assert rev == approx(expected, rel=1e-05)
    assert rev2 == approx(expected, rel=1e-05)
    assert rev == rev2

    rev[0::2] = m.Sequence([2.0, 2.0, 2.0])
    assert cstats.values() == ["of size", "3", "from std::vector"]

    assert rev == approx([2, 56.78, 2, 0, 2], rel=1e-05)

    assert cstats.alive() == 4
    del it
    assert cstats.alive() == 3
    del s
    assert cstats.alive() == 2
    del rev
    assert cstats.alive() == 1
    del rev2
    assert cstats.alive() == 0

    assert cstats.values() == []
    assert cstats.default_constructions == 0
    assert cstats.copy_constructions == 0
    assert cstats.move_constructions >= 1
    assert cstats.copy_assignments == 0
    assert cstats.move_assignments == 0


def test_sequence_length():
    """#2076: Exception raised by len(arg) should be propagated"""

    class BadLen(RuntimeError):
        pass

    class SequenceLike:
        def __getitem__(self, i):
            return None

        def __len__(self):
            raise BadLen()

    with pytest.raises(BadLen):
        m.sequence_length(SequenceLike())

    assert m.sequence_length([1, 2, 3]) == 3
    assert m.sequence_length("hello") == 5


def test_map_iterator():
    sm = m.StringMap({"hi": "bye", "black": "white"})
    assert sm["hi"] == "bye"
    assert len(sm) == 2
    assert sm["black"] == "white"

    with pytest.raises(KeyError):
        assert sm["orange"]
    sm["orange"] = "banana"
    assert sm["orange"] == "banana"

    expected = {"hi": "bye", "black": "white", "orange": "banana"}
    for k in sm:
        assert sm[k] == expected[k]
    for k, v in sm.items():
        assert v == expected[k]
    assert list(sm.values()) == [expected[k] for k in sm]

    it = iter(m.StringMap({}))
    for _ in range(3):  # __next__ must continue to raise StopIteration
        with pytest.raises(StopIteration):
            next(it)


def test_python_iterator_in_cpp():
    t = (1, 2, 3)
    assert m.object_to_list(t) == [1, 2, 3]
    assert m.object_to_list(iter(t)) == [1, 2, 3]
    assert m.iterator_to_list(iter(t)) == [1, 2, 3]

    with pytest.raises(TypeError) as excinfo:
        m.object_to_list(1)
    assert "object is not iterable" in str(excinfo.value)

    with pytest.raises(TypeError) as excinfo:
        m.iterator_to_list(1)
    assert "incompatible function arguments" in str(excinfo.value)

    def bad_next_call():
        raise RuntimeError("py::iterator::advance() should propagate errors")

    with pytest.raises(RuntimeError) as excinfo:
        m.iterator_to_list(iter(bad_next_call, None))
    assert str(excinfo.value) == "py::iterator::advance() should propagate errors"

    lst = [1, None, 0, None]
    assert m.count_none(lst) == 2
    assert m.find_none(lst) is True
    assert m.count_nonzeros({"a": 0, "b": 1, "c": 2}) == 2

    r = range(5)
    assert all(m.tuple_iterator(tuple(r)))
    assert all(m.list_iterator(list(r)))
    assert all(m.sequence_iterator(r))


def test_iterator_passthrough():
    """#181: iterator passthrough did not compile"""
    from pybind11_tests.sequences_and_iterators import iterator_passthrough

    values = [3, 5, 7, 9, 11, 13, 15]
    assert list(iterator_passthrough(iter(values))) == values


def test_iterator_rvp():
    """#388: Can't make iterators via make_iterator() with different r/v policies"""
    import pybind11_tests.sequences_and_iterators as m

    assert list(m.make_iterator_1()) == [1, 2, 3]
    assert list(m.make_iterator_2()) == [1, 2, 3]
    assert not isinstance(m.make_iterator_1(), type(m.make_iterator_2()))