From 94824d68a037d99253b92a5b260bb04907c42355 Mon Sep 17 00:00:00 2001 From: "woody.chow" Date: Thu, 6 Sep 2018 10:54:54 +0900 Subject: [PATCH] Add get method to py::dict. It works just like Python except that the default argument is not optional --- include/pybind11/pytypes.h | 8 ++++++++ tests/test_pytypes.cpp | 4 ++++ tests/test_pytypes.py | 1 + 3 files changed, 13 insertions(+) diff --git a/include/pybind11/pytypes.h b/include/pybind11/pytypes.h index 976abf86e3..737795ac51 100644 --- a/include/pybind11/pytypes.h +++ b/include/pybind11/pytypes.h @@ -1173,6 +1173,14 @@ class dict : public object { void clear() const { PyDict_Clear(ptr()); } bool contains(handle key) const { return PyDict_Contains(ptr(), key.ptr()) == 1; } bool contains(const char *key) const { return PyDict_Contains(ptr(), pybind11::str(key).ptr()) == 1; } + template object get(handle key, T &&defaultv) const { + PyObject* ret = PyDict_GetItem(ptr(), key.ptr()); + return reinterpret_borrow(ret ? handle(ret) : detail::object_or_cast(std::forward(defaultv))); + } + template object get(const char *key, T &&defaultv) const { + PyObject* ret = PyDict_GetItemString(ptr(), key); + return reinterpret_borrow(ret ? handle(ret) : detail::object_or_cast(std::forward(defaultv))); + } private: /// Call the `dict` Python type -- always returns a new reference diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index a962f0cccf..3c6b2cdf13 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -49,6 +49,10 @@ TEST_SUBMODULE(pytypes, m) { auto d2 = py::dict("z"_a=3, **d1); return d2; }); + m.def("dict_get_test", []() { + py::dict d("key"_a=1); + return py::make_tuple(d.get("key", 2), d.get("key2", 3)); + }); // test_str m.def("str_from_string", []() { return py::str(std::string("baz")); }); diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 992e7fc8e1..35494fb213 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -57,6 +57,7 @@ def test_dict(capture, doc): assert m.dict_keyword_constructor() == {"x": 1, "y": 2, "z": 3} + assert m.dict_get_test() == (1, 3) def test_str(doc): assert m.str_from_string().encode().decode() == "baz"