1
1
import unittest
2
+ import base64
2
3
import json
3
4
from unittest .mock import patch , MagicMock
4
5
5
6
from datadog_lambda .dsm import (
6
7
set_dsm_context ,
7
8
_dsm_set_sqs_context ,
9
+ _dsm_set_sns_context ,
8
10
_get_dsm_context_from_lambda ,
9
11
)
10
12
from datadog_lambda .trigger import EventTypes , _EventSource
@@ -16,6 +18,10 @@ def setUp(self):
16
18
self .mock_dsm_set_sqs_context = patcher .start ()
17
19
self .addCleanup (patcher .stop )
18
20
21
+ patcher = patch ("datadog_lambda.dsm._dsm_set_sns_context" )
22
+ self .mock_dsm_set_sns_context = patcher .start ()
23
+ self .addCleanup (patcher .stop )
24
+
19
25
patcher = patch ("ddtrace.internal.datastreams.data_streams_processor" )
20
26
self .mock_data_streams_processor = patcher .start ()
21
27
self .addCleanup (patcher .stop )
@@ -32,6 +38,13 @@ def setUp(self):
32
38
self .mock_calculate_sqs_payload_size .return_value = 100
33
39
self .addCleanup (patcher .stop )
34
40
41
+ patcher = patch (
42
+ "ddtrace.internal.datastreams.botocore.calculate_sns_payload_size"
43
+ )
44
+ self .mock_calculate_sns_payload_size = patcher .start ()
45
+ self .mock_calculate_sns_payload_size .return_value = 150
46
+ self .addCleanup (patcher .stop )
47
+
35
48
patcher = patch ("ddtrace.internal.datastreams.processor.DsmPathwayCodec.decode" )
36
49
self .mock_dsm_pathway_codec_decode = patcher .start ()
37
50
self .addCleanup (patcher .stop )
@@ -116,6 +129,84 @@ def test_sqs_multiple_records_process_each_record(self):
116
129
self .assertIn ("type:sqs" , tags )
117
130
self .assertEqual (kwargs ["payload_size" ], 100 )
118
131
132
+ def test_sns_event_with_no_records_does_nothing (self ):
133
+ """Test that events where Records is None don't trigger DSM processing"""
134
+ events_with_no_records = [
135
+ {},
136
+ {"Records" : None },
137
+ {"someOtherField" : "value" },
138
+ ]
139
+
140
+ for event in events_with_no_records :
141
+ _dsm_set_sns_context (event )
142
+ self .mock_data_streams_processor .assert_not_called ()
143
+
144
+ def test_sns_event_triggers_dsm_sns_context (self ):
145
+ """Test that SNS event sources trigger the SNS-specific DSM context function"""
146
+ sns_event = {
147
+ "Records" : [
148
+ {
149
+ "EventSource" : "aws:sns" ,
150
+ "Sns" : {
151
+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:my-topic" ,
152
+ "Message" : "Hello from SNS!" ,
153
+ },
154
+ }
155
+ ]
156
+ }
157
+
158
+ event_source = _EventSource (EventTypes .SNS )
159
+ set_dsm_context (sns_event , event_source )
160
+
161
+ self .mock_dsm_set_sns_context .assert_called_once_with (sns_event )
162
+
163
+ def test_sns_multiple_records_process_each_record (self ):
164
+ """Test that each record in an SNS event gets processed individually"""
165
+ multi_record_event = {
166
+ "Records" : [
167
+ {
168
+ "Sns" : {
169
+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic1" ,
170
+ "Message" : "Message 1" ,
171
+ }
172
+ },
173
+ {
174
+ "Sns" : {
175
+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic2" ,
176
+ "Message" : "Message 2" ,
177
+ }
178
+ },
179
+ {
180
+ "Sns" : {
181
+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:topic3" ,
182
+ "Message" : "Message 3" ,
183
+ }
184
+ },
185
+ ]
186
+ }
187
+
188
+ mock_context = MagicMock ()
189
+ self .mock_dsm_pathway_codec_decode .return_value = mock_context
190
+
191
+ _dsm_set_sns_context (multi_record_event )
192
+
193
+ self .assertEqual (mock_context .set_checkpoint .call_count , 3 )
194
+
195
+ calls = mock_context .set_checkpoint .call_args_list
196
+ expected_arns = [
197
+ "arn:aws:sns:us-east-1:123456789012:topic1" ,
198
+ "arn:aws:sns:us-east-1:123456789012:topic2" ,
199
+ "arn:aws:sns:us-east-1:123456789012:topic3" ,
200
+ ]
201
+
202
+ for i , call in enumerate (calls ):
203
+ args , kwargs = call
204
+ tags = args [0 ]
205
+ self .assertIn ("direction:in" , tags )
206
+ self .assertIn (f"topic:{ expected_arns [i ]} " , tags )
207
+ self .assertIn ("type:sns" , tags )
208
+ self .assertEqual (kwargs ["payload_size" ], 150 )
209
+
119
210
120
211
class TestGetDSMContext (unittest .TestCase ):
121
212
def test_sqs_to_lambda_string_value_format (self ):
@@ -164,6 +255,43 @@ def test_sqs_to_lambda_string_value_format(self):
164
255
assert result ["x-datadog-parent-id" ] == "321987654"
165
256
assert result ["dd-pathway-ctx" ] == "test-pathway-ctx"
166
257
258
+ def test_sns_to_lambda_format (self ):
259
+ """Test format: message.Sns.MessageAttributes._datadog.Value.decode() (SNS -> lambda)"""
260
+ trace_context = {
261
+ "x-datadog-trace-id" : "111111111" ,
262
+ "x-datadog-parent-id" : "222222222" ,
263
+ "dd-pathway-ctx" : "test-pathway-ctx" ,
264
+ }
265
+ binary_data = base64 .b64encode (
266
+ json .dumps (trace_context ).encode ("utf-8" )
267
+ ).decode ("utf-8" )
268
+
269
+ sns_lambda_record = {
270
+ "EventSource" : "aws:sns" ,
271
+ "EventSubscriptionArn" : (
272
+ "arn:aws:sns:us-east-1:123456789012:sns-topic:12345678-1234-1234-1234-123456789012"
273
+ ),
274
+ "Sns" : {
275
+ "Type" : "Notification" ,
276
+ "MessageId" : "95df01b4-ee98-5cb9-9903-4c221d41eb5e" ,
277
+ "TopicArn" : "arn:aws:sns:us-east-1:123456789012:sns-topic" ,
278
+ "Subject" : "Test Subject" ,
279
+ "Message" : "Hello from SNS!" ,
280
+ "Timestamp" : "2023-01-01T12:00:00.000Z" ,
281
+ "MessageAttributes" : {
282
+ "_datadog" : {"Type" : "Binary" , "Value" : binary_data }
283
+ },
284
+ },
285
+ }
286
+
287
+ result = _get_dsm_context_from_lambda (sns_lambda_record )
288
+
289
+ assert result is not None
290
+ assert result == trace_context
291
+ assert result ["x-datadog-trace-id" ] == "111111111"
292
+ assert result ["x-datadog-parent-id" ] == "222222222"
293
+ assert result ["dd-pathway-ctx" ] == "test-pathway-ctx"
294
+
167
295
def test_no_message_attributes (self ):
168
296
"""Test message without MessageAttributes returns None."""
169
297
message = {
0 commit comments