Skip to content

Conversation

christiangnrd
Copy link
Member

@christiangnrd christiangnrd commented Mar 17, 2025

This is ready for review once #565 is merged.

This adds the basic functionality for future extension of MPSGraph framework, and what's needed to replace the MPS matrix multiplication to work around #381.

Support for contiguous views (MtlArrays with an offset) should be added in a future PR, for now it will use the old method

Batched matmul is at the same level of support as the current code, but I think I can use broadcastTensor to enable batched matmul between 1 and N (or N and 1) matrices. Would potentially unblock FluxML/NNlib.jl#614, but that can also be a future PR.

@codecov
Copy link

codecov bot commented Mar 17, 2025

Codecov Report

Attention: Patch coverage is 60.98901% with 71 lines in your changes missing coverage. Please review.

Project coverage is 80.47%. Comparing base (369e292) to head (faa8b6a).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
lib/mpsgraphs/tensor.jl 27.41% 45 Missing ⚠️
lib/mpsgraphs/execution.jl 23.07% 10 Missing ⚠️
lib/mpsgraphs/operations.jl 70.00% 9 Missing ⚠️
lib/mpsgraphs/matmul.jl 86.00% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #566      +/-   ##
==========================================
- Coverage   81.97%   80.47%   -1.50%     
==========================================
  Files          54       61       +7     
  Lines        2485     2658     +173     
==========================================
+ Hits         2037     2139     +102     
- Misses        448      519      +71     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@christiangnrd christiangnrd marked this pull request as ready for review March 17, 2025 03:55
@github-actions
Copy link
Contributor

github-actions bot commented Mar 17, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/examples/flopscomp.jl b/examples/flopscomp.jl
index 4436cedd..e8ca4fed 100644
--- a/examples/flopscomp.jl
+++ b/examples/flopscomp.jl
@@ -8,14 +8,14 @@ testing = (@isdefined TESTING) && TESTING
     using Plots.Measures
 end
 
-const Ts=[
-        (Int8, Float16),
-        (Int8, Float32),
-        (Int16, Float32),
-        (Float16, Float16),
-        (Float16, Float32),
-        (Float32, Float32),
-    ]
+const Ts = [
+    (Int8, Float16),
+    (Int8, Float32),
+    (Int16, Float32),
+    (Float16, Float16),
+    (Float16, Float32),
+    (Float32, Float32),
+]
 
 n_gpu_cores = "??"
 # Comment this out if scary. Please mention number of cores in your comment when uploading the figure
@@ -24,17 +24,19 @@ n_gpu_cores = only(match(r"Total Number of Cores:\s*(\d+)", system_prof).capture
 
 PLOT_TITLE = "Matmul peakflops for $(device().name) ($n_gpu_cores GPU cores)"
 
-function cpupeakflops(; n::Integer=4096,
-                        n_batch::Integer=1,
-                        inT::DataType=Float32,
-                        outT::DataType=inT,
-                        ntrials::Integer=4,
-                        verify=true)
+function cpupeakflops(;
+        n::Integer = 4096,
+        n_batch::Integer = 1,
+        inT::DataType = Float32,
+        outT::DataType = inT,
+        ntrials::Integer = 4,
+        verify = true
+    )
     t = Base.zeros(Float64, ntrials)
     n_batch == 1 || @warn "n_batch > 1 not supported for `mul!`, running with n_batch=1"
     n_batch = 1
     shape = (n, n)
-    for i=1:ntrials
+    for i in 1:ntrials
         c = zeros(outT, shape...)
         a = ones(inT, shape...)
         b = ones(inT, shape...)
@@ -42,12 +44,12 @@ function cpupeakflops(; n::Integer=4096,
         verify && @assert only(unique(Array(c))) == n
     end
 
-    return n_batch*2*Float64(n)^3 / minimum(t)
+    return n_batch * 2 * Float64(n)^3 / minimum(t)
 end
-function _peakflops(f, n, n_batch, inT, outT, ntrials; verify=true)
+function _peakflops(f, n, n_batch, inT, outT, ntrials; verify = true)
     t = Base.zeros(Float64, ntrials)
     shape = n_batch == 1 ? (n, n) : (n, n, n_batch)
-    for i=1:ntrials
+    for i in 1:ntrials
         c = mtl(zeros(outT, shape...))
         a = mtl(ones(inT, shape...))
         b = mtl(ones(inT, shape...))
@@ -55,34 +57,40 @@ function _peakflops(f, n, n_batch, inT, outT, ntrials; verify=true)
         verify && @assert only(unique(Array(c))) == n
     end
 
-    return n_batch*2*Float64(n)^3 / minimum(t)
+    return n_batch * 2 * Float64(n)^3 / minimum(t)
 end
-function gpuarrpeakflops(; n::Integer=4096,
-                           n_batch::Integer=1,
-                           inT::DataType=Float32,
-                           outT::DataType=inT,
-                           ntrials::Integer=3,
-                           verify=true)
+function gpuarrpeakflops(;
+        n::Integer = 4096,
+        n_batch::Integer = 1,
+        inT::DataType = Float32,
+        outT::DataType = inT,
+        ntrials::Integer = 3,
+        verify = true
+    )
     n_batch == 1 || @warn "n_batch > 1 not supported for `GPUArrays.generic_matmatmul!`, running with n_batch=1"
-    _peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b
+    return _peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b
         GPUArrays.generic_matmatmul!(c, LinearAlgebra.wrap(a, 'N'), LinearAlgebra.wrap(b, 'N'), 1, 0)
     end
 end
-function mpspeakflops(; n::Integer=4096,
-                        n_batch::Integer=1,
-                        inT::DataType=Float32,
-                        outT::DataType=inT,
-                        ntrials::Integer=3,
-                        verify=true)
-    _peakflops(MPS.matmul!, n, n_batch, inT, outT, ntrials; verify)
+function mpspeakflops(;
+        n::Integer = 4096,
+        n_batch::Integer = 1,
+        inT::DataType = Float32,
+        outT::DataType = inT,
+        ntrials::Integer = 3,
+        verify = true
+    )
+    return _peakflops(MPS.matmul!, n, n_batch, inT, outT, ntrials; verify)
 end
-function graphpeakflops(; n::Integer=4096,
-                          n_batch::Integer=1,
-                          inT::DataType=Float32,
-                          outT::DataType=inT,
-                          ntrials::Integer=3,
-                          verify=true)
-    _peakflops(MPSGraphs.graph_matmul!, n, n_batch, inT, outT, ntrials; verify)
+function graphpeakflops(;
+        n::Integer = 4096,
+        n_batch::Integer = 1,
+        inT::DataType = Float32,
+        outT::DataType = inT,
+        ntrials::Integer = 3,
+        verify = true
+    )
+    return _peakflops(MPSGraphs.graph_matmul!, n, n_batch, inT, outT, ntrials; verify)
 end
 function anepeakflops(; kwargs...)
     # VERY HACKY
@@ -99,13 +107,13 @@ function anepeakflops(; kwargs...)
     return res
 end
 
-function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
+function compare(Ns, Fs, inT, outT = inT; n_batch = 1, ntrials)
     results = Dict()
 
     newFs = if (outT == Float16 || (outT == Float32 && inT == Float16))
         Fs
     else
-        filter(x -> !occursin("ANE", x[2]),Fs)
+        filter(x -> !occursin("ANE", x[2]), Fs)
     end
 
     for (_, info_str) in newFs
@@ -128,34 +136,36 @@ function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
     return results
 end
 
-function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
-                Fs=[
-                    (mpspeakflops, "MPS"),
-                    (graphpeakflops, "MPSGraph"),
-                    (anepeakflops, "MPSGraph (ANE)"),
-                    # (gpuarrpeakflops, "GPUArrays"),
-                    # (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
-                   ],
-                n_batch=1,
-                ntrials=5)
+function runcomparison(;
+        Ns = [50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192], #, 10000],
+        Fs = [
+            (mpspeakflops, "MPS"),
+            (graphpeakflops, "MPSGraph"),
+            (anepeakflops, "MPSGraph (ANE)"),
+            # (gpuarrpeakflops, "GPUArrays"),
+            # (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
+        ],
+        n_batch = 1,
+        ntrials = 5
+    )
     res = Dict()
 
     for (inT, outT) in Ts
-        res[(inT,outT)] = (n_batch, Ns, compare(Ns, Fs, inT, outT; n_batch, ntrials))
+        res[(inT, outT)] = (n_batch, Ns, compare(Ns, Fs, inT, outT; n_batch, ntrials))
     end
     return res
 end
 
-function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE)
-    ylim_upper = 9e12
+function plot_results(res, Fs = ["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath = nothing, outtype = "svg", plt_title = PLOT_TITLE)
+    ylim_upper = 9.0e12
     resplts = []
 
     n_batches = []
 
     for (inT, outT) in Ts
-        n_batch, Ns, tmpres = res[(inT,outT)]
+        n_batch, Ns, tmpres = res[(inT, outT)]
 
-        plt = plot(xlabel="N, n_batch=$(n_batch)", legendtitle="($inT, $outT)")
+        plt = plot(xlabel = "N, n_batch=$(n_batch)", legendtitle = "($inT, $outT)")
         for info_str in Fs
             haskey(tmpres, info_str) || continue
 
@@ -164,24 +174,26 @@ function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=not
             if maximum(flops) > ylim_upper
                 ylim_upper = maximum(flops) * 1.02
             end
-            plot!(plt, Ns, tmpres[info_str]; linewidth=1.5, label="$(peakf) peak: $info_str")
+            plot!(plt, Ns, tmpres[info_str]; linewidth = 1.5, label = "$(peakf) peak: $info_str")
         end
         push!(resplts, plt)
         push!(n_batches, n_batch)
     end
 
-    finalplot = plot(resplts...; layout=(2,3),
-                     ylim=(0,ylim_upper),
-                     plot_title=plt_title,
-                     tickfonthalign=:left,
-                     bottommargin=15pt,
-                     size=(2000,1200))
+    finalplot = plot(
+        resplts...; layout = (2, 3),
+        ylim = (0, ylim_upper),
+        plot_title = plt_title,
+        tickfonthalign = :left,
+        bottommargin = 15pt,
+        size = (2000, 1200)
+    )
     if !isnothing(outpath)
-        savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$outtype"))
+        savefig(plot(finalplot, dpi = 500), joinpath(outpath, "bench_all_$(first(n_batches)).$outtype"))
     end
     return finalplot
 end
 
 if testing
-    runcomparison(Ns=[50, 64, 100, 128, 250, 256, 500, 512])
+    runcomparison(Ns = [50, 64, 100, 128, 250, 256, 500, 512])
 end
diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl
index 0ef1f2d0..f6efd2de 100644
--- a/lib/mps/MPS.jl
+++ b/lib/mps/MPS.jl
@@ -21,7 +21,7 @@ using BFloat16s
 const MtlFloat = Union{Float32, Float16}
 
 const MPSShape = NSArray#{NSNumber}
-Base.convert(::Type{MPSShape}, tuple::Union{Vector{T},NTuple{T, <:Integer}}) where T = NSArray(NSNumber.(collect(tuple)))
+Base.convert(::Type{MPSShape}, tuple::Union{Vector{T}, NTuple{T, <:Integer}}) where {T} = NSArray(NSNumber.(collect(tuple)))
 
 # Valid combination of input (A and B matrices) and output (C) types
 const MPS_VALID_MATMUL_TYPES =
diff --git a/lib/mpsgraphs/MPSGraphs.jl b/lib/mpsgraphs/MPSGraphs.jl
index 2a06ae14..c6e0d373 100644
--- a/lib/mpsgraphs/MPSGraphs.jl
+++ b/lib/mpsgraphs/MPSGraphs.jl
@@ -20,23 +20,23 @@ using ObjectiveC, .Foundation, .Dispatch
 #   The commented type combinations work but are slower than with MPSMatrixMultiplicatiom
 const MPSGRAPH_VALID_MATMUL_TYPES =
     [
-     (Int8, Float16),
-     (Int8, Float32),
-     (Int16, Float32),
-     (Float16, Float16),
-     (Float16, Float32),
-     (Float32, Float32),
-    ]
+    (Int8, Float16),
+    (Int8, Float32),
+    (Int16, Float32),
+    (Float16, Float16),
+    (Float16, Float32),
+    (Float32, Float32),
+]
 
 const MPSGRAPH_VALID_MATVECMUL_TYPES =
     [
-     (Int8, Float16),
-     (Int8, Float32),
-     (Int16, Float32),
-     (Float16, Float16),
-     (Float16, Float32),
-     (Float32, Float32),
-    ]
+    (Int8, Float16),
+    (Int8, Float32),
+    (Int16, Float32),
+    (Float16, Float16),
+    (Float16, Float32),
+    (Float32, Float32),
+]
 
 include("libmpsgraph.jl")
 
diff --git a/lib/mpsgraphs/core.jl b/lib/mpsgraphs/core.jl
index 2f2e868e..f71051d6 100644
--- a/lib/mpsgraphs/core.jl
+++ b/lib/mpsgraphs/core.jl
@@ -6,7 +6,7 @@
 
 # @objcwrapper MPSGraph <: MPSGraphObject
 function MPSGraph()
-    MPSGraph(@objc [MPSGraph new]::id{MPSGraph})
+    return MPSGraph(@objc [MPSGraph new]::id{MPSGraph})
 end
 
 # @objcwrapper immutable=true MPSGraphShapedType <: MPSGraphType
@@ -26,17 +26,17 @@ end
 
 function MPSGraphDevice(device::MTLDevice)
     obj = @objc [MPSGraphDevice deviceWithMTLDevice:device::id{MTLDevice}]::id{MPSGraphDevice}
-    MPSGraphDevice(obj)
+    return MPSGraphDevice(obj)
 end
 
 # @objcwrapper MPSGraphExecutionDescriptor <: MPSGraphObject
 
 function MPSGraphExecutionDescriptor()
-    MPSGraphExecutionDescriptor(@objc [MPSGraphExecutionDescriptor new]::id{MPSGraphExecutionDescriptor})
+    return MPSGraphExecutionDescriptor(@objc [MPSGraphExecutionDescriptor new]::id{MPSGraphExecutionDescriptor})
 end
 
 # @objcwrapper MPSGraphCompilationDescriptor <: MPSGraphObject
 
 function MPSGraphCompilationDescriptor()
-    MPSGraphCompilationDescriptor(@objc [MPSGraphCompilationDescriptor new]::id{MPSGraphCompilationDescriptor})
+    return MPSGraphCompilationDescriptor(@objc [MPSGraphCompilationDescriptor new]::id{MPSGraphCompilationDescriptor})
 end
diff --git a/lib/mpsgraphs/execution.jl b/lib/mpsgraphs/execution.jl
index 9eaf82df..791e2cc7 100644
--- a/lib/mpsgraphs/execution.jl
+++ b/lib/mpsgraphs/execution.jl
@@ -1,34 +1,42 @@
 
 MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, resultsDictionary::MPSGraphTensorDataDictionary) = @inline MPS.encode!(commandBuffer, graph, feeds, nil, resultsDictionary, MPSGraphExecutionDescriptor())
 function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetOperations, resultsDictionary::MPSGraphTensorDataDictionary, executionDescriptor)
-    @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
-                                                  feeds:feeds::id{MPSGraphTensorDataDictionary}
-                                       targetOperations:targetOperations::id{Object}
-                                      resultsDictionary:resultsDictionary::id{MPSGraphTensorDataDictionary}
-                                    executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::Nothing
+    @objc [
+        graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
+        feeds:feeds::id{MPSGraphTensorDataDictionary}
+        targetOperations:targetOperations::id{Object}
+        resultsDictionary:resultsDictionary::id{MPSGraphTensorDataDictionary}
+        executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}
+    ]::Nothing
     return resultsDictionary
 end
 
-function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil, executionDescriptor=MPSGraphExecutionDescriptor())
-    obj = @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
-                                                        feeds:feeds::id{MPSGraphTensorDataDictionary}
-                                                targetTensors:targetTensors::id{NSArray}
-                                             targetOperations:targetOperations::id{Object}
-                                          executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::id{MPSGraphTensorDataDictionary}
-    MPSGraphTensorDataDictionary(obj)
+function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations = nil, executionDescriptor = MPSGraphExecutionDescriptor())
+    obj = @objc [
+        graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
+        feeds:feeds::id{MPSGraphTensorDataDictionary}
+        targetTensors:targetTensors::id{NSArray}
+        targetOperations:targetOperations::id{Object}
+        executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}
+    ]::id{MPSGraphTensorDataDictionary}
+    return MPSGraphTensorDataDictionary(obj)
 end
 
-function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil)
-    obj = @objc [graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary}
-                                            targetTensors:targetTensors::id{NSArray}
-                                         targetOperations:targetOperations::id{Object}]::id{MPSGraphTensorDataDictionary}
-    MPSGraphTensorDataDictionary(obj)
+function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations = nil)
+    obj = @objc [
+        graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary}
+        targetTensors:targetTensors::id{NSArray}
+        targetOperations:targetOperations::id{Object}
+    ]::id{MPSGraphTensorDataDictionary}
+    return MPSGraphTensorDataDictionary(obj)
 end
 
 function run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray)
-    obj = @objc [graph::id{MPSGraph} runWithMTLCommandQueue:commandQueue::id{MTLCommandQueue}
-                                                    feeds:feeds::id{MPSGraphTensorDataDictionary}
-                                            targetTensors:targetTensors::id{NSArray}
-                                         targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary}
-    MPSGraphTensorDataDictionary(obj)
+    obj = @objc [
+        graph::id{MPSGraph} runWithMTLCommandQueue:commandQueue::id{MTLCommandQueue}
+        feeds:feeds::id{MPSGraphTensorDataDictionary}
+        targetTensors:targetTensors::id{NSArray}
+        targetOperations:nil::id{Object}
+    ]::id{MPSGraphTensorDataDictionary}
+    return MPSGraphTensorDataDictionary(obj)
 end
diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl
index 0cd3fe08..ef5c3e30 100644
--- a/lib/mpsgraphs/matmul.jl
+++ b/lib/mpsgraphs/matmul.jl
@@ -30,9 +30,11 @@ else
 end
 
 
-@autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb},
-                                   alpha::Number, beta::Number,
-                                   transpose_a, transpose_b) where {Tc, Tab, Na, Nb}
+@autoreleasepool function _matmul!(
+        c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb},
+        alpha::Number, beta::Number,
+        transpose_a, transpose_b
+    ) where {Tc, Tab, Na, Nb}
     graph = MPSGraph()
 
     placeA = placeholderTensor(graph, size(a), Tab)
@@ -50,8 +52,8 @@ end
     castA = castTensor(graph, placeA, castT, "castA")
     castB = castTensor(graph, placeB, castT, "castB")
 
-    transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA
-    transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB
+    transA = transpose_a ? transposeTensor(graph, castA, Na - 2, Na - 1, "transpose_a") : castA
+    transB = transpose_b ? transposeTensor(graph, castB, Nb - 2, Nb - 1, "transpose_b") : castB
 
     nBatchA = Na == 2 ? 1 : size(transA)[1]
     nBatchB = Nb == 2 ? 1 : size(transB)[1]
@@ -94,9 +96,9 @@ end
 end
 
 function graph_matmul!(c::MtlArray{Tc, N}, a::MtlArray{Tab, N}, b::MtlArray{Tab, N}, alpha::Number = true, beta::Number = false, transpose_a = false, transpose_b = false) where {Tc, Tab, N}
-    _matmul!(c, a, b, alpha, beta, transpose_a, transpose_b)
+    return _matmul!(c, a, b, alpha, beta, transpose_a, transpose_b)
 end
 
 function graph_matvecmul!(c::MtlVector{Tc}, a::MtlMatrix{Tab}, b::MtlVector{Tab}, alpha::Number = true, beta::Number = false, transpose = false) where {Tc, Tab}
-    _matmul!(c, a, b, alpha, beta, transpose, false)
+    return _matmul!(c, a, b, alpha, beta, transpose, false)
 end
diff --git a/lib/mpsgraphs/operations.jl b/lib/mpsgraphs/operations.jl
index 107c9ae3..47631733 100644
--- a/lib/mpsgraphs/operations.jl
+++ b/lib/mpsgraphs/operations.jl
@@ -1,68 +1,88 @@
 
-function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shape::MPSShape, name="broadcast")
-    obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
-                                toShape:shape::id{MPSShape}
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shape::MPSShape, name = "broadcast")
+    obj = @objc [
+        graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
+        toShape:shape::id{MPSShape}
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
-function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shapeTensor::MPSGraphTensor, name="broadcast")
-    obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
-                                toShapeTensor:shapeTensor::id{MPSGraphTensor}
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shapeTensor::MPSGraphTensor, name = "broadcast")
+    obj = @objc [
+        graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
+        toShapeTensor:shapeTensor::id{MPSGraphTensor}
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 function castTensor(graph::MPSGraph, tensor::MPSGraphTensor, toType, name = "cast")
-    obj = @objc [graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor}
-                                toType:toType::MPSDataType
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor}
+        toType:toType::MPSDataType
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 function constantWithScalar(graph::MPSGraph, scalar::Number, dataType)
-    obj = @objc [graph::id{MPSGraph} constantWithScalar:scalar::Float64
-                                dataType:dataType::MPSDataType]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} constantWithScalar:scalar::Float64
+        dataType:dataType::MPSDataType
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 function matrixMultiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "matmul")
-    obj = @objc [graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
-                                secondaryTensor:secondary::id{MPSGraphTensor}
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
+        secondaryTensor:secondary::id{MPSGraphTensor}
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 function multiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "mul")
-    obj = @objc [graph::id{MPSGraph} multiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
-                                secondaryTensor:secondary::id{MPSGraphTensor}
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} multiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
+        secondaryTensor:secondary::id{MPSGraphTensor}
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 function additionWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "add")
-    obj = @objc [graph::id{MPSGraph} additionWithPrimaryTensor:primary::id{MPSGraphTensor}
-                                secondaryTensor:secondary::id{MPSGraphTensor}
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} additionWithPrimaryTensor:primary::id{MPSGraphTensor}
+        secondaryTensor:secondary::id{MPSGraphTensor}
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 function transposeTensor(graph::MPSGraph, tensor::MPSGraphTensor, dimension, withDimension, name = "transpose")
-    obj = @objc [graph::id{MPSGraph} transposeTensor:tensor::id{MPSGraphTensor}
-                                dimension:dimension::NSUInteger
-                                withDimension:withDimension::NSUInteger
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} transposeTensor:tensor::id{MPSGraphTensor}
+        dimension:dimension::NSUInteger
+        withDimension:withDimension::NSUInteger
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 function shapeOfTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "shapeOfTensor")
-    obj = @objc [graph::id{MPSGraph} shapeOfTensor:tensor::id{MPSGraphTensor}
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} shapeOfTensor:tensor::id{MPSGraphTensor}
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 function identityWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "identity")
-    obj = @objc [graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor}
-                                name:name::id{NSString}]::id{MPSGraphTensor}
-    MPSGraphTensor(obj)
+    obj = @objc [
+        graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor}
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
+    return MPSGraphTensor(obj)
 end
 
 """
diff --git a/lib/mpsgraphs/random.jl b/lib/mpsgraphs/random.jl
index b984dcae..026fe8c8 100644
--- a/lib/mpsgraphs/random.jl
+++ b/lib/mpsgraphs/random.jl
@@ -1,8 +1,10 @@
 # @objcwrapper immutable=false MPSGraphRandomOpDescriptor <: MPSGraphObject
 
 function MPSGraphRandomOpDescriptor(distribution::MPSGraphRandomDistribution, dataType)
-    desc = @objc [MPSGraphRandomOpDescriptor descriptorWithDistribution:distribution::MPSGraphRandomDistribution
-                    dataType:dataType::MPSDataType]::id{MPSGraphRandomOpDescriptor}
+    desc = @objc [
+        MPSGraphRandomOpDescriptor descriptorWithDistribution:distribution::MPSGraphRandomDistribution
+        dataType:dataType::MPSDataType
+    ]::id{MPSGraphRandomOpDescriptor}
     obj = MPSGraphRandomOpDescriptor(desc)
     return obj
 end
diff --git a/lib/mpsgraphs/tensor.jl b/lib/mpsgraphs/tensor.jl
index 67518f5b..d09c87fe 100644
--- a/lib/mpsgraphs/tensor.jl
+++ b/lib/mpsgraphs/tensor.jl
@@ -10,7 +10,7 @@ function Base.size(td::MPSGraphTensor)
     temp = map(td.shape) do nsnum
         NSNumber(reinterpret(id{NSNumber}, nsnum)).unsignedIntegerValue |> Int
     end
-    Tuple(temp)
+    return Tuple(temp)
 end
 
 function placeholderTensor(graph::MPSGraph, shape::Union{Vector, Tuple}, args...)
@@ -18,9 +18,11 @@ function placeholderTensor(graph::MPSGraph, shape::Union{Vector, Tuple}, args...
     return placeholderTensor(graph, mpsshape, args...)
 end
 function placeholderTensor(graph::MPSGraph, shape::MPSShape, dataType::Type, name = "placeholder tensor")
-    obj = @objc [graph::id{MPSGraph} placeholderWithShape:shape::id{MPSShape}
-                                dataType:dataType::MPSDataType
-                                name:name::id{NSString}]::id{MPSGraphTensor}
+    obj = @objc [
+        graph::id{MPSGraph} placeholderWithShape:shape::id{MPSShape}
+        dataType:dataType::MPSDataType
+        name:name::id{NSString}
+    ]::id{MPSGraphTensor}
     return MPSGraphTensor(obj)
 end
 
@@ -31,29 +33,33 @@ function Base.size(td::MPSGraphTensorData)
     temp = map(td.shape) do nsnum
         NSNumber(reinterpret(id{NSNumber}, nsnum)).unsignedIntegerValue |> Int
     end
-    Tuple(temp)
+    return Tuple(temp)
 end
 
 function MPSGraphTensorData(buffer::MTLBuffer, shape::MPSShape, dataType)
     obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
     tensor = MPSGraphTensorData(obj)
     finalizer(release, tensor)
-    @objc [tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}
-                                    shape:shape::id{MPSShape}
-                                    dataType:dataType::MPSDataType]::id{MPSGraphTensorData}
+    @objc [
+        tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}
+        shape:shape::id{MPSShape}
+        dataType:dataType::MPSDataType
+    ]::id{MPSGraphTensorData}
     return tensor
 end
 function MPSGraphTensorData(buffer::MTLBuffer, shape::MPSShape, dataType, rowBytes)
     obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
     tensor = MPSGraphTensorData(obj)
     finalizer(release, tensor)
-    @objc [tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}
-                                    shape:shape::id{MPSShape}
-                                    dataType:dataType::MPSDataType
-                                    rowBytes:rowBytes::NSUInteger]::id{MPSGraphTensorData}
+    @objc [
+        tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}
+        shape:shape::id{MPSShape}
+        dataType:dataType::MPSDataType
+        rowBytes:rowBytes::NSUInteger
+    ]::id{MPSGraphTensorData}
     return tensor
 end
-MPSGraphTensorData(matrix::MtlArray{T}) where T = MPSGraphTensorData(matrix.data[], convert(MPSShape, reverse(size(matrix))), T)
+MPSGraphTensorData(matrix::MtlArray{T}) where {T} = MPSGraphTensorData(matrix.data[], convert(MPSShape, reverse(size(matrix))), T)
 
 function MPSGraphTensorData(matrix::MPSMatrix)
     obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
@@ -69,8 +75,10 @@ function MPSGraphTensorData(matrix::MPSMatrix, rank)
     obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
     tensor = MPSGraphTensorData(obj)
     finalizer(release, tensor)
-    @objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:matrix::id{MPSMatrix}
-                              rank:rank::NSUInteger]::id{MPSGraphTensorData}
+    @objc [
+        tensor::id{MPSGraphTensorData} initWithMPSMatrix:matrix::id{MPSMatrix}
+        rank:rank::NSUInteger
+    ]::id{MPSGraphTensorData}
     return tensor
 end
 
@@ -88,8 +96,10 @@ function MPSGraphTensorData(vector::MPSVector, rank)
     obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
     tensor = MPSGraphTensorData(obj)
     finalizer(release, tensor)
-    @objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:vector::id{MPSVector}
-                                          rank:rank::NSUInteger]::id{MPSGraphTensorData}
+    @objc [
+        tensor::id{MPSGraphTensorData} initWithMPSMatrix:vector::id{MPSVector}
+        rank:rank::NSUInteger
+    ]::id{MPSGraphTensorData}
     return tensor
 end
 
@@ -118,5 +128,5 @@ Will copy contents if the contents are not stored in an MPS ndarray.
 """
 function MPS.MPSNDArray(tensor::MPSGraphTensorData)
     arr = @objc [tensor::id{MPSNDArray} mpsndarray]::id{MPSNDArray}
-    MPSNDArray(arr)
+    return MPSNDArray(arr)
 end
diff --git a/src/linalg.jl b/src/linalg.jl
index 0a6ac707..2f2fa00a 100644
--- a/src/linalg.jl
+++ b/src/linalg.jl
@@ -3,16 +3,16 @@ using LinearAlgebra: MulAddMul, wrap
 using .MPS
 using .MPS: MPS_VALID_MATMUL_TYPES, MPS_VALID_MATVECMUL_TYPES, MtlFloat
 using .MPSGraphs: MPSGRAPH_VALID_MATMUL_TYPES, MPSGRAPH_VALID_MATVECMUL_TYPES,
-                  graph_matmul!, graph_matvecmul!
+    graph_matmul!, graph_matvecmul!
 
 @inline function supports_mps_matmul(A, B, C, valid_types)
-    MPS.is_supported(device(A)) &&
+    return MPS.is_supported(device(A)) &&
         eltype(A) == eltype(B) &&
         (eltype(A), eltype(C)) in valid_types
 end
 
 @inline function supports_mpsgraph_matmul(A, B, C, valid_types)
-    MPS.is_supported(device(A)) &&
+    return MPS.is_supported(device(A)) &&
         eltype(A) == eltype(B) &&
         (eltype(A), eltype(C)) in valid_types &&
         # TODO: remove this limitation
diff --git a/test/examples.jl b/test/examples.jl
index d0f9017d..77fc65d6 100644
--- a/test/examples.jl
+++ b/test/examples.jl
@@ -19,7 +19,7 @@ cd(examples_dir) do
     @testset for example in examples
         mod = @eval module $(gensym()) end
         @eval mod begin
-            const TESTING=true
+                const TESTING = true
             redirect_stdout(devnull) do
                 include($example)
             end
diff --git a/test/linalg.jl b/test/linalg.jl
index 5d05d0dd..8f703169 100644
--- a/test/linalg.jl
+++ b/test/linalg.jl
@@ -25,7 +25,7 @@ if MPS.is_supported(device())
     mtl_view_c = mtl_view_a * mtl_view_b
     view_c = view_a * view_b
 
-    @test Array(mtl_view_c) ≈ view_c
+        @test Array(mtl_view_c) ≈ view_c
 end
 
 using Metal: storagemode
diff --git a/test/mps/linalg.jl b/test/mps/linalg.jl
index 90c6de9c..1d7c73f7 100644
--- a/test/mps/linalg.jl
+++ b/test/mps/linalg.jl
@@ -34,19 +34,19 @@ if MPS.is_supported(device())
 end
 
 @testset "batched matrix matrix multiplication" begin
-    M = 8
-    N = 7
-    P = 9
+        M = 8
+        N = 7
+        P = 9
     batch_size = 3
 
-    rows_a = M
+        rows_a = M
     cols_a = N
 
     rows_b = N
-    cols_b = P
+        cols_b = P
 
-    rows_c = M
-    cols_c = P
+        rows_c = M
+        cols_c = P
 
     alpha = Float64(1)
     beta = Float64(1)
diff --git a/test/mpsgraphs/core.jl b/test/mpsgraphs/core.jl
index ce716190..d538570f 100644
--- a/test/mpsgraphs/core.jl
+++ b/test/mpsgraphs/core.jl
@@ -1,19 +1,19 @@
 
 if MPS.is_supported(device())
 
-using .MPS: MPSShape
-using .MPSGraphs: MPSGraph, MPSGraphDevice
-@testset "Core" begin
+    using .MPS: MPSShape
+    using .MPSGraphs: MPSGraph, MPSGraphDevice
+    @testset "Core" begin
 
-graph = MPSGraph()
-@test graph isa MPSGraph
+        graph = MPSGraph()
+        @test graph isa MPSGraph
 
-dev = device()
-graphdev = MPSGraphDevice(dev)
-@test graphdev isa MPSGraphDevice
-@test graphdev.type == MPSGraphs.MPSGraphDeviceTypeMetal
-@test graphdev.metalDevice == dev
+        dev = device()
+        graphdev = MPSGraphDevice(dev)
+        @test graphdev isa MPSGraphDevice
+        @test graphdev.type == MPSGraphs.MPSGraphDeviceTypeMetal
+        @test graphdev.metalDevice == dev
 
-end # @testset "Core"
+    end # @testset "Core"
 
 end # MPS.is_supported(device())
diff --git a/test/mpsgraphs/linalg.jl b/test/mpsgraphs/linalg.jl
index 62438980..8711eb94 100644
--- a/test/mpsgraphs/linalg.jl
+++ b/test/mpsgraphs/linalg.jl
@@ -3,98 +3,98 @@ using LinearAlgebra
 
 if MPS.is_supported(device())
 
-@testset "mixed-precision matrix matrix multiplication" begin
-    N = 10
-    rows_a = N
-    cols_a = N
+    @testset "mixed-precision matrix matrix multiplication" begin
+        N = 10
+        rows_a = N
+        cols_a = N
 
-    rows_b = N
-    cols_b = N
+        rows_b = N
+        cols_b = N
 
-    rows_c = rows_a
-    cols_c = cols_b
+        rows_c = rows_a
+        cols_c = cols_b
 
-    alpha = Float64(1)
-    beta  = Float64(1)
+        alpha = Float64(1)
+        beta = Float64(1)
 
-    @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
-        arr_a = rand(input_jl_type, (rows_a, cols_a))
-        arr_b = rand(input_jl_type, (rows_b, cols_b))
-        arr_c = zeros(accum_jl_type, (rows_c, cols_c))
+        @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
+            arr_a = rand(input_jl_type, (rows_a, cols_a))
+            arr_b = rand(input_jl_type, (rows_b, cols_b))
+            arr_c = zeros(accum_jl_type, (rows_c, cols_c))
 
-        buf_a = MtlArray{input_jl_type}(arr_a)
-        buf_b = MtlArray{input_jl_type}(arr_b)
-        buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c))
+            buf_a = MtlArray{input_jl_type}(arr_a)
+            buf_b = MtlArray{input_jl_type}(arr_b)
+            buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c))
 
-        truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c)
+            truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c)
 
-        MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
+            MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
 
-        @test all(Array(buf_c) .≈ truth_c)
+            @test all(Array(buf_c) .≈ truth_c)
+        end
     end
-end
 
-@testset "batched matrix matrix multiplication" begin
-    M = 8
-    N = 7
-    P = 9
-    batch_size = 3
+    @testset "batched matrix matrix multiplication" begin
+        M = 8
+        N = 7
+        P = 9
+        batch_size = 3
 
-    rows_a = M
-    cols_a = N
+        rows_a = M
+        cols_a = N
 
-    rows_b = N
-    cols_b = P
+        rows_b = N
+        cols_b = P
 
-    rows_c = M
-    cols_c = P
+        rows_c = M
+        cols_c = P
 
-    alpha = Float64(1)
-    beta = Float64(1)
+        alpha = Float64(1)
+        beta = Float64(1)
 
-    @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
-        arr_a = rand(input_jl_type, (rows_a, cols_a, batch_size))
-        arr_b = rand(input_jl_type, (rows_b, cols_b, batch_size))
-        arr_c = zeros(accum_jl_type, (rows_c, cols_c, batch_size))
+        @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
+            arr_a = rand(input_jl_type, (rows_a, cols_a, batch_size))
+            arr_b = rand(input_jl_type, (rows_b, cols_b, batch_size))
+            arr_c = zeros(accum_jl_type, (rows_c, cols_c, batch_size))
 
-        buf_a = MtlArray{input_jl_type}(arr_a)
-        buf_b = MtlArray{input_jl_type}(arr_b)
-        buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
+            buf_a = MtlArray{input_jl_type}(arr_a)
+            buf_b = MtlArray{input_jl_type}(arr_b)
+            buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
 
-        truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
-        for i in 1:batch_size
-            @views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
-        end
+            truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
+            for i in 1:batch_size
+                @views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
+            end
 
-        MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
+            MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
 
-        @test all(Array(buf_c) .≈ truth_c)
+            @test all(Array(buf_c) .≈ truth_c)
+        end
     end
-end
 
-@testset "mixed-precision matrix vector multiplication" begin
-    N = 10
-    rows = N
-    cols = N
+    @testset "mixed-precision matrix vector multiplication" begin
+        N = 10
+        rows = N
+        cols = N
 
-    alpha = Float64(1)
-    beta  = Float64(0)
+        alpha = Float64(1)
+        beta = Float64(0)
 
-    @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATVECMUL_TYPES
-        arr_a = rand(input_jl_type, (rows, cols))
-        arr_b = rand(input_jl_type, (rows))
-        arr_c = zeros(accum_jl_type, (rows))
+        @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATVECMUL_TYPES
+            arr_a = rand(input_jl_type, (rows, cols))
+            arr_b = rand(input_jl_type, (rows))
+            arr_c = zeros(accum_jl_type, (rows))
 
-        buf_a = MtlArray{input_jl_type}(arr_a)
-        buf_b = MtlArray{input_jl_type}(arr_b)
-        buf_c = MtlArray{accum_jl_type}(undef, (rows))
+            buf_a = MtlArray{input_jl_type}(arr_a)
+            buf_b = MtlArray{input_jl_type}(arr_b)
+            buf_c = MtlArray{accum_jl_type}(undef, (rows))
 
-        truth_c = (accum_jl_type(alpha) .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (accum_jl_type(beta) .* arr_c)
+            truth_c = (accum_jl_type(alpha) .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (accum_jl_type(beta) .* arr_c)
 
-        MPSGraphs.graph_matvecmul!(buf_c, buf_a, buf_b, alpha, beta)
+            MPSGraphs.graph_matvecmul!(buf_c, buf_a, buf_b, alpha, beta)
 
-        @test all(Array(buf_c) .≈ truth_c)
+            @test all(Array(buf_c) .≈ truth_c)
+        end
     end
-end
 
 end # MPS.is_supported(device())
diff --git a/test/mpsgraphs/random.jl b/test/mpsgraphs/random.jl
index 4303ee83..35584c6f 100644
--- a/test/mpsgraphs/random.jl
+++ b/test/mpsgraphs/random.jl
@@ -2,24 +2,25 @@ using BFloat16s
 
 if MPS.is_supported(device())
 
-using .MPSGraphs: MPSGraphRandomOpDescriptor, MPSGraphRandomDistributionNormal, MPSGraphRandomDistributionTruncatedNormal, MPSGraphRandomDistributionUniform
-@testset "MPSGraph random" begin
-    # determined by looking at the error message when trying to construct
-    #  an invalid distribution/type combination
-    for (dist, T) in [(MPSGraphRandomDistributionNormal, Float32),
-                      (MPSGraphRandomDistributionNormal, Float16),
-                      (MPSGraphRandomDistributionNormal, BFloat16),
-                      (MPSGraphRandomDistributionTruncatedNormal, Float32),
-                      (MPSGraphRandomDistributionTruncatedNormal, Float16),
-                      (MPSGraphRandomDistributionTruncatedNormal, BFloat16),
-                      (MPSGraphRandomDistributionUniform, Int64),
-                      (MPSGraphRandomDistributionUniform, Int32),
-                      (MPSGraphRandomDistributionUniform, Float32),
-                      (MPSGraphRandomDistributionUniform, Float16),
-                      (MPSGraphRandomDistributionUniform, BFloat16),
-                      ]
-        @test MPSGraphRandomOpDescriptor(MPSGraphRandomDistributionNormal, Float32) isa MPSGraphRandomOpDescriptor
+    using .MPSGraphs: MPSGraphRandomOpDescriptor, MPSGraphRandomDistributionNormal, MPSGraphRandomDistributionTruncatedNormal, MPSGraphRandomDistributionUniform
+    @testset "MPSGraph random" begin
+        # determined by looking at the error message when trying to construct
+        #  an invalid distribution/type combination
+        for (dist, T) in [
+                (MPSGraphRandomDistributionNormal, Float32),
+                (MPSGraphRandomDistributionNormal, Float16),
+                (MPSGraphRandomDistributionNormal, BFloat16),
+                (MPSGraphRandomDistributionTruncatedNormal, Float32),
+                (MPSGraphRandomDistributionTruncatedNormal, Float16),
+                (MPSGraphRandomDistributionTruncatedNormal, BFloat16),
+                (MPSGraphRandomDistributionUniform, Int64),
+                (MPSGraphRandomDistributionUniform, Int32),
+                (MPSGraphRandomDistributionUniform, Float32),
+                (MPSGraphRandomDistributionUniform, Float16),
+                (MPSGraphRandomDistributionUniform, BFloat16),
+            ]
+            @test MPSGraphRandomOpDescriptor(MPSGraphRandomDistributionNormal, Float32) isa MPSGraphRandomOpDescriptor
+        end
     end
-end
 
 end # MPS.is_supported(device())

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Benchmark suite Current: faa8b6a Previous: 369e292 Ratio
private array/construct 28361.083333333336 ns 25757 ns 1.10
private array/broadcast 464250 ns 458125 ns 1.01
private array/random/randn/Float32 819958 ns 761125 ns 1.08
private array/random/randn!/Float32 605083 ns 612208 ns 0.99
private array/random/rand!/Int64 575541 ns 558500 ns 1.03
private array/random/rand!/Float32 610645.5 ns 592625 ns 1.03
private array/random/rand/Int64 784291 ns 796000 ns 0.99
private array/random/rand/Float32 604979 ns 623500 ns 0.97
private array/copyto!/gpu_to_gpu 674791 ns 655250 ns 1.03
private array/copyto!/cpu_to_gpu 627479 ns 804167 ns 0.78
private array/copyto!/gpu_to_cpu 820479 ns 804083 ns 1.02
private array/accumulate/1d 1339541.5 ns 1349500 ns 0.99
private array/accumulate/2d 1403958 ns 1395209 ns 1.01
private array/iteration/findall/int 2105875 ns 2038041.5 ns 1.03
private array/iteration/findall/bool 1817125 ns 1845313 ns 0.98
private array/iteration/findfirst/int 1711000 ns 1717625 ns 1.00
private array/iteration/findfirst/bool 1663875 ns 1673458.5 ns 0.99
private array/iteration/scalar 3952354 ns 3856833 ns 1.02
private array/iteration/logical 3213437.5 ns 3196542 ns 1.01
private array/iteration/findmin/1d 1762166.5 ns 1770187 ns 1.00
private array/iteration/findmin/2d 1368917 ns 1362333 ns 1.00
private array/reductions/reduce/1d 1052500 ns 1022584 ns 1.03
private array/reductions/reduce/2d 662250 ns 665709 ns 0.99
private array/reductions/mapreduce/1d 1053792 ns 1045166 ns 1.01
private array/reductions/mapreduce/2d 666958 ns 657625 ns 1.01
private array/permutedims/4d 2494103.5 ns 2547458.5 ns 0.98
private array/permutedims/2d 1030875 ns 1028750 ns 1.00
private array/permutedims/3d 1580291 ns 1582146 ns 1.00
private array/copy 600708 ns 621854 ns 0.97
latency/precompile 9647484834 ns 9111380625 ns 1.06
latency/ttfp 3730773459 ns 3718960708 ns 1.00
latency/import 1254998396 ns 1244311500 ns 1.01
integration/metaldevrt 707042 ns 714084 ns 0.99
integration/byval/slices=1 1560520.5 ns 1562604 ns 1.00
integration/byval/slices=3 10160271 ns 10396459 ns 0.98
integration/byval/reference 1512917 ns 1559083 ns 0.97
integration/byval/slices=2 2614916.5 ns 2742208.5 ns 0.95
kernel/indexing 463792 ns 464021 ns 1.00
kernel/indexing_checked 464020.5 ns 461542 ns 1.01
kernel/launch 9319.333333333334 ns 8000 ns 1.16
metal/synchronization/stream 14833 ns 14959 ns 0.99
metal/synchronization/context 15167 ns 15208 ns 1.00
shared array/construct 24565.916666666664 ns 24638.833333333332 ns 1.00
shared array/broadcast 462125 ns 463854.5 ns 1.00
shared array/random/randn/Float32 835729.5 ns 793500 ns 1.05
shared array/random/randn!/Float32 637479.5 ns 631500 ns 1.01
shared array/random/rand!/Int64 587000 ns 555208 ns 1.06
shared array/random/rand!/Float32 615042 ns 585625 ns 1.05
shared array/random/rand/Int64 762292 ns 791250 ns 0.96
shared array/random/rand/Float32 631520.5 ns 608229 ns 1.04
shared array/copyto!/gpu_to_gpu 84209 ns 84459 ns 1.00
shared array/copyto!/cpu_to_gpu 82875 ns 89208.5 ns 0.93
shared array/copyto!/gpu_to_cpu 83500 ns 84084 ns 0.99
shared array/accumulate/1d 1362584 ns 1372541.5 ns 0.99
shared array/accumulate/2d 1408521 ns 1411292 ns 1.00
shared array/iteration/findall/int 1844500 ns 1806208 ns 1.02
shared array/iteration/findall/bool 1584417 ns 1616209 ns 0.98
shared array/iteration/findfirst/int 1409437.5 ns 1408145.5 ns 1.00
shared array/iteration/findfirst/bool 1377000 ns 1370770.5 ns 1.00
shared array/iteration/scalar 157417 ns 159000 ns 0.99
shared array/iteration/logical 2993541.5 ns 3018917 ns 0.99
shared array/iteration/findmin/1d 1475208.5 ns 1480458 ns 1.00
shared array/iteration/findmin/2d 1371959 ns 1385541 ns 0.99
shared array/reductions/reduce/1d 743000 ns 731625 ns 1.02
shared array/reductions/reduce/2d 667000 ns 669875 ns 1.00
shared array/reductions/mapreduce/1d 748667 ns 740125 ns 1.01
shared array/reductions/mapreduce/2d 668354 ns 668917 ns 1.00
shared array/permutedims/4d 2529584 ns 2498521 ns 1.01
shared array/permutedims/2d 1032667 ns 1037291 ns 1.00
shared array/permutedims/3d 1611729 ns 1601042 ns 1.01
shared array/copy 244562.5 ns 252291 ns 0.97

This comment was automatically generated by workflow using github-action-benchmark.

@maleadt
Copy link
Member

maleadt commented Mar 17, 2025

This is ready for review once #565 is merged.

FYI, you can have PRs target other PRs by setting the base branch, which should make it possible to review (only showing the relevant diff) while auto-updating to main when the target PR is merged.

Close #381

I'd keep that issue open, given that there's MWE's using MPS directly.

What's the performance of MPSGraph like compared to using MPS directly? No difference, I would assume?

@christiangnrd
Copy link
Member Author

Very nice about the PR target thing that’ll be very useful for future PRs

I unlinked #381 for automatic closing.

As for performance, I want to make some graphs, but from some quick benchmarks while testing, it seems like it’s slower for smaller matrices, and similar performance for larger ones. That’s why I only made MPSGraphs the default for the types affected by #381.

@christiangnrd christiangnrd force-pushed the MPSGraph branch 2 times, most recently from 3b9f66a to 391bfb8 Compare March 17, 2025 18:50
@christiangnrd
Copy link
Member Author

christiangnrd commented Mar 18, 2025

Okay. I wrote a little script to compare MPS and MPSGraph matmul performance. It runs all the supported type variations and saves it to a plot.

Script is in dev/flopscomp.jl of this branch.

The script activates a temporary environment, installs this branch of Metal, and then compares matmul performance and saves the results to a plot.

To run it, run the following:

using Pkg
Pkg.activate(temp=true)
Pkg.add(url="https://github.com/christiangnrd/Metal.jl/", rev="MPSGraph")
Pkg.add(["GPUArrays", "Plots"])

using Metal

scriptpath = joinpath(dirname(pathof(Metal)), "..", "dev", "flopscomp.jl")

include(scriptpath)

# Save the figure `bench_all_1.svg` in current working
#  directory or specify your own outpath
main(outpath="")

It'll download this branch, run the benchmarks, and spit out a figure.

I ran it with GPUArrays.generic_matmatmul! but that's not needed and makes it take MUCH longer. My results were:
bench_all_1

These results are a bit unfortunate? I don't know if we have a choice though because the combinations where performance is worse are the cases affected by #381...

@christiangnrd
Copy link
Member Author

christiangnrd commented Mar 18, 2025

I also ran it on an M1 to get an idea of how long it takes total on the slowest Apple silicon it from installing, precompiling, and running, it took 391 seconds.

bench_all_1_M1

@christiangnrd
Copy link
Member Author

christiangnrd commented Mar 19, 2025

a4e8216 makes the operation actually in-place, while a86a55d makes every operation use Float32 under the hood. This makes memory usage quite a bit higher, but greatly improves performance. I'd love opinions on whether we keep it that way or not.

Batched multiplication where one Array has only one matrix and the other has more is also now working.

Latest performance on my M2 Max (compare with first screenshot):
bench_all_1

@maleadt
Copy link
Member

maleadt commented Mar 19, 2025

a86a55d makes every operation use Float32 under the hood

At the expense of accuracy?

I'm surprised by the large performance differences, but given the issues with MPS matmul we don't really have a choice. That's said, it's good that you kept the generic_matmatmul! dispatcher decoupled from the back-end as much as possible, that way we could add some heuristics (small inputs -> MPS, large inputs -> MPS graphs) if we ever figure out the NaN issue.

@christiangnrd
Copy link
Member Author

christiangnrd commented Mar 19, 2025

I was trying to figure out why Float16 performance was worse with MPSGraph, and this amazing line added right after initializing the graph gave me some amazing insights.

function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab, Na, Nb}
    graph = MPSGraph()

    graph.options = MPSGraphs.MPSGraphOptionsVerbose # <- This line

    placeA = placeholderTensor(graph, size(a), Tab)
    placeB = placeholderTensor(graph, size(b), Tab)
...

It seems like when the output eltype is Float16, MPSGraph runs the operation on the Neural Engine (ANE), which MPS does not do.

Latest commit disables this by default for graph_matmul!, restoring performance to what it is when I cast to Float32 without the extra memory usage. This still casts Int8 and Int16 input eltype arrays to the output eltype. Does that affect precision? If so, I'll make MPS the default for those types as I was unable to reproduce #381 with Integer input arrays.

New performance:
bench_all_1

I suspect ANE may be better for other M-series processors with weaker GPUs but the same ANE as my 30-core M2 Max. I'll ask on Slack if people want to benchmark on their device and report performance here.

macOS 13

Seems like ANE is not used in macOS 13. Also MPSGraph performance in macOS 15 seems noticeably better!
bench_all_1

macOS 14 ANE is used in macOS 14. Also MPSGraph performance in macOS 15 seems noticeably better!

bench_all_1

@haakon-e
Copy link

Results from your script above. I have quite a few programs open (slack, vscode, browsers, etc.) but didn't use anything actively while the test ran.

bench_all_1

@sinhtrung
Copy link

I ran the code posted on a M2 Macbook Pro
bench_all_1

@haakon-e
Copy link

Same as above with a MacBook Air:
bench_all_1

@christiangnrd
Copy link
Member Author

Thanks @haakon-e @sinhtrung! This is very helpful. I'm very glad to see M4 Pro as I was wondering if the apparently much faster neural engine would make it faster than GPU, but it seems like the TOPS they're referring to may not be Float16 TOPS. (seems like NPUs typically report Int8 OPS)

Seems like we should leave the ANE unused as at best it's as good as the GPU.

@AntoineBut

This comment was marked as resolved.

@christiangnrd

This comment was marked as resolved.

@AntoineBut
Copy link

Here is what I get on M3 Pro (MacBook Pro)

bench_all_1

@chengchingwen
Copy link
Contributor

Results with M2 Max

bench_all_1

@christiangnrd
Copy link
Member Author

If no one objects, I'll merge this tomorrow morning.

I'm wondering if we should use MPS instead of MPSGraph in situations where it's faster and not broken. I'll open a PR for feedback soon.

@christiangnrd christiangnrd merged commit fab6fc2 into JuliaGPU:main Apr 11, 2025
7 checks passed
@christiangnrd christiangnrd deleted the MPSGraph branch April 11, 2025 13:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants