Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 0 additions & 172 deletions LLama.Examples/Examples/BatchedDecoding.cs

This file was deleted.

138 changes: 138 additions & 0 deletions LLama.Examples/Examples/BatchedExecutorFork.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
public class BatchedExecutorFork
{
private const int n_split = 16;
private const int n_len = 64;

public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);

Console.WriteLine("Prompt (leave blank to select automatically):");
var prompt = Console.ReadLine();
if (string.IsNullOrWhiteSpace(prompt))
prompt = "Not many people know that";

// Create an executor that can evaluate a batch of conversations together
var executor = new BatchedExecutor(model, parameters);

// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");

// Evaluate the initial prompt to create one conversation
var start = executor.Prompt(prompt);
await executor.Infer();

// Create the root node of the tree
var root = new Node(start);

// Run inference loop
for (var i = 0; i < n_len; i++)
{
if (i != 0)
await executor.Infer();

// Occasionally fork all the active conversations
if (i != 0 && i % n_split == 0)
root.Split();

// Sample all active conversations
root.Sample();
}

Console.WriteLine($"{prompt}...");
root.Print(1);

Console.WriteLine("Press any key to exit demo");
Console.ReadKey(true);
}

class Node
{
private readonly StreamingTokenDecoder _decoder;

private readonly DefaultSamplingPipeline _sampler;
private Conversation? _conversation;

private Node? _left;
private Node? _right;

public int ActiveConversationCount => _conversation != null ? 1 : _left!.ActiveConversationCount + _right!.ActiveConversationCount;

public Node(Conversation conversation)
{
_sampler = new DefaultSamplingPipeline();
_conversation = conversation;
_decoder = new StreamingTokenDecoder(conversation.Executor.Context);
}

public void Sample()
{
if (_conversation == null)
{
_left?.Sample();
_right?.Sample();
return;
}

if (_conversation.RequiresInference)
return;

// Sample one token
var ctx = _conversation.Executor.Context.NativeHandle;
var logitsCopy = _conversation.Sample().ToArray();
var token = _sampler.Sample(ctx, logitsCopy, Array.Empty<LLamaToken>());
_sampler.Accept(ctx, token);
_decoder.Add(token);

// Prompt the conversation with this token, to continue generating from there
_conversation.Prompt(token);
}

public void Split()
{
if (_conversation != null)
{
_left = new Node(_conversation.Fork());
_right = new Node(_conversation.Fork());

_conversation.Dispose();
_conversation = null;
}
else
{
_left?.Split();
_right?.Split();
}
}

public void Print(int indendation)
{
var colors = new[] { ConsoleColor.Red, ConsoleColor.Green, ConsoleColor.Blue, ConsoleColor.Yellow, ConsoleColor.White };
Console.ForegroundColor = colors[indendation % colors.Length];

var message = _decoder.Read().ReplaceLineEndings("");

var prefix = new string(' ', indendation * 3);
var suffix = _conversation == null ? "..." : "";
Console.WriteLine($"{prefix}...{message}{suffix}");

_left?.Print(indendation + 2);
_right?.Print(indendation + 2);
}
}
}
Loading