diff --git a/README.md b/README.md index 9fec3df..0f3693e 100644 --- a/README.md +++ b/README.md @@ -239,11 +239,13 @@ Instead of using the field names, the export will use the labels as they are def Filters can automatically be added to the header row by setting `xlsx_auto_filter = True`. The filter will include all header columns in the worksheet. -### Ignore fields +### Specify or ignore fields -By default, all fields are exported, but you might want to exclude some fields from your export. To do so, you can set an array with fields you want to exclude: `xlsx_ignore_headers = []`. +By default, all fields are exported. However, this behavior can be changed. -This also works with nested fields, separated with a dot (i.e. `icon.url`). +To include only a specified list of fields, provide them with: `xlsx_specify_headers = []`. Conversely, to exclude certain fields from your export, provide them with: `xlsx_ignore_headers = []`. + +These both work with nested fields, separated with a dot (i.e. `icon.url`). ### Date/time and number formatting Formatting for cells follows [openpyxl formats](https://openpyxl.readthedocs.io/en/stable/_modules/openpyxl/styles/numbers.html). diff --git a/drf_excel/renderers.py b/drf_excel/renderers.py index ec67fea..879f5b9 100644 --- a/drf_excel/renderers.py +++ b/drf_excel/renderers.py @@ -42,6 +42,7 @@ class XLSXRenderer(BaseRenderer): format = "xlsx" # Reserved word, but required by BaseRenderer combined_header_dict = {} fields_dict = {} + specify_headers = None ignore_headers = [] boolean_display = None column_data_styles = None @@ -102,7 +103,8 @@ def render(self, data, accepted_media_type=None, renderer_context=None): # Set `xlsx_use_labels = True` inside the API View to enable labels. use_labels = getattr(drf_view, "xlsx_use_labels", False) - # A list of header keys to ignore in our export + # A list of header keys to use or ignore in our export + self.specify_headers = getattr(drf_view, "xlsx_specify_headers", None) self.ignore_headers = getattr(drf_view, "xlsx_ignore_headers", []) # Create a mapping dict named `xlsx_boolean_labels` inside the API View. @@ -284,8 +286,13 @@ def _get_label(parent_label, label_sep, obj): _fields = serializer.fields for k, v in _fields.items(): new_key = f"{parent_key}{key_sep}{k}" if parent_key else k - # Skip headers we want to ignore - if new_key in self.ignore_headers or getattr(v, "write_only", False): + # Skip headers that weren't in the list (if present) or were specifically ignored + if ( + self.specify_headers is not None + and new_key not in self.specify_headers + or new_key in self.ignore_headers + or getattr(v, "write_only", False) + ): continue # Iterate through fields if field is a serializer. Check for labels and # append if `use_labels` is True. Fallback to using keys. diff --git a/tests/test_viewset_mixin.py b/tests/test_viewset_mixin.py index e7dfe5c..c36680a 100644 --- a/tests/test_viewset_mixin.py +++ b/tests/test_viewset_mixin.py @@ -139,3 +139,19 @@ def test_auto_filter_viewset(api_client, workbook_reader): sheet = wb.worksheets[0] assert sheet.auto_filter.ref == "A1:B2" + + +def test_specify_headers(api_client, workbook_reader): + AllFieldsModel.objects.create(title="Hello", age=36) + + response = api_client.get("/specify-headers/") + assert response.status_code == 200 + + wb = workbook_reader(response.content) + sheet = wb.worksheets[0] + + header, data = list(sheet.rows) + + assert len(header) == 1 + assert len(data) == 1 + assert header[0].value == "title" diff --git a/tests/testapp/views.py b/tests/testapp/views.py index d34efbb..86feff8 100644 --- a/tests/testapp/views.py +++ b/tests/testapp/views.py @@ -59,3 +59,11 @@ class AutoFilterViewSet(XLSXFileMixin, ReadOnlyModelViewSet): renderer_classes = (XLSXRenderer,) xlsx_auto_filter = True + + +class SpecifyHeadersViewSet(XLSXFileMixin, ReadOnlyModelViewSet): + queryset = AllFieldsModel.objects.all() + serializer_class = AllFieldsSerializer + renderer_classes = (XLSXRenderer,) + + xlsx_specify_headers = ["title"] diff --git a/tests/urls.py b/tests/urls.py index 2cf3d99..1d7be92 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -6,6 +6,7 @@ DynamicFieldViewSet, ExampleViewSet, SecretFieldViewSet, + SpecifyHeadersViewSet, ) router = routers.SimpleRouter() @@ -14,5 +15,6 @@ router.register(r"secret-field", SecretFieldViewSet) router.register(r"dynamic-field", DynamicFieldViewSet, basename="dynamic-field") router.register(r"auto-filter", AutoFilterViewSet, basename="auto-filter") +router.register(r"specify-headers", SpecifyHeadersViewSet, basename="specify-headers") urlpatterns = router.urls