aboutsummaryrefslogtreecommitdiffstats
path: root/3rdparty/pybind11/tests/test_pickling.cpp
blob: e154bc483c641e6ddc0b28cbdf520780b3c55f37 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
/*
    tests/test_pickling.cpp -- pickle support

    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
    Copyright (c) 2021 The Pybind Development Team.

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#include "pybind11_tests.h"

#include <memory>
#include <stdexcept>
#include <utility>

namespace exercise_trampoline {

struct SimpleBase {
    int num = 0;
    virtual ~SimpleBase() = default;

    // For compatibility with old clang versions:
    SimpleBase() = default;
    SimpleBase(const SimpleBase &) = default;
};

struct SimpleBaseTrampoline : SimpleBase {};

struct SimpleCppDerived : SimpleBase {};

void wrap(py::module m) {
    py::class_<SimpleBase, SimpleBaseTrampoline>(m, "SimpleBase")
        .def(py::init<>())
        .def_readwrite("num", &SimpleBase::num)
        .def(py::pickle(
            [](const py::object &self) {
                py::dict d;
                if (py::hasattr(self, "__dict__")) {
                    d = self.attr("__dict__");
                }
                return py::make_tuple(self.attr("num"), d);
            },
            [](const py::tuple &t) {
                if (t.size() != 2) {
                    throw std::runtime_error("Invalid state!");
                }
                auto cpp_state = std::unique_ptr<SimpleBase>(new SimpleBaseTrampoline);
                cpp_state->num = t[0].cast<int>();
                auto py_state = t[1].cast<py::dict>();
                return std::make_pair(std::move(cpp_state), py_state);
            }));

    m.def("make_SimpleCppDerivedAsBase",
          []() { return std::unique_ptr<SimpleBase>(new SimpleCppDerived); });
    m.def("check_dynamic_cast_SimpleCppDerived", [](const SimpleBase *base_ptr) {
        return dynamic_cast<const SimpleCppDerived *>(base_ptr) != nullptr;
    });
}

} // namespace exercise_trampoline

TEST_SUBMODULE(pickling, m) {
    m.def("simple_callable", []() { return 20220426; });

    // test_roundtrip
    class Pickleable {
    public:
        explicit Pickleable(const std::string &value) : m_value(value) {}
        const std::string &value() const { return m_value; }

        void setExtra1(int extra1) { m_extra1 = extra1; }
        void setExtra2(int extra2) { m_extra2 = extra2; }
        int extra1() const { return m_extra1; }
        int extra2() const { return m_extra2; }

    private:
        std::string m_value;
        int m_extra1 = 0;
        int m_extra2 = 0;
    };

    class PickleableNew : public Pickleable {
    public:
        using Pickleable::Pickleable;
    };

    py::class_<Pickleable> pyPickleable(m, "Pickleable");
    pyPickleable.def(py::init<std::string>())
        .def("value", &Pickleable::value)
        .def("extra1", &Pickleable::extra1)
        .def("extra2", &Pickleable::extra2)
        .def("setExtra1", &Pickleable::setExtra1)
        .def("setExtra2", &Pickleable::setExtra2)
        // For details on the methods below, refer to
        // http://docs.python.org/3/library/pickle.html#pickling-class-instances
        .def("__getstate__", [](const Pickleable &p) {
            /* Return a tuple that fully encodes the state of the object */
            return py::make_tuple(p.value(), p.extra1(), p.extra2());
        });
    ignoreOldStyleInitWarnings([&pyPickleable]() {
        pyPickleable.def("__setstate__", [](Pickleable &p, const py::tuple &t) {
            if (t.size() != 3) {
                throw std::runtime_error("Invalid state!");
            }
            /* Invoke the constructor (need to use in-place version) */
            new (&p) Pickleable(t[0].cast<std::string>());

            /* Assign any additional state */
            p.setExtra1(t[1].cast<int>());
            p.setExtra2(t[2].cast<int>());
        });
    });

    py::class_<PickleableNew, Pickleable>(m, "PickleableNew")
        .def(py::init<std::string>())
        .def(py::pickle(
            [](const PickleableNew &p) {
                return py::make_tuple(p.value(), p.extra1(), p.extra2());
            },
            [](const py::tuple &t) {
                if (t.size() != 3) {
                    throw std::runtime_error("Invalid state!");
                }
                auto p = PickleableNew(t[0].cast<std::string>());

                p.setExtra1(t[1].cast<int>());
                p.setExtra2(t[2].cast<int>());
                return p;
            }));

#if !defined(PYPY_VERSION)
    // test_roundtrip_with_dict
    class PickleableWithDict {
    public:
        explicit PickleableWithDict(const std::string &value) : value(value) {}

        std::string value;
        int extra;
    };

    class PickleableWithDictNew : public PickleableWithDict {
    public:
        using PickleableWithDict::PickleableWithDict;
    };

    py::class_<PickleableWithDict> pyPickleableWithDict(
        m, "PickleableWithDict", py::dynamic_attr());
    pyPickleableWithDict.def(py::init<std::string>())
        .def_readwrite("value", &PickleableWithDict::value)
        .def_readwrite("extra", &PickleableWithDict::extra)
        .def("__getstate__", [](const py::object &self) {
            /* Also include __dict__ in state */
            return py::make_tuple(self.attr("value"), self.attr("extra"), self.attr("__dict__"));
        });
    ignoreOldStyleInitWarnings([&pyPickleableWithDict]() {
        pyPickleableWithDict.def("__setstate__", [](const py::object &self, const py::tuple &t) {
            if (t.size() != 3) {
                throw std::runtime_error("Invalid state!");
            }
            /* Cast and construct */
            auto &p = self.cast<PickleableWithDict &>();
            new (&p) PickleableWithDict(t[0].cast<std::string>());

            /* Assign C++ state */
            p.extra = t[1].cast<int>();

            /* Assign Python state */
            self.attr("__dict__") = t[2];
        });
    });

    py::class_<PickleableWithDictNew, PickleableWithDict>(m, "PickleableWithDictNew")
        .def(py::init<std::string>())
        .def(py::pickle(
            [](const py::object &self) {
                return py::make_tuple(
                    self.attr("value"), self.attr("extra"), self.attr("__dict__"));
            },
            [](const py::tuple &t) {
                if (t.size() != 3) {
                    throw std::runtime_error("Invalid state!");
                }

                auto cpp_state = PickleableWithDictNew(t[0].cast<std::string>());
                cpp_state.extra = t[1].cast<int>();

                auto py_state = t[2].cast<py::dict>();
                return std::make_pair(cpp_state, py_state);
            }));
#endif

    exercise_trampoline::wrap(m);
}