|
1628 | 1628 | "\n",
|
1629 | 1629 | " def forward(self, x):\n",
|
1630 | 1630 | " b, num_tokens, d_in = x.shape # New batch dimension b\n",
|
| 1631 | + " # For inputs where `num_tokens` exceeds `context_length`, this will result in errors\n", |
| 1632 | + " # in the mask creation further below.\n", |
| 1633 | + " # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs \n", |
| 1634 | + " # do not exceed `context_length` before reaching this forward method. \n", |
1631 | 1635 | " keys = self.W_key(x)\n",
|
1632 | 1636 | " queries = self.W_query(x)\n",
|
1633 | 1637 | " values = self.W_value(x)\n",
|
|
1837 | 1841 | "\n",
|
1838 | 1842 | " def forward(self, x):\n",
|
1839 | 1843 | " b, num_tokens, d_in = x.shape\n",
|
| 1844 | + " # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, \n", |
| 1845 | + " # this will result in errors in the mask creation further below. \n", |
| 1846 | + " # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs \n", |
| 1847 | + " # do not exceed `context_length` before reaching this forwar\n", |
1840 | 1848 | "\n",
|
1841 | 1849 | " keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n",
|
1842 | 1850 | " queries = self.W_query(x)\n",
|
|
2029 | 2037 | "name": "python",
|
2030 | 2038 | "nbconvert_exporter": "python",
|
2031 | 2039 | "pygments_lexer": "ipython3",
|
2032 |
| - "version": "3.11.4" |
| 2040 | + "version": "3.10.16" |
2033 | 2041 | }
|
2034 | 2042 | },
|
2035 | 2043 | "nbformat": 4,
|
|
0 commit comments