Skip to content

Use custom collector to avoid intermediate array allocations #2168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Dapper.StrongName/Dapper.StrongName.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

<ItemGroup Condition=" '$(TargetFramework)' == 'net461'">
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" />
<PackageReference Include="System.Memory" />
</ItemGroup>

<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" />
<PackageReference Include="System.Reflection.Emit.Lightweight" />
<PackageReference Include="System.Memory" />
</ItemGroup>
</Project>
149 changes: 149 additions & 0 deletions Dapper/CollectorT.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.ComponentModel;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
using System.Runtime.CompilerServices;

namespace Dapper;

/// <summary>
/// Allows efficient collection of data into lists, arrays, etc.
/// </summary>
/// <remarks>This is a mutable struct; treat with caution.</remarks>
/// <typeparam name="T"></typeparam>
[DebuggerDisplay($"{{{nameof(ToString)}(),nq}}")]
[SuppressMessage("Usage", "CA2231:Overload operator equals on overriding value type Equals", Justification = "Equality not supported")]
public struct Collector<T>
{
/// <summary>
/// Create a new collector using a size hint for the number of elements expected.
/// </summary>
public Collector(int capacityHint)
{
oversized = capacityHint > 0 ? ArrayPool<T>.Shared.Rent(capacityHint) : [];
capacity = oversized.Length;
}

/// <inheritdoc/>
public readonly override string ToString() => $"Count: {count}";

/// <inheritdoc/>
[Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
public readonly override bool Equals([NotNullWhen(true)] object? obj) => throw new NotSupportedException();

/// <inheritdoc/>
[Browsable(false), EditorBrowsable(EditorBrowsableState.Never)]
public readonly override int GetHashCode() => throw new NotSupportedException();

private T[] oversized;
private int count, capacity;

/// <summary>
/// Gets the current capacity of the backing buffer of this instance.
/// </summary>
internal readonly int Capacity => capacity;

/// <summary>
/// Gets the number of elements represented by this instance.
/// </summary>
public readonly int Count => count;

/// <summary>
/// Gets the underlying elements represented by this instance.
/// </summary>
public readonly Span<T> Span => new(oversized, 0, count);

/// <summary>
/// Gets the underlying elements represented by this instance.
/// </summary>
public readonly ArraySegment<T> ArraySegment => new(oversized, 0, count);

/// <summary>
/// Gets the element at the specified index.
/// </summary>
public readonly ref T this[int index]
{
get
{
return ref index >= 0 & index < count ? ref oversized[index] : ref OutOfRange();

static ref T OutOfRange() => throw new ArgumentOutOfRangeException(nameof(index));
}
}

/// <summary>
/// Add an element to the collection.
/// </summary>
public void Add(T value)
{
if (capacity == count) Expand();
oversized[count++] = value;
}

/// <summary>
/// Add elements to the collection.
/// </summary>
public void AddRange(ReadOnlySpan<T> values)
{
EnsureCapacity(count + values.Length);
values.CopyTo(new(oversized, count, values.Length));
count += values.Length;
}

private void EnsureCapacity(int minCapacity)
{
if (capacity < minCapacity)
{
var newBuffer = ArrayPool<T>.Shared.Rent(minCapacity);
Span.CopyTo(newBuffer);
var oldBuffer = oversized;
oversized = newBuffer;
capacity = newBuffer.Length;

if (oldBuffer is not null)
{
ArrayPool<T>.Shared.Return(oldBuffer);
}
}
}

[MethodImpl(MethodImplOptions.NoInlining)]
private void Expand() => EnsureCapacity(Math.Max(capacity * 2, 16));

/// <summary>
/// Release any resources associated with this instance.
/// </summary>
public void Clear()
{
count = 0;
if (capacity != 0)
{
capacity = 0;
ArrayPool<T>.Shared.Return(oversized);
oversized = [];
}
}

/// <summary>
/// Create an array with the elements associated with this instance, and release any resources.
/// </summary>
public T[] ToArrayAndClear()
{
T[] result = [.. Span]; // let the compiler worry about the per-platform implementation
Clear();
return result;
}

/// <summary>
/// Create an array with the elements associated with this instance, and release any resources.
/// </summary>
public List<T> ToListAndClear()
{
List<T> result = [.. Span]; // let the compiler worry about the per-platform implementation (net8+ in particular)
Clear();
return result;
}
}
2 changes: 2 additions & 0 deletions Dapper/Dapper.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@

<ItemGroup Condition=" '$(TargetFramework)' == 'net461'">
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" />
<PackageReference Include="System.Memory" />
</ItemGroup>

<ItemGroup Condition=" '$(TargetFramework)' == 'netstandard2.0'">
<PackageReference Include="Microsoft.Bcl.AsyncInterfaces" />
<PackageReference Include="System.Reflection.Emit.Lightweight" />
<PackageReference Include="System.Memory" />
</ItemGroup>
</Project>
6 changes: 3 additions & 3 deletions Dapper/DefaultTypeMap.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ internal static List<PropertyInfo> GetSettableProps(Type t)
return t
.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance)
.Where(p => GetPropertySetter(p, t) is not null)
.ToList();
.AsList();
}

internal static List<FieldInfo> GetSettableFields(Type t)
{
return t.GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance).ToList();
return t.GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance).AsList();
}

/// <summary>
Expand Down Expand Up @@ -115,7 +115,7 @@ internal static List<FieldInfo> GetSettableFields(Type t)
public ConstructorInfo? FindExplicitConstructor()
{
var constructors = _type.GetConstructors(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic);
var withAttr = constructors.Where(c => c.GetCustomAttributes(typeof(ExplicitConstructorAttribute), true).Length > 0).ToList();
var withAttr = constructors.Where(c => c.GetCustomAttributes(typeof(ExplicitConstructorAttribute), true).Length > 0).AsList();

if (withAttr.Count == 1)
{
Expand Down
19 changes: 18 additions & 1 deletion Dapper/PublicAPI.Shipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ static Dapper.SqlMapper.AddTypeHandlerImpl(System.Type! type, Dapper.SqlMapper.I
static Dapper.SqlMapper.AddTypeMap(System.Type! type, System.Data.DbType dbType) -> void
static Dapper.SqlMapper.AddTypeMap(System.Type! type, System.Data.DbType dbType, bool useGetFieldValue) -> void
static Dapper.SqlMapper.AsList<T>(this System.Collections.Generic.IEnumerable<T>? source) -> System.Collections.Generic.List<T>!
static Dapper.SqlMapper.AsListAsync<T>(this System.Collections.Generic.IAsyncEnumerable<T>? source, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<System.Collections.Generic.List<T>!>!
static Dapper.SqlMapper.AsTableValuedParameter(this System.Data.DataTable! table, string? typeName = null) -> Dapper.SqlMapper.ICustomQueryParameter!
static Dapper.SqlMapper.AsTableValuedParameter<T>(this System.Collections.Generic.IEnumerable<T>! list, string? typeName = null) -> Dapper.SqlMapper.ICustomQueryParameter!
static Dapper.SqlMapper.ConnectionStringComparer.get -> System.Collections.Generic.IEqualityComparer<string!>!
Expand Down Expand Up @@ -332,4 +333,20 @@ static Dapper.SqlMapper.ThrowDataException(System.Exception! ex, int index, Syst
static Dapper.SqlMapper.ThrowNullCustomQueryParameter(string! name) -> void
static Dapper.SqlMapper.TypeHandlerCache<T>.Parse(object! value) -> T?
static Dapper.SqlMapper.TypeHandlerCache<T>.SetValue(System.Data.IDbDataParameter! parameter, object! value) -> void
static Dapper.SqlMapper.TypeMapProvider -> System.Func<System.Type!, Dapper.SqlMapper.ITypeMap!>!
static Dapper.SqlMapper.TypeMapProvider -> System.Func<System.Type!, Dapper.SqlMapper.ITypeMap!>!

Dapper.Collector<T>
Dapper.Collector<T>.Collector() -> void
Dapper.Collector<T>.Collector(int capacityHint) -> void
Dapper.Collector<T>.Count.get -> int
Dapper.Collector<T>.Span.get -> System.Span<T>
Dapper.Collector<T>.ArraySegment.get -> System.ArraySegment<T>
Dapper.Collector<T>.Clear() -> void
Dapper.Collector<T>.Add(T value) -> void
Dapper.Collector<T>.AddRange(System.ReadOnlySpan<T> values) -> void
Dapper.Collector<T>.ToListAndClear() -> System.Collections.Generic.List<T>!
Dapper.Collector<T>.ToArrayAndClear() -> T[]!
Dapper.Collector<T>.this[int index].get -> T
override Dapper.Collector<T>.ToString() -> string!
override Dapper.Collector<T>.GetHashCode() -> int
override Dapper.Collector<T>.Equals(object? obj) -> bool
34 changes: 30 additions & 4 deletions Dapper/SqlMapper.Async.cs
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ private static async Task<IEnumerable<T>> QueryAsync<T>(this IDbConnection cnn,

if (command.Buffered)
{
var buffer = new List<T>();
var buffer = new Collector<T>();
var convertToType = Nullable.GetUnderlyingType(effectiveType) ?? effectiveType;
while (await reader.ReadAsync(cancel).ConfigureAwait(false))
{
Expand All @@ -456,7 +456,7 @@ private static async Task<IEnumerable<T>> QueryAsync<T>(this IDbConnection cnn,
}
while (await reader.NextResultAsync(cancel).ConfigureAwait(false)) { /* ignore subsequent result sets */ }
command.OnCompleted();
return buffer;
return buffer.ToListAndClear();
}
else
{
Expand Down Expand Up @@ -546,6 +546,32 @@ public static Task<int> ExecuteAsync(this IDbConnection cnn, CommandDefinition c
}
}

/// <summary>
/// Asynchronously collect a sequence of data into a list.
/// </summary>
/// <typeparam name="T">The type of element in the list.</typeparam>
/// <param name="source">The enumerable to return as a list.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> to observe while waiting for the task to complete.</param>
public static Task<List<T>> AsListAsync<T>(this IAsyncEnumerable<T>? source, CancellationToken cancellationToken = default)
{
if (source is null) return null!; // GIGO

return EnumerateAsync(source, cancellationToken);

static async Task<List<T>> EnumerateAsync(IAsyncEnumerable<T> source, CancellationToken cancellationToken)
{
var buffer = new Collector<T>(); // amortizes intermediate buffers
await using (var iterator = source.GetAsyncEnumerator(cancellationToken))
{
while (await iterator.MoveNextAsync().ConfigureAwait(false))
{
buffer.Add(iterator.Current);
}
}
return buffer.ToListAndClear();
}
}

private readonly struct AsyncExecState
{
public readonly DbCommand Command;
Expand Down Expand Up @@ -941,7 +967,7 @@ private static async Task<IEnumerable<TReturn>> MultiMapAsync<TFirst, TSecond, T
using var reader = await ExecuteReaderWithFlagsFallbackAsync(cmd, wasClosed, CommandBehavior.SequentialAccess | CommandBehavior.SingleResult, command.CancellationToken).ConfigureAwait(false);
if (!command.Buffered) wasClosed = false; // handing back open reader; rely on command-behavior
var results = MultiMapImpl<TFirst, TSecond, TThird, TFourth, TFifth, TSixth, TSeventh, TReturn>(null, CommandDefinition.ForCallback(command.Parameters, command.Flags), map, splitOn, reader, identity, true);
return command.Buffered ? results.ToList() : results;
return command.Buffered ? results.AsList() : results;
}
finally
{
Expand Down Expand Up @@ -989,7 +1015,7 @@ private static async Task<IEnumerable<TReturn>> MultiMapAsync<TReturn>(this IDbC
using var cmd = command.TrySetupAsyncCommand(cnn, info.ParamReader);
using var reader = await ExecuteReaderWithFlagsFallbackAsync(cmd, wasClosed, CommandBehavior.SequentialAccess | CommandBehavior.SingleResult, command.CancellationToken).ConfigureAwait(false);
var results = MultiMapImpl(null, default, types, map, splitOn, reader, identity, true);
return command.Buffered ? results.ToList() : results;
return command.Buffered ? results.AsList() : results;
}
finally
{
Expand Down
4 changes: 2 additions & 2 deletions Dapper/SqlMapper.GridReader.Async.cs
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ private async Task<IEnumerable<T>> ReadBufferedAsync<T>(int index, Func<DbDataRe
{
try
{
var buffer = new List<T>();
var buffer = new Collector<T>();
while (index == ResultIndex && await reader!.ReadAsync(cancel).ConfigureAwait(false))
{
buffer.Add(ConvertTo<T>(deserializer(reader)));
}
return buffer;
return buffer.ToListAndClear();
}
finally // finally so that First etc progresses things even when multiple rows
{
Expand Down
16 changes: 8 additions & 8 deletions Dapper/SqlMapper.GridReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ private IEnumerable<T> ReadImpl<T>(Type type, bool buffered)
cache.Deserializer = deserializer;
}
var result = ReadDeferred<T>(index, deserializer.Func, type);
return buffered ? result.ToList() : result;
return buffered ? result.AsList() : result;
}

private T ReadRow<T>(Type type, Row row)
Expand Down Expand Up @@ -283,7 +283,7 @@ private IEnumerable<TReturn> MultiReadInternal<TReturn>(Type[] types, Func<objec
public IEnumerable<TReturn> Read<TFirst, TSecond, TReturn>(Func<TFirst, TSecond, TReturn> func, string splitOn = "id", bool buffered = true)
{
var result = MultiReadInternal<TFirst, TSecond, DontMap, DontMap, DontMap, DontMap, DontMap, TReturn>(func, splitOn);
return buffered ? result.ToList() : result;
return buffered ? result.AsList() : result;
}

/// <summary>
Expand All @@ -299,7 +299,7 @@ public IEnumerable<TReturn> Read<TFirst, TSecond, TReturn>(Func<TFirst, TSecond,
public IEnumerable<TReturn> Read<TFirst, TSecond, TThird, TReturn>(Func<TFirst, TSecond, TThird, TReturn> func, string splitOn = "id", bool buffered = true)
{
var result = MultiReadInternal<TFirst, TSecond, TThird, DontMap, DontMap, DontMap, DontMap, TReturn>(func, splitOn);
return buffered ? result.ToList() : result;
return buffered ? result.AsList() : result;
}

/// <summary>
Expand All @@ -316,7 +316,7 @@ public IEnumerable<TReturn> Read<TFirst, TSecond, TThird, TReturn>(Func<TFirst,
public IEnumerable<TReturn> Read<TFirst, TSecond, TThird, TFourth, TReturn>(Func<TFirst, TSecond, TThird, TFourth, TReturn> func, string splitOn = "id", bool buffered = true)
{
var result = MultiReadInternal<TFirst, TSecond, TThird, TFourth, DontMap, DontMap, DontMap, TReturn>(func, splitOn);
return buffered ? result.ToList() : result;
return buffered ? result.AsList() : result;
}

/// <summary>
Expand All @@ -334,7 +334,7 @@ public IEnumerable<TReturn> Read<TFirst, TSecond, TThird, TFourth, TReturn>(Func
public IEnumerable<TReturn> Read<TFirst, TSecond, TThird, TFourth, TFifth, TReturn>(Func<TFirst, TSecond, TThird, TFourth, TFifth, TReturn> func, string splitOn = "id", bool buffered = true)
{
var result = MultiReadInternal<TFirst, TSecond, TThird, TFourth, TFifth, DontMap, DontMap, TReturn>(func, splitOn);
return buffered ? result.ToList() : result;
return buffered ? result.AsList() : result;
}

/// <summary>
Expand All @@ -353,7 +353,7 @@ public IEnumerable<TReturn> Read<TFirst, TSecond, TThird, TFourth, TFifth, TRetu
public IEnumerable<TReturn> Read<TFirst, TSecond, TThird, TFourth, TFifth, TSixth, TReturn>(Func<TFirst, TSecond, TThird, TFourth, TFifth, TSixth, TReturn> func, string splitOn = "id", bool buffered = true)
{
var result = MultiReadInternal<TFirst, TSecond, TThird, TFourth, TFifth, TSixth, DontMap, TReturn>(func, splitOn);
return buffered ? result.ToList() : result;
return buffered ? result.AsList() : result;
}

/// <summary>
Expand All @@ -373,7 +373,7 @@ public IEnumerable<TReturn> Read<TFirst, TSecond, TThird, TFourth, TFifth, TSixt
public IEnumerable<TReturn> Read<TFirst, TSecond, TThird, TFourth, TFifth, TSixth, TSeventh, TReturn>(Func<TFirst, TSecond, TThird, TFourth, TFifth, TSixth, TSeventh, TReturn> func, string splitOn = "id", bool buffered = true)
{
var result = MultiReadInternal<TFirst, TSecond, TThird, TFourth, TFifth, TSixth, TSeventh, TReturn>(func, splitOn);
return buffered ? result.ToList() : result;
return buffered ? result.AsList() : result;
}

/// <summary>
Expand All @@ -387,7 +387,7 @@ public IEnumerable<TReturn> Read<TFirst, TSecond, TThird, TFourth, TFifth, TSixt
public IEnumerable<TReturn> Read<TReturn>(Type[] types, Func<object[], TReturn> map, string splitOn = "id", bool buffered = true)
{
var result = MultiReadInternal(types, map, splitOn);
return buffered ? result.ToList() : result;
return buffered ? result.AsList() : result;
}

private IEnumerable<T> ReadDeferred<T>(int index, Func<DbDataReader, object> deserializer, Type effectiveType)
Expand Down
Loading