diff --git a/apps/baseline/models.py b/apps/baseline/models.py index 6467c7f..9cd2387 100644 --- a/apps/baseline/models.py +++ b/apps/baseline/models.py @@ -1123,7 +1123,7 @@ class LivelihoodActivity(common_models.Model): quantity_sold = models.PositiveIntegerField(blank=True, null=True, verbose_name=_("Quantity Sold/Exchanged")) quantity_other_uses = models.PositiveIntegerField(blank=True, null=True, verbose_name=_("Quantity Other Uses")) # Can normally be calculated / validated as `quantity_produced + quantity_purchased - quantity_sold - quantity_other_uses` # NOQA: E501 - # but there are exceptions, such as MilkProduction, where there is also an amount used for ButterProduction, is this captured quantity_other_uses? # NOQA: E501 + # but there are exceptions, such as MilkProduction which also stores MilkProduction.quantity_butter_production quantity_consumed = models.PositiveIntegerField(blank=True, null=True, verbose_name=_("Quantity Consumed")) price = models.FloatField( diff --git a/apps/baseline/serializers.py b/apps/baseline/serializers.py index 338eed0..f094a27 100644 --- a/apps/baseline/serializers.py +++ b/apps/baseline/serializers.py @@ -1,11 +1,10 @@ -from django.db.models import Sum +from django.db.models import F, FloatField, Sum from django.utils import translation -from rest_framework import fields as rest_framework_fields from rest_framework import serializers from rest_framework_gis.serializers import GeoFeatureModelSerializer from common.fields import translation_fields -from metadata.models import LivelihoodStrategyType +from common.serializers import AggregatingSerializer from .models import ( BaselineLivelihoodActivity, @@ -1472,28 +1471,78 @@ def get_wealth_group_label(self, obj): return str(obj.wealth_group) -class DictQuerySetField(rest_framework_fields.SerializerMethodField): - def __init__(self, field_name=None, **kwargs): - self.field_name = field_name - super().__init__(**kwargs) +class LivelihoodZoneBaselineReportSerializer(AggregatingSerializer): + """ + There are two ‘levels’ of filter needed on this endpoint. The standard ones which are already on the LZB endpoint + filter the LZBs that are returned (eg, population range and wealth group). Let’s call them ‘global’ filters. + Everything needs filtering by wealth group or population, if those filters are active. - def to_representation(self, obj): - return self.parent.get_field(obj, self.field_name) + The data ‘slice’ strategy type and product filters do not remove LZBs from the results by themselves; they + only exclude values from the calculated slice statistics. + If a user selects Sorghum, that filters the kcals income for our slice. The kcals income for the slice is then + divided by the kcals income on the global set for the kcals income percent. + + The global filters are identical to those already on the LZB endpoint (and will always be - it is sharing the + code). These are applied to the LZB, row and slice totals. + + The slice filters are: + + - slice_by_product (for multiple, repeat the parameter, eg, slice_by_product=R0&slice_by_product=B01). These + match any CPC code that starts with the value. (The client needs to convert the selected product to CPC.) + + - slice_by_strategy_type - you can specify multiple, and you need to pass the code not the label (which could be + translated). (These are case-insensitive but otherwise must be an exact match.) + + The slice is defined by matching any of the products, AND any of the strategy types (as opposed to OR). + + Translated fields, eg, name, description, are rendered in the currently selected locale if possible. (Except + Country, which has different translations following ISO.) This can be selected in the UI or set using eg, + &language=pt which overrides the UI selection. + + You select the fields you want using the &fields= parameter in the usual way. If you omit the fields parameter all + fields are returned. These are currently the same field list as the normal LZB endpoint, plus the aggregations, + called slice_sum_kcals_consumed, sum_kcals_consumed, kcals_consumed_percent, plus product CPC and product common + name translated. If you omit a field, the statistics for that field will be aggregated together. + + The ordering code is also shared with the normal LZB endpoint, which uses the standard + &ordering= parameter. If none are specified, the results are sorted by the aggregations descending, ie, + biggest percentage first. + + The strategy type codes are: + MilkProduction + ButterProduction + MeatProduction + LivestockSale + CropProduction + FoodPurchase + PaymentInKind + ReliefGiftOther + Hunting + Fishing + WildFoodGathering + OtherCashIncome + OtherPurchase + + The product hierarchy can be retrieved from the classified product endpoint /api/classifiedproduct/. + + You can then filter by any of the calculated fields. To do so, prefix the field name with min_ or max_. + """ -class LivelihoodZoneBaselineReportSerializer(serializers.ModelSerializer): class Meta: model = LivelihoodZoneBaseline fields = ( - "id", - "name", - "description", "source_organization", "source_organization_name", - "livelihood_zone", - "livelihood_zone_name", "country_pk", "country_iso_en_name", + "livelihoodzone_pk", + "livelihood_zone", + "livelihood_zone_name", + "id", + "name", + "description", + "wealth_group_category_code", "main_livelihood_category", "bss", "currency", @@ -1503,133 +1552,33 @@ class Meta: "valid_to_date", # to display "is latest" / "is historic" in the UI for each ref yr "population_source", "population_estimate", - "livelihoodzone_pk", - "livelihood_strategy_pk", "strategy_type", + "livelihood_strategy_pk", "livelihood_activity_pk", - "wealth_group_category_code", - "population_estimate", "product_cpc", "product_common_name", - "slice_sum_kcals_consumed", - "sum_kcals_consumed", - "kcals_consumed_percent", - "sum_income", - "slice_sum_income", - "income_percent", - "sum_expenditure", - "slice_sum_expenditure", - "expenditure_percent", ) - # For each of these aggregates the following calculation columns are added: - # (a) Total at the LZB level (filtered by population, wealth group, etc), eg, sum_kcals_consumed. - # (b) Total for the selected product/strategy type slice, eg, slice_sum_kcals_consumed. - # (c) The percentage the slice represents of the whole, eg, kcals_consumed_percent. - # Filters are automatically created, eg, min_kcals_consumed_percent and max_kcals_consumed_percent. - # If no ordering is specified by the FilterSet, the results are ordered by percent descending in the order here. aggregates = { "kcals_consumed": Sum, "income": Sum, "expenditure": Sum, + "percentage_kcals": Sum, + "kcal_income_sum": Sum( + ( + F("livelihood_strategies__livelihoodactivity__quantity_purchased") + + F("livelihood_strategies__livelihoodactivity__quantity_produced") + ) + * F("livelihood_strategies__product__kcals_per_unit"), + output_field=FloatField(), + ), } - # For each of these pairs, a URL parameter is created "slice_{field}", eg, ?slice_product= - # They can appear zero, one or multiple times in the URL, and define a sub-slice of the row-level data. - # A slice includes activities with ANY of the products, AND, ANY of the strategy types. - # For example: (product=R0 OR product=L0) AND (strategy_type=MilkProd OR strategy_type=CropProd) slice_fields = { "product": "livelihood_strategies__product__cpc__istartswith", - # this parameter must be set to one of values (not labels) from LivelihoodStrategyType, eg, MilkProduction "strategy_type": "livelihood_strategies__strategy_type__iexact", - # TODO: Support filter expressions on the right here, so we can slice on, for example, a - # WealthGroupCharacteristicValue where WealthGroupCharacteristic is some hard-coded value, - # eg, the slice on WGCV where WGC=PhoneOwnership, or on WGCV > 3 where WGC=HouseholdSize, eg: - # {"phone_ownership": lambda val: Q(wgcv__path=val, wgc__path__code="PhoneOwnership")} } - livelihood_zone_name = DictQuerySetField("livelihood_zone_name") - source_organization_name = DictQuerySetField("source_organization_pk") - country_pk = DictQuerySetField("country_pk") - country_iso_en_name = DictQuerySetField("country_iso_en_name") - livelihoodzone_pk = DictQuerySetField("livelihoodzone_pk") - livelihood_strategy_pk = DictQuerySetField("livelihood_strategy_pk") - livelihood_activity_pk = DictQuerySetField("livelihood_activity_pk") - wealth_group_category_code = DictQuerySetField("wealth_group_category_code") - id = DictQuerySetField("id") - name = DictQuerySetField("name") - description = DictQuerySetField("description") - source_organization = DictQuerySetField("source_organization") - livelihood_zone = DictQuerySetField("livelihood_zone") - main_livelihood_category = DictQuerySetField("main_livelihood_category") - bss = DictQuerySetField("bss") - currency = DictQuerySetField("currency") - reference_year_start_date = DictQuerySetField("reference_year_start_date") - reference_year_end_date = DictQuerySetField("reference_year_end_date") - valid_from_date = DictQuerySetField("valid_from_date") - valid_to_date = DictQuerySetField("valid_to_date") - population_source = DictQuerySetField("population_source") - population_estimate = DictQuerySetField("population_estimate") - product_cpc = DictQuerySetField("product_cpc") - product_common_name = DictQuerySetField("product_common_name") - strategy_type = DictQuerySetField("strategy_type") - - slice_sum_kcals_consumed = DictQuerySetField("slice_sum_kcals_consumed") - sum_kcals_consumed = DictQuerySetField("sum_kcals_consumed") - kcals_consumed_percent = DictQuerySetField("kcals_consumed_percent") - - slice_sum_income = DictQuerySetField("slice_sum_income") - sum_income = DictQuerySetField("sum_income") - income_percent = DictQuerySetField("income_percent") - - slice_sum_expenditure = DictQuerySetField("slice_sum_expenditure") - sum_expenditure = DictQuerySetField("sum_expenditure") - expenditure_percent = DictQuerySetField("expenditure_percent") - - def get_fields(self): - """ - User can specify fields= parameter to specify a field list, comma-delimited. - - If the fields parameter is not passed or does not match fields, defaults to self.Meta.fields. - - The aggregated fields self.aggregates are added regardless of user field selection. - """ - field_list = "request" in self.context and self.context["request"].query_params.get("fields", None) - if not field_list: - return super().get_fields() - - # User-provided list of fields - field_names = set(field_list.split(",")) - - # Add the aggregates that are always returned - for field_name, aggregate in self.aggregates.items(): - field_names |= { - field_name, - self.aggregate_field_name(field_name, aggregate), - self.slice_aggregate_field_name(field_name, aggregate), - self.slice_percent_field_name(field_name, aggregate), - } - - # Add the ordering field if specified - ordering = self.context["request"].query_params.get("ordering") - if ordering: - field_names.add(ordering) - - # Remove any that don't match a field as a dict - return {k: v for k, v in super().get_fields().items() if k in field_names} - - def get_field(self, obj, field_name): - """ - Aggregated querysets are a list of dicts. - This is called by AggregatedQuerysetField to get the value from the row dict. - """ - db_field = self.field_to_database_path(field_name) - value = obj.get(db_field, "") - # Get the readable, translated string from the choice key. - if field_name == "strategy_type" and value: - return dict(LivelihoodStrategyType.choices).get(value, value) - return value - @staticmethod def field_to_database_path(field_name): language_code = translation.get_language() @@ -1654,15 +1603,3 @@ def field_to_database_path(field_name): "strategy_type": "livelihood_strategies__strategy_type", "product_common_name": f"livelihood_strategies__product__common_name_{language_code}", }.get(field_name, field_name) - - @staticmethod - def aggregate_field_name(field_name, aggregate): - return f"{aggregate.name.lower()}_{field_name}" # eg, sum_kcals_consumed - - @staticmethod - def slice_aggregate_field_name(field_name, aggregate): - return f"slice_{aggregate.name.lower()}_{field_name}" # eg, slice_sum_kcals_consumed - - @staticmethod - def slice_percent_field_name(field_name, aggregate): - return f"{field_name}_percent" # eg, kcals_consumed_percent diff --git a/apps/baseline/viewsets.py b/apps/baseline/viewsets.py index eccb70a..d4dfed3 100644 --- a/apps/baseline/viewsets.py +++ b/apps/baseline/viewsets.py @@ -1,8 +1,7 @@ from django.apps import apps from django.conf import settings from django.db import models -from django.db.models import F, FloatField, Q, Subquery -from django.db.models.functions import Coalesce, NullIf +from django.db.models import Q, Subquery from django.utils.translation import override from django_filters import rest_framework as filters from django_filters.filters import CharFilter @@ -10,11 +9,10 @@ from rest_framework.renderers import JSONRenderer from rest_framework.response import Response from rest_framework.views import APIView -from rest_framework.viewsets import ModelViewSet from common.fields import translation_fields from common.filters import MultiFieldFilter, UpperCaseFilter -from common.viewsets import BaseModelViewSet +from common.viewsets import AggregatingViewSet, BaseModelViewSet from .models import ( BaselineLivelihoodActivity, @@ -1628,7 +1626,7 @@ class CopingStrategyViewSet(BaseModelViewSet): ] -class LivelihoodZoneBaselineReportViewSet(ModelViewSet): +class LivelihoodZoneBaselineReportViewSet(AggregatingViewSet): """ There are two ‘levels’ of filter needed on this endpoint. The standard ones which are already on the LZB endpoint filter the LZBs that are returned (eg, population range and wealth group). Let’s call them ‘global’ filters. @@ -1708,118 +1706,6 @@ class LivelihoodZoneBaselineReportViewSet(ModelViewSet): serializer_class = LivelihoodZoneBaselineReportSerializer filterset_class = LivelihoodZoneBaselineFilterSet - def get_queryset(self): - """ - Aggregates the values specified in the serializer.aggregates property, grouping and aggregating by any - fields not requested by the user. - """ - - # Add the global filters, eg, wealth group, population range, that apply to global search results AND slices: - queryset = self.filter_queryset(super().get_queryset()) - - # Add the global aggregations, eg, total consumption filtered by wealth group but not by prod/strategy slice: - queryset = queryset.annotate(**self.global_aggregates()) - - # Work out the slice aggregates, eg, slice_sum_kcals_consumed for product/strategy slice: - slice_aggregates = self.get_slice_aggregates() - # Work out the calculations on aggregates, eg, - # kcals_consumed_percent = slice_sum_kcals_consumed * 100 / sum_kcals_consumed - calcs_on_aggregates = self.get_calculations_on_aggregates() - - # Extract the model fields from the combined list of model and calculated fields: - model_fields = self.get_serializer().get_fields().keys() - slice_aggregates.keys() - calcs_on_aggregates.keys() - - # Convert user-friendly field name (eg, livelihood_strategy_pk) into db field path (livelihood_strategies__pk). - obj_field_paths = [self.serializer_class.field_to_database_path(field) for field in model_fields] - - # Get them from the query. The ORM converts this qs.values() call into a SQL `GROUP BY *field_paths` clause. - queryset = queryset.values(*obj_field_paths) - - # The ORM converts these annotations into grouped SELECT ..., SUM(lzb.population), SUM(la.kcals_consumed), etc. - queryset = queryset.annotate(**slice_aggregates, **calcs_on_aggregates) - - # Add the filters on aggregates, eg, kcals_consumed_percent > 50% - queryset = queryset.filter(self.get_filters_on_aggregates()) - - # If no ordering has been specified by the FilterSet, order by final value fields descending: - if not self.request.query_params.get("ordering"): - order_by_value_desc = [ - f"-{self.serializer_class.slice_percent_field_name(field_name, aggregate)}" - for field_name, aggregate in self.serializer_class.aggregates.items() - ] - queryset = queryset.order_by(*order_by_value_desc) - - return queryset - - def get_filters_on_aggregates(self): - # Add filters on aggregates, eg, .filter(kcals_consumed_percent__gte=params.get("min_kcals_consumed_percent")) - filters_on_aggregates = Q() - for url_param_prefix, orm_expr in (("min", "gte"), ("max", "lte")): - for field in self.serializer_class.aggregates.keys(): - url_param_name = f"{url_param_prefix}_{field}_percent" - limit = self.request.query_params.get(url_param_name) - if limit is not None: - filters_on_aggregates &= Q(**{f"{field}_percent__{orm_expr}": float(limit)}) - return filters_on_aggregates - - def global_aggregates(self): - """ - Produced a subquery per LZB-wide statistic that we need, eg, kcals_consumed for selected wealth groups for all - products and strategies. The kcals_consumed for a specific set of products and strategy types is divided by - this figure to obtain a percentage. - """ - global_aggregates = {} - for field_name, aggregate in self.serializer_class.aggregates.items(): - aggregate_field_name = self.serializer_class.aggregate_field_name(field_name, aggregate) - field_path = self.serializer_class.field_to_database_path(field_name) - global_aggregates[aggregate_field_name] = aggregate(field_path, default=0, output_field=FloatField()) - return global_aggregates - - def get_slice_aggregates(self): - # Construct the filters for the slice, for example specific products & strategy types, to apply to each measure - slice_filter = self.get_slice_filters() - # Remove the aggregated fields from the obj field list, and instead add them as sliced aggregate annotations: - slice_aggregates = {} - required_fields = set(self.get_serializer().get_fields().keys()) - required_fields.add(self.request.query_params.get("ordering", "")) - for field_name, aggregate in self.serializer_class.aggregates.items(): - aggregate_field_name = self.serializer_class.slice_aggregate_field_name(field_name, aggregate) - if aggregate_field_name in required_fields: - # Annotate the queryset with the aggregate, eg, slice_sum_kcals_consumed, applying the slice filters. - # This is then divided by, eg, sum_kcals_consumed for the percentage of the slice. - field_path = self.serializer_class.field_to_database_path(field_name) - slice_aggregates[aggregate_field_name] = aggregate( - field_path, filter=slice_filter, default=0, output_field=FloatField() - ) - return slice_aggregates - - def get_slice_filters(self): - # Filters to slice the aggregations, to obtain, eg, the kcals for the selected products/strategy types. - # This is then divided by the total for the LZB for the slice percentage. - slice_filters = Q() - for slice_field, slice_expr in self.serializer_class.slice_fields.items(): - slice_filter = Q() - for item in self.request.query_params.getlist(f"slice_{slice_field}"): - slice_filter |= Q(**{slice_expr: item}) - # Slice must match any of the products AND any of the strategy types (if selected) - slice_filters &= slice_filter - return slice_filters - - def get_calculations_on_aggregates(self): - # Aggregate slice percentages - # TODO: Add complex kcal income calculations from LIAS - calcs_on_aggregates = {} - for field_name, aggregate in self.serializer_class.aggregates.items(): - slice_total = F(self.serializer_class.slice_aggregate_field_name(field_name, aggregate)) - overall_total = F(self.serializer_class.aggregate_field_name(field_name, aggregate)) - # Protect against divide by zero (divide by null returns null without error) - expr = slice_total * 100 / NullIf(overall_total, 0) - # Zero if no LivActivities found for prod/strategy slice, rather than null: - expr = Coalesce(expr, 0, output_field=FloatField()) - slice_percent_field_name = self.serializer_class.slice_percent_field_name(field_name, aggregate) - calcs_on_aggregates[slice_percent_field_name] = expr - return calcs_on_aggregates - MODELS_TO_SEARCH = [ { diff --git a/apps/common/enums.py b/apps/common/enums.py new file mode 100644 index 0000000..207679d --- /dev/null +++ b/apps/common/enums.py @@ -0,0 +1,6 @@ +from enum import StrEnum + + +class AggregationScope(StrEnum): + ROW = "row" + SLICE = "slice" diff --git a/apps/common/serializers.py b/apps/common/serializers.py index 5516b0c..4aad847 100644 --- a/apps/common/serializers.py +++ b/apps/common/serializers.py @@ -1,6 +1,12 @@ +from collections import OrderedDict +from inspect import isclass + from django.contrib.auth.models import User from rest_framework import serializers +from rest_framework.fields import Field +from rest_framework.settings import api_settings +from .enums import AggregationScope from .models import ClassifiedProduct, Country, Currency, UnitOfMeasure, UserProfile @@ -102,3 +108,159 @@ class UserProfileSerializer(serializers.ModelSerializer): class Meta: model = UserProfile fields = ("user", "profile_data") + + +class AggregatingSerializer(serializers.ModelSerializer): + """ + A serializer that works with the AggregatingViewSet to provide aggregating functionality on a viewset. + + See the AggregatingViewSet docstring for a description of usage. The viewset must inherit from AggregatingViewSet, + and specify a serializer that sub-classes this AggregatingSerializer. + + All aggregation configuration is on the following serializer properties: + + * Meta.model: As standard, the base model of the endpoint + * Meta.fields: The maximum list of field names the endpoint can return. These are user-friendly field names, + converted to Django field paths by the field_to_database_path method if necessary. These can span model + joins in the usual way using double underscore. Do not include the auto-generated calculated fields, and + do not add Field class attributes on the serializer class. + * aggregates: A dictionary of {field name: expression or aggregate} pairs. These field names are also converted + to Django field paths by the field_to_database_path method if necessary. These can span model joins using __. + * slice_fields: A dict of {field name: database filter expression} pairs, eg, + {"kcals_consumed": "path__product__cpc__istartswith"}. This implements the parameter slice_by_product. + * field_to_database_path: A method that converts a user-friendly field name used in the results and parameter + names into a Django field path, eg, product_cpc into path__product__cpc. + + Field class attributes are not necessary. The values are rendered as returned by the database query, and this + endpoint is read-only. + """ + + # A dict of {field name: aggregate class} pairs, eg, {"kcals_consumed": Sum}. + # For each of these aggregates the following calculation columns are added: + # (a) Total at the row level (filtered and drilled down), eg, kcals_consumed_sum_row. + # (b) Total for the selected product/strategy type slice, eg, kcals_consumed_sum_slice. + # (c) The percentage the slice represents of the whole, eg, kcals_consumed_sum_slice_percentage_of_row. + # Filters are automatically created for all three, by prefixing min_ or max_ to any calculated field name. + # If no ordering is specified by the FilterSet, the results are ordered by percent descending in the order here. + # If specifying a custom expression, include how the field is aggregated, eg, the 'sum' in kcal_income_sum. + aggregates = {} + + # A dict of {field name: database filter expression} pairs, eg, + # {"product": "livelihood_strategies__product__cpc__istartswith"} + # For each of these pairs, a URL parameter is created "slice_by_{field}", eg, ?slice_by_product= + # They can appear zero, one or multiple times in the URL, and define a sub-slice of the row-level data. + # A slice includes activities with ANY of the products, AND, ANY of the strategy types. + # For example: (product=R0 OR product=L0) AND (strategy_type=MilkProd OR strategy_type=CropProd) + slice_fields = {} + + @staticmethod + def field_to_database_path(field_name): + """ + Convert user-friendly field name specified in Meta.fields, eg, strategy_name to database field path, eg, + join_path__strategies__name. + """ + return field_name + + def get_fields(self): + """ + User can specify a ?fields= URL parameter to specify a field list, comma-delimited. This also + determines the level of aggregation drill-down. + + The user need not specify the aggregate field names - these are all always included. + + The fields in the returned data will be in the same order as specified in this ?fields= parameter. + + Ignores any fields requested not found in self.Meta.fields. + + If the ?fields= URL parameter is not specified, defaults to self.Meta.fields. + """ + field_list = "request" in self.context and self.context["request"].query_params.get("fields", None) + if not field_list: + return {f: Field() for f in self.Meta.fields} + + # User-provided list of fields + field_names = field_list.split(",") + + # Add the ordering field if specified + ordering = self.context["request"].query_params.get(api_settings.ORDERING_PARAM) + if ordering: + field_names.append(ordering) + + # Remove any that don't match one of self.Meta.fields + # Return Field() to save sub-classes having to specify Field class attributes for model and aggregate fields. + return {f: Field() for f in field_names if f in self.Meta.fields} + + def get_aggregate_field_names(self): + """ + The order of the fields here determines the order in which they are returned by the endpoint + (see self.to_representation()). + """ + aggregate_fields = [] + for field_name, aggregate in self.aggregates.items(): + aggregate_fields.extend( + [ + self.get_aggregate_field_name(field_name, aggregate, AggregationScope.ROW), + # eg, kcals_consumed_sum_row + self.get_aggregate_field_name(field_name, aggregate, AggregationScope.SLICE), + # eg, kcals_consumed_sum_slice + self.get_aggregate_field_name(field_name, aggregate, AggregationScope.SLICE, AggregationScope.ROW), + # eg, kcals_consumed_sum_slice_percentage_of_row + ] + ) + return aggregate_fields + + def to_representation(self, instance): + """ + Order the fields in the order they are specified in self.Meta.fields, followed by the aggregates in the order + they are specified in self.aggregates. + + Raises an exception if a field in Meta.fields is not returned by the queryset (after conversion by + field_to_database_path). + + Ignores any aggregate fields that are not returned by the queryset, as this depends on the request. + Slice and percentage fields are only calculated when a slice is specified, for example. + """ + ret = OrderedDict() + for field_name in self.get_fields(): + field_path = self.field_to_database_path(field_name) + ret[field_name] = instance[field_path] + for field_name in self.get_aggregate_field_names(): + if field_name in instance: + ret[field_name] = instance[field_name] + return ret + + @classmethod + def get_aggregate_field_name(cls, field_name, aggregate, scope, percentage_of=None): + """ + Returns the field name for the auto-generated aggregate fields. + + field_name is the original column name, or the name provided in serializer.aggregates. + aggregate is the aggregation class, eg, Sum, or the fully formed expression. It is ignored in the latter case. + scope is the level of aggregation, and can be "row" or "slice". + percentage_of can also be "row" or "slice", for fields that are the scope-level aggregate as a + percentage of the percentage_of field, for example, the percentage a "slice" represents of its "row". + + Returns the field name of the aggregate output column as: + + f"{field_name}_{aggregate}_{scope}[_percentage_of_{percentage_of}]" + + scope and percentage_of may be "row" or "slice". percentage_of is optional. + + Examples: + * expenditure_avg_row + * kcals_consumed_sum_slice + * kcals_consumed_sum_slice_percentage_of_row + """ + + assert scope in AggregationScope + assert percentage_of in AggregationScope or percentage_of is None + + if isclass(aggregate): + field_name += "_" + aggregate.name.lower() # eg, kcals_consumed_sum + + field_name += "_" + scope # eg, kcals_consumed_sum_slice + + if percentage_of: + field_name += "_percentage_of_" + percentage_of # eg, kcals_consumed_sum_slice_percentage_of_row + + return field_name diff --git a/apps/common/viewsets.py b/apps/common/viewsets.py index e36f454..21814df 100644 --- a/apps/common/viewsets.py +++ b/apps/common/viewsets.py @@ -1,4 +1,8 @@ +from inspect import isclass + from django.contrib.auth.models import User +from django.db.models import ExpressionWrapper, F, FloatField, Q +from django.db.models.functions import Coalesce, NullIf from django.utils.text import format_lazy from django.utils.translation import gettext_lazy as _ from django_filters import rest_framework as filters @@ -7,7 +11,11 @@ from rest_framework.exceptions import NotAcceptable from rest_framework.pagination import PageNumberPagination from rest_framework.permissions import BasePermission, IsAuthenticated +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.viewsets import GenericViewSet +from .enums import AggregationScope from .fields import translation_fields from .filters import MultiFieldFilter from .models import ClassifiedProduct, Country, Currency, UnitOfMeasure, UserProfile @@ -390,3 +398,265 @@ def get_queryset(self): return queryset.filter(user=pk) else: return queryset + + +class AggregatingViewSet(GenericViewSet): + """ + A viewset parent class that adds aggregation functionality to a viewset. + + Supports the following URL parameters: + + * fields: Some or all of the model fields listed in the serializer Meta.fields property, comma delimited. + This field list controls the fields returned, and the level of aggregation / drill-down. + + * slice_by_[field]: Produces aggregates of a slice of the data within each row. The serializer defines some + common slice_by properties for easy use, eg, slice_by_product=R0 slices by product__cpc__istartswith. + + * slice_by_[field]__[Django lookup type]: Slices can be calculated on any model field, but the user + must specify a Django ORM lookup type that is suitable for the field type. This allows expert users to define + custom slices that the developer hasn't pre-configured on serializer.slice_fields. + + * min_ and/or max_[field]: Removes any rows where the value is outside the specified min/max/range. + The specific nature of these range parameters are configured on the serializer. + + * filters and ordering parameters provided by the FilterSet + + To use, inherit from this class, and ensure the linked serializer inherits from AggregatingSerializer and includes + the Meta.model, Meta.fields, aggregates, slice_fields properties, and the field_to_database_path method if + necessary. Field class attributes are not necessary. + + The results will include all valid fields specified in the fields parameter, plus values aggregated to the model, + row and slice levels, plus the percentage the slice represents of the row and model. + + The FilterSet filters are applied to all calculations, including at the model-level. + + The fields parameter determines the drill-down of each row from the base model. For example if a child model field + is included, then rows will be disaggregated to show figures broken down by that child model field. Any fields not + included are aggregated together. The FilterSet filters are applied to the rows too. + + The slice_by parameter defines a slice of data within each row. The FilterSet filters are applied to the slice + totals too. + + Calculated field examples: + * expenditure_avg_row # this is the average expenditure for all database records within the row. + * kcals_consumed_sum_slice # this is the sum of kcals consumed for a slice within each row. + * kcals_consumed_sum_slice_percentage_of_row # this is the percentage the slice represents of the row total. + + Filters can be applied on any of these calculated fields, by passing a parameter that prefixes min_ or max_ to + the field name, eg, min_kcals_consumed_sum=1000 or max_kcals_consumed_sum_slice_percentage_of_row=99. + + All configuration is set in the AggregatingSerializer sub-class. The AggregatingViewSet sub-class need only + implement, for example: + + queryset = LivelihoodZoneBaseline.objects.all() + serializer_class = LivelihoodZoneBaselineReportSerializer # must be a sub-class of AggregatingSerializer. + filterset_class = LivelihoodZoneBaselineFilterSet # a standard FilterSet, supports filter and order options. + """ + + pagination_class = ApiOnlyPagination + + def list(self, request, *args, **kwargs): + """ + Aggregates the values specified in the serializer.aggregates property, grouping and aggregating by any + fields not requested by the user. + """ + + # These filters are applied to global, row and slice totals. + queryset = super().get_queryset() + queryset = self.filter_queryset(queryset) + + # Get the field list to group/disaggregate the results by: + # TODO: Should get_fields be on the viewset? This is prematurely instantiating it before we've a queryset. + group_by_fields = self.get_serializer().get_fields().keys() + + # Convert user-friendly field name (eg, livelihood_strategy_pk) into db field path (livelihood_strategies__pk). + group_by_field_paths = [self.serializer_class.field_to_database_path(field) for field in group_by_fields] + + # Get them from the query. The ORM converts this qs.values() call into a SQL `GROUP BY *field_paths` clause. + queryset = queryset.values(*group_by_field_paths) + + # Add the row aggregations, eg, total consumption filtered by wealth group and row but not prd/strtgy slice: + row_aggregates = self.get_aggregates(AggregationScope.ROW) + queryset = queryset.annotate(**row_aggregates) + + # Add the slice aggregates, eg, slice_sum_kcals_consumed for product/strategy slice: + slice_aggregates = self.get_aggregates(AggregationScope.SLICE) + if slice_aggregates: + queryset = queryset.annotate(**slice_aggregates) + + # Add the calculations on aggregates, eg, + # kcals_consumed_sum_slice_percentage_of_row = slice_sum_kcals_consumed * 100 / sum_kcals_consumed + percentage_expressions = self.get_percentage_expressions() + queryset = queryset.annotate(**percentage_expressions) + + # Add the filters on aggregates, eg, kcals_consumed_percent > 50% + queryset = queryset.filter(self.get_filters_by_calculated_fields()) + + # If no ordering has been specified by the FilterSet, order by value descending: + if not self.request.query_params.get(api_settings.ORDERING_PARAM): + if slice_aggregates: + # If a slice has been specified, order by slice_percentage_of_row desc + order_by_value_desc = [ + f"-{self.serializer_class.get_aggregate_field_name(field_name, aggregate, AggregationScope.SLICE, AggregationScope.ROW,)}" # NOQA: E501 + for field_name, aggregate in self.serializer_class.aggregates.items() + ] + else: + # If no slice specified, order by row value desc + order_by_value_desc = [ + f"-{self.serializer_class.get_aggregate_field_name(field_name, aggregate, AggregationScope.ROW,)}" + for field_name, aggregate in self.serializer_class.aggregates.items() + ] + queryset = queryset.order_by(*order_by_value_desc) + + page = self.paginate_queryset(queryset) + if page is not None: + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) + + serializer = self.get_serializer(queryset, many=True) + return Response(serializer.data) + + def get_aggregates(self, scope): + """ + Produces aggregate expressions for scopes row or slice. + """ + assert isinstance(scope, AggregationScope) + + if scope == AggregationScope.SLICE: + slice_filters = self.get_slice_filters() + # Add slice and slice_percentage_of columns only if the user has specified a slice. + if not slice_filters: + return {} + + aggregates = {} + for field_name, aggregate in self.serializer_class.aggregates.items(): + aggregate_field_name = self.serializer_class.get_aggregate_field_name(field_name, aggregate, scope) + + if isclass(aggregate): + field_path = self.serializer_class.field_to_database_path(field_name) + aggregate_args = {"default": 0, "output_field": FloatField()} + if scope == AggregationScope.SLICE: + aggregate_args["filter"] = slice_filters + scoped_aggregate = aggregate(field_path, **aggregate_args) + + else: + scoped_aggregate = aggregate.copy() + scoped_aggregate.default = 0 + if scope == AggregationScope.SLICE: + scoped_aggregate.filter = slice_filters + + aggregates[aggregate_field_name] = scoped_aggregate + return aggregates + + def get_slice_filters(self): + """ + Filters to slice the aggregations, to obtain, eg, the kcals for the selected products/strategy types. + This is then divided by the total for the LZB and row for the slice percentages. + """ + slice_filters = Q() + for slice_field, slice_expr in self.serializer_class.slice_fields.items(): + slice_filter = Q() + for item in self.request.query_params.getlist(f"slice_by_{slice_field}"): + slice_filter |= Q(**{slice_expr: item}) + # Slice must match any of the products AND any of the strategy types (if selected) + slice_filters &= slice_filter + + # Also support slices on any field, but user must specify ORM lookup type in URL parameter name, and prefix + # the parameter with 'slice_by_', for example, ?slice_by_product_cpc__startswith=botswana + # An error is returned if the user uses an inappropriate lookup type for a field. Note that string lookup + # types cannot be used on Foreign Key fields, even if they are string codes - field_to_database_path must + # reach the corresponding primary key, eg, livelihood_strategies__product__cpc not + # livelihood_strategies__product. + for field_name in set(self.serializer_class.Meta.fields) - self.get_serializer().get_fields().keys(): + # fmt: off + # Copied from https://docs.djangoproject.com/en/5.1/ref/models/querysets/#field-lookups + lookup_types = ("exact", "iexact", "contains", "icontains", "in", "gt", "gte", "lt", "lte", "startswith", "istartswith", "endswith", "iendswith", "range", "date", "year", "iso_year", "month", "day", "week", "week_day", "iso_week_day", "quarter", "time", "hour", "minute", "second", "isnull", "regex", "iregex", ) # NOQA: E501 + # fmt: on + for lookup_type in lookup_types: + slice_filter = Q() + field_path = self.serializer_class.field_to_database_path(field_name) + slice_expr = f"{field_path}__{lookup_type}" + for item in self.request.query_params.getlist(f"slice_by_{field_name}__{lookup_type}"): + slice_filter |= Q(**{slice_expr: item}) + # Slice must match ANY of the products AND any of the strategy types AND any of the custom slices + slice_filters &= slice_filter + + return slice_filters + + def get_percentage_expressions(self): + """ + Aggregate slice percentages. + Nb. Cannot calculate in Python as need db-wide filtering and ordering on them. + Could add row percentage of model, eg, lzb. + """ + # TODO: Add complex kcal income calculations from LIAS + percentage_expressions = {} + for field_name, aggregate in self.serializer_class.aggregates.items(): + # Eg, kcals_consumed_sum_slice + slice_field_name = self.serializer_class.get_aggregate_field_name( + field_name, + aggregate, + AggregationScope.SLICE, + ) + slice_total = F(slice_field_name) + + # The denominator, eg, kcals_consumed_sum_row or kcals_consumed_sum_row: + denominator_field_name = self.serializer_class.get_aggregate_field_name( + field_name, + aggregate, + AggregationScope.ROW, + ) + total = F(denominator_field_name) + + # Multiply slice by 100 for percentage, divide by denominator. NullIf protects against divide by zero. + expr = ExpressionWrapper( + ExpressionWrapper(slice_total * 100.0, output_field=FloatField()) + / NullIf(total, 0.0, output_field=FloatField()), + output_field=FloatField(), + ) + + # Zero if no there are no value records retrieved for the slice. + expr = Coalesce(expr, 0.0, output_field=FloatField()) + + # Output field, eg, kcals_consumed_slice_percentage_of_row or kcals_consumed_slice_percentage_of_row: + pct_field_name = self.serializer_class.get_aggregate_field_name( + field_name, + aggregate, + AggregationScope.SLICE, + AggregationScope.ROW, + ) + percentage_expressions[pct_field_name] = ExpressionWrapper(expr, output_field=FloatField()) + return percentage_expressions + + def get_filters_by_calculated_fields(self): + """ + Add min/max range filters. Filters are available for any aggregate or percentage field, by prefixing + min_ or max_ to the aggregate field name. + eg, .filter( + kcals_consumed_sum_slice_percentage_of_row__gte= + params.get("min_kcals_consumed_sum_slice_percentage_of_row"), + kcals_consumed_sum_slice__gte= + params.get("min_kcals_consumed_sum_slice"), + ) + """ + filters_on_aggregates = Q() + for field_name, aggregate in self.serializer_class.aggregates.items(): + for url_param_prefix, orm_expr in (("min", "gte"), ("max", "lte")): + for agg_field_name in ( + self.serializer_class.get_aggregate_field_name( + field_name, + aggregate, + AggregationScope.ROW, + ), + self.serializer_class.get_aggregate_field_name( + field_name, + aggregate, + AggregationScope.SLICE, + AggregationScope.ROW, + ), + ): + url_param_name = f"{url_param_prefix}_{agg_field_name}" + limit = self.request.query_params.get(url_param_name) + if limit is not None: + filters_on_aggregates &= Q(**{f"{agg_field_name}__{orm_expr}": float(limit)}) + return filters_on_aggregates