Skip to content

Commit 29c216d

Browse files
Merge pull request #99 from yinchengfeng-bytedance/fix/trace
fix(arkitect): trace span add context support
2 parents 79b3321 + 420fa3e commit 29c216d

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

arkitect/telemetry/trace/wrapper.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextvars
1516
import inspect
1617
import time
1718
from functools import wraps
@@ -31,6 +32,9 @@
3132

3233
T = TypeVar("T", covariant=True)
3334
tracer = trace.get_tracer(__name__)
35+
_current_span_context: contextvars.ContextVar = contextvars.ContextVar(
36+
"current_span_context"
37+
)
3438

3539

3640
def get_remote_func(func): # type: ignore
@@ -74,7 +78,12 @@ def task(
7478

7579
def task_wrapper(func): # type: ignore
7680
async def async_exec(*args: Any, **kwargs: Any) -> Any:
77-
with tracer.start_as_current_span(name=func.__qualname__) as span:
81+
parent_ctx = _current_span_context.get(None)
82+
with tracer.start_as_current_span(
83+
name=func.__qualname__, context=parent_ctx
84+
) as span:
85+
_current_span_context.set(trace.set_span_in_context(span))
86+
7887
input = _update_kwargs(args, kwargs, func)
7988
try:
8089
result = await (get_remote_func(func) if distributed else func)(
@@ -98,7 +107,12 @@ async def async_exec(*args: Any, **kwargs: Any) -> Any:
98107
raise e
99108

100109
def sync_exec(*args: Any, **kwargs: Any) -> Any:
101-
with tracer.start_as_current_span(name=func.__qualname__) as span:
110+
parent_ctx = _current_span_context.get(None)
111+
with tracer.start_as_current_span(
112+
name=func.__qualname__, context=parent_ctx
113+
) as span:
114+
_current_span_context.set(trace.set_span_in_context(span))
115+
102116
input = _update_kwargs(args, kwargs, func)
103117
try:
104118
result = func(*args, **kwargs)
@@ -121,9 +135,13 @@ def sync_exec(*args: Any, **kwargs: Any) -> Any:
121135

122136
@wraps(func)
123137
async def async_iter_task(*args: Any, **kwargs: Any) -> AsyncGenerator[T, None]:
138+
parent_ctx = _current_span_context.get(None)
124139
span = tracer.start_span(
125-
name=func.__qualname__ + ".first_iter", start_time=time.time_ns()
140+
name=func.__qualname__ + ".first_iter",
141+
start_time=time.time_ns(),
142+
context=parent_ctx,
126143
)
144+
_current_span_context.set(trace.set_span_in_context(span))
127145
input = _update_kwargs(args, kwargs, func)
128146
try:
129147
async for i, resp in aenumerate(func(*args, **kwargs)): # type: ignore
@@ -142,12 +160,17 @@ async def async_iter_task(*args: Any, **kwargs: Any) -> AsyncGenerator[T, None]:
142160
custom_attributes=custom_attributes,
143161
)
144162
span.end(end_time=time.time_ns())
163+
_current_span_context.set(parent_ctx)
145164
yield resp
146165

147166
if trace_all:
167+
parent_ctx = _current_span_context.get()
148168
span = tracer.start_span(
149-
name=func.__qualname__, start_time=time.time_ns()
169+
name=func.__qualname__,
170+
start_time=time.time_ns(),
171+
context=parent_ctx,
150172
)
173+
_current_span_context.set(trace.set_span_in_context(span))
151174
except Exception as e:
152175
if not trace_all:
153176
span = tracer.start_span(
@@ -160,7 +183,12 @@ async def async_iter_task(*args: Any, **kwargs: Any) -> AsyncGenerator[T, None]:
160183

161184
@wraps(func)
162185
def iter_task(*args: Any, **kwargs: Any) -> Iterable[T]:
163-
span = tracer.start_span(name=func.__qualname__, start_time=time.time_ns())
186+
parent_ctx = _current_span_context.get(None)
187+
span = tracer.start_span(
188+
name=func.__qualname__, start_time=time.time_ns(), context=parent_ctx
189+
)
190+
_current_span_context.set(trace.set_span_in_context(span))
191+
164192
input = _update_kwargs(args, kwargs, func)
165193
try:
166194
for i, resp in enumerate(func(*args, **kwargs)):
@@ -179,11 +207,16 @@ def iter_task(*args: Any, **kwargs: Any) -> Iterable[T]:
179207
custom_attributes=custom_attributes,
180208
)
181209
span.end(end_time=time.time_ns())
210+
_current_span_context.set(parent_ctx)
182211
yield resp
183212
if trace_all:
213+
parent_ctx = _current_span_context.get()
184214
span = tracer.start_span(
185-
name=func.__qualname__, start_time=time.time_ns()
215+
name=func.__qualname__,
216+
start_time=time.time_ns(),
217+
context=parent_ctx,
186218
)
219+
_current_span_context.set(trace.set_span_in_context(span))
187220
except Exception as e:
188221
if not trace_all:
189222
span = tracer.start_span(

tests/ut/test_task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_sync_nested(self):
9898
span = json.loads(line)
9999
spans.append(span)
100100
self.assertEqual(len(spans), 3)
101-
self.assertEqual(spans[2]["parent_id"], None)
101+
self.assertNotEqual(spans[2]["parent_id"], None)
102102
self.assertEqual(spans[1]["parent_id"], spans[2]["context"]["span_id"])
103103
self.assertEqual(spans[0]["parent_id"], spans[1]["context"]["span_id"])
104104

@@ -122,7 +122,7 @@ def test_async_sync_mixed(self):
122122
span = json.loads(line)
123123
spans.append(span)
124124
self.assertEqual(len(spans), 4)
125-
self.assertEqual(spans[3]["parent_id"], None)
125+
self.assertEqual(spans[3]["parent_id"], spans[0]["context"]["span_id"])
126126
self.assertEqual(spans[2]["parent_id"], None)
127127
self.assertEqual(spans[1]["parent_id"], spans[2]["context"]["span_id"])
128128
self.assertEqual(spans[0]["parent_id"], spans[1]["context"]["span_id"])
@@ -147,7 +147,7 @@ async def assistant_chat():
147147
span = json.loads(line)
148148
spans.append(span)
149149
self.assertEqual(len(spans), 3)
150-
self.assertEqual(spans[2]["parent_id"], None)
150+
self.assertNotEqual(spans[2]["parent_id"], None)
151151
self.assertEqual(spans[1]["parent_id"], spans[2]["context"]["span_id"])
152152
self.assertEqual(spans[0]["parent_id"], spans[1]["context"]["span_id"])
153153

0 commit comments

Comments
 (0)