diff --git a/src/BlockArraysExtensions/blockedunitrange.jl b/src/BlockArraysExtensions/blockedunitrange.jl index 2e916729..35779057 100644 --- a/src/BlockArraysExtensions/blockedunitrange.jl +++ b/src/BlockArraysExtensions/blockedunitrange.jl @@ -29,10 +29,16 @@ axis(a::AbstractVector) = axes(a, 1) function eachblockaxis(a::AbstractVector) return map(axis, blocks(a)) end +function blockaxistype(a::AbstractVector) + return eltype(eachblockaxis(a)) +end # Take a collection of axes and mortar them # into a single blocked axis. function mortar_axis(axs) + return blockrange(axs) +end +function mortar_axis(axs::Vector{<:Base.OneTo{<:Integer}}) return blockedrange(length.(axs)) end diff --git a/src/abstractblocksparsearray/abstractblocksparsearray.jl b/src/abstractblocksparsearray/abstractblocksparsearray.jl index e90e97e0..4cb6b75a 100644 --- a/src/abstractblocksparsearray/abstractblocksparsearray.jl +++ b/src/abstractblocksparsearray/abstractblocksparsearray.jl @@ -19,12 +19,12 @@ end # Specialized in order to fix ambiguity error with `BlockArrays`. function Base.getindex(a::AbstractBlockSparseArray{<:Any,N}, I::Vararg{Int,N}) where {N} - return @interface BlockSparseArrayInterface() getindex(a, I...) + return @interface interface(a) getindex(a, I...) end # Specialized in order to fix ambiguity error with `BlockArrays`. function Base.getindex(a::AbstractBlockSparseArray{<:Any,0}) - return @interface BlockSparseArrayInterface() getindex(a) + return @interface interface(a) getindex(a) end ## # Fix ambiguity error with `BlockArrays`. @@ -53,7 +53,7 @@ end # Fix ambiguity error. function Base.setindex!(a::AbstractBlockSparseArray{<:Any,0}, value) - @interface BlockSparseArrayInterface() setindex!(a, value) + @interface interface(a) setindex!(a, value) return a end diff --git a/src/abstractblocksparsearray/broadcast.jl b/src/abstractblocksparsearray/broadcast.jl index 3527831c..1a9eafc1 100644 --- a/src/abstractblocksparsearray/broadcast.jl +++ b/src/abstractblocksparsearray/broadcast.jl @@ -2,7 +2,7 @@ using BlockArrays: AbstractBlockedUnitRange, BlockSlice using Base.Broadcast: Broadcast function Broadcast.BroadcastStyle(arraytype::Type{<:AnyAbstractBlockSparseArray}) - return BlockSparseArrayStyle{ndims(arraytype)}() + return BlockSparseArrayStyle(BroadcastStyle(blocktype(arraytype))) end # Fix ambiguity error with `BlockArrays`. diff --git a/src/abstractblocksparsearray/map.jl b/src/abstractblocksparsearray/map.jl index 4dcec66f..494503ac 100644 --- a/src/abstractblocksparsearray/map.jl +++ b/src/abstractblocksparsearray/map.jl @@ -111,9 +111,13 @@ function Base.isreal(a::AnyAbstractBlockSparseArray) return @interface interface(a) isreal(a) end +# Helps with specialization. function Base.:*(x::Number, a::AnyAbstractBlockSparseArray) return map(Base.Fix1(*, x), a) end function Base.:*(a::AnyAbstractBlockSparseArray, x::Number) return map(Base.Fix2(*, x), a) end +function Base.:/(a::AnyAbstractBlockSparseArray, x::Number) + return map(Base.Fix2(/, x), a) +end diff --git a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl index b72f1e41..73c709e2 100644 --- a/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl +++ b/src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl @@ -29,8 +29,8 @@ const AnyAbstractBlockSparseVecOrMat{T,N} = Union{ AnyAbstractBlockSparseVector{T},AnyAbstractBlockSparseMatrix{T} } -function DerivableInterfaces.interface(::Type{<:AnyAbstractBlockSparseArray}) - return BlockSparseArrayInterface() +function DerivableInterfaces.interface(arrayt::Type{<:AnyAbstractBlockSparseArray}) + return BlockSparseArrayInterface(interface(blocktype(arrayt))) end # a[1:2, 1:2] @@ -88,7 +88,7 @@ end # BlockArrays `AbstractBlockArray` interface function BlockArrays.blocks(a::AnyAbstractBlockSparseArray) - @interface BlockSparseArrayInterface() blocks(a) + @interface interface(a) blocks(a) end # Fix ambiguity error with `BlockArrays` @@ -230,10 +230,19 @@ function Base.similar( return similar(arraytype, eltype(arraytype), axes) end +# This circumvents some issues with `TypeParameterAccessors.similartype`. +# TODO: Fix this poperly in `TypeParameterAccessors.jl`. +function _similartype(arrayt::Type{<:AbstractArray}, elt::Type, axt::Type{<:Tuple}) + return Base.promote_op(similar, arrayt, elt, axt) +end +function _similartype(arrayt::Type{<:AbstractArray}, axt::Type{<:Tuple}) + return Base.promote_op(similar, arrayt, axt) +end + function blocksparse_similar(a, elt::Type, axes::Tuple) - return BlockSparseArray{elt,length(axes),similartype(blocktype(a), elt, axes)}( - undef, axes - ) + block_axt = Tuple{blockaxistype.(axes)...} + blockt = _similartype(blocktype(a), Type{elt}, block_axt) + return BlockSparseArray{elt,length(axes),blockt}(undef, axes) end @interface ::AbstractBlockSparseArrayInterface function Base.similar( a::AbstractArray, elt::Type, axes::Tuple{Vararg{Int}} @@ -275,7 +284,7 @@ function Base.similar( elt::Type, axes::Tuple{Vararg{AbstractUnitRange{<:Integer}}}, ) - return @interface BlockSparseArrayInterface() similar(arraytype, elt, axes) + return @interface interface(arraytype) similar(arraytype, elt, axes) end # TODO: Define a `@interface BlockSparseArrayInterface() similar` function. @@ -302,8 +311,7 @@ function Base.similar( AbstractBlockedUnitRange{<:Integer},Vararg{AbstractBlockedUnitRange{<:Integer}} }, ) - # TODO: Use `@interface interface(a) similar(...)`. - return @interface BlockSparseArrayInterface() similar(a, elt, axes) + return @interface interface(a) similar(a, elt, axes) end # Fixes ambiguity error with `OffsetArrays`. @@ -312,8 +320,7 @@ function Base.similar( elt::Type, axes::Tuple{AbstractUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, ) - # TODO: Use `@interface interface(a) similar(...)`. - return @interface BlockSparseArrayInterface() similar(a, elt, axes) + return @interface interface(a) similar(a, elt, axes) end # Fixes ambiguity error with `BlockArrays`. @@ -322,8 +329,7 @@ function Base.similar( elt::Type, axes::Tuple{AbstractBlockedUnitRange{<:Integer},Vararg{AbstractUnitRange{<:Integer}}}, ) - # TODO: Use `@interface interface(a) similar(...)`. - return @interface BlockSparseArrayInterface() similar(a, elt, axes) + return @interface interface(a) similar(a, elt, axes) end # Fixes ambiguity errors with BlockArrays. @@ -336,16 +342,14 @@ function Base.similar( Vararg{AbstractUnitRange{<:Integer}}, }, ) - # TODO: Use `@interface interface(a) similar(...)`. - return @interface BlockSparseArrayInterface() similar(a, elt, axes) + return @interface interface(a) similar(a, elt, axes) end # Fixes ambiguity error with `StaticArrays`. function Base.similar( a::AnyAbstractBlockSparseArray, elt::Type, axes::Tuple{Base.OneTo,Vararg{Base.OneTo}} ) - # TODO: Use `@interface interface(a) similar(...)`. - return @interface BlockSparseArrayInterface() similar(a, elt, axes) + return @interface interface(a) similar(a, elt, axes) end struct BlockType{T} end diff --git a/src/blocksparsearray/blocksparsearray.jl b/src/blocksparsearray/blocksparsearray.jl index d92580fe..75a5d304 100644 --- a/src/blocksparsearray/blocksparsearray.jl +++ b/src/blocksparsearray/blocksparsearray.jl @@ -173,7 +173,9 @@ end function BlockSparseArray{T,N}( ::UndefInitializer, axes::Tuple{Vararg{AbstractUnitRange{<:Integer},N}} ) where {T,N} - return BlockSparseArray{T,N,Array{T,N}}(undef, axes) + axt = Tuple{blockaxistype.(axes)...} + A = _similartype(Array{T}, axt) + return BlockSparseArray{T,N,A}(undef, axes) end function BlockSparseArray{T,N}( @@ -230,6 +232,20 @@ function BlockSparseArray{T}( return BlockSparseArray{T}(undef, axes) end +function blocksparsezeros(elt::Type, axes...) + return BlockSparseArray{elt}(undef, axes...) +end +function blocksparsezeros(::BlockType{A}, axes...) where {A<:AbstractArray} + return BlockSparseArray{eltype(A),ndims(A),A}(undef, axes...) +end +function blocksparse(d::Dict{<:Block,<:AbstractArray}, axes...) + a = blocksparsezeros(BlockType(valtype(d)), axes...) + for I in eachindex(d) + a[I] = d[I] + end + return a +end + # Base `AbstractArray` interface Base.axes(a::BlockSparseArray) = a.axes @@ -238,6 +254,10 @@ Base.axes(a::BlockSparseArray) = a.axes @interface ::AbstractBlockSparseArrayInterface BlockArrays.blocks(a::BlockSparseArray) = a.blocks +function blocktype(arraytype::Type{<:BlockSparseArray{<:Any,<:Any,A}}) where {A} + return A +end + # TODO: Use `TypeParameterAccessors`. function blockstype( arraytype::Type{<:BlockSparseArray{T,N,A,Blocks}} diff --git a/src/blocksparsearrayinterface/arraylayouts.jl b/src/blocksparsearrayinterface/arraylayouts.jl index f1e70c91..a8ac596c 100644 --- a/src/blocksparsearrayinterface/arraylayouts.jl +++ b/src/blocksparsearrayinterface/arraylayouts.jl @@ -1,6 +1,6 @@ using ArrayLayouts: ArrayLayouts, Dot, MatMulMatAdd, MatMulVecAdd, MulAdd using BlockArrays: BlockArrays, BlockLayout, muladd! -using DerivableInterfaces: @interface +using DerivableInterfaces: DerivableInterfaces, @interface, interface using SparseArraysBase: SparseLayout using LinearAlgebra: LinearAlgebra, dot, mul! @@ -11,6 +11,10 @@ using LinearAlgebra: LinearAlgebra, dot, mul! return a_dest end +function DerivableInterfaces.interface(m::MulAdd) + return interface(m.A, m.B, m.C) +end + function ArrayLayouts.materialize!( m::MatMulMatAdd{ <:BlockLayout{<:SparseLayout}, @@ -18,8 +22,7 @@ function ArrayLayouts.materialize!( <:BlockLayout{<:SparseLayout}, }, ) - α, a1, a2, β, a_dest = m.α, m.A, m.B, m.β, m.C - @interface BlockSparseArrayInterface() muladd!(m.α, m.A, m.B, m.β, m.C) + @interface interface(m) muladd!(m.α, m.A, m.B, m.β, m.C) return m.C end function ArrayLayouts.materialize!( @@ -29,7 +32,7 @@ function ArrayLayouts.materialize!( <:BlockLayout{<:SparseLayout}, }, ) - @interface BlockSparseArrayInterface() matmul!(m) + @interface interface(m) matmul!(m) return m.C end @@ -42,5 +45,5 @@ end end function Base.copy(d::Dot{<:BlockLayout{<:SparseLayout},<:BlockLayout{<:SparseLayout}}) - return @interface BlockSparseArrayInterface() dot(d.A, d.B) + return @interface interface(d.A, d.B) dot(d.A, d.B) end diff --git a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl index 5c1ed5be..1356f949 100644 --- a/src/blocksparsearrayinterface/blocksparsearrayinterface.jl +++ b/src/blocksparsearrayinterface/blocksparsearrayinterface.jl @@ -16,7 +16,8 @@ using BlockArrays: blocklength, blocks, findblockindex -using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface, zero! +using DerivableInterfaces: + DerivableInterfaces, @interface, AbstractArrayInterface, DefaultArrayInterface, zero! using LinearAlgebra: Adjoint, Transpose using SparseArraysBase: AbstractSparseArrayInterface, @@ -104,13 +105,29 @@ blocktype(a::BlockArray) = eltype(blocks(a)) abstract type AbstractBlockSparseArrayInterface{N} <: AbstractSparseArrayInterface{N} end # TODO: Also support specifying the `blocktype` along with the `eltype`. -function Base.similar(::AbstractBlockSparseArrayInterface, T::Type, ax::Tuple) - return similar(BlockSparseArray{T}, ax) +function Base.similar(interface::AbstractBlockSparseArrayInterface, T::Type, ax::Tuple) + N = length(ax) + block_axt = Tuple{blockaxistype.(ax)...} + B = _similartype(blockinterface(interface), Type{T}, block_axt) + return similar(BlockSparseArray{T,N,B}, ax) end -struct BlockSparseArrayInterface{N} <: AbstractBlockSparseArrayInterface{N} end +struct BlockSparseArrayInterface{N,B<:AbstractArrayInterface{N}} <: + AbstractBlockSparseArrayInterface{N} + blockinterface::B +end +blockinterface(interface::BlockSparseArrayInterface) = getfield(interface, :blockinterface) +function BlockSparseArrayInterface{N}(blockinterface::AbstractArrayInterface{N}) where {N} + return BlockSparseArrayInterface{N,typeof(blockinterface)}(blockinterface) +end +function BlockSparseArrayInterface{N}() where {N} + return BlockSparseArrayInterface{N}(DefaultArrayInterface{N}()) +end BlockSparseArrayInterface(::Val{N}) where {N} = BlockSparseArrayInterface{N}() BlockSparseArrayInterface{M}(::Val{N}) where {M,N} = BlockSparseArrayInterface{N}() +function BlockSparseArrayInterface{M,B}(::Val{N}) where {M,B<:AbstractArrayInterface{M},N} + return BlockSparseArrayInterface{N,B}(B(Val(N))) +end BlockSparseArrayInterface() = BlockSparseArrayInterface{Any}() @interface ::AbstractBlockSparseArrayInterface function BlockArrays.blocks(a::AbstractArray) diff --git a/src/blocksparsearrayinterface/broadcast.jl b/src/blocksparsearrayinterface/broadcast.jl index d8ab5ec8..f3cad907 100644 --- a/src/blocksparsearrayinterface/broadcast.jl +++ b/src/blocksparsearrayinterface/broadcast.jl @@ -3,19 +3,28 @@ using GPUArraysCore: @allowscalar using MapBroadcast: Mapped using DerivableInterfaces: DerivableInterfaces, @interface -abstract type AbstractBlockSparseArrayStyle{N} <: AbstractArrayStyle{N} end +abstract type AbstractBlockSparseArrayStyle{N,B} <: AbstractArrayStyle{N} end -function DerivableInterfaces.interface(::Type{<:AbstractBlockSparseArrayStyle}) - return BlockSparseArrayInterface() +function DerivableInterfaces.interface( + ::Type{<:AbstractBlockSparseArrayStyle{N,B}} +) where {N,B} + return BlockSparseArrayInterface(interface(B)) end -struct BlockSparseArrayStyle{N} <: AbstractBlockSparseArrayStyle{N} end +struct BlockSparseArrayStyle{N,B<:AbstractArrayStyle{N}} <: + AbstractBlockSparseArrayStyle{N,B} + blockstyle::B +end +function BlockSparseArrayStyle{N}(blockstyle::AbstractArrayStyle{N}) where {N} + return BlockSparseArrayStyle{N,typeof(blockstyle)}(blockstyle) +end # Define for new sparse array types. # function Broadcast.BroadcastStyle(arraytype::Type{<:MyBlockSparseArray}) # return BlockSparseArrayStyle{ndims(arraytype)}() # end +BlockSparseArrayStyle{N}() where {N} = BlockSparseArrayStyle{N}(DefaultArrayStyle{N}()) BlockSparseArrayStyle(::Val{N}) where {N} = BlockSparseArrayStyle{N}() BlockSparseArrayStyle{M}(::Val{N}) where {M,N} = BlockSparseArrayStyle{N}() diff --git a/src/factorizations/svd.jl b/src/factorizations/svd.jl index 1f8f4a42..c4871099 100644 --- a/src/factorizations/svd.jl +++ b/src/factorizations/svd.jl @@ -20,14 +20,20 @@ function MatrixAlgebraKit.default_svd_algorithm( return BlockPermutedDiagonalAlgorithm(alg) end +# TODO: Put this in a common location or package, +# maybe `TypeParameterAccessors.jl`? +# Also define `imagtype`, `complextype`, etc. +realtype(a::AbstractArray) = realtype(typeof(a)) +function realtype(A::Type{<:AbstractArray}) + return Base.promote_op(real, A) +end + +using DiagonalArrays: diagonaltype function similar_output( ::typeof(svd_compact!), A, S_axes, alg::MatrixAlgebraKit.AbstractAlgorithm ) U = similar(A, axes(A, 1), S_axes[1]) - T = real(eltype(A)) - # TODO: this should be replaced with a more general similar function that can handle setting - # the blocktype and element type - something like S = similar(A, BlockType(...)) - S = BlockSparseMatrix{T,Diagonal{T,Vector{T}}}(undef, S_axes) + S = similar(A, BlockType(diagonaltype(realtype(blocktype(A)))), S_axes) Vt = similar(A, S_axes[2], axes(A, 2)) return U, S, Vt end @@ -49,9 +55,9 @@ function MatrixAlgebraKit.initialize_output( bcolIs = Int.(last.(Tuple.(bIs))) for bI in eachblockstoredindex(A) row, col = Int.(Tuple(bI)) - len = minimum(length, (brows[row], bcols[col])) - u_axes[col] = brows[row][Base.OneTo(len)] - v_axes[col] = bcols[col][Base.OneTo(len)] + b = argmin(length, (brows[row], bcols[col])) + u_axes[col] = b + v_axes[col] = b end # fill in values for blocks that aren't present, pairing them in order of occurence