Skip to content

Commit 11059f6

Browse files
committed
program-data-separation demo
1 parent 3649e1f commit 11059f6

File tree

8 files changed

+288
-1
lines changed

8 files changed

+288
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ build/
88
*.bin
99
*.model
1010
*.pte
11+
*.ptd
1112

1213
# Xcode
1314
xcuserdata/

.gitmodules

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,8 @@
22
path = mv2/cpp/executorch
33
url = https://github.com/pytorch/executorch.git
44
branch = release/0.6
5+
6+
[submodule "program-data-separation/cpp/executorch"]
7+
path = program-data-separation/cpp/executorch
8+
url = https://github.com/pytorch/executorch.git
9+
branch = main

mv2/cpp/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON)
1212
option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON)
1313
option(EXECUTORCH_BUILD_XNNPACK "" ON)
1414

15-
# Add ExecutorTorch subdirectory
15+
# Add ExecuTorch subdirectory
1616
add_subdirectory("executorch")
1717

1818
set(DEMO_SOURCES main.cpp)

program-data-separation/README.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Program Data Separation Examples
2+
3+
This directory contains two examples of how to use the Program Data Separation APIs in ExecuTorch.
4+
5+
## Virtual environment setup
6+
Create and activate a Python virtual environment:
7+
```bash
8+
python3 -m venv .venv && source .venv/bin/activate && pip install --upgrade pip
9+
```
10+
Or alternatively, [install conda on your machine](https://conda.io/projects/conda/en/latest/user-guide/install/index.html)
11+
```bash
12+
conda create -yn executorch-examples-mv2 python=3.10.0 && conda activate executorch-examples-mv2
13+
```
14+
15+
Install dependencies:
16+
```
17+
pip install -r requirements.txt
18+
```
19+
20+
## Export a model with program-data separation
21+
To export a non-delegated linear model:
22+
```python
23+
python export.py --outdir .
24+
```
25+
Expect the files 'linear.pte' and 'linear.ptd'.
26+
27+
To export a linear model delegated to XNNPACK:
28+
```python
29+
python export.py --outdir . --xnnpack
30+
```
31+
Expect the files 'linear_xnnpack.pte' and 'linear_xnnpack.ptd'.
32+
33+
Note:
34+
- PTE: contains the program execution plan.
35+
- PTD: contains the constant tensors used by the PTE.
36+
37+
For more information on the PTD data format, please see the [flat_tensor](https://github.com/pytorch/executorch/blob/main/extension/flat_tensor/README.md) directory.
38+
39+
## Runtime (cpp)
40+
The cpp/ directory contains the executorch submodule along with a main.cpp file that demonstrates how to load a PTE and PTD file and execute the program.
41+
42+
First, export your PTE, PTD files using the AoT instructions above.
43+
44+
**Build instructions**
45+
46+
Change to the cpp directory.
47+
```
48+
cd cpp
49+
```
50+
51+
Create build directory if it doesn't exist
52+
```
53+
mkdir -p build
54+
cd build
55+
```
56+
57+
Configure CMake
58+
```
59+
cmake -DCMAKE_BUILD_TYPE=Release ..
60+
```
61+
62+
Build the project
63+
```
64+
cmake --build . -j$(nproc)
65+
echo "Build complete! Executable located at: ./bin/executorch_program_data_separation"
66+
```
67+
68+
Run the executable
69+
```
70+
./bin/executorch_program_data_separation --model-path ../../linear.pte --data-path ../../linear.ptd
71+
72+
./bin/executorch_program_data_separation --model-path ../../linear_xnnpack.pte --data-path ../../linear_xnnpack.ptd
73+
```
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
3+
project(executorch_mv2_demo CXX)
4+
5+
set(CMAKE_CXX_STANDARD 17)
6+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
7+
8+
# Set options for executorch build.
9+
option(EXECUTORCH_ENABLE_LOGGING "" ON)
10+
option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "" ON)
11+
option(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR "" ON)
12+
option(EXECUTORCH_BUILD_EXTENSION_MODULE "" ON)
13+
option(EXECUTORCH_BUILD_EXTENSION_TENSOR "" ON)
14+
option(EXECUTORCH_BUILD_KERNELS_OPTIMIZED "" ON)
15+
option(EXECUTORCH_BUILD_XNNPACK "" ON)
16+
17+
# Add ExecuTorch subdirectory
18+
add_subdirectory("executorch")
19+
20+
set(DEMO_SOURCES main.cpp)
21+
22+
# Create executable
23+
add_executable(executorch_program_data_separation ${DEMO_SOURCES})
24+
25+
# Include directories
26+
target_include_directories(executorch_program_data_separation PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
27+
28+
# Link libraries
29+
target_link_libraries(
30+
executorch_program_data_separation
31+
PRIVATE executorch
32+
extension_module_static
33+
extension_flat_tensor
34+
extension_tensor
35+
xnnpack_backend
36+
portable_ops_lib
37+
portable_kernels
38+
gflags
39+
)
40+
41+
# Set output directory
42+
set_target_properties(executorch_program_data_separation
43+
PROPERTIES
44+
RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/bin"
45+
)
Submodule executorch added at 4456407

program-data-separation/cpp/main.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/module/module.h>
10+
#include <executorch/extension/tensor/tensor.h>
11+
#include <iostream>
12+
13+
#include <gflags/gflags.h>
14+
15+
DEFINE_string(model_path, "linear.pte",
16+
"Model serialized in flatbuffer format.");
17+
DEFINE_string(data_path, "linear.ptd", "Data serialized in flatbuffer format.");
18+
19+
using namespace ::executorch::extension;
20+
21+
int main(int argc, char *argv[]) {
22+
23+
std::cout << "Running program-data separation example" << std::endl;
24+
gflags::ParseCommandLineFlags(&argc, &argv, true);
25+
26+
const char *model_path = FLAGS_model_path.c_str();
27+
const char *data_path = FLAGS_data_path.c_str();
28+
29+
// Load the model.
30+
Module module(model_path, data_path);
31+
32+
float input[3];
33+
auto tensor = from_blob(input, {3});
34+
35+
// Perform an inference.
36+
const auto result = module.forward(tensor);
37+
38+
if (result.ok()) {
39+
const auto output = result->at(0).toTensor().const_data_ptr<float>();
40+
for (int i = 0; i < 3; i++) {
41+
std::cout << output[i] << std::endl;
42+
}
43+
std::cout << "Success" << std::endl;
44+
}
45+
46+
return 0;
47+
}

program-data-separation/export.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import argparse
10+
import os
11+
12+
from functools import partial
13+
from typing import Dict, final, Optional, Sequence, Type
14+
15+
import executorch.exir as exir
16+
import torch
17+
18+
from executorch.exir import (
19+
EdgeCompileConfig,
20+
ExecutorchBackendConfig,
21+
to_edge,
22+
to_edge_transform_and_lower,
23+
)
24+
from executorch.exir.passes.external_constants_pass import (
25+
delegate_external_constants_pass,
26+
)
27+
from executorch.exir.program import ExecutorchProgramManager
28+
from torch.export import export
29+
30+
31+
class ModuleLinear(torch.nn.Module):
32+
def __init__(self):
33+
super().__init__()
34+
self.linear = torch.nn.Linear(3, 3)
35+
36+
def forward(self, x: torch.Tensor):
37+
return self.linear(x)
38+
39+
def get_random_inputs(self):
40+
return (torch.randn(3),)
41+
42+
43+
def main() -> None:
44+
45+
parser = argparse.ArgumentParser(
46+
prog="export_program",
47+
description="Exports nn.Module models to ExecuTorch .pte and .ptd files",
48+
)
49+
parser.add_argument(
50+
"--outdir",
51+
type=str,
52+
required=True,
53+
help="Path to the directory to write <classname>.pte files and .ptd files to",
54+
)
55+
parser.add_argument(
56+
"--xnnpack",
57+
action="store_true",
58+
help="Export the model lowered to XNNPACK",
59+
)
60+
args = parser.parse_args()
61+
62+
if args.xnnpack:
63+
print("Exporting to ExecuTorch with XNNPACK")
64+
else:
65+
print("Exporting to ExecuTorch")
66+
67+
# Construct eager model.
68+
model = ModuleLinear()
69+
# Export model.
70+
exported_program = torch.export.export(model, model.get_random_inputs())
71+
model_name = "linear_xnnpack" if args.xnnpack else "linear"
72+
73+
# Lower to XNNPACK.
74+
if args.xnnpack:
75+
print("Lowering to XNNPACK...")
76+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
77+
XnnpackPartitioner,
78+
)
79+
80+
partial_function = partial(
81+
delegate_external_constants_pass,
82+
ep=exported_program,
83+
gen_tag_fn=lambda x: model_name,
84+
)
85+
executorch_program = to_edge_transform_and_lower(
86+
exported_program,
87+
transform_passes=[partial_function],
88+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
89+
partitioner=[XnnpackPartitioner()],
90+
).to_executorch(config=ExecutorchBackendConfig())
91+
92+
# No backends.
93+
else:
94+
print("Lowering to ExecuTorch...")
95+
edge_program = to_edge(exported_program)
96+
executorch_program = edge_program.to_executorch(
97+
ExecutorchBackendConfig(external_constants=True)
98+
)
99+
100+
print("Saving PTE and PTD files")
101+
os.makedirs(args.outdir, exist_ok=True)
102+
pte_file = os.path.join(args.outdir, f"{model_name}.pte")
103+
with open(pte_file, "wb") as fp:
104+
executorch_program.write_to_file(fp)
105+
if executorch_program._tensor_data.get("_default_external_constant"):
106+
executorch_program._tensor_data[model_name] = (
107+
executorch_program._tensor_data.pop("_default_external_constant")
108+
)
109+
executorch_program.write_tensor_data_to_file(args.outdir)
110+
111+
print(f"Successfully exported {model_name}.pte and {model_name}.ptd")
112+
113+
114+
if __name__ == "__main__":
115+
main()

0 commit comments

Comments
 (0)