Skip to content

[MPS] Add index.Tensor and aten.logical_not #3267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions backends/apple/mps/mps_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
MPSGraph,
MPSTensor,
OpType,
)

from executorch.backends.apple.mps.serialization.mps_graph_serialize import (
Expand Down Expand Up @@ -65,6 +66,7 @@ def preprocess(
input_ids=[],
output_ids=[],
constant_ids=[],
graph_type=OpType.mps_graph,
)

convert_model_to_fp16 = True
Expand Down Expand Up @@ -111,6 +113,16 @@ def handle_call_function(
mps_graph: MPSGraph,
) -> None:
logging.info(f"Visiting: {node}, {node.target.__name__}")

if (
"delegation_tag" in node.meta
and "metal_kernel" in node.meta["delegation_tag"]
):
logging.info(
f"Node '{node.target.__name__}' was marked as a Metal kernel by the MPSPartitioner!"
)
mps_graph.graph_type = OpType.metal_kernel

if node.target.__name__ in node_visitors:
node_visitors[node.target.__name__].define_node(node, mps_graph)
else:
Expand Down
77 changes: 76 additions & 1 deletion backends/apple/mps/operators/indexing_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Provided subject to the LICENSE file in the top level directory.
#

from typing import cast
from typing import cast, List

import torch
from executorch.backends.apple.mps.operators.node_visitor import (
Expand All @@ -13,9 +13,12 @@
from executorch.backends.apple.mps.serialization.mps_graph_schema import (
MPSEmbedding,
MPSGraph,
MPSIndexPut,
MPSIndexSelect,
MPSIndexTensor,
)
from executorch.backends.apple.mps.utils.mps_utils import get_input_node
from executorch.backends.transforms import get_shape
from executorch.exir.sym_util import eval_expr


Expand All @@ -40,6 +43,78 @@ def define_node(
mps_graph.mps_nodes.append(mps_node)


@register_node_visitor
class IndexTensorVisitor(NodeVisitor):
target = "aten.index.Tensor"

def __init__(self, *args) -> None:
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
mps_node = self.create_unary_node(node, mps_graph, MPSIndexTensor)
tensors = cast(List[torch.fx.Node], node.args[1])
for tensor in tensors:
mps_node.mpsnode_union.indices_id.append(
self.define_tensor(tensor, mps_graph)
)

mps_graph.mps_nodes.append(mps_node)


# [MPS TODO]: Works on a single iteration of llama2, but subsequent tokens
# are wrong when using Index put. Disabling it for now.
@register_node_visitor
class IndexPutVisitor(NodeVisitor):
# target = "aten.index_put.default"
target = "disabled"

def __init__(self, *args) -> None:
super().__init__(*args)

def infer_sizes(self, a: List[int], b: List[int]):
dimsA = len(a)
dimsB = len(b)
ndim = dimsA if dimsA > dimsB else dimsB
expandedSizes = [0] * ndim
for i in range(ndim - 1, -1, -1):
offset = ndim - 1 - i
dimA = dimsA - 1 - offset
dimB = dimsB - 1 - offset
sizeA = a[dimA] if dimA >= 0 else -1
sizeB = b[dimB] if dimB >= 0 else -1
expandedSizes[i] = sizeA if sizeB == -1 else sizeB

return expandedSizes

def define_node(
self,
node: torch.fx.Node,
mps_graph: MPSGraph,
) -> None:
mps_node = self.create_unary_node(node, mps_graph, MPSIndexPut)
updates_shape = get_shape(node.args[2])
input_shape = get_shape(node.args[0])
new_shape = []
if len(updates_shape) != 1 and len(updates_shape) != len(input_shape):
new_shape = self.infer_sizes(input_shape, updates_shape)
mps_node.mpsnode_union.values_shape = new_shape

tensors = cast(List[torch.fx.Node], node.args[1])
for tensor in tensors:
mps_node.mpsnode_union.indices_id.append(
self.define_tensor(tensor, mps_graph)
)

mps_node.mpsnode_union.values_id = self.define_tensor(
get_input_node(node, 2), mps_graph
)
mps_graph.mps_nodes.append(mps_node)


@register_node_visitor
class EmbeddingVisitor(NodeVisitor):
target = "aten.embedding.default"
Expand Down
3 changes: 3 additions & 0 deletions backends/apple/mps/operators/unary_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
MPSLog,
MPSLog10,
MPSLog2,
MPSLogicalNot,
MPSNeg,
MPSReciprocal,
MPSRound,
Expand Down Expand Up @@ -79,6 +80,7 @@ class UnaryOpVisitor(NodeVisitor):
"aten.isnan.default",
"aten.isinf.default",
"aten.round.default",
"aten.logical_not.default",
]

def __init__(self, *args) -> None:
Expand Down Expand Up @@ -115,6 +117,7 @@ def __init__(self, *args) -> None:
exir_ops.edge.aten.isnan.default: MPSIsnan,
exir_ops.edge.aten.isinf.default: MPSIsinf,
exir_ops.edge.aten.round.default: MPSRound,
exir_ops.edge.aten.logical_not.default: MPSLogicalNot,
}

def define_node(
Expand Down
50 changes: 48 additions & 2 deletions backends/apple/mps/partition/mps_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
#

import logging
from typing import Any, Dict, List, Union
from typing import Any, cast, Dict, List, Union

import torch
from executorch.backends.apple.mps.mps_preprocess import MPSBackend
from executorch.backends.apple.mps.operators.node_visitor import get_node_visitors
from executorch.backends.apple.mps.utils.mps_utils import is_parameter
from executorch.backends.transforms import get_shape
from executorch.exir.backend.backend_details import CompileSpec
from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import (
generate_partitions_from_list_of_nodes,
Expand All @@ -20,6 +21,7 @@
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data
from executorch.exir.dialects._ops import ops as exir_ops
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import Partition
from torch.fx.passes.operator_support import OperatorSupportBase
Expand All @@ -28,6 +30,13 @@
logging.basicConfig(level=logging.DEBUG, format=FORMAT)


# ops implemented as Metal kernels.
METAL_KERNELS = [
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.index_put.default,
]


class MPSOperatorSupport(OperatorSupportBase):
def __init__(self, edge_program: torch.export.ExportedProgram, compiler_specs):
self.node_visitors = get_node_visitors(edge_program)
Expand Down Expand Up @@ -65,10 +74,47 @@ def generate_partitions(self, edge_program: ExportedProgram) -> List[Any]:
op_support=self.supported_ops,
)

def mps_graph_advanced_indexing_support(self, node: torch.fx.Node):
num_indices = 0
tensors = cast(List[torch.fx.Node], node.args[1])
input = cast(torch.fx.Node, node.args[0])
for t in tensors:
if t is not None:
num_indices += 1
# Can dispatch to MPSGraph if the length of the slices is equal
# to the number of dimensions of the sliced tensors, or only one
# slice is present. All other cases will fallback to a Metal kernel.
if num_indices == len(get_shape(input)) or num_indices == 1:
return True

return False

def use_metal_kernel(self, node: torch.fx.Node):
if node.target in METAL_KERNELS:
if (
node.target == exir_ops.edge.aten.index.Tensor
or node.target == exir_ops.edge.aten.index_put.default
):
if not self.mps_graph_advanced_indexing_support(node):
return True
return False

def tag_nodes(self, partitions: List[Partition]) -> None:
for partition in partitions:
for node in partition.nodes:
crt_partition_counter = 0
for node in sorted(partition.nodes):
delegation_tag = f"mps_{partition.id}"
if self.use_metal_kernel(node):
logging.warning(f"[WARNING] Using Metal kernel for op {node.name}!")
# Partition the Metal kernel into a separate partition
crt_partition_counter += 1
delegation_tag = (
f"{delegation_tag}_metal_kernel_{crt_partition_counter}"
)
crt_partition_counter += 1
else:
delegation_tag = f"{delegation_tag}_{crt_partition_counter}"

node.meta["delegation_tag"] = delegation_tag
self.partition_tags[delegation_tag] = self.delegation_spec

Expand Down
23 changes: 23 additions & 0 deletions backends/apple/mps/runtime/MPSDevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,19 @@

#pragma once

// Obj-C headers
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>

// Runtime headers
#include <executorch/runtime/backend/interface.h>

// MPS headers
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>

#include <unordered_map>
#include <vector>

#define MB(x) (x * 1048576UL)

namespace torch {
Expand All @@ -25,6 +34,11 @@ enum class MacOSVersion : uint32_t {
MACOS_VER_14_0_PLUS,
};

enum class LibraryType : uint32_t {
INDEXING_KERNELS = 0,
MAX = INDEXING_KERNELS,
};

class MPSDevice {
public:
/**
Expand Down Expand Up @@ -53,9 +67,18 @@ class MPSDevice {

~MPSDevice();

/**
* Compile a PSO for a given library type.
* Once compiled, the library and PSOs are cached.
*/
Error compilePSO(LibraryType libraryType, const char* kernelName);
Error compileLibrary(LibraryType);

private:
static MPSDevice* _device;
id<MTLDevice> _mtl_device;
std::unordered_map<LibraryType, id<MTLLibrary>> _m_library_cache;
std::unordered_map<std::string, id<MTLComputePipelineState>> _m_pso_cache;
MPSDevice();
};

Expand Down
65 changes: 65 additions & 0 deletions backends/apple/mps/runtime/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@
static std::unique_ptr<MPSDevice> mps_device;
static std::once_flag mpsdev_init;

static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device, bool macOS13Plus) {
// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffers and function constants)
// host_name attribute needs at least Metal 2.2 and ulong needs Metal 2.3 (supported on MacOS 11+)
MTLLanguageVersion languageVersion = MTLLanguageVersion2_3;
#if defined(__MAC_13_0)
if (macOS13Plus) {
languageVersion = MTLLanguageVersion3_0;
}
#endif

ET_CHECK_MSG([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
return languageVersion;
}

MPSDevice::~MPSDevice() {
[_mtl_device release];
_mtl_device = nil;
Expand Down Expand Up @@ -79,6 +93,57 @@
}
}

const char* getLibraryCString(LibraryType libraryType) {
switch (libraryType) {
case LibraryType::INDEXING_KERNELS:
return "TODO";
default:
ET_CHECK_MSG(false, "Unhandled library type!");
}
}

Error
MPSDevice::compileLibrary(LibraryType libraryType) {
Error err = Error::Ok;
NSError* error = nil;
MTLCompileOptions* options = [MTLCompileOptions new];
[options setLanguageVersion:getMetalLanguageVersion(_mtl_device, isMacOS13Plus(MacOSVersion::MACOS_VER_13_0_PLUS))];
[options setFastMathEnabled:YES];
id<MTLLibrary> lib =
[_mtl_device newLibraryWithSource:[NSString stringWithCString:getLibraryCString(libraryType)
encoding:NSASCIIStringEncoding]
options:options
error:&error];

ET_CHECK_OR_RETURN_ERROR(
lib != nil,
Internal,
"Failed to create indexing library, error: %s", [[error description] UTF8String]
);

_m_library_cache[libraryType] = lib;
return err;
}

Error
MPSDevice::compilePSO(LibraryType libraryType, const char* kernelName) {
Error err = Error::Ok;
if (_m_library_cache.find(libraryType) == _m_library_cache.end()) {
ET_LOG(Debug, "Compiling library type: %d", libraryType);
err = compileLibrary(libraryType);
ET_CHECK_OR_RETURN_ERROR(
err == Error::Ok,
Internal,
"An error occured occured while compiling library %d", libraryType
);
}
if (_m_pso_cache.find(kernelName) == _m_pso_cache.end()) {
ET_LOG(Debug, "Compiling kernel: %s", kernelName);
// err = compilePSO(libraryType, kernelName);
}
return err;
}

bool isMacOS13OrNewer(MacOSVersion version) {
return MPSDevice::getInstance()->isMacOS13Plus(version);
}
Expand Down
Loading