aboutsummaryrefslogtreecommitdiffstats
path: root/3rdparty/pybind11/tests/test_numpy_vectorize.cpp
blob: 274b7558a9e80cab3cc390b03dd303eb37533f5d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
/*
    tests/test_numpy_vectorize.cpp -- auto-vectorize functions over NumPy array
    arguments

    Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>

    All rights reserved. Use of this source code is governed by a
    BSD-style license that can be found in the LICENSE file.
*/

#include "pybind11_tests.h"
#include <pybind11/numpy.h>

double my_func(int x, float y, double z) {
    py::print("my_func(x:int={}, y:float={:.0f}, z:float={:.0f})"_s.format(x, y, z));
    return (float) x*y*z;
}

TEST_SUBMODULE(numpy_vectorize, m) {
    try { py::module_::import("numpy"); }
    catch (...) { return; }

    // test_vectorize, test_docs, test_array_collapse
    // Vectorize all arguments of a function (though non-vector arguments are also allowed)
    m.def("vectorized_func", py::vectorize(my_func));

    // Vectorize a lambda function with a capture object (e.g. to exclude some arguments from the vectorization)
    m.def("vectorized_func2",
        [](py::array_t<int> x, py::array_t<float> y, float z) {
            return py::vectorize([z](int x, float y) { return my_func(x, y, z); })(x, y);
        }
    );

    // Vectorize a complex-valued function
    m.def("vectorized_func3", py::vectorize(
        [](std::complex<double> c) { return c * std::complex<double>(2.f); }
    ));

    // test_type_selection
    // NumPy function which only accepts specific data types
    m.def("selective_func", [](py::array_t<int, py::array::c_style>) { return "Int branch taken."; });
    m.def("selective_func", [](py::array_t<float, py::array::c_style>) { return "Float branch taken."; });
    m.def("selective_func", [](py::array_t<std::complex<float>, py::array::c_style>) { return "Complex float branch taken."; });


    // test_passthrough_arguments
    // Passthrough test: references and non-pod types should be automatically passed through (in the
    // function definition below, only `b`, `d`, and `g` are vectorized):
    struct NonPODClass {
        NonPODClass(int v) : value{v} {}
        int value;
    };
    py::class_<NonPODClass>(m, "NonPODClass")
        .def(py::init<int>())
        .def_readwrite("value", &NonPODClass::value);
    m.def("vec_passthrough", py::vectorize(
        [](double *a, double b, py::array_t<double> c, const int &d, int &e, NonPODClass f, const double g) {
            return *a + b + c.at(0) + d + e + f.value + g;
        }
    ));

    // test_method_vectorization
    struct VectorizeTestClass {
        VectorizeTestClass(int v) : value{v} {};
        float method(int x, float y) { return y + (float) (x + value); }
        int value = 0;
    };
    py::class_<VectorizeTestClass> vtc(m, "VectorizeTestClass");
    vtc .def(py::init<int>())
        .def_readwrite("value", &VectorizeTestClass::value);

    // Automatic vectorizing of methods
    vtc.def("method", py::vectorize(&VectorizeTestClass::method));

    // test_trivial_broadcasting
    // Internal optimization test for whether the input is trivially broadcastable:
    py::enum_<py::detail::broadcast_trivial>(m, "trivial")
        .value("f_trivial", py::detail::broadcast_trivial::f_trivial)
        .value("c_trivial", py::detail::broadcast_trivial::c_trivial)
        .value("non_trivial", py::detail::broadcast_trivial::non_trivial);
    m.def("vectorized_is_trivial", [](
                py::array_t<int, py::array::forcecast> arg1,
                py::array_t<float, py::array::forcecast> arg2,
                py::array_t<double, py::array::forcecast> arg3
                ) {
        py::ssize_t ndim;
        std::vector<py::ssize_t> shape;
        std::array<py::buffer_info, 3> buffers {{ arg1.request(), arg2.request(), arg3.request() }};
        return py::detail::broadcast(buffers, ndim, shape);
    });

    m.def("add_to", py::vectorize([](NonPODClass& x, int a) { x.value += a; }));
}
pan>.c_str(),i); std::string res; do { PyObject* py_result; PyObject* dum; py_result = Py_CompileString(command.c_str(), "<stdin>", Py_single_input); dum = PyEval_EvalCode (py_result, glb, loc); Py_XDECREF (dum); Py_XDECREF (py_result); res = GetResultString( m_threadState ); GetResultString( m_threadState ) = ""; ++i; command = string_format("sys.completer.complete('%s', %d)\n", hint.c_str(),i); if (res.size()) { // throw away the newline res = res.substr(1, res.size() - 3); m_suggestions.push_back(res); } } while (res.size()); PyEval_ReleaseThread( m_threadState ); return m_suggestions; } void Interpreter::Initialize( ) { PyImport_AppendInittab("redirector", Interpreter::PyInit_redirector); Py_Initialize( ); PyEval_InitThreads( ); MainThreadState = PyEval_SaveThread( ); } void Interpreter::Finalize( ) { PyEval_RestoreThread( MainThreadState ); Py_Finalize( ); } std::string& Interpreter::GetResultString( PyThreadState* threadState ) { static std::map< PyThreadState*, std::string > ResultStrings; if ( !ResultStrings.count( threadState ) ) { ResultStrings[ threadState ] = ""; } return ResultStrings[ threadState ]; } PyObject* Interpreter::RedirectorInit(PyObject *, PyObject *) { Py_INCREF(Py_None); return Py_None; } PyObject* Interpreter::RedirectorWrite(PyObject *, PyObject *args) { char* output; PyObject *selfi; if (!PyArg_ParseTuple(args,"Os",&selfi,&output)) { return NULL; } std::string outputString( output ); PyThreadState* currentThread = PyThreadState_Get( ); std::string& resultString = GetResultString( currentThread ); resultString = resultString + outputString; Py_INCREF(Py_None); return Py_None; } PyMethodDef Interpreter::RedirectorMethods[] = { {"__init__", Interpreter::RedirectorInit, METH_VARARGS, "initialize the stdout/err redirector"}, {"write", Interpreter::RedirectorWrite, METH_VARARGS, "implement the write method to redirect stdout/err"}, {NULL,NULL,0,NULL}, }; PyObject *createClassObject(const char *name, PyMethodDef methods[]) { PyObject *pClassName = PyUnicode_FromString(name); PyObject *pClassBases = PyTuple_New(0); // An empty tuple for bases is equivalent to `(object,)` PyObject *pClassDic = PyDict_New(); PyMethodDef *def; // add methods to class for (def = methods; def->ml_name != NULL; def++) { PyObject *func = PyCFunction_New(def, NULL); PyObject *method = PyInstanceMethod_New(func); PyDict_SetItemString(pClassDic, def->ml_name, method); Py_DECREF(func); Py_DECREF(method); } // pClass = type(pClassName, pClassBases, pClassDic) PyObject *pClass = PyObject_CallFunctionObjArgs((PyObject *)&PyType_Type, pClassName, pClassBases, pClassDic, NULL); Py_DECREF(pClassName); Py_DECREF(pClassBases); Py_DECREF(pClassDic); return pClass; } PyMODINIT_FUNC Interpreter::PyInit_redirector(void) { static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, "redirector", 0, -1, 0 }; PyObject *m = PyModule_Create(&moduledef); if (m) { PyObject *fooClass = createClassObject("redirector", RedirectorMethods); PyModule_AddObject(m, "redirector", fooClass); Py_DECREF(fooClass); } return m; }