Skip to content

Commit 09a1ff7

Browse files
committed
docs: new sharding docs
1 parent e28e663 commit 09a1ff7

File tree

5 files changed

+159
-1
lines changed

5 files changed

+159
-1
lines changed

.github/workflows/Documenter.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ jobs:
5353
using Reactant
5454
DocMeta.setdocmeta!(Reactant, :DocTestSetup, :(using Reactant); recursive=true)
5555
doctest(Reactant)'
56+
env:
57+
XLA_FLAGS: --xla_force_host_platform_device_count=8
5658
- name: Build documentation
5759
run: julia --color=yes --project=docs docs/make.jl
5860
env:

docs/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365"
4+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
5+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
46
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
57
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
68
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

docs/src/.vitepress/config.mts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ export default defineConfig({
8787
{text: "Profiling", link: "/tutorials/profiling"},
8888
{text: "Distributed", link: "/tutorials/multihost"},
8989
{text: "Local build", link: "/tutorials/local-build"},
90+
{text: "Sharding", link: "/tutorials/sharding"},
9091
],
9192
},
9293
{
@@ -154,6 +155,7 @@ export default defineConfig({
154155
{ text: "Profiling", link: "/tutorials/profiling" },
155156
{ text: "Distributed", link: "/tutorials/multihost" },
156157
{ text: "Local build", link: "/tutorials/local-build" },
158+
{ text: "Sharding", link: "/tutorials/sharding" },
157159
],
158160
}
159161
],

docs/src/tutorials/sharding.md

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# [Automatic Sharding-based Distributed Parallelism](@id sharding)
2+
3+
!!! tip "Use XLA IFRT Runtime"
4+
5+
While PJRT does support some minimal sharding capabilities on CUDA GPUs, sharding
6+
support in Reactant is primarily provided via IFRT. Before loading Reactant, set the
7+
"xla_runtime" preference to be "IFRT". This can be done with:
8+
9+
```julia
10+
using Preferences, UUIDs
11+
12+
Preferences.set_preference!(
13+
UUID("3c362404-f566-11ee-1572-e11a4b42c853"),
14+
"xla_runtime" => "IFRT"
15+
)
16+
```
17+
18+
Sharding is one mechanism supported within Reactant that tries to make it easy to program for multiple devices (including [multiple nodes](@ref distributed)).
19+
20+
```@example sharding_tutorial
21+
using Reactant
22+
23+
@assert length(Reactant.devices()) > 1 # hide
24+
Reactant.devices()
25+
```
26+
27+
Sharding provides Reactant users a [PGAS (parallel-global address space)](https://en.wikipedia.org/wiki/Partitioned_global_address_space) programming model. Let's understand what this means through example.
28+
29+
Suppose we have a function that takes a large input array and computes sin for all elements of the array.
30+
31+
```@example sharding_tutorial
32+
function big_sin(data)
33+
data .= sin(data)
34+
return nothing
35+
end
36+
37+
N = 100
38+
x = Reactant.to_array(collect(1:N))
39+
40+
compiled_big_sin = @compile big_sin(x)
41+
42+
compiled_big_sin(x)
43+
```
44+
45+
This successfully allocates the array `x` on one device, and executes it on the same device. However, suppose we want to execute this computation on multiple devices. Perhaps this is because the size of our inputs (`N`) is too large to fit on a single device. Or alternatively the function we execute is computationally expensive and we want to leverage the computing power of multiple devices.
46+
47+
Unlike more explicit communication libraries like MPI, the sharding model used by Reactant aims to let you execute a program on multiple devices without significant modifications to the single-device program. In particular, you do not need to write explicit communication calls (e.g. `MPI.Send` or `MPI.Recv`). Instead you write your program as if it executes on a very large single-node and Reactant will automatically determine how to subdivide the data, computation, and required communication.
48+
49+
# TODO describe how arrays are the "global data arrays, even though data is itself only stored on relevant device and computation is performed only devices with the required data (effectively showing under the hood how execution occurs)
50+
51+
# TODO simple case that demonstrates send/recv within (e.g. a simple neighbor add)
52+
53+
54+
# TODO make a simple conway's game of life, or heat equation using sharding simulation example to show how a ``typical MPI'' simulation can be written using sharding.
55+
56+
## Simple 1-Dimensional Heat Equation
57+
58+
::: code-group
59+
60+
```julia [MPI Based Parallelism]
61+
function one_dim_heat_equation_time_step_mpi!(data)
62+
id = MPI.Comm_rank(MPI.COMM_WORLD)
63+
last_id = MPI.Comm_size(MPI.COMM_WORLD)
64+
65+
# Send data right
66+
if id > 1
67+
MPI.Send(@view(data[end]), MPI.COMM_WORLD; dest=id + 1)
68+
end
69+
70+
# Recv data from left
71+
if id != last_id
72+
MPI.Recv(@view(data[1]), MPI.COMM_WORLD; dest=id - 1)
73+
end
74+
75+
# 1-D Heat equation x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1]
76+
data[2:end-1] .= 0.5 * data[2:end-1] + 0.25 * data[1:end-2] + 0.25 * data[3:end]
77+
78+
return nothing
79+
end
80+
81+
82+
# Total size of grid we want to simulate
83+
N = 100
84+
85+
# Local size of grid (total size divided by number of MPI devices)
86+
_local = N / MPI.Comm_size(MPI.COMM_WORLD)
87+
88+
# We add two to add a left side padding and right side padding, necessary for storing
89+
# boundaries from other nodes
90+
data = rand(_local + 2)
91+
92+
function simulate(data, time_steps)
93+
for i in 1:time_steps
94+
one_dim_heat_equation_time_step_mpi!(data)
95+
end
96+
end
97+
98+
simulate(data, 100)
99+
```
100+
101+
```julia [Sharded Parallelism]
102+
function one_dim_heat_equation_time_step_sharded!(data)
103+
# No send recv's required
104+
105+
# 1-D Heat equation x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1]
106+
# Reactant will automatically insert send and recv's
107+
data[2:end-1] .= 0.5 * data[2:end-1] + 0.25 * data[1:end-2] + 0.25 * data[3:end]
108+
109+
return nothing
110+
end
111+
112+
113+
# Total size of grid we want to simulate
114+
N = 100
115+
116+
# Reactant's sharding handles distributing the data amongst devices, with each device
117+
# getting a corresponding fraction of the data
118+
data = Reactant.to_rarray(
119+
rand(N + 2);
120+
sharding=Sharding.NamedSharding(
121+
Sharding.Mesh(Reactant.devices(), (:x,)),
122+
(:x,)
123+
)
124+
)
125+
126+
function simulate(data, time_steps)
127+
@traced for i in 1:time_steps
128+
one_dim_heat_equation_time_step_sharded!(data)
129+
end
130+
end
131+
132+
@jit simulate(data, 100)
133+
```
134+
135+
:::
136+
137+
# TODO describe generation of distributed array by concatenating local-worker data
138+
139+
140+
# TODO more complex tutorial describing replicated
141+
142+
## Sharding in Neural Networks
143+
144+
### 8-way Batch Parallelism
145+
146+
### 4-way Batch & 2-way Model Parallelism
147+
148+
## Related links
149+
150+
<!-- shardy? https://openxla.org/shardy -->
151+
<!-- https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html -->
152+
<!-- https://colab.research.google.com/drive/1UobcFjfwDI3N2EXvH3KbRS5ZxY9Riy4y#scrollTo=IiR7-0nDLPKK -->

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ end
115115
116116
### Certain Symbols are Reserved
117117
118-
Symbols like $(SPECIAL_SYMBOLS) are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For
118+
Symbols like `$(SPECIAL_SYMBOLS)` are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For
119119
example, the following will not work:
120120
121121
```julia

0 commit comments

Comments
 (0)