Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,21 @@ public static IResourceBuilder<OllamaResource> WithGPUSupport(this IResourceBuil

return vendor switch
{
OllamaGpuVendor.Nvidia => builder.WithContainerRuntimeArgs("--gpus", "all"),
OllamaGpuVendor.Nvidia => builder.WithNvidiaGPUSupport(),
OllamaGpuVendor.AMD => builder.WithAMDGPUSupport(),
_ => throw new ArgumentException("Invalid GPU vendor", nameof(vendor))
};
}

private static IResourceBuilder<OllamaResource> WithNvidiaGPUSupport(this IResourceBuilder<OllamaResource> builder)
{
return builder.ApplicationBuilder.GetContainerRuntime() switch
{
"podman" => builder.WithContainerRuntimeArgs("--device", "nvidia.com/gpu=all"),
_ => builder.WithContainerRuntimeArgs("--gpus", "all"),
};
}

private static IResourceBuilder<OllamaResource> WithAMDGPUSupport(this IResourceBuilder<OllamaResource> builder)
{
if (builder.Resource.TryGetLastAnnotation<ContainerImageAnnotation>(out var containerAnnotation))
Expand Down Expand Up @@ -170,4 +179,9 @@ private static string ToJson(this object obj)
{
return JsonSerializer.Serialize(obj, jsonSerializerOptions);
}

private static string? GetContainerRuntime(this IDistributedApplicationBuilder builder)
// For config names, see https://github.com/dotnet/aspire/blob/91481d9b5a5602d3d641fa8ade554c95bcac22b5/src/Shared/KnownConfigNames.cs
=> (builder.Configuration["ASPIRE_CONTAINER_RUNTIME"]
?? builder.Configuration["DOTNET_ASPIRE_CONTAINER_RUNTIME"])?.ToLowerInvariant();
}
Original file line number Diff line number Diff line change
Expand Up @@ -619,25 +619,43 @@ public async Task WithNvidiaGPUSupport()

using var app = builder.Build();

var appModel = app.Services.GetRequiredService<DistributedApplicationModel>();
var resource = AssertSingleResource<OllamaResource>(app);

var resource = Assert.Single(appModel.Resources.OfType<OllamaResource>());
await AssertContainerRuntimeArgs(resource,
"--gpus",
"all");
}

Assert.True(resource.TryGetLastAnnotation(out ContainerRuntimeArgsCallbackAnnotation? argsAnnotations));
ContainerRuntimeArgsCallbackContext context = new([]);
await argsAnnotations.Callback(context);
[Fact]
public async Task WithNvidiaGPUSupportOnPodman()
{
var builder = DistributedApplication.CreateBuilder();
builder.Configuration["ASPIRE_CONTAINER_RUNTIME"] = "podman";
_ = builder.AddOllama("ollama").WithGPUSupport(OllamaGpuVendor.Nvidia);

Assert.Collection(
context.Args,
arg =>
{
Assert.Equal("--gpus", arg);
},
arg =>
{
Assert.Equal("all", arg);
}
);
using var app = builder.Build();

var resource = AssertSingleResource<OllamaResource>(app);

await AssertContainerRuntimeArgs(resource,
"--device",
"nvidia.com/gpu=all");
}

[Fact]
public async Task WithNvidiaGPUSupportOnPodmanLegacy()
{
var builder = DistributedApplication.CreateBuilder();
builder.Configuration["DOTNET_ASPIRE_CONTAINER_RUNTIME"] = "podman";
_ = builder.AddOllama("ollama").WithGPUSupport(OllamaGpuVendor.Nvidia);

using var app = builder.Build();

var resource = AssertSingleResource<OllamaResource>(app);

await AssertContainerRuntimeArgs(resource,
"--device",
"nvidia.com/gpu=all");
}

[Fact]
Expand All @@ -648,36 +666,34 @@ public async Task WithAMDGPUSupport()

using var app = builder.Build();

var appModel = app.Services.GetRequiredService<DistributedApplicationModel>();
var resource = AssertSingleResource<OllamaResource>(app);

var resource = Assert.Single(appModel.Resources.OfType<OllamaResource>());
await AssertContainerRuntimeArgs(resource,
"--device", "/dev/kfd",
"--device", "/dev/dri");

Assert.True(resource.TryGetLastAnnotation<ContainerImageAnnotation>(out var imageAnnotation));
Assert.NotNull(imageAnnotation);
Assert.EndsWith("-rocm", imageAnnotation.Tag, StringComparison.OrdinalIgnoreCase);
}

private static T AssertSingleResource<T>(DistributedApplication app) where T : ContainerResource
{
var appModel = app.Services.GetRequiredService<DistributedApplicationModel>();
return Assert.Single(appModel.Resources.OfType<T>());
}

private static async Task AssertContainerRuntimeArgs(ContainerResource resource, params string[] expectedArgs)
{
Assert.True(resource.TryGetLastAnnotation(out ContainerRuntimeArgsCallbackAnnotation? argsAnnotations));
Assert.NotNull(argsAnnotations);
ContainerRuntimeArgsCallbackContext context = new([]);
await argsAnnotations.Callback(context);

Assert.Collection(
context.Args,
arg =>
{
Assert.Equal("--device", arg);
},
arg =>
{
Assert.Equal("/dev/kfd", arg);
},
arg =>
{
Assert.Equal("--device", arg);
},
arg =>
{
Assert.Equal("/dev/dri", arg);
}
);

Assert.True(resource.TryGetLastAnnotation<ContainerImageAnnotation>(out var imageAnnotation));
Assert.NotNull(imageAnnotation);
Assert.EndsWith("-rocm", imageAnnotation.Tag, StringComparison.OrdinalIgnoreCase);
Assert.Equal(expectedArgs.Length, context.Args.Count);
for (int i = 0; i < expectedArgs.Length; i++)
{
Assert.Equal(expectedArgs[i], context.Args[i]);
}
}
}
Loading