Skip to content

Commit 4bdd3df

Browse files
authored
Allow custom sizes, dim order and strides for tensor view. (#14944)
Summary: The `make_tensor_ptr(TensrPtr)` overload creates a view on an existing `Tensor`. Here we provide a way for users to customize the shape, etc. so that they can easily do squeeze/unsqueeze and other convenient operations. Differential Revision: D84259597
1 parent cf13b9a commit 4bdd3df

File tree

2 files changed

+258
-32
lines changed

2 files changed

+258
-32
lines changed

extension/tensor/tensor_ptr.h

Lines changed: 57 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -323,26 +323,54 @@ inline TensorPtr make_tensor_ptr(
323323
}
324324

325325
/**
326-
* Creates a TensorPtr to manage a new Tensor with the same properties
327-
* as the given Tensor, sharing the same data without owning it.
326+
* Creates a TensorPtr to manage a new Tensor that aliases the given Tensor's
327+
* storage, with optional metadata overrides. Shape dynamism is inherited from
328+
* the source tensor.
328329
*
329-
* @param tensor The Tensor whose properties are used to create a new TensorPtr.
330-
* @return A new TensorPtr managing a Tensor with the same properties as the
331-
* original.
330+
* If an override is provided (non-empty), it is passed as-is. If an override is
331+
* empty, the corresponding metadata is reused from the source tensor when it
332+
* fits; otherwise it is left empty for the core factory to derive a valid
333+
* configuration. If `dim_order` is empty but `strides` is provided, `dim_order`
334+
* is left empty so the core may infer it from the provided strides.
335+
*
336+
* @param tensor The source tensor to alias.
337+
* @param sizes Optional sizes override.
338+
* @param dim_order Optional dimension order override.
339+
* @param strides Optional strides override.
340+
* @return A TensorPtr aliasing the same storage with requested metadata.
332341
*/
333-
inline TensorPtr make_tensor_ptr(const executorch::aten::Tensor& tensor) {
334-
return make_tensor_ptr(
335-
std::vector<executorch::aten::SizesType>(
336-
tensor.sizes().begin(), tensor.sizes().end()),
337-
tensor.mutable_data_ptr(),
342+
inline TensorPtr make_tensor_ptr(
343+
const executorch::aten::Tensor& tensor,
344+
std::vector<executorch::aten::SizesType> sizes = {},
345+
std::vector<executorch::aten::DimOrderType> dim_order = {},
346+
std::vector<executorch::aten::StridesType> strides = {}) {
347+
if (sizes.empty()) {
348+
sizes.assign(tensor.sizes().begin(), tensor.sizes().end());
349+
}
350+
const auto same_rank = sizes.size() == static_cast<size_t>(tensor.dim());
351+
const auto same_shape = same_rank &&
352+
std::equal(sizes.begin(), sizes.end(), tensor.sizes().begin());
353+
const auto element_count =
354+
executorch::aten::compute_numel(sizes.data(), sizes.size());
355+
const auto parent_element_count = tensor.numel();
356+
ET_CHECK_MSG(
357+
element_count <= parent_element_count,
358+
"Requested view has %zd elements, but source tensor only has %zd.",
359+
static_cast<ssize_t>(element_count),
360+
static_cast<ssize_t>(parent_element_count));
338361
#ifndef USE_ATEN_LIB
339-
std::vector<executorch::aten::DimOrderType>(
340-
tensor.dim_order().begin(), tensor.dim_order().end()),
341-
#else // USE_ATEN_LIB
342-
{},
362+
if (dim_order.empty() && strides.empty() && same_rank) {
363+
dim_order.assign(tensor.dim_order().begin(), tensor.dim_order().end());
364+
}
343365
#endif // USE_ATEN_LIB
344-
std::vector<executorch::aten::StridesType>(
345-
tensor.strides().begin(), tensor.strides().end()),
366+
if (strides.empty() && dim_order.empty() && same_shape) {
367+
strides.assign(tensor.strides().begin(), tensor.strides().end());
368+
}
369+
return make_tensor_ptr(
370+
std::move(sizes),
371+
tensor.mutable_data_ptr(),
372+
std::move(dim_order),
373+
std::move(strides),
346374
tensor.scalar_type()
347375
#ifndef USE_ATEN_LIB
348376
,
@@ -352,21 +380,21 @@ inline TensorPtr make_tensor_ptr(const executorch::aten::Tensor& tensor) {
352380
}
353381

354382
/**
355-
* Creates a TensorPtr to manage a new Tensor with the same properties
356-
* as the Tensor referenced by the given TensorPtr, sharing the same data
357-
* without owning it.
383+
* Convenience overload identical to make_tensor_ptr(*tensor_ptr, ...).
358384
*
359-
* This is a convenience overload equivalent to make_tensor_ptr(*tensor_ptr).
360-
* It does not extend the lifetime of the underlying buffer; if the original
361-
* owner releases the storage, all views aliasing it become dangling.
362-
*
363-
* @param tensor_ptr The TensorPtr whose underlying Tensor is used to initialize
364-
* the returned view.
365-
* @return A new TensorPtr managing a Tensor with the same properties as the
366-
* original.
385+
* @param tensor_ptr The source tensor pointer to alias.
386+
* @param sizes Optional sizes override.
387+
* @param dim_order Optional dimension order override.
388+
* @param strides Optional strides override.
389+
* @return A TensorPtr aliasing the same storage with requested metadata.
367390
*/
368-
inline TensorPtr make_tensor_ptr(const TensorPtr& tensor_ptr) {
369-
return make_tensor_ptr(*tensor_ptr);
391+
inline TensorPtr make_tensor_ptr(
392+
const TensorPtr& tensor_ptr,
393+
std::vector<executorch::aten::SizesType> sizes = {},
394+
std::vector<executorch::aten::DimOrderType> dim_order = {},
395+
std::vector<executorch::aten::StridesType> strides = {}) {
396+
return make_tensor_ptr(
397+
*tensor_ptr, std::move(sizes), std::move(dim_order), std::move(strides));
370398
}
371399

372400
/**

extension/tensor/test/tensor_ptr_test.cpp

Lines changed: 201 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,204 @@ TEST_F(TensorPtrTest, MakeTensorPtrFromExistingTensorInt32) {
357357
EXPECT_EQ(new_tensor->scalar_type(), executorch::aten::ScalarType::Int);
358358
}
359359

360+
TEST_F(TensorPtrTest, MakeViewOverrideSizesRankIncrease) {
361+
std::vector<float> data = {1, 2, 3, 4, 5, 6};
362+
auto tensor = make_tensor_ptr({2, 3}, std::move(data));
363+
auto view = make_tensor_ptr(tensor, {1, 2, 3});
364+
365+
EXPECT_EQ(view->dim(), 3);
366+
EXPECT_EQ(view->size(0), 1);
367+
EXPECT_EQ(view->size(1), 2);
368+
EXPECT_EQ(view->size(2), 3);
369+
EXPECT_EQ(view->const_data_ptr<float>(), tensor->const_data_ptr<float>());
370+
EXPECT_EQ(view->strides()[0], 6);
371+
EXPECT_EQ(view->strides()[1], 3);
372+
EXPECT_EQ(view->strides()[2], 1);
373+
}
374+
375+
TEST_F(TensorPtrTest, MakeViewOverrideSizesSameRankRecomputesStrides) {
376+
float data[12] = {0};
377+
auto tensor = make_tensor_ptr({3, 4}, data);
378+
auto view = make_tensor_ptr(tensor, {4, 3});
379+
380+
EXPECT_EQ(view->dim(), 2);
381+
EXPECT_EQ(view->size(0), 4);
382+
EXPECT_EQ(view->size(1), 3);
383+
EXPECT_EQ(view->strides()[0], 3);
384+
EXPECT_EQ(view->strides()[1], 1);
385+
}
386+
387+
TEST_F(TensorPtrTest, MakeViewOverrideDimOrderOnly) {
388+
float data[6] = {0};
389+
auto tensor = make_tensor_ptr({2, 3}, data);
390+
auto view = make_tensor_ptr(tensor, {}, {1, 0}, {});
391+
392+
EXPECT_EQ(view->dim(), 2);
393+
EXPECT_EQ(view->size(0), 2);
394+
EXPECT_EQ(view->size(1), 3);
395+
EXPECT_EQ(view->strides()[0], 1);
396+
EXPECT_EQ(view->strides()[1], 2);
397+
}
398+
399+
TEST_F(TensorPtrTest, MakeViewOverrideStridesOnlyInfersDimOrder) {
400+
float data[12] = {0};
401+
auto tensor = make_tensor_ptr({3, 4}, data);
402+
auto view = make_tensor_ptr(tensor, {}, {}, {1, 3});
403+
404+
EXPECT_EQ(view->dim(), 2);
405+
EXPECT_EQ(view->size(0), 3);
406+
EXPECT_EQ(view->size(1), 4);
407+
EXPECT_EQ(view->strides()[0], 1);
408+
EXPECT_EQ(view->strides()[1], 3);
409+
}
410+
411+
TEST_F(TensorPtrTest, MakeViewReuseMetadataWhenShapeSame) {
412+
float data[12] = {0};
413+
auto tensor = make_tensor_ptr({3, 4}, data, {1, 0}, {1, 3});
414+
auto view = make_tensor_ptr(tensor, {3, 4});
415+
416+
EXPECT_EQ(view->dim(), 2);
417+
EXPECT_EQ(view->size(0), 3);
418+
EXPECT_EQ(view->size(1), 4);
419+
EXPECT_EQ(view->strides()[0], 1);
420+
EXPECT_EQ(view->strides()[1], 3);
421+
}
422+
423+
TEST_F(TensorPtrTest, MakeViewShapeChangeWithExplicitOldStridesExpectDeath) {
424+
float data[12] = {0};
425+
auto tensor = make_tensor_ptr({3, 4}, data);
426+
std::vector<executorch::aten::StridesType> old_strides(
427+
tensor->strides().begin(), tensor->strides().end());
428+
429+
ET_EXPECT_DEATH(
430+
{ auto _ = make_tensor_ptr(tensor, {2, 6}, {}, old_strides); }, "");
431+
}
432+
433+
TEST_F(TensorPtrTest, MakeViewInvalidDimOrderExpectDeath) {
434+
float data[12] = {0};
435+
auto tensor = make_tensor_ptr({3, 4}, data);
436+
437+
ET_EXPECT_DEATH(
438+
{ auto _ = make_tensor_ptr(tensor, {3, 4}, {2, 1}, {1, 4}); }, "");
439+
}
440+
441+
TEST_F(TensorPtrTest, MakeViewFromTensorPtrConvenienceOverload) {
442+
float data[12] = {0};
443+
auto tensor = make_tensor_ptr({3, 4}, data);
444+
auto view = make_tensor_ptr(tensor, {}, {1, 0}, {});
445+
446+
EXPECT_EQ(view->dim(), 2);
447+
EXPECT_EQ(view->size(0), 3);
448+
EXPECT_EQ(view->size(1), 4);
449+
EXPECT_EQ(view->strides()[0], 1);
450+
EXPECT_EQ(view->strides()[1], 3);
451+
}
452+
453+
TEST_F(TensorPtrTest, MakeViewRankDecreaseFlatten) {
454+
float data[6] = {1, 2, 3, 4, 5, 6};
455+
auto tensor = make_tensor_ptr(
456+
{2, 3},
457+
data,
458+
{},
459+
{},
460+
executorch::aten::ScalarType::Float,
461+
executorch::aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
462+
auto view = make_tensor_ptr(tensor, {6});
463+
EXPECT_EQ(view->dim(), 1);
464+
EXPECT_EQ(view->size(0), 6);
465+
EXPECT_EQ(view->strides()[0], 1);
466+
EXPECT_NE(tensor->unsafeGetTensorImpl(), view->unsafeGetTensorImpl());
467+
EXPECT_EQ(resize_tensor_ptr(view, {3, 2}), Error::NotSupported);
468+
EXPECT_EQ(view->dim(), 1);
469+
EXPECT_EQ(view->size(0), 6);
470+
}
471+
472+
TEST_F(TensorPtrTest, MakeViewFromScalarAliasAnd1D) {
473+
float scalar_value = 7.f;
474+
auto tensor = make_tensor_ptr({}, &scalar_value);
475+
auto alias = make_tensor_ptr(tensor);
476+
EXPECT_EQ(alias->dim(), 0);
477+
EXPECT_EQ(alias->numel(), 1);
478+
auto reshaped = make_tensor_ptr(tensor, {1});
479+
EXPECT_EQ(reshaped->dim(), 1);
480+
EXPECT_EQ(reshaped->size(0), 1);
481+
EXPECT_EQ(reshaped->strides()[0], 1);
482+
ET_EXPECT_DEATH({ auto unused = make_tensor_ptr(tensor, {}, {0}, {}); }, "");
483+
ET_EXPECT_DEATH({ auto unused = make_tensor_ptr(tensor, {}, {}, {1}); }, "");
484+
}
485+
486+
TEST_F(TensorPtrTest, MakeViewExplicitDimOrderAndStridesShapeChange) {
487+
float data[6] = {0};
488+
auto tensor = make_tensor_ptr({2, 3}, data);
489+
auto view = make_tensor_ptr(tensor, {3, 2}, {1, 0}, {1, 3});
490+
EXPECT_EQ(view->dim(), 2);
491+
EXPECT_EQ(view->size(0), 3);
492+
EXPECT_EQ(view->size(1), 2);
493+
EXPECT_EQ(view->strides()[0], 1);
494+
EXPECT_EQ(view->strides()[1], 3);
495+
}
496+
497+
TEST_F(TensorPtrTest, TensorUint8dataInt16Type) {
498+
std::vector<int16_t> int16_values = {-1, 2, -3, 4};
499+
auto byte_pointer = reinterpret_cast<const uint8_t*>(int16_values.data());
500+
std::vector<uint8_t> byte_data(
501+
byte_pointer, byte_pointer + int16_values.size() * sizeof(int16_t));
502+
auto tensor = make_tensor_ptr(
503+
{4}, std::move(byte_data), executorch::aten::ScalarType::Short);
504+
EXPECT_EQ(tensor->dim(), 1);
505+
EXPECT_EQ(tensor->size(0), 4);
506+
auto int16_data = tensor->const_data_ptr<int16_t>();
507+
EXPECT_EQ(int16_data[0], -1);
508+
EXPECT_EQ(int16_data[1], 2);
509+
EXPECT_EQ(int16_data[2], -3);
510+
EXPECT_EQ(int16_data[3], 4);
511+
}
512+
513+
TEST_F(TensorPtrTest, MakeView3DDimOrderOnly) {
514+
float data[24] = {0};
515+
auto tensor = make_tensor_ptr({2, 3, 4}, data);
516+
auto view = make_tensor_ptr(tensor, {}, {2, 0, 1}, {});
517+
EXPECT_EQ(view->dim(), 3);
518+
EXPECT_EQ(view->size(0), 2);
519+
EXPECT_EQ(view->size(1), 3);
520+
EXPECT_EQ(view->size(2), 4);
521+
EXPECT_EQ(view->strides()[0], 3);
522+
EXPECT_EQ(view->strides()[1], 1);
523+
EXPECT_EQ(view->strides()[2], 6);
524+
}
525+
526+
#ifndef USE_ATEN_LIB
527+
TEST_F(TensorPtrTest, MakeViewDynamismPropagationResizeAlias) {
528+
float data[12] = {0};
529+
auto tensor = make_tensor_ptr(
530+
{3, 4},
531+
data,
532+
{},
533+
{},
534+
executorch::aten::ScalarType::Float,
535+
executorch::aten::TensorShapeDynamism::DYNAMIC_UNBOUND);
536+
auto alias = make_tensor_ptr(tensor);
537+
EXPECT_EQ(resize_tensor_ptr(alias, {2, 6}), Error::Ok);
538+
EXPECT_EQ(alias->size(0), 2);
539+
EXPECT_EQ(alias->size(1), 6);
540+
EXPECT_EQ(tensor->size(0), 3);
541+
EXPECT_EQ(tensor->size(1), 4);
542+
}
543+
544+
TEST_F(TensorPtrTest, MakeViewSameRankShapeChangeCopiesDimOrder) {
545+
float data[24] = {0};
546+
auto tensor = make_tensor_ptr({2, 3, 4}, data, {2, 0, 1}, {3, 1, 6});
547+
auto view = make_tensor_ptr(tensor, {4, 2, 3});
548+
EXPECT_EQ(view->dim(), 3);
549+
EXPECT_EQ(view->size(0), 4);
550+
EXPECT_EQ(view->size(1), 2);
551+
EXPECT_EQ(view->size(2), 3);
552+
EXPECT_EQ(view->strides()[0], 2);
553+
EXPECT_EQ(view->strides()[1], 1);
554+
EXPECT_EQ(view->strides()[2], 8);
555+
}
556+
#endif
557+
360558
TEST_F(TensorPtrTest, CloneTensorPtrFromExistingTensorInt32) {
361559
std::vector<int32_t> data = {1, 2, 3, 4};
362560
auto tensor = make_tensor_ptr({2, 2}, std::move(data));
@@ -803,7 +1001,7 @@ TEST_F(TensorPtrTest, TensorDeducedScalarType) {
8031001
EXPECT_EQ(tensor->const_data_ptr<double>()[3], 4.0);
8041002
}
8051003

806-
TEST_F(TensorPtrTest, TensorUint8BufferWithFloatScalarType) {
1004+
TEST_F(TensorPtrTest, TensorUint8dataWithFloatScalarType) {
8071005
std::vector<uint8_t> data(
8081006
4 * executorch::aten::elementSize(executorch::aten::ScalarType::Float));
8091007

@@ -827,14 +1025,14 @@ TEST_F(TensorPtrTest, TensorUint8BufferWithFloatScalarType) {
8271025
EXPECT_EQ(tensor->const_data_ptr<float>()[3], 4.0f);
8281026
}
8291027

830-
TEST_F(TensorPtrTest, TensorUint8BufferTooSmallExpectDeath) {
1028+
TEST_F(TensorPtrTest, TensorUint8dataTooSmallExpectDeath) {
8311029
std::vector<uint8_t> data(
8321030
2 * executorch::aten::elementSize(executorch::aten::ScalarType::Float));
8331031
ET_EXPECT_DEATH(
8341032
{ auto tensor = make_tensor_ptr({2, 2}, std::move(data)); }, "");
8351033
}
8361034

837-
TEST_F(TensorPtrTest, TensorUint8BufferTooLargeExpectDeath) {
1035+
TEST_F(TensorPtrTest, TensorUint8dataTooLargeExpectDeath) {
8381036
std::vector<uint8_t> data(
8391037
5 * executorch::aten::elementSize(executorch::aten::ScalarType::Float));
8401038
ET_EXPECT_DEATH({ auto _ = make_tensor_ptr({2, 2}, std::move(data)); }, "");

0 commit comments

Comments
 (0)