Skip to content

Support dict for jit trace #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
cdaae0b
support python dict in jit.trace
tangleintel Jun 22, 2022
bca2afd
Add comments for key parts
tangleintel Jun 22, 2022
937d85d
add missing file
tangleintel Jun 22, 2022
110cf02
add UT for this feature
tangleintel Jun 26, 2022
a878835
Modify failed UT to obey my solution
tangleintel Jul 15, 2022
663102d
modify code format
tangleintel Jul 18, 2022
f13c658
Complete UT
tangleintel Sep 5, 2022
5860d78
modify the python internal API _create_method_from_trace_with_tuple()…
tangleintel Oct 4, 2022
eda9a22
add an option to trace() and trace_module() to extend the meaning of …
tangleintel Oct 5, 2022
0deda00
revert the UT
tangleintel Oct 5, 2022
d37e252
add warning msg and function doc of the adding arguments
tangleintel Oct 6, 2022
b5539e4
fix the lint error
tangleintel Oct 6, 2022
34d585d
didn't change the debug name's order, just compact when there is miss…
tangleintel Oct 7, 2022
6f37531
lint error clang-format
tangleintel Oct 7, 2022
bd8ea4c
fix build error for some compilers on other platform
tangleintel Oct 7, 2022
f8a4718
clang-format
tangleintel Oct 7, 2022
89497ff
add argument example_kwarg_inputs to unpack dict
tangleintel Oct 11, 2022
ead480a
refine the docstring
tangleintel Oct 12, 2022
54c3f66
support this feature to pure python function in jit.trace() and modif…
tangleintel Oct 13, 2022
8dba25f
modify the interface of trace_module() and add UT for it & update doc…
tangleintel Oct 13, 2022
5beeb63
fix lint error
tangleintel Oct 14, 2022
073861c
doc format
tangleintel Oct 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3027,6 +3027,46 @@ def forward(self, x):
checker.check("def forward")
checker.run(str(cm.exception))

def test_dictionary_as_example_inputs_for_jit_trace(self):
class TestModule_v1(torch.nn.Module):
def __init__(self):
super(TestModule_v1, self).__init__()

def forward(self, key2=None, key3=None, key4=None, key5=None, key1=None, key6=None):
return key1 + key2 + key3

class TestModule_v2(torch.nn.Module):
def __init__(self):
super(TestModule_v2, self).__init__()

def forward(self, x, y):
return x + y

def test_func(x, y):
return x + y
model_1 = TestModule_v1()
model_2 = TestModule_v2()
value1 = torch.ones(1)
value2 = torch.ones(1)
value3 = torch.ones(1)
example_input_dict = {'key1': value1, 'key2': value2, 'key3': value3}
example_input_dict_func = {'x': value1, 'y': value2}
traced_model_1 = torch.jit.trace(model_1, example_kwarg_inputs=example_input_dict, strict=False)
traced_model_1_m = torch.jit.trace_module(
model_1, {'forward': example_input_dict}, example_inputs_is_kwarg=True, strict=False)
traced_model_2 = torch.jit.trace(model_2, example_kwarg_inputs={'x': torch.rand([2]), 'y': torch.rand([2])})
traced_func = torch.jit.trace(test_func, example_kwarg_inputs=example_input_dict_func, strict=False)
res_1 = traced_model_1(**example_input_dict)
res_1_m = traced_model_1_m(**example_input_dict)
self.assertEqual(res_1, 3 * torch.ones(1))
self.assertEqual(res_1_m, 3 * torch.ones(1))
res_func = traced_func(**example_input_dict_func)
self.assertEqual(res_func, 2 * torch.ones(1))
with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'x'."):
res_2 = traced_model_2(**{'z': torch.rand([2]), 'y': torch.rand([2])})
with self.assertRaisesRegex(RuntimeError, r"forward\(\) is missing value for argument 'y'."):
res_2 = traced_model_2(**{'x': torch.rand([2]), 'z': torch.rand([2])})


class TestScript(JitTestCase):

Expand Down
9 changes: 9 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,15 @@ def _create_function_from_trace(
force_outplace: _bool,
argument_names: List[str]
) -> Tuple[Graph, Stack]: ...
def _create_function_from_trace_with_dict(
qualname: str,
func: Callable[..., Any],
input_dict: Dict[str, Any],
var_lookup_fn: Callable[[Tensor], str],
strict: _bool,
force_outplace: _bool,
argument_names: List[str]
) -> Tuple[Graph, Stack]: ...
def _jit_is_script_object(obj: Any) -> _bool: ...
def _last_executed_optimized_graph() -> Graph: ...
def parse_type_comment(comment: str) -> Decl: ...
Expand Down
11 changes: 11 additions & 0 deletions torch/csrc/jit/python/pybind_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,17 @@ inline Stack toTraceableStack(const py::tuple& inputs) {
return info.toTupleRef().elements().vec();
}

// Serialize the python dictionary into a traceable stack.
inline Stack toTraceableStack(const py::dict& inputs) {
Stack res;
for (auto it = inputs.begin(); it != inputs.end(); it++) {
if (THPVariable_Check(it->second.ptr())) {
res.push_back(toIValue(it->second, tryToInferType(it->second).type()));
}
}
return res;
}

inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
auto elems = c10::impl::GenericList(elem_type);
for (auto elem : obj) {
Expand Down
63 changes: 63 additions & 0 deletions torch/csrc/jit/python/python_tracer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,69 @@ SourceRange getPythonInterpreterSourceRange() {
return SourceRange(source, 0, stack_trace_text.size());
}

std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict(
const py::function& func,
const py::dict& inputs_dict,
Stack trace_inputs,
const py::function& var_name_lookup_fn,
bool strict,
bool force_outplace,
Module* self,
const std::vector<std::string>& argument_names) {
C10_LOG_API_USAGE_ONCE("torch.tracer");

auto lookup_fn_adapter =
[var_name_lookup_fn](const Variable& var) -> std::string {
pybind11::gil_scoped_acquire ag;
return py::cast<std::string>(var_name_lookup_fn(var));
};

// The argument_names parameter is parsed in python and its order
// is the same as the arguments' decalaration order in forward() method.
// These name shall be added to the graph as debug name and the order
// should align with the traceable stack we generated by the python dict.
std::vector<std::string> compact_argument_names;
Stack compact_trace_inputs;
for (std::vector<std::string>::size_type i = 0; i < argument_names.size();
i++) {
if (inputs_dict.contains(argument_names[i])) {
compact_argument_names.push_back(argument_names[i]);
}
}
for (std::vector<std::string>::size_type i = 0;
i < compact_argument_names.size();
i++) {
for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) {
if (py::cast<std::string>(it->first) == compact_argument_names[i]) {
if (THPVariable_Check(it->second.ptr())) {
compact_trace_inputs.push_back(
toIValue(it->second, tryToInferType(it->second).type()));
}
}
}
}

auto outs = tracer::trace(
std::move(compact_trace_inputs),
[&](Stack inputs) -> Stack {
// We just leave the inputs_dict as it was and pass it to forward
// method.
auto out = func(**inputs_dict);
if (out.ptr() == Py_None) {
AT_ERROR(
"The traced function didn't return any values! Side-effects are not "
"captured in traces, so it would be a no-op.");
}
return {toTypeInferredIValue(out)};
},
lookup_fn_adapter,
strict,
force_outplace,
self,
compact_argument_names);
return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs));
}

std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
const py::function& func,
Stack trace_inputs,
Expand Down
10 changes: 10 additions & 0 deletions torch/csrc/jit/python/python_tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ Node* preRecordPythonTrace(
at::ArrayRef<autograd::Variable> inputs,
std::vector<THPObjectPtr> scalar_args);

std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracingWithDict(
const py::function& func,
const py::dict& inputs_dict,
Stack inputs,
const py::function& var_name_lookup_fn,
bool strict,
bool force_outplace,
Module* self = nullptr,
const std::vector<std::string>& argument_names = {});

std::pair<std::shared_ptr<Graph>, Stack> createGraphByTracing(
const py::function& func,
Stack inputs,
Expand Down
74 changes: 74 additions & 0 deletions torch/csrc/jit/python/script_init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1218,6 +1218,43 @@ void initJitScriptBindings(PyObject* module) {
py::arg("strict"),
py::arg("force_outplace"),
py::arg("argument_names") = std::vector<std::string>())
.def(
"_create_method_from_trace_with_dict",
[](Module& self,
const std::string& name,
const py::function& func,
const py::dict& input_dict,
const py::function& var_name_lookup_fn,
bool strict,
bool force_outplace,
const std::vector<std::string>& argument_names) {
// prereq: Module's buffers and parameters are unique
// this was ensured in python before calling this function
auto typed_inputs = toTraceableStack(input_dict);

std::shared_ptr<Graph> graph =
std::get<0>(tracer::createGraphByTracingWithDict(
func,
input_dict,
typed_inputs,
var_name_lookup_fn,
strict,
force_outplace,
&self,
argument_names));
const auto method_name = QualifiedName(*self.type()->name(), name);
auto fn = self._ivalue()->compilation_unit()->create_function(
method_name, graph);
self.type()->addMethod(fn);
didFinishEmitModule(self);
},
py::arg("name"),
py::arg("func"),
py::arg("input_dict"),
py::arg("var_name_lookup_fn"),
py::arg("strict"),
py::arg("force_outplace"),
py::arg("argument_names") = std::vector<std::string>())
.def(
"_get_forward_hooks",
[](const Module& m) {
Expand Down Expand Up @@ -1668,6 +1705,43 @@ void initJitScriptBindings(PyObject* module) {
py::arg("force_outplace"),
py::arg("argument_names") = std::vector<std::string>());

m.def(
"_create_function_from_trace_with_dict",
[](const std::string& qualname,
const py::function& func,
const py::dict& input_dict,
const py::function& var_name_lookup_fn,
bool strict,
bool force_outplace,
const std::vector<std::string>& argument_names) {
auto typed_inputs = toTraceableStack(input_dict);
std::shared_ptr<Graph> graph =
std::get<0>(tracer::createGraphByTracingWithDict(
func,
input_dict,
typed_inputs,
var_name_lookup_fn,
strict,
force_outplace,
/*self=*/nullptr,
argument_names));

auto cu = get_python_cu();
auto name = c10::QualifiedName(qualname);
auto result = cu->create_function(
std::move(name), std::move(graph), /*shouldMangle=*/true);
StrongFunctionPtr ret(std::move(cu), result);
didFinishEmitFunction(ret);
return ret;
},
py::arg("name"),
py::arg("func"),
py::arg("input_dict"),
py::arg("var_name_lookup_fn"),
py::arg("strict"),
py::arg("force_outplace"),
py::arg("argument_names") = std::vector<std::string>());

m.def(
"_jit_script_class_compile",
[](const std::string& qualifiedName,
Expand Down
Loading