diff --git a/examples/sharding/sharding.ipynb b/examples/sharding/sharding.ipynb index b61321a54..a0678a3fc 100644 --- a/examples/sharding/sharding.ipynb +++ b/examples/sharding/sharding.ipynb @@ -16,18 +16,36 @@ }, { "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "BB2K68OYUJ_t", + "output": { + "id": 1226213035567159, + "loadingStatus": "loaded" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rm: cannot remove 'Miniconda3-py37_4.9.2-Linux-x86_64.sh': No such file or directory\r\n", + "rm: cannot remove 'Miniconda3-py37_4.9.2-Linux-x86_64.sh.*': No such file or directory\r\n", + "--2025-05-15 11:42:12-- https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh\r\n", + "Resolving repo.anaconda.com (repo.anaconda.com)... failed: Name or service not known.\r\n", + "wget: unable to resolve host address ‘repo.anaconda.com’\r\n", + "chmod: cannot access 'Miniconda3-py37_4.9.2-Linux-x86_64.sh': No such file or directory\r\n", + "bash: ./Miniconda3-py37_4.9.2-Linux-x86_64.sh: No such file or directory\r\n" + ] + } + ], "source": [ - "# install conda to make installying pytorch with cudatoolkit 11.3 easier. \n", + "# install conda to make installying pytorch with cudatoolkit 11.3 easier.\n", "!sudo rm Miniconda3-py37_4.9.2-Linux-x86_64.sh Miniconda3-py37_4.9.2-Linux-x86_64.sh.*\n", "!sudo wget https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh\n", "!sudo chmod +x Miniconda3-py37_4.9.2-Linux-x86_64.sh\n", "!sudo bash ./Miniconda3-py37_4.9.2-Linux-x86_64.sh -b -f -p /usr/local" - ], - "metadata": { - "id": "BB2K68OYUJ_t" - }, - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "code", @@ -43,12 +61,12 @@ }, { "cell_type": "markdown", - "source": [ - "Installing torchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run " - ], "metadata": { "id": "7iY7Uv11mJYK" - } + }, + "source": [ + "Installing torchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run " + ] }, { "cell_type": "code", @@ -64,23 +82,23 @@ }, { "cell_type": "markdown", - "source": [ - "Install multiprocess which works with ipython to for multi-processing programming within colab" - ], "metadata": { "id": "0wLX94Lw_Lml" - } + }, + "source": [ + "Install multiprocess which works with ipython to for multi-processing programming within colab" + ] }, { "cell_type": "code", - "source": [ - "!pip3 install multiprocess" - ], + "execution_count": null, "metadata": { "id": "HKoKRP-QzRCF" }, - "execution_count": null, - "outputs": [] + "outputs": [], + "source": [ + "!pip3 install multiprocess" + ] }, { "cell_type": "markdown", @@ -125,27 +143,32 @@ }, { "cell_type": "markdown", + "metadata": { + "id": "HWBOrwVSnrNE" + }, "source": [ "## **Overview**\n", "This tutorial will mainly cover the sharding schemes of embedding tables via `EmbeddingPlanner` and `DistributedModelParallel` API and explore the benefits of different sharding schemes for the embedding tables by explicitly configuring them." - ], - "metadata": { - "id": "HWBOrwVSnrNE" - } + ] }, { "cell_type": "markdown", + "metadata": { + "id": "udsN6PlUo1zF" + }, "source": [ "### Distributed Setup\n", "Due to the notebook enviroment, we cannot run [`SPMD`](https://en.wikipedia.org/wiki/SPMD) program here but we can do multiprocessing inside the notebook to mimic the setup. Users should be responsible for setting up their own [`SPMD`](https://en.wikipedia.org/wiki/SPMD) launcher when using Torchrec. \n", "We setup our environment so that torch distributed based communication backend can work." - ], - "metadata": { - "id": "udsN6PlUo1zF" - } + ] }, { "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "4-v17rxkopQw" + }, + "outputs": [], "source": [ "import os\n", "import torch\n", @@ -153,15 +176,13 @@ "\n", "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n", "os.environ[\"MASTER_PORT\"] = \"29500\"" - ], - "metadata": { - "id": "4-v17rxkopQw" - }, - "execution_count": 18, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "ZdSUWBRxoP8R" + }, "source": [ "### Constructing our embedding model\n", "Here we use TorchRec offering of [`EmbeddingBagCollection`](https://github.com/facebookresearch/torchrec/blob/main/torchrec/modules/embedding_modules.py#L59) to construct our embedding bag model with embedding tables.\n", @@ -177,17 +198,19 @@ "* `data_parallel`: replicate the tables for every device;\n", "\n", "Note how we initially allocate the EBC on device \"meta\". This will tell EBC to not allocate memory yet." - ], - "metadata": { - "id": "ZdSUWBRxoP8R" - } + ] }, { "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "e7UQBuG09hbj" + }, + "outputs": [], "source": [ "from torchrec.distributed.planner.types import ParameterConstraints\n", "from torchrec.distributed.embedding_types import EmbeddingComputeKernel\n", - "from torchrec.distributed.types import ShardingType\n", + "from torchrec.distributed.types import ShardingType, ShardingPlan\n", "from typing import Dict\n", "\n", "large_table_cnt = 2\n", @@ -224,41 +247,41 @@ " }\n", " constraints = {**large_table_constraints, **small_table_constraints}\n", " return constraints" - ], - "metadata": { - "id": "e7UQBuG09hbj" - }, - "execution_count": 19, - "outputs": [] + ] }, { "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "Iz_GZDp_oQ19" + }, + "outputs": [], "source": [ "ebc = torchrec.EmbeddingBagCollection(\n", " device=\"cuda\",\n", " tables=large_tables + small_tables\n", ")" - ], - "metadata": { - "id": "Iz_GZDp_oQ19" - }, - "execution_count": 20, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "7m0_ssVLFQEH" + }, "source": [ "### DistributedModelParallel in multiprocessing\n", "Now, we have a single process execution function for mimicking one rank's work during [`SPMD`](https://en.wikipedia.org/wiki/SPMD) execution.\n", "\n", "This code will shard the model collectively with other processes and allocate memories accordingly. It first sets up process groups and do embedding table placement using planner and generate sharded model using `DistributedModelParallel`.\n" - ], - "metadata": { - "id": "7m0_ssVLFQEH" - } + ] }, { "cell_type": "code", + "execution_count": 21, + "metadata": { + "id": "PztCaGmLA85u" + }, + "outputs": [], "source": [ "def single_rank_execution(\n", " rank: int,\n", @@ -310,29 +333,29 @@ " )\n", " print(f\"rank:{rank},sharding plan: {plan}\")\n", " return sharded_model\n" - ], - "metadata": { - "id": "PztCaGmLA85u" - }, - "execution_count": 21, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "3YvDnV_wz_An" + }, "source": [ "### Multiprocessing Execution\n", "Now let's execute the code in multi-processes representing multiple GPU ranks.\n", "\n" - ], - "metadata": { - "id": "3YvDnV_wz_An" - } + ] }, { "cell_type": "code", + "execution_count": 22, + "metadata": { + "id": "arW0Jf6qEl-h" + }, + "outputs": [], "source": [ "import multiprocess\n", - " \n", + "\n", "def spmd_sharing_simulation(\n", " sharding_type: ShardingType = ShardingType.TABLE_WISE,\n", " world_size = 2,\n", @@ -356,28 +379,21 @@ " for p in processes:\n", " p.join()\n", " assert 0 == p.exitcode" - ], - "metadata": { - "id": "arW0Jf6qEl-h" - }, - "execution_count": 22, - "outputs": [] + ] }, { "cell_type": "markdown", + "metadata": { + "id": "31UWMaymj7Pu" + }, "source": [ "### Table Wise Sharding\n", "Now let's execute the code in two processes for 2 GPUs. We can see in the plan print that how our tables are sharded across GPUs. Each node will have one large table and one small which shows our planner tries for load balance for the embedding tables. Table-wise is the de-factor go-to sharding schemes for many small-medium size tables for load balancing over the devices." - ], - "metadata": { - "id": "31UWMaymj7Pu" - } + ] }, { "cell_type": "code", - "source": [ - "spmd_sharing_simulation(ShardingType.TABLE_WISE)" - ], + "execution_count": 23, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -385,19 +401,18 @@ "id": "Yb4v1HA3IJzU", "outputId": "b8f08b10-eb85-48f3-8705-b67efd4eba2c" }, - "execution_count": 23, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'large_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:0/cuda:0)])), 'small_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:1/cuda:1)]))}}\n", "rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'large_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:0/cuda:0)])), 'small_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:1/cuda:1)]))}}\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n", " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n", @@ -405,43 +420,43 @@ " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n" ] } + ], + "source": [ + "spmd_sharing_simulation(ShardingType.TABLE_WISE)" ] }, { "cell_type": "markdown", + "metadata": { + "id": "5HkwxEwm4O8u" + }, "source": [ "### Explore other sharding modes\n", "We have initially explored what table-wise sharding would look like and how it balances the tables placement. Now we explore sharding modes with finer focus on load balance: row-wise. Row-wise is specifically addressing large tables which a single device cannot hold due to the memory size increase from large embedding row numbers. It can address the placement of the super large tables in your models. Users can see that in the `shard_sizes` section in the printed plan log, the tables are halved by row dimension to be distributed onto two GPUs.\n" - ], - "metadata": { - "id": "5HkwxEwm4O8u" - } + ] }, { "cell_type": "code", - "source": [ - "spmd_sharing_simulation(ShardingType.ROW_WISE)" - ], + "execution_count": 24, "metadata": { - "id": "pGBgReGx5VrB", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "pGBgReGx5VrB", "outputId": "6e22a2f0-7373-4dcc-ee69-67f3e95d78a7" }, - "execution_count": 24, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)]))}}\n", "rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank:1/cuda:1)]))}}\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n", " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n", @@ -449,42 +464,42 @@ " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n" ] } + ], + "source": [ + "spmd_sharing_simulation(ShardingType.ROW_WISE)" ] }, { "cell_type": "markdown", - "source": [ - "Column-wise on the other hand, address the load imbalance problems for tables with large embedding dimensions. We will split the table vertically. Users can see that in the `shard_sizes` section in the printed plan log, the tables are halved by embedding dimension to be distributed onto two GPUs.\n" - ], "metadata": { "id": "mqnInw_uEjjY" - } + }, + "source": [ + "Column-wise on the other hand, address the load imbalance problems for tables with large embedding dimensions. We will split the table vertically. Users can see that in the `shard_sizes` section in the printed plan log, the tables are halved by embedding dimension to be distributed onto two GPUs.\n" + ] }, { "cell_type": "code", - "source": [ - "spmd_sharing_simulation(ShardingType.COLUMN_WISE)" - ], + "execution_count": 25, "metadata": { - "id": "DWTyuV9I5afU", "colab": { "base_uri": "https://localhost:8080/" }, + "id": "DWTyuV9I5afU", "outputId": "daaa95cd-f653-47fe-809f-5d1d63cc05d7" }, - "execution_count": 25, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)]))}}\n", "rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[4096, 32], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)])), 'small_table_1': ParameterSharding(sharding_type='column_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 32], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[0, 32], shard_sizes=[1024, 32], placement=rank:1/cuda:1)]))}}\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n", " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n", @@ -492,32 +507,33 @@ " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n" ] } + ], + "source": [ + "spmd_sharing_simulation(ShardingType.COLUMN_WISE)" ] }, { "cell_type": "markdown", - "source": [ - "For `table-row-wise`, unfortuately we cannot simulate it due to its nature of operating under multi-host setup. We will present a python [`SPMD`](https://en.wikipedia.org/wiki/SPMD) example in the future to train models with `table-row-wise`." - ], "metadata": { "id": "711VBygVHGJ6" - } + }, + "source": [ + "For `table-row-wise`, unfortuately we cannot simulate it due to its nature of operating under multi-host setup. We will present a python [`SPMD`](https://en.wikipedia.org/wiki/SPMD) example in the future to train models with `table-row-wise`." + ] }, { "cell_type": "markdown", + "metadata": { + "id": "1G8aUfmeMA7m" + }, "source": [ "\n", "With data parallel, we will repeat the tables for all devices.\n" - ], - "metadata": { - "id": "1G8aUfmeMA7m" - } + ] }, { "cell_type": "code", - "source": [ - "spmd_sharing_simulation(ShardingType.DATA_PARALLEL)" - ], + "execution_count": 26, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -525,19 +541,18 @@ "id": "WFk-QLlRL-ST", "outputId": "662a6d6e-cb1b-440d-ff1b-4619076117a3" }, - "execution_count": 26, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}\n", "rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'large_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_0': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None), 'small_table_1': ParameterSharding(sharding_type='data_parallel', compute_kernel='batched_dense', ranks=[0, 1], sharding_spec=None)}}\n" ] }, { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py:1403: UserWarning: positional arguments and argument \"destination\" are deprecated. nn.Module.state_dict will not accept them in the future. Refer to https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.state_dict for details.\n", " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n", @@ -545,6 +560,9 @@ " \" and \".join(warn_msg) + \" are deprecated. nn.Module.state_dict will not accept them in the future. \"\n" ] } + ], + "source": [ + "spmd_sharing_simulation(ShardingType.DATA_PARALLEL)" ] } ], @@ -557,14 +575,18 @@ "name": "Torchrec Sharding Introduction.ipynb", "provenance": [] }, + "fileHeader": "", + "fileUid": "7751c451-b188-42e5-8c09-4564538f4e06", + "isAdHoc": false, "kernelspec": { "display_name": "Python 3", - "name": "python3" + "language": "python", + "name": "bento_kernel_default" }, "language_info": { "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + }, + "notebookId": "1976711006070535", + "notebookNumber": "N7182976" + } }