diff --git a/extension_cpp_cmake/.gitignore b/extension_cpp_cmake/.gitignore new file mode 100644 index 0000000..b26c417 --- /dev/null +++ b/extension_cpp_cmake/.gitignore @@ -0,0 +1,12 @@ +__pycache__/ +.cache +*.i +*.ii +*.gpu +*.ptx +*.cubin +*.fatbin +build/** +driss_torch/lib/** +compile_commands.json +benchmarks/data diff --git a/extension_cpp_cmake/.pre-commit-config.yaml b/extension_cpp_cmake/.pre-commit-config.yaml new file mode 100644 index 0000000..6b12549 --- /dev/null +++ b/extension_cpp_cmake/.pre-commit-config.yaml @@ -0,0 +1,26 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v3.2.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + +- repo: https://github.com/omnilib/ufmt + rev: v2.1.0 + hooks: + - id: ufmt + additional_dependencies: + - black == 23.3.0 + - usort == 1.0.6 + - ufmt == 2.1.0 + - libcst == 1.0.1 + +# missing host field? +# - repo: https://github.com/pre-commit/mirrors-clang-format +# rev: v17.0.5 +# hooks: +# - id: clang-format diff --git a/extension_cpp_cmake/CMakeLists.txt b/extension_cpp_cmake/CMakeLists.txt new file mode 100644 index 0000000..ed4876f --- /dev/null +++ b/extension_cpp_cmake/CMakeLists.txt @@ -0,0 +1,54 @@ +cmake_minimum_required(VERSION 3.26 FATAL_ERROR) + +project( + ${SKBUILD_PROJECT_NAME} + VERSION ${SKBUILD_PROJECT_VERSION} + LANGUAGES CXX CUDA) + +# Set the C++ standard for all targets +set(CMAKE_CXX_STANDARD 20) # This might be unsafe since pytorch use std17 +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +# Enable better clangd support +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +find_package(Python REQUIRED COMPONENTS Interpreter Development) +execute_process( + COMMAND "${Python3_EXECUTABLE}" "-c" "import torch;print(torch.utils.cmake_prefix_path)" + OUTPUT_VARIABLE PT_CMAKE_PREFIX + COMMAND_ECHO STDOUT + OUTPUT_STRIP_TRAILING_WHITESPACE + COMMAND_ERROR_IS_FATAL ANY +) + +# cache CUDA_ARCHITECTURES, which seems to be reset by Torch +set(TMP_STORE_CUDA_ARCHITECTURES "${CMAKE_CUDA_ARCHITECTURES}") +set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH};${PT_CMAKE_PREFIX}) + +find_package(Torch REQUIRED CONFIG) + +# simple_cuda source files +file(GLOB_RECURSE CU_SOURCES csrc/*.cu) +file(GLOB_RECURSE CPP_SOURCES csrc/*.cpp) +MESSAGE(STATUS "CU_SOURCES: ${CU_SOURCES}") +MESSAGE(STATUS "CPP_SOURCES: ${CPP_SOURCES}") + +add_library(${SKBUILD_PROJECT_NAME} SHARED + ${CU_SOURCES} + ${CPP_SOURCES} +) + +# Set the library output directory, I think this makes ninja builds work +set_target_properties(${SKBUILD_PROJECT_NAME} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/${SKBUILD_PROJECT_NAME}/lib" +) +# Add include directories to the library +target_include_directories(${SKBUILD_PROJECT_NAME} PUBLIC src/include) + +# Link the library to the Torch library +target_link_libraries(${SKBUILD_PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES} Python::Python) + +# Install the library to the wheel distribution +install(TARGETS ${SKBUILD_PROJECT_NAME} + LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}/lib +) diff --git a/extension_cpp_cmake/README.md b/extension_cpp_cmake/README.md new file mode 100644 index 0000000..d7a3bd2 --- /dev/null +++ b/extension_cpp_cmake/README.md @@ -0,0 +1,11 @@ +# Extension Template + +This is a template for creating a new extension. +It contains the basic structure and files needed to create a new extension for pytorch written in C++ or CUDA. + + +### Build and install the extension + +```bash +pip install -v --no-build-isolation -e . +``` diff --git a/extension_cpp_cmake/csrc/cpp/lltm.cpp b/extension_cpp_cmake/csrc/cpp/lltm.cpp new file mode 100644 index 0000000..fa48cae --- /dev/null +++ b/extension_cpp_cmake/csrc/cpp/lltm.cpp @@ -0,0 +1,100 @@ +#include + +#include + +// s'(z) = (1 - s(z)) * s(z) +torch::Tensor d_sigmoid(torch::Tensor z) { + auto s = torch::sigmoid(z); + return (1 - s) * s; +} + +// tanh'(z) = 1 - tanh^2(z) +torch::Tensor d_tanh(torch::Tensor z) { + return 1 - z.tanh().pow(2); +} + +// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} +torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) { + auto e = z.exp(); + auto mask = (alpha * (e - 1)) < 0; + return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); +} + +std::tuple lltm_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { + auto X = torch::cat({old_h, input}, /*dim=*/1); + + auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); + auto gates = gate_weights.chunk(3, /*dim=*/1); + + auto input_gate = torch::sigmoid(gates[0]); + auto output_gate = torch::sigmoid(gates[1]); + auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); + + auto new_cell = old_cell + candidate_cell * input_gate; + auto new_h = torch::tanh(new_cell) * output_gate; + + return {new_h, + new_cell, + input_gate, + output_gate, + candidate_cell, + X, + gate_weights}; +} + +std::tuple lltm_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights) { + auto d_output_gate = torch::tanh(new_cell) * grad_h; + auto d_tanh_new_cell = output_gate * grad_h; + auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; + + auto d_old_cell = d_new_cell; + auto d_candidate_cell = input_gate * d_new_cell; + auto d_input_gate = candidate_cell * d_new_cell; + + auto gates = gate_weights.chunk(3, /*dim=*/1); + d_input_gate *= d_sigmoid(gates[0]); + d_output_gate *= d_sigmoid(gates[1]); + d_candidate_cell *= d_elu(gates[2]); + + auto d_gates = + torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); + + auto d_weights = d_gates.t().mm(X); + auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); + + auto d_X = d_gates.mm(weights); + const auto state_size = grad_h.size(1); + auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); + auto d_input = d_X.slice(/*dim=*/1, state_size); + + return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; +} + +// Registers _C as an extension module. +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} + +// Defines the operators +TORCH_LIBRARY(extension_cpp, m) { + m.def("lltm_forward(Tensor input, Tensor weights, Tensor bias, Tensor old_h, Tensor old_cell) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)"); + m.def("lltm_backward(Tensor grad_h, Tensor grad_cell, Tensor new_cell, Tensor input_gate, Tensor output_gate, Tensor candidate_cell, Tensor X, Tensor gate_weights, Tensor weights) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"); +} + +// Registers CPU implementations for lltm_forward, lltm_backward +TORCH_LIBRARY_IMPL(extension_cpp, CPU, m) { + m.impl("lltm_forward", &lltm_forward); + m.impl("lltm_backward", &lltm_backward); +} \ No newline at end of file diff --git a/extension_cpp_cmake/csrc/cuda/lltm_cuda.cu b/extension_cpp_cmake/csrc/cuda/lltm_cuda.cu new file mode 100644 index 0000000..78ffb6c --- /dev/null +++ b/extension_cpp_cmake/csrc/cuda/lltm_cuda.cu @@ -0,0 +1,183 @@ +#include + +#include +#include + +#include + +namespace { +template +__device__ __forceinline__ scalar_t sigmoid(scalar_t z) { + return 1.0 / (1.0 + exp(-z)); +} + +template +__device__ __forceinline__ scalar_t d_sigmoid(scalar_t z) { + const auto s = sigmoid(z); + return (1.0 - s) * s; +} + +template +__device__ __forceinline__ scalar_t d_tanh(scalar_t z) { + const auto t = tanh(z); + return 1 - (t * t); +} + +template +__device__ __forceinline__ scalar_t elu(scalar_t z, scalar_t alpha = 1.0) { + return fmaxf(0.0, z) + fminf(0.0, alpha * (exp(z) - 1.0)); +} + +template +__device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) { + const auto e = exp(z); + const auto d_relu = z < 0.0 ? 0.0 : 1.0; + return d_relu + (((alpha * (e - 1.0)) < 0.0) ? (alpha * e) : 0.0); +} + +template +__global__ void lltm_cuda_forward_kernel( + const torch::PackedTensorAccessor gates, + const torch::PackedTensorAccessor old_cell, + torch::PackedTensorAccessor new_h, + torch::PackedTensorAccessor new_cell, + torch::PackedTensorAccessor input_gate, + torch::PackedTensorAccessor output_gate, + torch::PackedTensorAccessor candidate_cell) { + //batch index + const int n = blockIdx.y; + // column index + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < gates.size(2)){ + input_gate[n][c] = sigmoid(gates[n][0][c]); + output_gate[n][c] = sigmoid(gates[n][1][c]); + candidate_cell[n][c] = elu(gates[n][2][c]); + new_cell[n][c] = + old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c]; + new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c]; + } +} + +template +__global__ void lltm_cuda_backward_kernel( + torch::PackedTensorAccessor d_old_cell, + torch::PackedTensorAccessor d_gates, + const torch::PackedTensorAccessor grad_h, + const torch::PackedTensorAccessor grad_cell, + const torch::PackedTensorAccessor new_cell, + const torch::PackedTensorAccessor input_gate, + const torch::PackedTensorAccessor output_gate, + const torch::PackedTensorAccessor candidate_cell, + const torch::PackedTensorAccessor gate_weights) { + //batch index + const int n = blockIdx.y; + // column index + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < d_gates.size(2)){ + const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c]; + const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c]; + const auto d_new_cell = + d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c]; + + + d_old_cell[n][c] = d_new_cell; + const auto d_candidate_cell = input_gate[n][c] * d_new_cell; + const auto d_input_gate = candidate_cell[n][c] * d_new_cell; + + d_gates[n][0][c] = + d_input_gate * d_sigmoid(gate_weights[n][0][c]); + d_gates[n][1][c] = + d_output_gate * d_sigmoid(gate_weights[n][1][c]); + d_gates[n][2][c] = + d_candidate_cell * d_elu(gate_weights[n][2][c]); + } +} +} // namespace + +std::tuple lltm_cuda_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { + auto X = torch::cat({old_h, input}, /*dim=*/1); + auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); + + const auto batch_size = old_cell.size(0); + const auto state_size = old_cell.size(1); + + auto gates = gate_weights.reshape({batch_size, 3, state_size}); + auto new_h = torch::zeros_like(old_cell); + auto new_cell = torch::zeros_like(old_cell); + auto input_gate = torch::zeros_like(old_cell); + auto output_gate = torch::zeros_like(old_cell); + auto candidate_cell = torch::zeros_like(old_cell); + + const int threads = 1024; + const dim3 blocks((state_size + threads - 1) / threads, batch_size); + + AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] { + lltm_cuda_forward_kernel<<>>( + gates.packed_accessor(), + old_cell.packed_accessor(), + new_h.packed_accessor(), + new_cell.packed_accessor(), + input_gate.packed_accessor(), + output_gate.packed_accessor(), + candidate_cell.packed_accessor()); + })); + + return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; +} + +std::tuple lltm_cuda_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gates, + torch::Tensor weights) { + auto d_old_cell = torch::zeros_like(new_cell); + auto d_gates = torch::zeros_like(gates); + + auto grad_h_contig = grad_h.contiguous(); + auto grad_cell_contig = grad_cell.contiguous(); + + const auto batch_size = new_cell.size(0); + const auto state_size = new_cell.size(1); + + const int threads = 1024; + const dim3 blocks((state_size + threads - 1) / threads, batch_size); + + AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] { + lltm_cuda_backward_kernel<<>>( + d_old_cell.packed_accessor(), + d_gates.packed_accessor(), + grad_h_contig.packed_accessor(), + grad_cell_contig.packed_accessor(), + new_cell.packed_accessor(), + input_gate.packed_accessor(), + output_gate.packed_accessor(), + candidate_cell.packed_accessor(), + gates.packed_accessor()); + })); + + auto d_gate_weights = d_gates.flatten(1, 2); + auto d_weights = d_gate_weights.t().mm(X); + auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true); + + auto d_X = d_gate_weights.mm(weights); + auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); + auto d_input = d_X.slice(/*dim=*/1, state_size); + + return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; +} + +// Registers CUDA implementations for lltm_forward, lltm_backward +TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) { + m.impl("lltm_forward", &lltm_cuda_forward); + m.impl("lltm_backward", &lltm_cuda_backward); +} \ No newline at end of file diff --git a/extension_cpp_cmake/csrc/include/saturated_cast.h b/extension_cpp_cmake/csrc/include/saturated_cast.h new file mode 100644 index 0000000..4d4d780 --- /dev/null +++ b/extension_cpp_cmake/csrc/include/saturated_cast.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +namespace driss_torch { + +at::Tensor saturated_cast(const at::Tensor &input, const at::Tensor &attn_mask, + at::ScalarType dtype, bool transpose); +at::Tensor saturated_cast_meta(const at::Tensor &input, + const at::Tensor &attn_mask, + at::ScalarType dtype, bool transpose); +} // namespace driss_torch diff --git a/extension_cpp_cmake/csrc/include/utils.h b/extension_cpp_cmake/csrc/include/utils.h new file mode 100644 index 0000000..3e38c86 --- /dev/null +++ b/extension_cpp_cmake/csrc/include/utils.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + + +namespace driss_torch { + +template +__device__ void thread_zero_print(const char *fmt, Args &&...args) { + if (threadIdx.x == 0 && blockIdx.x == 0) { + printf(fmt, std::forward(args)...); + } +} + +// error checking macro +#define cudaCheckErrors(msg) \ + do { \ + cudaError_t __err = cudaGetLastError(); \ + if (__err != cudaSuccess) { \ + fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", msg, \ + cudaGetErrorString(__err), __FILE__, __LINE__); \ + fprintf(stderr, "*** FAILED - ABORTING\n"); \ + exit(1); \ + } \ + } while (0) + +template T __host__ __device__ ceil_div(T a, Y b) { + return a / b + (a % b != 0); +} +// Functions whose implementation is in the .cu file +extern "C" { + +float kernel_time(std::function kernelLauncher); + +} // extern "C" + +} // namespace driss_torch diff --git a/extension_cpp_cmake/extension_cpp/__init__.py b/extension_cpp_cmake/extension_cpp/__init__.py new file mode 100644 index 0000000..b1e3345 --- /dev/null +++ b/extension_cpp_cmake/extension_cpp/__init__.py @@ -0,0 +1,69 @@ +from pathlib import Path +from typing import Tuple +from torch import Tensor + +import torch + +lib_path = Path(__file__).parent / "lib" / "libextension_cpp.so" +torch.ops.load_library(str(lib_path.resolve())) +torch.ops.load_library(lib_path) +# torch.ops.import_module("my_extension.abstract_impls") + + +__all__ = ["lltm", "reference_lltm"] + + +def lltm( + input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor +) -> Tuple[Tensor, Tensor]: + return LLTMFunction.apply(input, weights, bias, old_h, old_cell) + + +class LLTMFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weights, bias, old_h, old_cell): + outputs = torch.ops.extension_cpp.lltm_forward.default( + input, weights, bias, old_h, old_cell + ) + new_h, new_cell = outputs[:2] + variables = list(outputs[1:]) + [weights] + ctx.save_for_backward(*variables) + + return new_h, new_cell + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_h, grad_cell): + ( + d_old_h, + d_input, + d_weights, + d_bias, + d_old_cell, + ) = torch.ops.extension_cpp.lltm_backward.default( + grad_h, grad_cell, *ctx.saved_tensors + ) + return d_input, d_weights, d_bias, d_old_h, d_old_cell + + +def reference_lltm( + input: Tensor, weights: Tensor, bias: Tensor, old_h: Tensor, old_cell: Tensor +) -> Tuple[Tensor, Tensor]: + X = torch.cat([old_h, input], dim=1) + + # Compute the input, output and candidate cell gates with one MM. + gate_weights = torch.nn.functional.linear(X, weights, bias) + # Split the combined gate weight matrix into its components. + gates = gate_weights.chunk(3, dim=1) + + input_gate = torch.sigmoid(gates[0]) + output_gate = torch.sigmoid(gates[1]) + # Here we use an ELU instead of the usual tanh. + candidate_cell = torch.nn.functional.elu(gates[2]) + + # Compute the new cell state. + new_cell = old_cell + candidate_cell * input_gate + # Compute the new hidden state and output. + new_h = torch.tanh(new_cell) * output_gate + + return new_h, new_cell diff --git a/extension_cpp_cmake/extension_cpp/abstract_impls.py b/extension_cpp_cmake/extension_cpp/abstract_impls.py new file mode 100644 index 0000000..68d51cd --- /dev/null +++ b/extension_cpp_cmake/extension_cpp/abstract_impls.py @@ -0,0 +1,12 @@ +import torch +from torch.library import impl_abstract + + +# @impl_abstract("DrissTorch::saturated_cast") +# def saturated_cast_meta( +# x: torch.Tensor, +# scale: torch.Tensor, +# out_dtype: torch.dtype, +# transpose: bool = False, +# ): +# return torch.empty_like(x, dtype=out_dtype) diff --git a/extension_cpp_cmake/pyproject.toml b/extension_cpp_cmake/pyproject.toml new file mode 100644 index 0000000..464baa3 --- /dev/null +++ b/extension_cpp_cmake/pyproject.toml @@ -0,0 +1,65 @@ +[build-system] +# Speicfy which version of torch to use +requires = ["scikit-build-core>=0.3.3", "torch==2.2"] +build-backend = "scikit_build_core.build" + +[project] +name = "extension_cpp" +version = "0.0.1" +authors = [{ name = "Driss Guessous", email = "drisspguessous@gmail.com" }] +description = "A Pytorch Extension with C++ and CMake" +readme = "README.md" +requires-python = ">=3.9" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] + +dependencies = [] + +[project.optional-dependencies] +dev = [ + "black==23.3.0", + "usort==1.0.6", + "ufmt==2.1.0", + "libcst==1.0.1", + "pre-commit==3.6.0", + "bumpver", + "pip-tools", + "pytest", +] + +# For a more comprehensive list of options see this link: +# https://scikit-build-core.readthedocs.io/en/latest/configuration.html + +[tool.scikit-build] +cmake.verbose = true +logging.level = "INFO" +cmake.build-type = "RelWithDebInfo" +build-dir = "build" + +# Specify CMake defines: +[tool.scikit-build.cmake.define] +TORCH_CUDA_ARCH_LIST="9.0" + +[tool.usort] +first_party_detection = false + +[tool.black] +target-version = ["py38"] +line-length = 99 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' diff --git a/extension_cpp_cmake/test/benchmark.py b/extension_cpp_cmake/test/benchmark.py new file mode 100644 index 0000000..2272bda --- /dev/null +++ b/extension_cpp_cmake/test/benchmark.py @@ -0,0 +1,83 @@ +from __future__ import division +from __future__ import print_function + +import argparse +import math +import time + +import torch + +TIME_SCALES = {"s": 1, "ms": 1000, "us": 1000000} + +parser = argparse.ArgumentParser() +parser.add_argument("example", choices=["py", "cpp", "cuda"]) +parser.add_argument("-b", "--batch-size", type=int, default=16) +parser.add_argument("-f", "--features", type=int, default=32) +parser.add_argument("-s", "--state-size", type=int, default=128) +parser.add_argument("-r", "--runs", type=int, default=100) +parser.add_argument("--scale", choices=["s", "ms", "us"], default="us") +parser.add_argument("-c", "--cuda", action="store_true") +parser.add_argument("-d", "--double", action="store_true") +options = parser.parse_args() + +if options.example == "py": + from extension_cpp.ops import reference_lltm as LLTM +else: + from extension_cpp.ops import lltm as LLTM +if options.example == "cuda": + options.cuda = True + +device = torch.device("cuda") if options.cuda else torch.device("cpu") +dtype = torch.float64 if options.double else torch.float32 + +kwargs = {"dtype": dtype, "device": device, "requires_grad": True} +batch_size = options.batch_size +features = options.features +state_size = options.state_size +X = torch.randn( + batch_size, # E: No overload variant of "randn" matches argument + features, + **kwargs +) +h = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia +C = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia +W = torch.randn(3 * state_size, features + state_size, **kwargs) +b = torch.randn(1, 3 * state_size, **kwargs) # E: No overload variant of "randn" + +# Force CUDA initialization +new_h, new_C = LLTM(X, W, b, h, C) +(new_h.sum() + new_C.sum()).backward() + +forward_min = math.inf +forward_time = 0 +backward_min = math.inf +backward_time = 0 +for _ in range(options.runs): + X.grad = None + h.grad = None + C.grad = None + W.grad = None + b.grad = None + start = time.time() + new_h, new_C = LLTM(X, W, b, h, C) + elapsed = time.time() - start + forward_min = min(forward_min, elapsed) + forward_time += elapsed + + start = time.time() + (new_h.sum() + new_C.sum()).backward() + elapsed = time.time() - start + backward_min = min(backward_min, elapsed) + backward_time += elapsed + +scale = TIME_SCALES[options.scale] +forward_min *= scale +backward_min *= scale +forward_average = forward_time / options.runs * scale +backward_average = backward_time / options.runs * scale + +print( + "Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}".format( + forward_min, forward_average, backward_min, backward_average, options.scale + ) +) \ No newline at end of file diff --git a/extension_cpp_cmake/test/test_extension.py b/extension_cpp_cmake/test/test_extension.py new file mode 100644 index 0000000..55844a0 --- /dev/null +++ b/extension_cpp_cmake/test/test_extension.py @@ -0,0 +1,59 @@ +import torch +from torch.testing._internal.common_utils import TestCase +from torch.testing._internal.optests import opcheck +import unittest +from extension_cpp import lltm, reference_lltm +from torch import Tensor +from typing import Tuple +import torch.nn.functional as F + + +def sample_inputs(device): + batch_size = 3 + features = 17 + state_size = 5 + kwargs = {"dtype": torch.float64, "device": device, "requires_grad": True} + X = torch.randn( + batch_size, # E: No overload variant of "randn" matches argument + features, + **kwargs + ) + h = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia + C = torch.randn(batch_size, state_size, **kwargs) # E: No overload varia + W = torch.randn(3 * state_size, features + state_size, **kwargs) + b = torch.randn(1, 3 * state_size, **kwargs) # E: No overload variant of "randn" + return X, W, b, h, C + + +class TestLLTM(TestCase): + def _test_correctness(self, device): + args = sample_inputs(device) + result = lltm(*args) + result = lltm(*args) + expected = reference_lltm(*args) + self.assertEqual(len(result), len(expected)) + torch.testing.assert_close(result, expected) + + def test_correctness_cpu(self): + self._test_correctness("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_correctness_cuda(self): + self._test_correctness("cuda") + + def _test_gradients(self, device): + args = sample_inputs(device) + torch.autograd.gradcheck(lltm, args) + + def test_gradients_cpu(self): + self._test_gradients("cpu") + + # This is supposed to succeed, there's probably a bug in the CUDA kernel. + @unittest.expectedFailure + @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") + def test_gradients_cuda(self): + self._test_gradients("cuda") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file