|
| 1 | +# [Automatic Sharding-based Distributed Parallelism](@id sharding) |
| 2 | + |
| 3 | +!!! tip "Use XLA IFRT Runtime" |
| 4 | + |
| 5 | + While PJRT does support some minimal sharding capabilities on CUDA GPUs, sharding |
| 6 | + support in Reactant is primarily provided via IFRT. Before loading Reactant, set the |
| 7 | + "xla_runtime" preference to be "IFRT". This can be done with: |
| 8 | + |
| 9 | + ```julia |
| 10 | + using Preferences, UUIDs |
| 11 | + |
| 12 | + Preferences.set_preference!( |
| 13 | + UUID("3c362404-f566-11ee-1572-e11a4b42c853"), |
| 14 | + "xla_runtime" => "IFRT" |
| 15 | + ) |
| 16 | + ``` |
| 17 | + |
| 18 | +Sharding is one mechanism supported within Reactant that tries to make it easy to program for multiple devices (including [multiple nodes](@ref distributed)). |
| 19 | + |
| 20 | +```@example sharding_tutorial |
| 21 | +using Reactant |
| 22 | +
|
| 23 | +@assert length(Reactant.devices()) > 1 # hide |
| 24 | +Reactant.devices() |
| 25 | +``` |
| 26 | + |
| 27 | +Sharding provides Reactant users a [PGAS (parallel-global address space)](https://en.wikipedia.org/wiki/Partitioned_global_address_space) programming model. Let's understand what this means through example. |
| 28 | + |
| 29 | +Suppose we have a function that takes a large input array and computes sin for all elements of the array. |
| 30 | + |
| 31 | +```@example sharding_tutorial |
| 32 | +function big_sin(data) |
| 33 | + data .= sin(data) |
| 34 | + return nothing |
| 35 | +end |
| 36 | +
|
| 37 | +N = 100 |
| 38 | +x = Reactant.to_array(collect(1:N)) |
| 39 | +
|
| 40 | +compiled_big_sin = @compile big_sin(x) |
| 41 | +
|
| 42 | +compiled_big_sin(x) |
| 43 | +``` |
| 44 | + |
| 45 | +This successfully allocates the array `x` on one device, and executes it on the same device. However, suppose we want to execute this computation on multiple devices. Perhaps this is because the size of our inputs (`N`) is too large to fit on a single device. Or alternatively the function we execute is computationally expensive and we want to leverage the computing power of multiple devices. |
| 46 | + |
| 47 | +Unlike more explicit communication libraries like MPI, the sharding model used by Reactant aims to let you execute a program on multiple devices without significant modifications to the single-device program. In particular, you do not need to write explicit communication calls (e.g. `MPI.Send` or `MPI.Recv`). Instead you write your program as if it executes on a very large single-node and Reactant will automatically determine how to subdivide the data, computation, and required communication. |
| 48 | + |
| 49 | +# TODO describe how arrays are the "global data arrays, even though data is itself only stored on relevant device and computation is performed only devices with the required data (effectively showing under the hood how execution occurs) |
| 50 | + |
| 51 | +# TODO simple case that demonstrates send/recv within (e.g. a simple neighbor add) |
| 52 | + |
| 53 | + |
| 54 | +# TODO make a simple conway's game of life, or heat equation using sharding simulation example to show how a ``typical MPI'' simulation can be written using sharding. |
| 55 | + |
| 56 | +## Simple 1-Dimensional Heat Equation |
| 57 | + |
| 58 | +::: code-group |
| 59 | + |
| 60 | +```julia [MPI Based Parallelism] |
| 61 | +function one_dim_heat_equation_time_step_mpi!(data) |
| 62 | + id = MPI.Comm_rank(MPI.COMM_WORLD) |
| 63 | + last_id = MPI.Comm_size(MPI.COMM_WORLD) |
| 64 | + |
| 65 | + # Send data right |
| 66 | + if id > 1 |
| 67 | + MPI.Send(@view(data[end]), MPI.COMM_WORLD; dest=id + 1) |
| 68 | + end |
| 69 | + |
| 70 | + # Recv data from left |
| 71 | + if id != last_id |
| 72 | + MPI.Recv(@view(data[1]), MPI.COMM_WORLD; dest=id - 1) |
| 73 | + end |
| 74 | + |
| 75 | + # 1-D Heat equation x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1] |
| 76 | + data[2:end-1] .= 0.5 * data[2:end-1] + 0.25 * data[1:end-2] + 0.25 * data[3:end] |
| 77 | + |
| 78 | + return nothing |
| 79 | +end |
| 80 | + |
| 81 | + |
| 82 | +# Total size of grid we want to simulate |
| 83 | +N = 100 |
| 84 | + |
| 85 | +# Local size of grid (total size divided by number of MPI devices) |
| 86 | +_local = N / MPI.Comm_size(MPI.COMM_WORLD) |
| 87 | + |
| 88 | +# We add two to add a left side padding and right side padding, necessary for storing |
| 89 | +# boundaries from other nodes |
| 90 | +data = rand(_local + 2) |
| 91 | + |
| 92 | +function simulate(data, time_steps) |
| 93 | + for i in 1:time_steps |
| 94 | + one_dim_heat_equation_time_step_mpi!(data) |
| 95 | + end |
| 96 | +end |
| 97 | + |
| 98 | +simulate(data, 100) |
| 99 | +``` |
| 100 | + |
| 101 | +```julia [Sharded Parallelism] |
| 102 | +function one_dim_heat_equation_time_step_sharded!(data) |
| 103 | + # No send recv's required |
| 104 | + |
| 105 | + # 1-D Heat equation x[i, t] = 0.x * [i, t-1] + 0.25 * x[i-1, t-1] + 0.25 * x[i+1, t-1] |
| 106 | + # Reactant will automatically insert send and recv's |
| 107 | + data[2:end-1] .= 0.5 * data[2:end-1] + 0.25 * data[1:end-2] + 0.25 * data[3:end] |
| 108 | + |
| 109 | + return nothing |
| 110 | +end |
| 111 | + |
| 112 | + |
| 113 | +# Total size of grid we want to simulate |
| 114 | +N = 100 |
| 115 | + |
| 116 | +# Reactant's sharding handles distributing the data amongst devices, with each device |
| 117 | +# getting a corresponding fraction of the data |
| 118 | +data = Reactant.to_rarray( |
| 119 | + rand(N + 2); |
| 120 | + sharding=Sharding.NamedSharding( |
| 121 | + Sharding.Mesh(Reactant.devices(), (:x,)), |
| 122 | + (:x,) |
| 123 | + ) |
| 124 | +) |
| 125 | + |
| 126 | +function simulate(data, time_steps) |
| 127 | + @traced for i in 1:time_steps |
| 128 | + one_dim_heat_equation_time_step_sharded!(data) |
| 129 | + end |
| 130 | +end |
| 131 | + |
| 132 | +@jit simulate(data, 100) |
| 133 | +``` |
| 134 | + |
| 135 | +::: |
| 136 | + |
| 137 | +# TODO describe generation of distributed array by concatenating local-worker data |
| 138 | + |
| 139 | + |
| 140 | +# TODO more complex tutorial describing replicated |
| 141 | + |
| 142 | +## Sharding in Neural Networks |
| 143 | + |
| 144 | +### 8-way Batch Parallelism |
| 145 | + |
| 146 | +### 4-way Batch & 2-way Model Parallelism |
| 147 | + |
| 148 | +## Related links |
| 149 | + |
| 150 | +<!-- shardy? https://openxla.org/shardy --> |
| 151 | +<!-- https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html --> |
| 152 | +<!-- https://colab.research.google.com/drive/1UobcFjfwDI3N2EXvH3KbRS5ZxY9Riy4y#scrollTo=IiR7-0nDLPKK --> |
0 commit comments