12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import contextvars
15
16
import inspect
16
17
import time
17
18
from functools import wraps
31
32
32
33
T = TypeVar ("T" , covariant = True )
33
34
tracer = trace .get_tracer (__name__ )
35
+ _current_span_context : contextvars .ContextVar = contextvars .ContextVar (
36
+ "current_span_context"
37
+ )
34
38
35
39
36
40
def get_remote_func (func ): # type: ignore
@@ -74,7 +78,12 @@ def task(
74
78
75
79
def task_wrapper (func ): # type: ignore
76
80
async def async_exec (* args : Any , ** kwargs : Any ) -> Any :
77
- with tracer .start_as_current_span (name = func .__qualname__ ) as span :
81
+ parent_ctx = _current_span_context .get (None )
82
+ with tracer .start_as_current_span (
83
+ name = func .__qualname__ , context = parent_ctx
84
+ ) as span :
85
+ _current_span_context .set (trace .set_span_in_context (span ))
86
+
78
87
input = _update_kwargs (args , kwargs , func )
79
88
try :
80
89
result = await (get_remote_func (func ) if distributed else func )(
@@ -98,7 +107,12 @@ async def async_exec(*args: Any, **kwargs: Any) -> Any:
98
107
raise e
99
108
100
109
def sync_exec (* args : Any , ** kwargs : Any ) -> Any :
101
- with tracer .start_as_current_span (name = func .__qualname__ ) as span :
110
+ parent_ctx = _current_span_context .get (None )
111
+ with tracer .start_as_current_span (
112
+ name = func .__qualname__ , context = parent_ctx
113
+ ) as span :
114
+ _current_span_context .set (trace .set_span_in_context (span ))
115
+
102
116
input = _update_kwargs (args , kwargs , func )
103
117
try :
104
118
result = func (* args , ** kwargs )
@@ -121,9 +135,13 @@ def sync_exec(*args: Any, **kwargs: Any) -> Any:
121
135
122
136
@wraps (func )
123
137
async def async_iter_task (* args : Any , ** kwargs : Any ) -> AsyncGenerator [T , None ]:
138
+ parent_ctx = _current_span_context .get (None )
124
139
span = tracer .start_span (
125
- name = func .__qualname__ + ".first_iter" , start_time = time .time_ns ()
140
+ name = func .__qualname__ + ".first_iter" ,
141
+ start_time = time .time_ns (),
142
+ context = parent_ctx ,
126
143
)
144
+ _current_span_context .set (trace .set_span_in_context (span ))
127
145
input = _update_kwargs (args , kwargs , func )
128
146
try :
129
147
async for i , resp in aenumerate (func (* args , ** kwargs )): # type: ignore
@@ -142,12 +160,17 @@ async def async_iter_task(*args: Any, **kwargs: Any) -> AsyncGenerator[T, None]:
142
160
custom_attributes = custom_attributes ,
143
161
)
144
162
span .end (end_time = time .time_ns ())
163
+ _current_span_context .set (parent_ctx )
145
164
yield resp
146
165
147
166
if trace_all :
167
+ parent_ctx = _current_span_context .get ()
148
168
span = tracer .start_span (
149
- name = func .__qualname__ , start_time = time .time_ns ()
169
+ name = func .__qualname__ ,
170
+ start_time = time .time_ns (),
171
+ context = parent_ctx ,
150
172
)
173
+ _current_span_context .set (trace .set_span_in_context (span ))
151
174
except Exception as e :
152
175
if not trace_all :
153
176
span = tracer .start_span (
@@ -160,7 +183,12 @@ async def async_iter_task(*args: Any, **kwargs: Any) -> AsyncGenerator[T, None]:
160
183
161
184
@wraps (func )
162
185
def iter_task (* args : Any , ** kwargs : Any ) -> Iterable [T ]:
163
- span = tracer .start_span (name = func .__qualname__ , start_time = time .time_ns ())
186
+ parent_ctx = _current_span_context .get (None )
187
+ span = tracer .start_span (
188
+ name = func .__qualname__ , start_time = time .time_ns (), context = parent_ctx
189
+ )
190
+ _current_span_context .set (trace .set_span_in_context (span ))
191
+
164
192
input = _update_kwargs (args , kwargs , func )
165
193
try :
166
194
for i , resp in enumerate (func (* args , ** kwargs )):
@@ -179,11 +207,16 @@ def iter_task(*args: Any, **kwargs: Any) -> Iterable[T]:
179
207
custom_attributes = custom_attributes ,
180
208
)
181
209
span .end (end_time = time .time_ns ())
210
+ _current_span_context .set (parent_ctx )
182
211
yield resp
183
212
if trace_all :
213
+ parent_ctx = _current_span_context .get ()
184
214
span = tracer .start_span (
185
- name = func .__qualname__ , start_time = time .time_ns ()
215
+ name = func .__qualname__ ,
216
+ start_time = time .time_ns (),
217
+ context = parent_ctx ,
186
218
)
219
+ _current_span_context .set (trace .set_span_in_context (span ))
187
220
except Exception as e :
188
221
if not trace_all :
189
222
span = tracer .start_span (
0 commit comments