-
Notifications
You must be signed in to change notification settings - Fork 46
Initial MPSGraph support #566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #566 +/- ##
==========================================
- Coverage 81.97% 80.47% -1.50%
==========================================
Files 54 61 +7
Lines 2485 2658 +173
==========================================
+ Hits 2037 2139 +102
- Misses 448 519 +71 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/examples/flopscomp.jl b/examples/flopscomp.jl
index 4436cedd..e8ca4fed 100644
--- a/examples/flopscomp.jl
+++ b/examples/flopscomp.jl
@@ -8,14 +8,14 @@ testing = (@isdefined TESTING) && TESTING
using Plots.Measures
end
-const Ts=[
- (Int8, Float16),
- (Int8, Float32),
- (Int16, Float32),
- (Float16, Float16),
- (Float16, Float32),
- (Float32, Float32),
- ]
+const Ts = [
+ (Int8, Float16),
+ (Int8, Float32),
+ (Int16, Float32),
+ (Float16, Float16),
+ (Float16, Float32),
+ (Float32, Float32),
+]
n_gpu_cores = "??"
# Comment this out if scary. Please mention number of cores in your comment when uploading the figure
@@ -24,17 +24,19 @@ n_gpu_cores = only(match(r"Total Number of Cores:\s*(\d+)", system_prof).capture
PLOT_TITLE = "Matmul peakflops for $(device().name) ($n_gpu_cores GPU cores)"
-function cpupeakflops(; n::Integer=4096,
- n_batch::Integer=1,
- inT::DataType=Float32,
- outT::DataType=inT,
- ntrials::Integer=4,
- verify=true)
+function cpupeakflops(;
+ n::Integer = 4096,
+ n_batch::Integer = 1,
+ inT::DataType = Float32,
+ outT::DataType = inT,
+ ntrials::Integer = 4,
+ verify = true
+ )
t = Base.zeros(Float64, ntrials)
n_batch == 1 || @warn "n_batch > 1 not supported for `mul!`, running with n_batch=1"
n_batch = 1
shape = (n, n)
- for i=1:ntrials
+ for i in 1:ntrials
c = zeros(outT, shape...)
a = ones(inT, shape...)
b = ones(inT, shape...)
@@ -42,12 +44,12 @@ function cpupeakflops(; n::Integer=4096,
verify && @assert only(unique(Array(c))) == n
end
- return n_batch*2*Float64(n)^3 / minimum(t)
+ return n_batch * 2 * Float64(n)^3 / minimum(t)
end
-function _peakflops(f, n, n_batch, inT, outT, ntrials; verify=true)
+function _peakflops(f, n, n_batch, inT, outT, ntrials; verify = true)
t = Base.zeros(Float64, ntrials)
shape = n_batch == 1 ? (n, n) : (n, n, n_batch)
- for i=1:ntrials
+ for i in 1:ntrials
c = mtl(zeros(outT, shape...))
a = mtl(ones(inT, shape...))
b = mtl(ones(inT, shape...))
@@ -55,34 +57,40 @@ function _peakflops(f, n, n_batch, inT, outT, ntrials; verify=true)
verify && @assert only(unique(Array(c))) == n
end
- return n_batch*2*Float64(n)^3 / minimum(t)
+ return n_batch * 2 * Float64(n)^3 / minimum(t)
end
-function gpuarrpeakflops(; n::Integer=4096,
- n_batch::Integer=1,
- inT::DataType=Float32,
- outT::DataType=inT,
- ntrials::Integer=3,
- verify=true)
+function gpuarrpeakflops(;
+ n::Integer = 4096,
+ n_batch::Integer = 1,
+ inT::DataType = Float32,
+ outT::DataType = inT,
+ ntrials::Integer = 3,
+ verify = true
+ )
n_batch == 1 || @warn "n_batch > 1 not supported for `GPUArrays.generic_matmatmul!`, running with n_batch=1"
- _peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b
+ return _peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b
GPUArrays.generic_matmatmul!(c, LinearAlgebra.wrap(a, 'N'), LinearAlgebra.wrap(b, 'N'), 1, 0)
end
end
-function mpspeakflops(; n::Integer=4096,
- n_batch::Integer=1,
- inT::DataType=Float32,
- outT::DataType=inT,
- ntrials::Integer=3,
- verify=true)
- _peakflops(MPS.matmul!, n, n_batch, inT, outT, ntrials; verify)
+function mpspeakflops(;
+ n::Integer = 4096,
+ n_batch::Integer = 1,
+ inT::DataType = Float32,
+ outT::DataType = inT,
+ ntrials::Integer = 3,
+ verify = true
+ )
+ return _peakflops(MPS.matmul!, n, n_batch, inT, outT, ntrials; verify)
end
-function graphpeakflops(; n::Integer=4096,
- n_batch::Integer=1,
- inT::DataType=Float32,
- outT::DataType=inT,
- ntrials::Integer=3,
- verify=true)
- _peakflops(MPSGraphs.graph_matmul!, n, n_batch, inT, outT, ntrials; verify)
+function graphpeakflops(;
+ n::Integer = 4096,
+ n_batch::Integer = 1,
+ inT::DataType = Float32,
+ outT::DataType = inT,
+ ntrials::Integer = 3,
+ verify = true
+ )
+ return _peakflops(MPSGraphs.graph_matmul!, n, n_batch, inT, outT, ntrials; verify)
end
function anepeakflops(; kwargs...)
# VERY HACKY
@@ -99,13 +107,13 @@ function anepeakflops(; kwargs...)
return res
end
-function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
+function compare(Ns, Fs, inT, outT = inT; n_batch = 1, ntrials)
results = Dict()
newFs = if (outT == Float16 || (outT == Float32 && inT == Float16))
Fs
else
- filter(x -> !occursin("ANE", x[2]),Fs)
+ filter(x -> !occursin("ANE", x[2]), Fs)
end
for (_, info_str) in newFs
@@ -128,34 +136,36 @@ function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
return results
end
-function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
- Fs=[
- (mpspeakflops, "MPS"),
- (graphpeakflops, "MPSGraph"),
- (anepeakflops, "MPSGraph (ANE)"),
- # (gpuarrpeakflops, "GPUArrays"),
- # (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
- ],
- n_batch=1,
- ntrials=5)
+function runcomparison(;
+ Ns = [50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192], #, 10000],
+ Fs = [
+ (mpspeakflops, "MPS"),
+ (graphpeakflops, "MPSGraph"),
+ (anepeakflops, "MPSGraph (ANE)"),
+ # (gpuarrpeakflops, "GPUArrays"),
+ # (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
+ ],
+ n_batch = 1,
+ ntrials = 5
+ )
res = Dict()
for (inT, outT) in Ts
- res[(inT,outT)] = (n_batch, Ns, compare(Ns, Fs, inT, outT; n_batch, ntrials))
+ res[(inT, outT)] = (n_batch, Ns, compare(Ns, Fs, inT, outT; n_batch, ntrials))
end
return res
end
-function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE)
- ylim_upper = 9e12
+function plot_results(res, Fs = ["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath = nothing, outtype = "svg", plt_title = PLOT_TITLE)
+ ylim_upper = 9.0e12
resplts = []
n_batches = []
for (inT, outT) in Ts
- n_batch, Ns, tmpres = res[(inT,outT)]
+ n_batch, Ns, tmpres = res[(inT, outT)]
- plt = plot(xlabel="N, n_batch=$(n_batch)", legendtitle="($inT, $outT)")
+ plt = plot(xlabel = "N, n_batch=$(n_batch)", legendtitle = "($inT, $outT)")
for info_str in Fs
haskey(tmpres, info_str) || continue
@@ -164,24 +174,26 @@ function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=not
if maximum(flops) > ylim_upper
ylim_upper = maximum(flops) * 1.02
end
- plot!(plt, Ns, tmpres[info_str]; linewidth=1.5, label="$(peakf) peak: $info_str")
+ plot!(plt, Ns, tmpres[info_str]; linewidth = 1.5, label = "$(peakf) peak: $info_str")
end
push!(resplts, plt)
push!(n_batches, n_batch)
end
- finalplot = plot(resplts...; layout=(2,3),
- ylim=(0,ylim_upper),
- plot_title=plt_title,
- tickfonthalign=:left,
- bottommargin=15pt,
- size=(2000,1200))
+ finalplot = plot(
+ resplts...; layout = (2, 3),
+ ylim = (0, ylim_upper),
+ plot_title = plt_title,
+ tickfonthalign = :left,
+ bottommargin = 15pt,
+ size = (2000, 1200)
+ )
if !isnothing(outpath)
- savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$outtype"))
+ savefig(plot(finalplot, dpi = 500), joinpath(outpath, "bench_all_$(first(n_batches)).$outtype"))
end
return finalplot
end
if testing
- runcomparison(Ns=[50, 64, 100, 128, 250, 256, 500, 512])
+ runcomparison(Ns = [50, 64, 100, 128, 250, 256, 500, 512])
end
diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl
index 0ef1f2d0..f6efd2de 100644
--- a/lib/mps/MPS.jl
+++ b/lib/mps/MPS.jl
@@ -21,7 +21,7 @@ using BFloat16s
const MtlFloat = Union{Float32, Float16}
const MPSShape = NSArray#{NSNumber}
-Base.convert(::Type{MPSShape}, tuple::Union{Vector{T},NTuple{T, <:Integer}}) where T = NSArray(NSNumber.(collect(tuple)))
+Base.convert(::Type{MPSShape}, tuple::Union{Vector{T}, NTuple{T, <:Integer}}) where {T} = NSArray(NSNumber.(collect(tuple)))
# Valid combination of input (A and B matrices) and output (C) types
const MPS_VALID_MATMUL_TYPES =
diff --git a/lib/mpsgraphs/MPSGraphs.jl b/lib/mpsgraphs/MPSGraphs.jl
index 2a06ae14..c6e0d373 100644
--- a/lib/mpsgraphs/MPSGraphs.jl
+++ b/lib/mpsgraphs/MPSGraphs.jl
@@ -20,23 +20,23 @@ using ObjectiveC, .Foundation, .Dispatch
# The commented type combinations work but are slower than with MPSMatrixMultiplicatiom
const MPSGRAPH_VALID_MATMUL_TYPES =
[
- (Int8, Float16),
- (Int8, Float32),
- (Int16, Float32),
- (Float16, Float16),
- (Float16, Float32),
- (Float32, Float32),
- ]
+ (Int8, Float16),
+ (Int8, Float32),
+ (Int16, Float32),
+ (Float16, Float16),
+ (Float16, Float32),
+ (Float32, Float32),
+]
const MPSGRAPH_VALID_MATVECMUL_TYPES =
[
- (Int8, Float16),
- (Int8, Float32),
- (Int16, Float32),
- (Float16, Float16),
- (Float16, Float32),
- (Float32, Float32),
- ]
+ (Int8, Float16),
+ (Int8, Float32),
+ (Int16, Float32),
+ (Float16, Float16),
+ (Float16, Float32),
+ (Float32, Float32),
+]
include("libmpsgraph.jl")
diff --git a/lib/mpsgraphs/core.jl b/lib/mpsgraphs/core.jl
index 2f2e868e..f71051d6 100644
--- a/lib/mpsgraphs/core.jl
+++ b/lib/mpsgraphs/core.jl
@@ -6,7 +6,7 @@
# @objcwrapper MPSGraph <: MPSGraphObject
function MPSGraph()
- MPSGraph(@objc [MPSGraph new]::id{MPSGraph})
+ return MPSGraph(@objc [MPSGraph new]::id{MPSGraph})
end
# @objcwrapper immutable=true MPSGraphShapedType <: MPSGraphType
@@ -26,17 +26,17 @@ end
function MPSGraphDevice(device::MTLDevice)
obj = @objc [MPSGraphDevice deviceWithMTLDevice:device::id{MTLDevice}]::id{MPSGraphDevice}
- MPSGraphDevice(obj)
+ return MPSGraphDevice(obj)
end
# @objcwrapper MPSGraphExecutionDescriptor <: MPSGraphObject
function MPSGraphExecutionDescriptor()
- MPSGraphExecutionDescriptor(@objc [MPSGraphExecutionDescriptor new]::id{MPSGraphExecutionDescriptor})
+ return MPSGraphExecutionDescriptor(@objc [MPSGraphExecutionDescriptor new]::id{MPSGraphExecutionDescriptor})
end
# @objcwrapper MPSGraphCompilationDescriptor <: MPSGraphObject
function MPSGraphCompilationDescriptor()
- MPSGraphCompilationDescriptor(@objc [MPSGraphCompilationDescriptor new]::id{MPSGraphCompilationDescriptor})
+ return MPSGraphCompilationDescriptor(@objc [MPSGraphCompilationDescriptor new]::id{MPSGraphCompilationDescriptor})
end
diff --git a/lib/mpsgraphs/execution.jl b/lib/mpsgraphs/execution.jl
index 9eaf82df..791e2cc7 100644
--- a/lib/mpsgraphs/execution.jl
+++ b/lib/mpsgraphs/execution.jl
@@ -1,34 +1,42 @@
MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, resultsDictionary::MPSGraphTensorDataDictionary) = @inline MPS.encode!(commandBuffer, graph, feeds, nil, resultsDictionary, MPSGraphExecutionDescriptor())
function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetOperations, resultsDictionary::MPSGraphTensorDataDictionary, executionDescriptor)
- @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
- feeds:feeds::id{MPSGraphTensorDataDictionary}
- targetOperations:targetOperations::id{Object}
- resultsDictionary:resultsDictionary::id{MPSGraphTensorDataDictionary}
- executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::Nothing
+ @objc [
+ graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
+ feeds:feeds::id{MPSGraphTensorDataDictionary}
+ targetOperations:targetOperations::id{Object}
+ resultsDictionary:resultsDictionary::id{MPSGraphTensorDataDictionary}
+ executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}
+ ]::Nothing
return resultsDictionary
end
-function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil, executionDescriptor=MPSGraphExecutionDescriptor())
- obj = @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
- feeds:feeds::id{MPSGraphTensorDataDictionary}
- targetTensors:targetTensors::id{NSArray}
- targetOperations:targetOperations::id{Object}
- executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::id{MPSGraphTensorDataDictionary}
- MPSGraphTensorDataDictionary(obj)
+function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations = nil, executionDescriptor = MPSGraphExecutionDescriptor())
+ obj = @objc [
+ graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
+ feeds:feeds::id{MPSGraphTensorDataDictionary}
+ targetTensors:targetTensors::id{NSArray}
+ targetOperations:targetOperations::id{Object}
+ executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}
+ ]::id{MPSGraphTensorDataDictionary}
+ return MPSGraphTensorDataDictionary(obj)
end
-function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil)
- obj = @objc [graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary}
- targetTensors:targetTensors::id{NSArray}
- targetOperations:targetOperations::id{Object}]::id{MPSGraphTensorDataDictionary}
- MPSGraphTensorDataDictionary(obj)
+function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations = nil)
+ obj = @objc [
+ graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary}
+ targetTensors:targetTensors::id{NSArray}
+ targetOperations:targetOperations::id{Object}
+ ]::id{MPSGraphTensorDataDictionary}
+ return MPSGraphTensorDataDictionary(obj)
end
function run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray)
- obj = @objc [graph::id{MPSGraph} runWithMTLCommandQueue:commandQueue::id{MTLCommandQueue}
- feeds:feeds::id{MPSGraphTensorDataDictionary}
- targetTensors:targetTensors::id{NSArray}
- targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary}
- MPSGraphTensorDataDictionary(obj)
+ obj = @objc [
+ graph::id{MPSGraph} runWithMTLCommandQueue:commandQueue::id{MTLCommandQueue}
+ feeds:feeds::id{MPSGraphTensorDataDictionary}
+ targetTensors:targetTensors::id{NSArray}
+ targetOperations:nil::id{Object}
+ ]::id{MPSGraphTensorDataDictionary}
+ return MPSGraphTensorDataDictionary(obj)
end
diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl
index 0cd3fe08..ef5c3e30 100644
--- a/lib/mpsgraphs/matmul.jl
+++ b/lib/mpsgraphs/matmul.jl
@@ -30,9 +30,11 @@ else
end
-@autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb},
- alpha::Number, beta::Number,
- transpose_a, transpose_b) where {Tc, Tab, Na, Nb}
+@autoreleasepool function _matmul!(
+ c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb},
+ alpha::Number, beta::Number,
+ transpose_a, transpose_b
+ ) where {Tc, Tab, Na, Nb}
graph = MPSGraph()
placeA = placeholderTensor(graph, size(a), Tab)
@@ -50,8 +52,8 @@ end
castA = castTensor(graph, placeA, castT, "castA")
castB = castTensor(graph, placeB, castT, "castB")
- transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA
- transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB
+ transA = transpose_a ? transposeTensor(graph, castA, Na - 2, Na - 1, "transpose_a") : castA
+ transB = transpose_b ? transposeTensor(graph, castB, Nb - 2, Nb - 1, "transpose_b") : castB
nBatchA = Na == 2 ? 1 : size(transA)[1]
nBatchB = Nb == 2 ? 1 : size(transB)[1]
@@ -94,9 +96,9 @@ end
end
function graph_matmul!(c::MtlArray{Tc, N}, a::MtlArray{Tab, N}, b::MtlArray{Tab, N}, alpha::Number = true, beta::Number = false, transpose_a = false, transpose_b = false) where {Tc, Tab, N}
- _matmul!(c, a, b, alpha, beta, transpose_a, transpose_b)
+ return _matmul!(c, a, b, alpha, beta, transpose_a, transpose_b)
end
function graph_matvecmul!(c::MtlVector{Tc}, a::MtlMatrix{Tab}, b::MtlVector{Tab}, alpha::Number = true, beta::Number = false, transpose = false) where {Tc, Tab}
- _matmul!(c, a, b, alpha, beta, transpose, false)
+ return _matmul!(c, a, b, alpha, beta, transpose, false)
end
diff --git a/lib/mpsgraphs/operations.jl b/lib/mpsgraphs/operations.jl
index 107c9ae3..47631733 100644
--- a/lib/mpsgraphs/operations.jl
+++ b/lib/mpsgraphs/operations.jl
@@ -1,68 +1,88 @@
-function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shape::MPSShape, name="broadcast")
- obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
- toShape:shape::id{MPSShape}
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shape::MPSShape, name = "broadcast")
+ obj = @objc [
+ graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
+ toShape:shape::id{MPSShape}
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
-function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shapeTensor::MPSGraphTensor, name="broadcast")
- obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
- toShapeTensor:shapeTensor::id{MPSGraphTensor}
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shapeTensor::MPSGraphTensor, name = "broadcast")
+ obj = @objc [
+ graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
+ toShapeTensor:shapeTensor::id{MPSGraphTensor}
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function castTensor(graph::MPSGraph, tensor::MPSGraphTensor, toType, name = "cast")
- obj = @objc [graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor}
- toType:toType::MPSDataType
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor}
+ toType:toType::MPSDataType
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function constantWithScalar(graph::MPSGraph, scalar::Number, dataType)
- obj = @objc [graph::id{MPSGraph} constantWithScalar:scalar::Float64
- dataType:dataType::MPSDataType]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} constantWithScalar:scalar::Float64
+ dataType:dataType::MPSDataType
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function matrixMultiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "matmul")
- obj = @objc [graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
- secondaryTensor:secondary::id{MPSGraphTensor}
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
+ secondaryTensor:secondary::id{MPSGraphTensor}
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function multiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "mul")
- obj = @objc [graph::id{MPSGraph} multiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
- secondaryTensor:secondary::id{MPSGraphTensor}
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} multiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
+ secondaryTensor:secondary::id{MPSGraphTensor}
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function additionWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "add")
- obj = @objc [graph::id{MPSGraph} additionWithPrimaryTensor:primary::id{MPSGraphTensor}
- secondaryTensor:secondary::id{MPSGraphTensor}
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} additionWithPrimaryTensor:primary::id{MPSGraphTensor}
+ secondaryTensor:secondary::id{MPSGraphTensor}
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function transposeTensor(graph::MPSGraph, tensor::MPSGraphTensor, dimension, withDimension, name = "transpose")
- obj = @objc [graph::id{MPSGraph} transposeTensor:tensor::id{MPSGraphTensor}
- dimension:dimension::NSUInteger
- withDimension:withDimension::NSUInteger
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} transposeTensor:tensor::id{MPSGraphTensor}
+ dimension:dimension::NSUInteger
+ withDimension:withDimension::NSUInteger
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function shapeOfTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "shapeOfTensor")
- obj = @objc [graph::id{MPSGraph} shapeOfTensor:tensor::id{MPSGraphTensor}
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} shapeOfTensor:tensor::id{MPSGraphTensor}
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
function identityWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "identity")
- obj = @objc [graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor}
- name:name::id{NSString}]::id{MPSGraphTensor}
- MPSGraphTensor(obj)
+ obj = @objc [
+ graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor}
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
+ return MPSGraphTensor(obj)
end
"""
diff --git a/lib/mpsgraphs/random.jl b/lib/mpsgraphs/random.jl
index b984dcae..026fe8c8 100644
--- a/lib/mpsgraphs/random.jl
+++ b/lib/mpsgraphs/random.jl
@@ -1,8 +1,10 @@
# @objcwrapper immutable=false MPSGraphRandomOpDescriptor <: MPSGraphObject
function MPSGraphRandomOpDescriptor(distribution::MPSGraphRandomDistribution, dataType)
- desc = @objc [MPSGraphRandomOpDescriptor descriptorWithDistribution:distribution::MPSGraphRandomDistribution
- dataType:dataType::MPSDataType]::id{MPSGraphRandomOpDescriptor}
+ desc = @objc [
+ MPSGraphRandomOpDescriptor descriptorWithDistribution:distribution::MPSGraphRandomDistribution
+ dataType:dataType::MPSDataType
+ ]::id{MPSGraphRandomOpDescriptor}
obj = MPSGraphRandomOpDescriptor(desc)
return obj
end
diff --git a/lib/mpsgraphs/tensor.jl b/lib/mpsgraphs/tensor.jl
index 67518f5b..d09c87fe 100644
--- a/lib/mpsgraphs/tensor.jl
+++ b/lib/mpsgraphs/tensor.jl
@@ -10,7 +10,7 @@ function Base.size(td::MPSGraphTensor)
temp = map(td.shape) do nsnum
NSNumber(reinterpret(id{NSNumber}, nsnum)).unsignedIntegerValue |> Int
end
- Tuple(temp)
+ return Tuple(temp)
end
function placeholderTensor(graph::MPSGraph, shape::Union{Vector, Tuple}, args...)
@@ -18,9 +18,11 @@ function placeholderTensor(graph::MPSGraph, shape::Union{Vector, Tuple}, args...
return placeholderTensor(graph, mpsshape, args...)
end
function placeholderTensor(graph::MPSGraph, shape::MPSShape, dataType::Type, name = "placeholder tensor")
- obj = @objc [graph::id{MPSGraph} placeholderWithShape:shape::id{MPSShape}
- dataType:dataType::MPSDataType
- name:name::id{NSString}]::id{MPSGraphTensor}
+ obj = @objc [
+ graph::id{MPSGraph} placeholderWithShape:shape::id{MPSShape}
+ dataType:dataType::MPSDataType
+ name:name::id{NSString}
+ ]::id{MPSGraphTensor}
return MPSGraphTensor(obj)
end
@@ -31,29 +33,33 @@ function Base.size(td::MPSGraphTensorData)
temp = map(td.shape) do nsnum
NSNumber(reinterpret(id{NSNumber}, nsnum)).unsignedIntegerValue |> Int
end
- Tuple(temp)
+ return Tuple(temp)
end
function MPSGraphTensorData(buffer::MTLBuffer, shape::MPSShape, dataType)
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
tensor = MPSGraphTensorData(obj)
finalizer(release, tensor)
- @objc [tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}
- shape:shape::id{MPSShape}
- dataType:dataType::MPSDataType]::id{MPSGraphTensorData}
+ @objc [
+ tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}
+ shape:shape::id{MPSShape}
+ dataType:dataType::MPSDataType
+ ]::id{MPSGraphTensorData}
return tensor
end
function MPSGraphTensorData(buffer::MTLBuffer, shape::MPSShape, dataType, rowBytes)
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
tensor = MPSGraphTensorData(obj)
finalizer(release, tensor)
- @objc [tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}
- shape:shape::id{MPSShape}
- dataType:dataType::MPSDataType
- rowBytes:rowBytes::NSUInteger]::id{MPSGraphTensorData}
+ @objc [
+ tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}
+ shape:shape::id{MPSShape}
+ dataType:dataType::MPSDataType
+ rowBytes:rowBytes::NSUInteger
+ ]::id{MPSGraphTensorData}
return tensor
end
-MPSGraphTensorData(matrix::MtlArray{T}) where T = MPSGraphTensorData(matrix.data[], convert(MPSShape, reverse(size(matrix))), T)
+MPSGraphTensorData(matrix::MtlArray{T}) where {T} = MPSGraphTensorData(matrix.data[], convert(MPSShape, reverse(size(matrix))), T)
function MPSGraphTensorData(matrix::MPSMatrix)
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
@@ -69,8 +75,10 @@ function MPSGraphTensorData(matrix::MPSMatrix, rank)
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
tensor = MPSGraphTensorData(obj)
finalizer(release, tensor)
- @objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:matrix::id{MPSMatrix}
- rank:rank::NSUInteger]::id{MPSGraphTensorData}
+ @objc [
+ tensor::id{MPSGraphTensorData} initWithMPSMatrix:matrix::id{MPSMatrix}
+ rank:rank::NSUInteger
+ ]::id{MPSGraphTensorData}
return tensor
end
@@ -88,8 +96,10 @@ function MPSGraphTensorData(vector::MPSVector, rank)
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
tensor = MPSGraphTensorData(obj)
finalizer(release, tensor)
- @objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:vector::id{MPSVector}
- rank:rank::NSUInteger]::id{MPSGraphTensorData}
+ @objc [
+ tensor::id{MPSGraphTensorData} initWithMPSMatrix:vector::id{MPSVector}
+ rank:rank::NSUInteger
+ ]::id{MPSGraphTensorData}
return tensor
end
@@ -118,5 +128,5 @@ Will copy contents if the contents are not stored in an MPS ndarray.
"""
function MPS.MPSNDArray(tensor::MPSGraphTensorData)
arr = @objc [tensor::id{MPSNDArray} mpsndarray]::id{MPSNDArray}
- MPSNDArray(arr)
+ return MPSNDArray(arr)
end
diff --git a/src/linalg.jl b/src/linalg.jl
index 0a6ac707..2f2fa00a 100644
--- a/src/linalg.jl
+++ b/src/linalg.jl
@@ -3,16 +3,16 @@ using LinearAlgebra: MulAddMul, wrap
using .MPS
using .MPS: MPS_VALID_MATMUL_TYPES, MPS_VALID_MATVECMUL_TYPES, MtlFloat
using .MPSGraphs: MPSGRAPH_VALID_MATMUL_TYPES, MPSGRAPH_VALID_MATVECMUL_TYPES,
- graph_matmul!, graph_matvecmul!
+ graph_matmul!, graph_matvecmul!
@inline function supports_mps_matmul(A, B, C, valid_types)
- MPS.is_supported(device(A)) &&
+ return MPS.is_supported(device(A)) &&
eltype(A) == eltype(B) &&
(eltype(A), eltype(C)) in valid_types
end
@inline function supports_mpsgraph_matmul(A, B, C, valid_types)
- MPS.is_supported(device(A)) &&
+ return MPS.is_supported(device(A)) &&
eltype(A) == eltype(B) &&
(eltype(A), eltype(C)) in valid_types &&
# TODO: remove this limitation
diff --git a/test/examples.jl b/test/examples.jl
index d0f9017d..77fc65d6 100644
--- a/test/examples.jl
+++ b/test/examples.jl
@@ -19,7 +19,7 @@ cd(examples_dir) do
@testset for example in examples
mod = @eval module $(gensym()) end
@eval mod begin
- const TESTING=true
+ const TESTING = true
redirect_stdout(devnull) do
include($example)
end
diff --git a/test/linalg.jl b/test/linalg.jl
index 5d05d0dd..8f703169 100644
--- a/test/linalg.jl
+++ b/test/linalg.jl
@@ -25,7 +25,7 @@ if MPS.is_supported(device())
mtl_view_c = mtl_view_a * mtl_view_b
view_c = view_a * view_b
- @test Array(mtl_view_c) ≈ view_c
+ @test Array(mtl_view_c) ≈ view_c
end
using Metal: storagemode
diff --git a/test/mps/linalg.jl b/test/mps/linalg.jl
index 90c6de9c..1d7c73f7 100644
--- a/test/mps/linalg.jl
+++ b/test/mps/linalg.jl
@@ -34,19 +34,19 @@ if MPS.is_supported(device())
end
@testset "batched matrix matrix multiplication" begin
- M = 8
- N = 7
- P = 9
+ M = 8
+ N = 7
+ P = 9
batch_size = 3
- rows_a = M
+ rows_a = M
cols_a = N
rows_b = N
- cols_b = P
+ cols_b = P
- rows_c = M
- cols_c = P
+ rows_c = M
+ cols_c = P
alpha = Float64(1)
beta = Float64(1)
diff --git a/test/mpsgraphs/core.jl b/test/mpsgraphs/core.jl
index ce716190..d538570f 100644
--- a/test/mpsgraphs/core.jl
+++ b/test/mpsgraphs/core.jl
@@ -1,19 +1,19 @@
if MPS.is_supported(device())
-using .MPS: MPSShape
-using .MPSGraphs: MPSGraph, MPSGraphDevice
-@testset "Core" begin
+ using .MPS: MPSShape
+ using .MPSGraphs: MPSGraph, MPSGraphDevice
+ @testset "Core" begin
-graph = MPSGraph()
-@test graph isa MPSGraph
+ graph = MPSGraph()
+ @test graph isa MPSGraph
-dev = device()
-graphdev = MPSGraphDevice(dev)
-@test graphdev isa MPSGraphDevice
-@test graphdev.type == MPSGraphs.MPSGraphDeviceTypeMetal
-@test graphdev.metalDevice == dev
+ dev = device()
+ graphdev = MPSGraphDevice(dev)
+ @test graphdev isa MPSGraphDevice
+ @test graphdev.type == MPSGraphs.MPSGraphDeviceTypeMetal
+ @test graphdev.metalDevice == dev
-end # @testset "Core"
+ end # @testset "Core"
end # MPS.is_supported(device())
diff --git a/test/mpsgraphs/linalg.jl b/test/mpsgraphs/linalg.jl
index 62438980..8711eb94 100644
--- a/test/mpsgraphs/linalg.jl
+++ b/test/mpsgraphs/linalg.jl
@@ -3,98 +3,98 @@ using LinearAlgebra
if MPS.is_supported(device())
-@testset "mixed-precision matrix matrix multiplication" begin
- N = 10
- rows_a = N
- cols_a = N
+ @testset "mixed-precision matrix matrix multiplication" begin
+ N = 10
+ rows_a = N
+ cols_a = N
- rows_b = N
- cols_b = N
+ rows_b = N
+ cols_b = N
- rows_c = rows_a
- cols_c = cols_b
+ rows_c = rows_a
+ cols_c = cols_b
- alpha = Float64(1)
- beta = Float64(1)
+ alpha = Float64(1)
+ beta = Float64(1)
- @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
- arr_a = rand(input_jl_type, (rows_a, cols_a))
- arr_b = rand(input_jl_type, (rows_b, cols_b))
- arr_c = zeros(accum_jl_type, (rows_c, cols_c))
+ @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
+ arr_a = rand(input_jl_type, (rows_a, cols_a))
+ arr_b = rand(input_jl_type, (rows_b, cols_b))
+ arr_c = zeros(accum_jl_type, (rows_c, cols_c))
- buf_a = MtlArray{input_jl_type}(arr_a)
- buf_b = MtlArray{input_jl_type}(arr_b)
- buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c))
+ buf_a = MtlArray{input_jl_type}(arr_a)
+ buf_b = MtlArray{input_jl_type}(arr_b)
+ buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c))
- truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c)
+ truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c)
- MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
+ MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
- @test all(Array(buf_c) .≈ truth_c)
+ @test all(Array(buf_c) .≈ truth_c)
+ end
end
-end
-@testset "batched matrix matrix multiplication" begin
- M = 8
- N = 7
- P = 9
- batch_size = 3
+ @testset "batched matrix matrix multiplication" begin
+ M = 8
+ N = 7
+ P = 9
+ batch_size = 3
- rows_a = M
- cols_a = N
+ rows_a = M
+ cols_a = N
- rows_b = N
- cols_b = P
+ rows_b = N
+ cols_b = P
- rows_c = M
- cols_c = P
+ rows_c = M
+ cols_c = P
- alpha = Float64(1)
- beta = Float64(1)
+ alpha = Float64(1)
+ beta = Float64(1)
- @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
- arr_a = rand(input_jl_type, (rows_a, cols_a, batch_size))
- arr_b = rand(input_jl_type, (rows_b, cols_b, batch_size))
- arr_c = zeros(accum_jl_type, (rows_c, cols_c, batch_size))
+ @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES
+ arr_a = rand(input_jl_type, (rows_a, cols_a, batch_size))
+ arr_b = rand(input_jl_type, (rows_b, cols_b, batch_size))
+ arr_c = zeros(accum_jl_type, (rows_c, cols_c, batch_size))
- buf_a = MtlArray{input_jl_type}(arr_a)
- buf_b = MtlArray{input_jl_type}(arr_b)
- buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
+ buf_a = MtlArray{input_jl_type}(arr_a)
+ buf_b = MtlArray{input_jl_type}(arr_b)
+ buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
- truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
- for i in 1:batch_size
- @views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
- end
+ truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size))
+ for i in 1:batch_size
+ @views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i])
+ end
- MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
+ MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta)
- @test all(Array(buf_c) .≈ truth_c)
+ @test all(Array(buf_c) .≈ truth_c)
+ end
end
-end
-@testset "mixed-precision matrix vector multiplication" begin
- N = 10
- rows = N
- cols = N
+ @testset "mixed-precision matrix vector multiplication" begin
+ N = 10
+ rows = N
+ cols = N
- alpha = Float64(1)
- beta = Float64(0)
+ alpha = Float64(1)
+ beta = Float64(0)
- @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATVECMUL_TYPES
- arr_a = rand(input_jl_type, (rows, cols))
- arr_b = rand(input_jl_type, (rows))
- arr_c = zeros(accum_jl_type, (rows))
+ @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATVECMUL_TYPES
+ arr_a = rand(input_jl_type, (rows, cols))
+ arr_b = rand(input_jl_type, (rows))
+ arr_c = zeros(accum_jl_type, (rows))
- buf_a = MtlArray{input_jl_type}(arr_a)
- buf_b = MtlArray{input_jl_type}(arr_b)
- buf_c = MtlArray{accum_jl_type}(undef, (rows))
+ buf_a = MtlArray{input_jl_type}(arr_a)
+ buf_b = MtlArray{input_jl_type}(arr_b)
+ buf_c = MtlArray{accum_jl_type}(undef, (rows))
- truth_c = (accum_jl_type(alpha) .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (accum_jl_type(beta) .* arr_c)
+ truth_c = (accum_jl_type(alpha) .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (accum_jl_type(beta) .* arr_c)
- MPSGraphs.graph_matvecmul!(buf_c, buf_a, buf_b, alpha, beta)
+ MPSGraphs.graph_matvecmul!(buf_c, buf_a, buf_b, alpha, beta)
- @test all(Array(buf_c) .≈ truth_c)
+ @test all(Array(buf_c) .≈ truth_c)
+ end
end
-end
end # MPS.is_supported(device())
diff --git a/test/mpsgraphs/random.jl b/test/mpsgraphs/random.jl
index 4303ee83..35584c6f 100644
--- a/test/mpsgraphs/random.jl
+++ b/test/mpsgraphs/random.jl
@@ -2,24 +2,25 @@ using BFloat16s
if MPS.is_supported(device())
-using .MPSGraphs: MPSGraphRandomOpDescriptor, MPSGraphRandomDistributionNormal, MPSGraphRandomDistributionTruncatedNormal, MPSGraphRandomDistributionUniform
-@testset "MPSGraph random" begin
- # determined by looking at the error message when trying to construct
- # an invalid distribution/type combination
- for (dist, T) in [(MPSGraphRandomDistributionNormal, Float32),
- (MPSGraphRandomDistributionNormal, Float16),
- (MPSGraphRandomDistributionNormal, BFloat16),
- (MPSGraphRandomDistributionTruncatedNormal, Float32),
- (MPSGraphRandomDistributionTruncatedNormal, Float16),
- (MPSGraphRandomDistributionTruncatedNormal, BFloat16),
- (MPSGraphRandomDistributionUniform, Int64),
- (MPSGraphRandomDistributionUniform, Int32),
- (MPSGraphRandomDistributionUniform, Float32),
- (MPSGraphRandomDistributionUniform, Float16),
- (MPSGraphRandomDistributionUniform, BFloat16),
- ]
- @test MPSGraphRandomOpDescriptor(MPSGraphRandomDistributionNormal, Float32) isa MPSGraphRandomOpDescriptor
+ using .MPSGraphs: MPSGraphRandomOpDescriptor, MPSGraphRandomDistributionNormal, MPSGraphRandomDistributionTruncatedNormal, MPSGraphRandomDistributionUniform
+ @testset "MPSGraph random" begin
+ # determined by looking at the error message when trying to construct
+ # an invalid distribution/type combination
+ for (dist, T) in [
+ (MPSGraphRandomDistributionNormal, Float32),
+ (MPSGraphRandomDistributionNormal, Float16),
+ (MPSGraphRandomDistributionNormal, BFloat16),
+ (MPSGraphRandomDistributionTruncatedNormal, Float32),
+ (MPSGraphRandomDistributionTruncatedNormal, Float16),
+ (MPSGraphRandomDistributionTruncatedNormal, BFloat16),
+ (MPSGraphRandomDistributionUniform, Int64),
+ (MPSGraphRandomDistributionUniform, Int32),
+ (MPSGraphRandomDistributionUniform, Float32),
+ (MPSGraphRandomDistributionUniform, Float16),
+ (MPSGraphRandomDistributionUniform, BFloat16),
+ ]
+ @test MPSGraphRandomOpDescriptor(MPSGraphRandomDistributionNormal, Float32) isa MPSGraphRandomOpDescriptor
+ end
end
-end
end # MPS.is_supported(device()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Metal Benchmarks
Benchmark suite | Current: faa8b6a | Previous: 369e292 | Ratio |
---|---|---|---|
private array/construct |
28361.083333333336 ns |
25757 ns |
1.10 |
private array/broadcast |
464250 ns |
458125 ns |
1.01 |
private array/random/randn/Float32 |
819958 ns |
761125 ns |
1.08 |
private array/random/randn!/Float32 |
605083 ns |
612208 ns |
0.99 |
private array/random/rand!/Int64 |
575541 ns |
558500 ns |
1.03 |
private array/random/rand!/Float32 |
610645.5 ns |
592625 ns |
1.03 |
private array/random/rand/Int64 |
784291 ns |
796000 ns |
0.99 |
private array/random/rand/Float32 |
604979 ns |
623500 ns |
0.97 |
private array/copyto!/gpu_to_gpu |
674791 ns |
655250 ns |
1.03 |
private array/copyto!/cpu_to_gpu |
627479 ns |
804167 ns |
0.78 |
private array/copyto!/gpu_to_cpu |
820479 ns |
804083 ns |
1.02 |
private array/accumulate/1d |
1339541.5 ns |
1349500 ns |
0.99 |
private array/accumulate/2d |
1403958 ns |
1395209 ns |
1.01 |
private array/iteration/findall/int |
2105875 ns |
2038041.5 ns |
1.03 |
private array/iteration/findall/bool |
1817125 ns |
1845313 ns |
0.98 |
private array/iteration/findfirst/int |
1711000 ns |
1717625 ns |
1.00 |
private array/iteration/findfirst/bool |
1663875 ns |
1673458.5 ns |
0.99 |
private array/iteration/scalar |
3952354 ns |
3856833 ns |
1.02 |
private array/iteration/logical |
3213437.5 ns |
3196542 ns |
1.01 |
private array/iteration/findmin/1d |
1762166.5 ns |
1770187 ns |
1.00 |
private array/iteration/findmin/2d |
1368917 ns |
1362333 ns |
1.00 |
private array/reductions/reduce/1d |
1052500 ns |
1022584 ns |
1.03 |
private array/reductions/reduce/2d |
662250 ns |
665709 ns |
0.99 |
private array/reductions/mapreduce/1d |
1053792 ns |
1045166 ns |
1.01 |
private array/reductions/mapreduce/2d |
666958 ns |
657625 ns |
1.01 |
private array/permutedims/4d |
2494103.5 ns |
2547458.5 ns |
0.98 |
private array/permutedims/2d |
1030875 ns |
1028750 ns |
1.00 |
private array/permutedims/3d |
1580291 ns |
1582146 ns |
1.00 |
private array/copy |
600708 ns |
621854 ns |
0.97 |
latency/precompile |
9647484834 ns |
9111380625 ns |
1.06 |
latency/ttfp |
3730773459 ns |
3718960708 ns |
1.00 |
latency/import |
1254998396 ns |
1244311500 ns |
1.01 |
integration/metaldevrt |
707042 ns |
714084 ns |
0.99 |
integration/byval/slices=1 |
1560520.5 ns |
1562604 ns |
1.00 |
integration/byval/slices=3 |
10160271 ns |
10396459 ns |
0.98 |
integration/byval/reference |
1512917 ns |
1559083 ns |
0.97 |
integration/byval/slices=2 |
2614916.5 ns |
2742208.5 ns |
0.95 |
kernel/indexing |
463792 ns |
464021 ns |
1.00 |
kernel/indexing_checked |
464020.5 ns |
461542 ns |
1.01 |
kernel/launch |
9319.333333333334 ns |
8000 ns |
1.16 |
metal/synchronization/stream |
14833 ns |
14959 ns |
0.99 |
metal/synchronization/context |
15167 ns |
15208 ns |
1.00 |
shared array/construct |
24565.916666666664 ns |
24638.833333333332 ns |
1.00 |
shared array/broadcast |
462125 ns |
463854.5 ns |
1.00 |
shared array/random/randn/Float32 |
835729.5 ns |
793500 ns |
1.05 |
shared array/random/randn!/Float32 |
637479.5 ns |
631500 ns |
1.01 |
shared array/random/rand!/Int64 |
587000 ns |
555208 ns |
1.06 |
shared array/random/rand!/Float32 |
615042 ns |
585625 ns |
1.05 |
shared array/random/rand/Int64 |
762292 ns |
791250 ns |
0.96 |
shared array/random/rand/Float32 |
631520.5 ns |
608229 ns |
1.04 |
shared array/copyto!/gpu_to_gpu |
84209 ns |
84459 ns |
1.00 |
shared array/copyto!/cpu_to_gpu |
82875 ns |
89208.5 ns |
0.93 |
shared array/copyto!/gpu_to_cpu |
83500 ns |
84084 ns |
0.99 |
shared array/accumulate/1d |
1362584 ns |
1372541.5 ns |
0.99 |
shared array/accumulate/2d |
1408521 ns |
1411292 ns |
1.00 |
shared array/iteration/findall/int |
1844500 ns |
1806208 ns |
1.02 |
shared array/iteration/findall/bool |
1584417 ns |
1616209 ns |
0.98 |
shared array/iteration/findfirst/int |
1409437.5 ns |
1408145.5 ns |
1.00 |
shared array/iteration/findfirst/bool |
1377000 ns |
1370770.5 ns |
1.00 |
shared array/iteration/scalar |
157417 ns |
159000 ns |
0.99 |
shared array/iteration/logical |
2993541.5 ns |
3018917 ns |
0.99 |
shared array/iteration/findmin/1d |
1475208.5 ns |
1480458 ns |
1.00 |
shared array/iteration/findmin/2d |
1371959 ns |
1385541 ns |
0.99 |
shared array/reductions/reduce/1d |
743000 ns |
731625 ns |
1.02 |
shared array/reductions/reduce/2d |
667000 ns |
669875 ns |
1.00 |
shared array/reductions/mapreduce/1d |
748667 ns |
740125 ns |
1.01 |
shared array/reductions/mapreduce/2d |
668354 ns |
668917 ns |
1.00 |
shared array/permutedims/4d |
2529584 ns |
2498521 ns |
1.01 |
shared array/permutedims/2d |
1032667 ns |
1037291 ns |
1.00 |
shared array/permutedims/3d |
1611729 ns |
1601042 ns |
1.01 |
shared array/copy |
244562.5 ns |
252291 ns |
0.97 |
This comment was automatically generated by workflow using github-action-benchmark.
FYI, you can have PRs target other PRs by setting the base branch, which should make it possible to review (only showing the relevant diff) while auto-updating to
I'd keep that issue open, given that there's MWE's using MPS directly. What's the performance of MPSGraph like compared to using MPS directly? No difference, I would assume? |
Very nice about the PR target thing that’ll be very useful for future PRs I unlinked #381 for automatic closing. As for performance, I want to make some graphs, but from some quick benchmarks while testing, it seems like it’s slower for smaller matrices, and similar performance for larger ones. That’s why I only made MPSGraphs the default for the types affected by #381. |
3b9f66a
to
391bfb8
Compare
Okay. I wrote a little script to compare MPS and MPSGraph matmul performance. It runs all the supported type variations and saves it to a plot. Script is in dev/flopscomp.jl of this branch. The script activates a temporary environment, installs this branch of Metal, and then compares matmul performance and saves the results to a plot. To run it, run the following: using Pkg
Pkg.activate(temp=true)
Pkg.add(url="https://github.com/christiangnrd/Metal.jl/", rev="MPSGraph")
Pkg.add(["GPUArrays", "Plots"])
using Metal
scriptpath = joinpath(dirname(pathof(Metal)), "..", "dev", "flopscomp.jl")
include(scriptpath)
# Save the figure `bench_all_1.svg` in current working
# directory or specify your own outpath
main(outpath="")
It'll download this branch, run the benchmarks, and spit out a figure. I ran it with These results are a bit unfortunate? I don't know if we have a choice though because the combinations where performance is worse are the cases affected by #381... |
a4e8216 makes the operation actually in-place, while a86a55d makes every operation use Float32 under the hood. This makes memory usage quite a bit higher, but greatly improves performance. I'd love opinions on whether we keep it that way or not. Batched multiplication where one Array has only one matrix and the other has more is also now working. Latest performance on my M2 Max (compare with first screenshot): |
At the expense of accuracy? I'm surprised by the large performance differences, but given the issues with MPS matmul we don't really have a choice. That's said, it's good that you kept the |
I was trying to figure out why Float16 performance was worse with MPSGraph, and this amazing line added right after initializing the graph gave me some amazing insights. function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, alpha::Number, beta::Number, transpose_a, transpose_b) where {Tc, Tab, Na, Nb}
graph = MPSGraph()
graph.options = MPSGraphs.MPSGraphOptionsVerbose # <- This line
placeA = placeholderTensor(graph, size(a), Tab)
placeB = placeholderTensor(graph, size(b), Tab)
... It seems like when the output eltype is Latest commit disables this by default for I suspect ANE may be better for other M-series processors with weaker GPUs but the same ANE as my 30-core M2 Max. I'll ask on Slack if people want to benchmark on their device and report performance here. |
Thanks @haakon-e @sinhtrung! This is very helpful. I'm very glad to see M4 Pro as I was wondering if the apparently much faster neural engine would make it faster than GPU, but it seems like the TOPS they're referring to may not be Float16 TOPS. (seems like NPUs typically report Int8 OPS) Seems like we should leave the ANE unused as at best it's as good as the GPU. |
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
This comment was marked as resolved.
The only code left running with MPS are the contiguous views.
Will unblock NNlib issue 614
If no one objects, I'll merge this tomorrow morning. I'm wondering if we should use MPS instead of MPSGraph in situations where it's faster and not broken. I'll open a PR for feedback soon. |
This is ready for review once #565 is merged.
This adds the basic functionality for future extension of MPSGraph framework, and what's needed to replace the MPS matrix multiplication to work around #381.
Support for contiguous views (MtlArrays with an offset) should be added in a future PR, for now it will use the old method
Batched matmul is at the same level of support as the current code, but I think I can use
broadcastTensor
to enable batched matmul between 1 and N (or N and 1) matrices. Would potentially unblock FluxML/NNlib.jl#614, but that can also be a future PR.