diff --git a/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.cs b/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.cs index 64ca03b12..e01876806 100644 --- a/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.cs +++ b/src/CommunityToolkit.Aspire.Hosting.Ollama/OllamaResourceBuilderExtensions.cs @@ -101,12 +101,21 @@ public static IResourceBuilder WithGPUSupport(this IResourceBuil return vendor switch { - OllamaGpuVendor.Nvidia => builder.WithContainerRuntimeArgs("--gpus", "all"), + OllamaGpuVendor.Nvidia => builder.WithNvidiaGPUSupport(), OllamaGpuVendor.AMD => builder.WithAMDGPUSupport(), _ => throw new ArgumentException("Invalid GPU vendor", nameof(vendor)) }; } + private static IResourceBuilder WithNvidiaGPUSupport(this IResourceBuilder builder) + { + return builder.ApplicationBuilder.GetContainerRuntime() switch + { + "podman" => builder.WithContainerRuntimeArgs("--device", "nvidia.com/gpu=all"), + _ => builder.WithContainerRuntimeArgs("--gpus", "all"), + }; + } + private static IResourceBuilder WithAMDGPUSupport(this IResourceBuilder builder) { if (builder.Resource.TryGetLastAnnotation(out var containerAnnotation)) @@ -170,4 +179,9 @@ private static string ToJson(this object obj) { return JsonSerializer.Serialize(obj, jsonSerializerOptions); } + + private static string? GetContainerRuntime(this IDistributedApplicationBuilder builder) + // For config names, see https://github.com/dotnet/aspire/blob/91481d9b5a5602d3d641fa8ade554c95bcac22b5/src/Shared/KnownConfigNames.cs + => (builder.Configuration["ASPIRE_CONTAINER_RUNTIME"] + ?? builder.Configuration["DOTNET_ASPIRE_CONTAINER_RUNTIME"])?.ToLowerInvariant(); } diff --git a/tests/CommunityToolkit.Aspire.Hosting.Ollama.Tests/AddOllamaTests.cs b/tests/CommunityToolkit.Aspire.Hosting.Ollama.Tests/AddOllamaTests.cs index cf7d2b7c9..a53abb7f9 100644 --- a/tests/CommunityToolkit.Aspire.Hosting.Ollama.Tests/AddOllamaTests.cs +++ b/tests/CommunityToolkit.Aspire.Hosting.Ollama.Tests/AddOllamaTests.cs @@ -619,25 +619,43 @@ public async Task WithNvidiaGPUSupport() using var app = builder.Build(); - var appModel = app.Services.GetRequiredService(); + var resource = AssertSingleResource(app); - var resource = Assert.Single(appModel.Resources.OfType()); + await AssertContainerRuntimeArgs(resource, + "--gpus", + "all"); + } - Assert.True(resource.TryGetLastAnnotation(out ContainerRuntimeArgsCallbackAnnotation? argsAnnotations)); - ContainerRuntimeArgsCallbackContext context = new([]); - await argsAnnotations.Callback(context); + [Fact] + public async Task WithNvidiaGPUSupportOnPodman() + { + var builder = DistributedApplication.CreateBuilder(); + builder.Configuration["ASPIRE_CONTAINER_RUNTIME"] = "podman"; + _ = builder.AddOllama("ollama").WithGPUSupport(OllamaGpuVendor.Nvidia); - Assert.Collection( - context.Args, - arg => - { - Assert.Equal("--gpus", arg); - }, - arg => - { - Assert.Equal("all", arg); - } - ); + using var app = builder.Build(); + + var resource = AssertSingleResource(app); + + await AssertContainerRuntimeArgs(resource, + "--device", + "nvidia.com/gpu=all"); + } + + [Fact] + public async Task WithNvidiaGPUSupportOnPodmanLegacy() + { + var builder = DistributedApplication.CreateBuilder(); + builder.Configuration["DOTNET_ASPIRE_CONTAINER_RUNTIME"] = "podman"; + _ = builder.AddOllama("ollama").WithGPUSupport(OllamaGpuVendor.Nvidia); + + using var app = builder.Build(); + + var resource = AssertSingleResource(app); + + await AssertContainerRuntimeArgs(resource, + "--device", + "nvidia.com/gpu=all"); } [Fact] @@ -648,36 +666,34 @@ public async Task WithAMDGPUSupport() using var app = builder.Build(); - var appModel = app.Services.GetRequiredService(); + var resource = AssertSingleResource(app); - var resource = Assert.Single(appModel.Resources.OfType()); + await AssertContainerRuntimeArgs(resource, + "--device", "/dev/kfd", + "--device", "/dev/dri"); + + Assert.True(resource.TryGetLastAnnotation(out var imageAnnotation)); + Assert.NotNull(imageAnnotation); + Assert.EndsWith("-rocm", imageAnnotation.Tag, StringComparison.OrdinalIgnoreCase); + } + private static T AssertSingleResource(DistributedApplication app) where T : ContainerResource + { + var appModel = app.Services.GetRequiredService(); + return Assert.Single(appModel.Resources.OfType()); + } + + private static async Task AssertContainerRuntimeArgs(ContainerResource resource, params string[] expectedArgs) + { Assert.True(resource.TryGetLastAnnotation(out ContainerRuntimeArgsCallbackAnnotation? argsAnnotations)); + Assert.NotNull(argsAnnotations); ContainerRuntimeArgsCallbackContext context = new([]); await argsAnnotations.Callback(context); - Assert.Collection( - context.Args, - arg => - { - Assert.Equal("--device", arg); - }, - arg => - { - Assert.Equal("/dev/kfd", arg); - }, - arg => - { - Assert.Equal("--device", arg); - }, - arg => - { - Assert.Equal("/dev/dri", arg); - } - ); - - Assert.True(resource.TryGetLastAnnotation(out var imageAnnotation)); - Assert.NotNull(imageAnnotation); - Assert.EndsWith("-rocm", imageAnnotation.Tag, StringComparison.OrdinalIgnoreCase); + Assert.Equal(expectedArgs.Length, context.Args.Count); + for (int i = 0; i < expectedArgs.Length; i++) + { + Assert.Equal(expectedArgs[i], context.Args[i]); + } } }