From 24a07ab6e6dd339941da0b8334227262bc2fab8f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Mar 2025 01:30:16 +0100 Subject: [PATCH 01/31] tts : implement mimi decoder --- .gitignore | 1 + common/common.cpp | 28 + common/common.h | 22 + examples/tts/CMakeLists.txt | 6 + examples/tts/README-mimi.md | 50 ++ examples/tts/convert_mimi_to_gguf.py | 191 +++++++ examples/tts/mimi.cpp | 770 +++++++++++++++++++++++++++ 7 files changed, 1068 insertions(+) create mode 100644 examples/tts/README-mimi.md create mode 100644 examples/tts/convert_mimi_to_gguf.py create mode 100644 examples/tts/mimi.cpp diff --git a/.gitignore b/.gitignore index 2c67ad7f7c609..41fe1f31271d2 100644 --- a/.gitignore +++ b/.gitignore @@ -107,6 +107,7 @@ examples/server/*.gz.hpp !examples/*/*/*.kts !examples/sycl/*.bat !examples/sycl/*.sh +/*.wav # Server Web UI temporary files node_modules diff --git a/common/common.cpp b/common/common.cpp index 18ffb4e738aee..30870980a148d 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2055,3 +2055,31 @@ common_grammar_trigger common_grammar_trigger::from_json(const json & in) { } return out; } + +// +// Audio utils +// + +bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate) { + std::ofstream file(fname, std::ios::binary); + if (!file) { + LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); + return false; + } + + wav_header header; + header.sample_rate = sample_rate; + header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); + header.block_align = header.num_channels * (header.bits_per_sample / 8); + header.data_size = data.size() * (header.bits_per_sample / 8); + header.chunk_size = 36 + header.data_size; + + file.write(reinterpret_cast(&header), sizeof(header)); + + for (const auto & sample : data) { + int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); + file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); + } + + return file.good(); +} diff --git a/common/common.h b/common/common.h index 1c0f199774976..0c67693149285 100644 --- a/common/common.h +++ b/common/common.h @@ -683,3 +683,25 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count"; const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; } + +// +// Audio utils +// + +struct wav_header { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t chunk_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_chunk_size = 16; + uint16_t audio_format = 1; // PCM + uint16_t num_channels = 1; // Mono + uint32_t sample_rate; + uint32_t byte_rate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; +}; + +bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate); diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index c72bd814c3b31..f76d834b18fec 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -3,3 +3,9 @@ add_executable(${TARGET} tts.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-mimi) +add_executable(${TARGET} mimi.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/tts/README-mimi.md b/examples/tts/README-mimi.md new file mode 100644 index 0000000000000..b46f5f77b95d0 --- /dev/null +++ b/examples/tts/README-mimi.md @@ -0,0 +1,50 @@ +# llama.cpp/example/mimi + +This demonstrates running [Kyutai's Mimi](https://huggingface.co/kyutai/mimi) model via GGML. + +## Quickstart + +Convert model to GGUF (no need to download, the script will automatically download the `safetensors` file) + +```sh +python examples/tts/convert_mimi_to_gguf.py + +# output file: kyutai-mimi.gguf + +# optionally, use q8_0 quantization for faster speed +python examples/tts/convert_mimi_to_gguf.py --outtype q8_0 +``` + +Then compile, run it: + +```sh +cmake --build build -j --target llama-mimi + +./build/bin/llama-mimi kyutai-mimi.gguf codes.txt + +# output: output.wav + +# alternatively, use "dummy1" to get a "hey hello there" sample output file +./build/bin/llama-mimi kyutai-mimi.gguf dummy1 +``` + +Example of code file (one code per line): + +``` +1263 +1597 +1596 +1477 +1540 +1720 +1433 +118 +1066 +1968 +1096 +232 +418 +566 +1653 +2010 +``` diff --git a/examples/tts/convert_mimi_to_gguf.py b/examples/tts/convert_mimi_to_gguf.py new file mode 100644 index 0000000000000..5b44ef62103ba --- /dev/null +++ b/examples/tts/convert_mimi_to_gguf.py @@ -0,0 +1,191 @@ +import gguf +import argparse +import logging +import torch +from typing import Union +from pathlib import Path +from torch import Tensor +from transformers import MimiModel + +logger = logging.getLogger("mimi") + + +class MimiModelConverter: + mimi_model: MimiModel + gguf_writer: gguf.GGUFWriter + fname_out: Path + ftype: gguf.LlamaFileType + + def __init__(self, + pretrained_model_name_or_path: Union[Path, str], + fname_out: Path, + ftype: gguf.LlamaFileType, + is_big_endian: bool,): + self.mimi_model = MimiModel.from_pretrained(pretrained_model_name_or_path) + self.fname_out = fname_out + self.ftype = ftype + endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.gguf_writer = gguf.GGUFWriter( + path=None, + arch="if you see this, you are using the wrong file", + endianess=endianess) + + assert self.mimi_model.config.architectures[0] == "MimiModel" + + # load tensors + for name, data_torch in self.mimi_model.state_dict().items(): + # convert any unsupported data types to float32 + old_dtype = data_torch.dtype + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + self.add_tensor(name, data_torch, old_dtype) + + def add_tensor(self, name: str, data_torch: Tensor, old_dtype: torch.dtype): + is_1d = len(data_torch.shape) == 1 + is_bias = ".bias" in name + can_quantize = not is_1d and not is_bias + data_qtype = gguf.GGMLQuantizationType.F32 + + n_head = self.mimi_model.config.num_attention_heads + n_kv_head = self.mimi_model.config.num_key_value_heads + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = self.undo_permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = self.undo_permute(data_torch, n_head, n_kv_head) + + # process codebook + if ".codebook.initialized" in name: + # "initialized" tensor + state_dict = self.mimi_model.state_dict() + embed_sum = state_dict[name.replace(".initialized", ".embed_sum")] + cluster_usage = state_dict[name.replace(".initialized", ".cluster_usage")] + # see modeling_mimi.py --> MimiEuclideanCodebook + data_torch = embed_sum / cluster_usage.clamp(min=self.mimi_model.config.norm_eps)[:, None] + name = name.replace(".initialized", "") + + # ignore processed tensors + if ".cluster_usage" in name or ".embed_sum" in name: + return + + # transpose some tensors + if ".conv.bias" in name: + data_torch = data_torch.view((1, data_torch.shape[0])) + data_torch = data_torch.transpose(0, 1) + + # change view 3d to 2d + if "quantizer" in name and "_proj." in name: + assert data_torch.shape[2] == 1 + data_torch = data_torch.view((data_torch.shape[0], data_torch.shape[1])) + + # shorten name, otherwise it will be too long for ggml to read + name = name.replace("_residual_vector_quantizer", "_rvq") + + if can_quantize: + if self.ftype == gguf.LlamaFileType.ALL_F32: + data_qtype = gguf.GGMLQuantizationType.F32 + elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_qtype = gguf.GGMLQuantizationType.F16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data_qtype = gguf.GGMLQuantizationType.BF16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: + data_qtype = gguf.GGMLQuantizationType.Q8_0 + else: + raise ValueError(f"Unsupported file type: {self.ftype}") + + # Conv kernels are always F16 + if ".conv.weight" in name: + data_qtype = gguf.GGMLQuantizationType.F16 + + data = data_torch.numpy() + + try: + data = gguf.quants.quantize(data, data_qtype) + except Exception as e: + logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) + + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}" + logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype) + + def write(self): + self.gguf_writer.write_header_to_file(path=self.fname_out) + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_tensors_to_file(progress=True) + self.gguf_writer.close() + + @staticmethod + def undo_permute(weights: Tensor, n_head: int, n_head_kv: int): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert Mimi safetensors model to GGUF",) + parser.add_argument( + "--outfile", type=Path, default="kyutai-mimi.gguf", + help="path to write to", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", + help="output format", + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "model", type=Path, + help="directory or model ID containing model file (if model ID is specified, download from Hugging Face hub)", + nargs="?", + default="kyutai/mimi", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + + args = parser.parse_args() + if args.model is None: + parser.error("the following arguments are required: model") + return args + + +def main() -> None: + args = parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + dir_model = args.model + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + } + + logger.info(f"Loading model: {dir_model}") + + with torch.inference_mode(): + converter = MimiModelConverter( + pretrained_model_name_or_path=dir_model, + fname_out=args.outfile, + ftype=ftype_map[args.outtype], + is_big_endian=args.bigendian, + ) + converter.write() + + +if __name__ == '__main__': + main() + diff --git a/examples/tts/mimi.cpp b/examples/tts/mimi.cpp new file mode 100644 index 0000000000000..2c5833faa277b --- /dev/null +++ b/examples/tts/mimi.cpp @@ -0,0 +1,770 @@ +#include "ggml.h" +#include "ggml-cpp.h" +#include "ggml-cpu.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "gguf.h" + +#include "common.h" + +#include +#include +#include +#include +#include +#include + +/** + * Implementation of Kyutai's Mimi model using GGML. + * Based on this research: https://github.com/ngxson/ggml-easy/blob/master/demo/kyutai-mimi.cpp + * + * NOTE: only decoder is working for now. + * + * Background: + * - The audio codes can be generated using any Mimi-based model, for example: Moshi, Hibiki, Sesame, etc + * - Audio codes must be in the order: (1 semantic component, 31 acoustic components) repeated N times + * + * How it works? + * 1. Audio code passed to RVQ (mimi_residual_vector_quantizer) to get the latent code + * 2. The latent code is passed to a mimi_conv_transpose_1d (depthwise) to upscale + * 3. The upscaled code is passed to transformer, it converts N frames to N frames + * 4. The output embeddings is then passed to SEANet (mimi_encoder_decoder) to get the final waveform + * 5. Waveform is written to a file + */ + +// copied from https://huggingface.co/kyutai/mimi/blob/main/config.json +struct mimi_config_t { + bool causal = true; + int max_position_embeddings = 8000; + int num_hidden_layers = 8; + int n_embd = 512; + int n_ffn = 2048; + int n_head = 8; + int n_head_kv = 8; + int n_rot = 64; + float norm_eps = 1e-5; + float rope_theta = 10000.0f; + int sliding_window = 250; + std::array upsampling_ratio = {8, 6, 5, 4}; + std::array downsampling_ratio = {4, 5, 6, 8}; // reverse of upsampling_ratio + // vector quantizer + float frame_rate = 12.5; + int audio_channels = 1; + int codebook_size = 2048; + int codebook_dim = 256; + int n_semantic_components = 1; + int n_acoustic_components = 31; + // decode + float trim_right_ratio = 1.0f; +} mimi_config; + +// Adapted from https://github.com/ngxson/ggml-easy/blob/master/ggml-easy.h +struct mimi_ggml_ctx { + gguf_context * ctx_gguf = nullptr; + ggml_context * ctx_data = nullptr; + ggml_context * ctx_gf = nullptr; + + // CPU-only for now, as many kernels are missing and we actually get less performance with GPU + ggml_backend_t backend = nullptr; + ggml_backend_buffer_t buf = nullptr; + ggml_backend_sched_ptr sched; + + ggml_cgraph * gf = nullptr; + std::vector buf_compute_meta; + int max_nodes = 16 * 1024; + + std::unordered_map tensors; + + mimi_ggml_ctx() { + backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + auto buft = ggml_backend_get_default_buffer_type(backend); + sched.reset( + ggml_backend_sched_new(&backend, &buft, 1, max_nodes, false) + ); + buf_compute_meta.resize(max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); + } + + void load_gguf(const char * fname) { + ggml_context * meta = nullptr; + + gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &meta, + }; + + ctx_gguf = gguf_init_from_file(fname, params); + + // load tensors + const int n_tensors = gguf_get_n_tensors(ctx_gguf); + + std::vector read_buf; + ggml_init_params ggml_params = { + /*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ctx_data = ggml_init(ggml_params); + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + ggml_free(meta); + throw std::runtime_error("cannot open model file for loading tensors"); + } + + // add tensors to context + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + ggml_tensor * t = ggml_get_tensor(meta, name); + ggml_tensor * cur = ggml_dup_tensor(ctx_data, t); + ggml_set_name(cur, name); + tensors.insert({name, cur}); + } + + // alloc memory and offload data + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_data, buft); + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + ggml_tensor * cur = ggml_get_tensor(ctx_data, name); + const size_t offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i); + // printf("%s: Loading tensor \"%s\"\n", __func__, name); + fin.seekg(offset, std::ios::beg); + if (!fin) { + ggml_free(meta); + throw std::runtime_error(string_format("failed to seek for tensor: %s", name)); + } + int num_bytes = ggml_nbytes(cur); + if (ggml_backend_buft_is_host(buft)) { + // for the CPU and Metal backend, we can read directly into the tensor + fin.read(reinterpret_cast(cur->data), num_bytes); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(num_bytes); + fin.read(reinterpret_cast(read_buf.data()), num_bytes); + ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + } + } + printf("%s: Loaded %d tensors from %s\n", __func__, n_tensors, fname); + fin.close(); + + ggml_free(meta); + } + + /** + * Build a cgraph using the given builder function. + * + * The built cgraph will be stored in `ctx.gf` + */ + void build_graph(std::function builder_fn) { + ggml_free(ctx_gf); + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ctx_gf = ggml_init(params); + ggml_backend_sched_reset(sched.get()); + gf = ggml_new_graph_custom(ctx_gf, max_nodes, false); + + builder_fn(ctx_gf, gf); + ggml_backend_sched_alloc_graph(sched.get(), gf); + } + + ggml_status compute() { + ggml_status status = ggml_backend_sched_graph_compute(sched.get(), gf); + return status; + } + + void set_tensor_data(const std::string & name, const void * data) { + ggml_tensor * t = ggml_get_tensor(ctx_gf, name.c_str()); + if (!t) { + throw std::runtime_error(string_format("tensor not found: %s", name.c_str())); + } + ggml_backend_tensor_set(t, data, 0, ggml_nbytes(t)); + } + + std::pair> get_tensor_data(const std::string & name) { + ggml_tensor * t = ggml_get_tensor(ctx_gf, name.c_str()); + if (!t) { + throw std::runtime_error(string_format("tensor not found: %s", name.c_str())); + } + std::vector data(ggml_nbytes(t)); + ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t)); + return std::make_pair(t, data); + } + + ggml_tensor * get_weight(const char *fmt, ...) { + std::vector str(128); + va_list va; + va_start(va, fmt); + vsnprintf(str.data(), 128, fmt, va); + va_end(va); + auto it = tensors.find(str.data()); + if (it == tensors.end()) { + throw std::runtime_error(string_format("weight tensor not found: %s", str.data())); + } + return it->second; + } + + ~mimi_ggml_ctx() { + ggml_free(ctx_data); + gguf_free(ctx_gguf); + ggml_backend_buffer_free(buf); + } +}; + +/////////////////////////////////////////////////////////////////////////// +// extension to ggml.h +// TODO: add these ops to the library (ofc with a more optimized kernel) + + +// mode: (0) constant, (1) reflect, (2) replicate, (3) circular +// value is only used in "constant" +// only "constant" with 0.0f and "replicate" are implemented here +static ggml_tensor * ggml_pad_ext(ggml_context * ctx0, ggml_tensor * x, int mode, + int64_t pad_left, int64_t pad_right, float value = 0.0f) { + GGML_ASSERT(value == 0.0f); // we can technically use ggml_arange, but for simplication we only support 0.0f + GGML_ASSERT(mode == 0 || mode == 2); + if (pad_left > 0) { + ggml_tensor * tmp = ggml_new_tensor_2d(ctx0, x->type, pad_left, x->ne[1]); + if (mode == 0) { + tmp = ggml_scale(ctx0, tmp, value); + } else if (mode == 2) { + ggml_tensor * elem = ggml_view_2d(ctx0, x, 1, x->ne[1], x->nb[1], 0); // get first column + tmp = ggml_repeat(ctx0, elem, tmp); + } + x = ggml_concat(ctx0, tmp, x, 0); + } + if (pad_right > 0) { + ggml_tensor * tmp = ggml_new_tensor_2d(ctx0, x->type, pad_right, x->ne[1]); + if (mode == 0) { + tmp = ggml_scale(ctx0, tmp, value); + } else if (mode == 2) { + int64_t last = x->ne[0] - 1; + ggml_tensor * elem = ggml_view_2d(ctx0, x, 1, x->ne[1], x->nb[1], last * ggml_element_size(x)); // get last column + tmp = ggml_repeat(ctx0, elem, tmp); + } + x = ggml_concat(ctx0, x, tmp, 0); + } + return x; +} + + + + +/////////////////////////////////////////////////////////////////////////// +// MimiConv and MimiConvTranspose + +static int64_t div_ceil(int64_t a, int64_t b) { + return a / b + (a % b ? 1 : 0); +} + +static ggml_tensor * mimi_conv_1d(ggml_context * ctx0, ggml_tensor * x, + ggml_tensor * kernel, ggml_tensor * bias, int stride, int dilation, bool pad_zero = true) { + int64_t kernel_size = (kernel->ne[0] - 1) * dilation + 1; + int64_t p_total = kernel_size - stride; // padding total + int64_t p_half = p_total / 2; + + int64_t n_frames = div_ceil(x->ne[0] - kernel_size + p_total, stride); + int64_t ideal_len = n_frames * stride + kernel_size - p_total; + int64_t p_extra = ideal_len - x->ne[0]; + + int64_t p_right = (mimi_config.causal ? 0 : p_half) + p_extra; + int64_t p_left = p_total - (mimi_config.causal ? 0 : p_half); + + x = ggml_pad_ext(ctx0, x, pad_zero ? 0 : 2, p_left, p_right); + + x = ggml_conv_1d(ctx0, kernel, x, stride, 0, dilation); + if (bias) { + x = ggml_add(ctx0, x, bias); + } + ggml_set_name(x, "mimi_conv_1d"); + return x; +} + +static ggml_tensor * mimi_conv_transpose_1d(ggml_context * ctx0, ggml_tensor * x, + ggml_tensor * kernel, ggml_tensor * bias, int stride, int dilation, bool depthwise) { + GGML_ASSERT(x->ne[1] == kernel->ne[2]); + int64_t n_rows = x->ne[1]; + int64_t kernel_size = kernel->ne[0]; + int64_t p_total = kernel_size - stride; // padding total + + int64_t p_right = mimi_config.causal + ? (float)p_total / mimi_config.trim_right_ratio + : p_total / 2; + int64_t p_left = p_total - p_right; + + ggml_tensor * out = nullptr; + + if (depthwise) { + for (int64_t ir = 0; ir < n_rows; ir++) { + ggml_tensor * row = ggml_view_1d(ctx0, x, + x->ne[0], ir*x->ne[0]*ggml_element_size(x)); + ggml_tensor * krn = ggml_view_1d(ctx0, kernel, + kernel->ne[0], ir*kernel->ne[0]*ggml_element_size(kernel)); + row = ggml_conv_transpose_1d(ctx0, krn, row, stride, 0, dilation); + // unpad (remove p_right and p_left columns) + row = ggml_view_1d(ctx0, row, row->ne[0] - p_total, p_left*ggml_element_size(row)); + + // TODO: concat can be slow, we should use ggml_view_1d/ggml_cpy to avoid realloc + out = out ? ggml_concat(ctx0, out, row, 1) : row; + } + + } else { + out = ggml_conv_transpose_1d(ctx0, kernel, x, stride, 0, dilation); + // unpad + out = ggml_view_2d(ctx0, out, + out->ne[0] - p_total, out->ne[1], + out->nb[1], p_left*ggml_element_size(out)); + } + + if (bias) { + out = ggml_add(ctx0, out, bias); + } + + return out; +} + + + +/////////////////////////////////////////////////////////////////////////// + +// based on MimiEncoder +// SEANet encoder as used by Mimi. +struct mimi_encoder_decoder { + mimi_ggml_ctx & ctx; + struct layer { + bool is_elu = false; + bool is_resnet = false; + bool is_transposed_conv = false; + ggml_tensor * conv_0_w; + ggml_tensor * conv_0_b; + ggml_tensor * conv_1_w; + ggml_tensor * conv_1_b; + int stride = 1; + }; + std::vector layers; + + std::array repeated_pattern = {1, 4, 7, 10}; + + mimi_encoder_decoder(mimi_ggml_ctx & ctx): ctx(ctx) { + layers.push_back({ + .conv_0_w = ctx.get_weight("decoder.layers.0.conv.weight"), + .conv_0_b = ctx.get_weight("decoder.layers.0.conv.bias"), + }); + for (int i = 0; i < (int)repeated_pattern.size(); ++i) { + int i_start = repeated_pattern[i]; + // upsampling layers + layers.push_back({ + .is_elu = true, // layer (i_start) + }); + layers.push_back({ + .conv_0_w = ctx.get_weight("decoder.layers.%d.conv.weight", i_start + 1), + .conv_0_b = ctx.get_weight("decoder.layers.%d.conv.bias", i_start + 1), + .stride = mimi_config.upsampling_ratio[i], + .is_transposed_conv = true, + }); + // residual layers + layers.push_back({ + .is_resnet = true, + .conv_0_w = ctx.get_weight("decoder.layers.%d.block.1.conv.weight", i_start + 2), + .conv_0_b = ctx.get_weight("decoder.layers.%d.block.1.conv.bias", i_start + 2), + .conv_1_w = ctx.get_weight("decoder.layers.%d.block.3.conv.weight", i_start + 2), + .conv_1_b = ctx.get_weight("decoder.layers.%d.block.3.conv.bias", i_start + 2), + }); + } + layers.push_back({ + .is_elu = true, // layer 13 + }); + layers.push_back({ + .conv_0_w = ctx.get_weight("decoder.layers.14.conv.weight"), + .conv_0_b = ctx.get_weight("decoder.layers.14.conv.bias"), + }); + } + + ggml_tensor * forward(ggml_context * ctx0, ggml_tensor * input) { + ggml_tensor * x = input; + + for (auto & layer : layers) { + if (layer.is_elu) { + x = ggml_elu(ctx0, x); + } else if (layer.is_resnet) { + ggml_tensor * residual = x; + x = ggml_elu(ctx0, x); + x = mimi_conv_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, 1, 1); + x = ggml_elu(ctx0, x); + x = mimi_conv_1d(ctx0, x, layer.conv_1_w, layer.conv_1_b, 1, 1); + x = ggml_add(ctx0, x, residual); + } else { + x = layer.is_transposed_conv + ? mimi_conv_transpose_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, layer.stride, 1, false) + : mimi_conv_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, layer.stride, 1); + } + } + + return x; + } +}; + +struct mimi_transformer { + struct layer { + ggml_tensor * inp_norm_w; + ggml_tensor * inp_norm_b; + + ggml_tensor * attn_q; + ggml_tensor * attn_k; + ggml_tensor * attn_v; + ggml_tensor * attn_o; + ggml_tensor * attn_post_norm_w; + ggml_tensor * attn_post_norm_b; + ggml_tensor * attn_layer_scale; + + ggml_tensor * ffn_up; + ggml_tensor * ffn_down; + ggml_tensor * mlp_layer_scale; + }; + std::vector layers; + + mimi_transformer(mimi_ggml_ctx & ctx, const char * prefix, int n_layers) { + for (int il = 0; il < n_layers; il++) { + layers.push_back({ + .inp_norm_w = ctx.get_weight("%s_transformer.layers.%d.input_layernorm.weight", prefix, il), + .inp_norm_b = ctx.get_weight("%s_transformer.layers.%d.input_layernorm.bias", prefix, il), + + .attn_q = ctx.get_weight("%s_transformer.layers.%d.self_attn.q_proj.weight", prefix, il), + .attn_k = ctx.get_weight("%s_transformer.layers.%d.self_attn.k_proj.weight", prefix, il), + .attn_v = ctx.get_weight("%s_transformer.layers.%d.self_attn.v_proj.weight", prefix, il), + .attn_o = ctx.get_weight("%s_transformer.layers.%d.self_attn.o_proj.weight", prefix, il), + .attn_post_norm_w = ctx.get_weight("%s_transformer.layers.%d.post_attention_layernorm.weight", prefix, il), + .attn_post_norm_b = ctx.get_weight("%s_transformer.layers.%d.post_attention_layernorm.bias", prefix, il), + .attn_layer_scale = ctx.get_weight("%s_transformer.layers.%d.self_attn_layer_scale.scale", prefix, il), + + .ffn_up = ctx.get_weight("%s_transformer.layers.%d.mlp.fc1.weight", prefix, il), + .ffn_down = ctx.get_weight("%s_transformer.layers.%d.mlp.fc2.weight", prefix, il), + .mlp_layer_scale = ctx.get_weight("%s_transformer.layers.%d.mlp_layer_scale.scale", prefix, il), + }); + } + } + + ggml_tensor * forward(ggml_context * ctx0, ggml_tensor * input, ggml_tensor * inp_pos) { + int n_tokens = input->ne[1]; + ggml_tensor * x = input; + + auto layer_norm = [&](ggml_tensor * x, ggml_tensor * w, ggml_tensor * b) { + x = ggml_norm(ctx0, x, mimi_config.norm_eps); + x = ggml_mul(ctx0, x, w); + x = ggml_add(ctx0, x, b); + return x; + }; + + ggml_tensor * residual = input; + + for (auto & layer : layers) { + residual = x; + + // input layer norm + x = layer_norm(x, layer.inp_norm_w, layer.inp_norm_b); + + // self attention + { + ggml_tensor * q = ggml_mul_mat(ctx0, layer.attn_q, x); + ggml_tensor * k = ggml_mul_mat(ctx0, layer.attn_k, x); + ggml_tensor * v = ggml_mul_mat(ctx0, layer.attn_v, x); + + int n_embd_head = mimi_config.n_embd / mimi_config.n_head; + q = ggml_reshape_3d(ctx0, q, n_embd_head, mimi_config.n_head, n_tokens); + k = ggml_reshape_3d(ctx0, k, n_embd_head, mimi_config.n_head_kv, n_tokens); + v = ggml_reshape_3d(ctx0, v, n_embd_head, mimi_config.n_head_kv, n_tokens); + + int n_rot = n_embd_head; + q = ggml_rope_inplace(ctx0, q, inp_pos, n_rot, 0); + q = ggml_cont(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3)); + + k = ggml_rope_inplace(ctx0, k, inp_pos, n_rot, 0); + k = ggml_cont(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3)); + + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); // mimic behavior of llama.cpp + kq = ggml_scale_inplace(ctx0, kq, 1.0f / std::sqrt(n_embd_head)); + ggml_tensor * kq_masked = ggml_diag_mask_inf_inplace(ctx0, kq, n_tokens); + kq = ggml_soft_max_inplace(ctx0, kq_masked); + + v = ggml_cont(ctx0, ggml_permute(ctx0, v, 1, 2, 0, 3)); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + kqv = ggml_reshape_3d(ctx0, kqv, n_embd_head, n_tokens, mimi_config.n_head); + kqv = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + kqv = ggml_cont_2d(ctx0, kqv, mimi_config.n_embd, n_tokens); + + x = ggml_mul_mat(ctx0, layer.attn_o, kqv); + } + + // residual + x = ggml_mul(ctx0, x, layer.attn_layer_scale); + x = ggml_add(ctx0, x, residual); + + residual = x; + x = layer_norm(x, layer.attn_post_norm_w, layer.attn_post_norm_b); + + // mlp + { + x = ggml_mul_mat(ctx0, layer.ffn_up, x); + x = ggml_gelu(ctx0, x); + x = ggml_mul_mat(ctx0, layer.ffn_down, x); + } + + // residual + x = ggml_mul(ctx0, x, layer.mlp_layer_scale); + x = ggml_add(ctx0, x, residual); + } + + return x; + } +}; + +struct mimi_residual_vector_quantizer { + struct component { + ggml_tensor * codebook; + }; + + ggml_tensor * semantic_inp_proj; + std::vector semantic_components; + ggml_tensor * semantic_out_proj; + + ggml_tensor * acoustic_inp_proj; + std::vector acoustic_components; + ggml_tensor * acoustic_out_proj; + + mimi_residual_vector_quantizer(mimi_ggml_ctx & ctx) { + semantic_inp_proj = ctx.get_weight("quantizer.semantic_rvq.input_proj.weight"); + semantic_out_proj = ctx.get_weight("quantizer.semantic_rvq.output_proj.weight"); + for (int i = 0; i < mimi_config.n_semantic_components; i++) { + semantic_components.push_back({ + .codebook = ctx.get_weight("quantizer.semantic_rvq.layers.%d.codebook", i), + }); + } + acoustic_inp_proj = ctx.get_weight("quantizer.acoustic_rvq.input_proj.weight"); + acoustic_out_proj = ctx.get_weight("quantizer.acoustic_rvq.output_proj.weight"); + for (int i = 0; i < mimi_config.n_acoustic_components; i++) { + acoustic_components.push_back({ + .codebook = ctx.get_weight("quantizer.acoustic_rvq.layers.%d.codebook", i), + }); + } + } + + // the input has shape [n_codes, n_codes_per_embd] + // first row is semantic, the rest are acoustic + // example: [ [semantic], [acoustic1], [acoustic2], ... ] + ggml_tensor * decode(ggml_context * ctx0, ggml_tensor * input) { + GGML_ASSERT(input->type == GGML_TYPE_I32); + + size_t n_semantic = semantic_components.size(); + int64_t n_codes_per_embd = (n_semantic + acoustic_components.size()); + int64_t n_codes = input->ne[0] / n_codes_per_embd; + + GGML_ASSERT(input->ne[0] % n_codes_per_embd == 0); + + ggml_tensor * out_s = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mimi_config.codebook_dim, n_codes); + ggml_tensor * out_a = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mimi_config.codebook_dim, n_codes); + out_s = ggml_scale(ctx0, out_s, 0.0f); // clear + out_a = ggml_scale(ctx0, out_a, 0.0f); // clear + + for (size_t ir = 0; ir < (size_t)n_codes_per_embd; ir++) { + ggml_tensor * row = ggml_view_1d(ctx0, input, n_codes, ir*n_codes*ggml_element_size(input)); + if (ir < n_semantic) { + // semantic + ggml_tensor * codebook = semantic_components[ir].codebook; + ggml_tensor * embd = ggml_get_rows(ctx0, codebook, row); + out_s = ggml_add(ctx0, out_s, embd); + } else { + // acoustic + ggml_tensor * codebook = acoustic_components[ir-n_semantic].codebook; + ggml_tensor * embd = ggml_get_rows(ctx0, codebook, row); + out_a = ggml_add(ctx0, out_a, embd); + } + } + + out_s = ggml_mul_mat(ctx0, semantic_out_proj, out_s); + out_a = ggml_mul_mat(ctx0, acoustic_out_proj, out_a); + + return ggml_add(ctx0, out_s, out_a); + } +}; + + + +/////////////////////////////////////////////////////////////////////////// +// main program + +int main(int argc, const char ** argv) { + if (argc < 3) { + fprintf(stderr, "Usage: %s model.gguf codes.txt [output.wav]\n", argv[0]); + fprintf(stderr, " Format of codes.txt file: one code per line\n"); + fprintf(stderr, " Replace codes.txt with dummy0 and dummy1 for testing\n"); + fprintf(stderr, " dummy0: using code 1, 2, 3,..., 96, used for logits matching\n"); + fprintf(stderr, " dummy1: using code that will outputs 'hey hello there' sound\n"); + return 1; + } + + const char * model_path = argv[1]; + const char * codes_path = argv[2]; + const char * out_path = argc < 4 ? "output.wav" : argv[3]; + + mimi_ggml_ctx ctx; + ctx.load_gguf(model_path); + + // initialize components + mimi_encoder_decoder decoder(ctx); + mimi_transformer transformer(ctx, "decoder", mimi_config.num_hidden_layers); + mimi_residual_vector_quantizer quantizer(ctx); + + // load codes + std::vector codes; + if (strcmp(codes_path, "dummy0") == 0) { + printf("Using dummy0 codes\n"); + codes.resize(32 * 3); // [n_codes = 3, n_codes_per_embd = 32] + int n = 0; + for (int c = 0; c < 32; c++) { + for (int r = 0; r < 3; r++) { + codes[r*32 + c] = n++; + } + } + } else if (strcmp(codes_path, "dummy1") == 0) { + printf("Using dummy1 codes\n"); + codes = { + 1263 ,1597 ,1596 ,1477 ,1540 ,1720 ,1433 ,118 ,1066 ,1968 ,1096 ,232 ,418 ,566 ,1653 ,2010 , + 1029 ,1874 ,77 ,1803 ,123 ,908 ,97 ,1616 ,595 ,1170 ,1654 ,1211 ,1967 ,1579 ,1846 ,1462 , + 1962 ,175 ,1539 ,742 ,1065 ,1226 ,19 ,955 ,528 ,1031 ,659 ,1687 ,1173 ,1802 ,1031 ,1714 , + 1986 ,582 ,367 ,112 ,1245 ,1386 ,759 ,532 ,1472 ,1790 ,802 ,1213 ,1543 ,1916 ,1251 ,309 , + 1962 ,1280 ,1943 ,878 ,1588 ,1989 ,568 ,1463 ,1814 ,1095 ,103 ,583 ,976 ,998 ,871 ,587 , + 247 ,1698 ,1817 ,1024 ,268 ,597 ,45 ,1608 ,1880 ,2047 ,759 ,1578 ,1612 ,49 ,1031 ,1076 , + 927 ,1202 ,1601 ,1719 ,1670 ,412 ,568 ,1838 ,341 ,1265 ,1279 ,830 ,1997 ,32 ,1369 ,1686 , + 1307 ,419 ,1143 ,324 ,325 ,572 ,1597 ,1920 ,795 ,915 ,610 ,2000 ,819 ,718 ,1235 ,282 , + 1912 ,1911 ,141 ,1069 ,1485 ,642 ,1370 ,732 ,284 ,1407 ,1591 ,1002 ,939 ,671 ,951 ,1411 , + 1887 ,460 ,1588 ,1636 ,1312 ,232 ,969 ,1513 ,1336 ,1185 ,1660 ,4 ,926 ,1243 ,1077 ,1379 , + 704 ,85 ,257 ,1302 ,1029 ,1717 ,899 ,1345 ,355 ,1915 ,1007 ,315 ,1283 ,779 ,415 ,335 , + 1848 ,1786 ,469 ,295 ,380 ,1736 ,393 ,765 ,1921 ,836 ,374 ,1649 ,52 ,1633 ,759 ,548 , + 1922 ,47 ,564 ,893 ,34 ,131 ,1063 ,1657 ,474 ,1960 ,1255 ,1275 ,92 ,976 ,1217 ,483 , + 105 ,1746 ,1158 ,1557 ,1001 ,512 ,1668 ,1255 ,1045 ,1596 ,613 ,1272 ,1366 ,1147 ,411 ,831 , + 349 ,692 ,1435 ,2005 ,1465 ,37 ,892 ,95 ,460 ,557 ,1315 ,259 ,1978 ,1838 ,1232 ,2003 , + 1197 ,111 ,1953 ,1297 ,1843 ,671 ,1687 ,91 ,1788 ,1138 ,1896 ,399 ,615 ,758 ,1423 ,365 , + 288 ,632 ,876 ,875 ,1156 ,345 ,1189 ,638 ,1527 ,1981 ,1925 ,333 ,1353 ,473 ,1913 ,1443 , + 1634 ,1373 ,803 ,420 ,192 ,1440 ,1593 ,1925 ,784 ,831 ,552 ,807 ,1942 ,1289 ,612 ,511 , + 968 ,1091 ,30 ,828 ,1611 ,1241 ,1985 ,596 ,273 ,529 ,1182 ,302 ,726 ,1942 ,733 ,1590 , + 1564 ,214 ,1156 ,1722 ,1215 ,1837 ,1729 ,1823 ,672 ,116 ,340 ,396 ,721 ,462 ,1615 ,1380 , + 1459 ,1553 ,636 ,586 ,1148 ,1147 ,1941 ,471 ,876 ,127 ,1938 ,2002 ,1563 ,1121 ,857 ,1179 , + 1983 ,1324 ,1726 ,1445 ,295 ,270 ,896 ,1947 ,1740 ,1211 ,128 ,1266 ,734 ,715 ,1562 ,285 , + 1139 ,304 ,526 ,653 ,1270 ,320 ,484 ,22 ,687 ,1065 ,489 ,827 ,993 ,1654 ,431 ,1552 , + 1418 ,1604 ,455 ,841 ,412 ,848 ,475 ,540 ,1903 ,575 ,584 ,300 ,1079 ,189 ,1481 ,893 , + 228 ,1577 ,429 ,635 ,106 ,1536 ,176 ,348 ,1733 ,1570 ,537 ,1840 ,798 ,410 ,1714 ,1318 , + 487 ,332 ,1109 ,1744 ,283 ,692 ,681 ,1744 ,1008 ,1715 ,1956 ,1066 ,1768 ,1645 ,139 ,1967 , + 897 ,132 ,1010 ,1932 ,277 ,1536 ,1541 ,952 ,19 ,88 ,1663 ,1232 ,1681 ,1878 ,1241 ,1805 , + 89 ,1401 ,544 ,1061 ,1166 ,267 ,1351 ,1998 ,1623 ,1898 ,425 ,1320 ,2006 ,865 ,1981 ,823 , + 1243 ,471 ,485 ,1765 ,391 ,1281 ,1607 ,1418 ,116 ,1702 ,1725 ,512 ,1088 ,1375 ,1994 ,1738 , + 725 ,1471 ,811 ,1251 ,1156 ,1664 ,898 ,1511 ,1872 ,1717 ,444 ,1005 ,254 ,103 ,202 ,1769 , + 1511 ,433 ,284 ,721 ,1741 ,56 ,615 ,916 ,887 ,1253 ,916 ,535 ,1666 ,1713 ,741 ,873 , + 447 ,492 ,388 ,321 ,1860 ,1456 ,1658 ,1682 ,848 ,462 ,2034 ,1368 ,1609 ,1887 ,510 ,1516 , + }; + } else { + std::ifstream fin(codes_path); + if (!fin) { + fprintf(stderr, "Error: cannot open codes file: %s\n", codes_path); + return 1; + } + std::string line; + while (std::getline(fin, line)) { + // Skip empty lines + if (line.empty()) continue; + try { + int code = std::stoi(line); + codes.push_back(code); + } catch (const std::exception& e) { + fprintf(stderr, "Error parsing code: %s\n", line.c_str()); + return 1; + } + } + if (codes.empty()) { + fprintf(stderr, "Error: no codes found in file: %s\n", codes_path); + return 1; + } + + printf("Loaded %d codes from %s\n", (int)codes.size(), codes_path); + } + + // build cgraph + int n_pos = -1; + int n_codes = codes.size(); + int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components; + GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiple of n_codes_per_embd"); + + ctx.build_graph([&](ggml_context * ctx_gf, ggml_cgraph * gf) { + ggml_tensor * inp_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_codes); + ggml_set_name(inp_dec, "inp_dec"); + ggml_set_input(inp_dec); + + // RVQ + ggml_tensor * embeddings = quantizer.decode(ctx_gf, inp_dec); + + // upsample + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + embeddings = mimi_conv_transpose_1d(ctx_gf, embeddings, ctx.get_weight("upsample.conv.weight"), nullptr, 2, 1, true); + + // transformer + n_pos = embeddings->ne[0]; + ggml_tensor * pos_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_dec, "pos_dec"); + ggml_set_input(pos_dec); + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + embeddings = transformer.forward(ctx_gf, embeddings, pos_dec); + + // SEANET decoder + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + ggml_tensor * output = decoder.forward(ctx_gf, embeddings); + + ggml_set_name(output, "output"); + ggml_set_output(output); + ggml_build_forward_expand(gf, output); + }); + + // position data + std::vector pos_data(1024); + for (int i = 0; i < (int)pos_data.size(); i++) { + pos_data[i] = i; + } + ctx.set_tensor_data("pos_dec", pos_data.data()); + + // code data (need to transpose it) + // code [n_codes, n_codes_per_embd] -> [n_codes_per_embd, n_codes] + std::vector codes_t(n_codes_per_embd * n_codes); + for (int i = 0; i < n_codes / n_codes_per_embd; i++) { + for (int j = 0; j < n_codes_per_embd; j++) { + int src_idx = i * n_codes_per_embd + j; + int dst_idx = j * (n_codes / n_codes_per_embd) + i; + codes_t[dst_idx] = codes[src_idx]; + } + } + ctx.set_tensor_data("inp_dec", codes_t.data()); + + ctx.compute(); + + auto output = ctx.get_tensor_data("output"); + auto output_tensor = output.first; + auto output_data = output.second; + printf("Output shape: [%lld, %lld]\n", output_tensor->ne[0], output_tensor->ne[1]); + + // print first 20 values + for (int i = 0; i < 20; i++) { + printf("%2.4f, ", ((float *)output_data.data())[i]); + } + printf("...\n"); + + // write to wav + std::vector wav_data(output_data.size() / sizeof(float)); + for (size_t i = 0; i < wav_data.size(); i++) { + wav_data[i] = ((float *)output_data.data())[i]; + } + printf("Writing to %s\n", out_path); + save_wav16(out_path, wav_data, 24000); +} From efeaa5712cb6489b9a704daf670d043e5e758347 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Mar 2025 09:06:00 +0100 Subject: [PATCH 02/31] fix llama-tts --- examples/tts/tts.cpp | 40 ---------------------------------------- 1 file changed, 40 deletions(-) diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 4cc42e1674ccc..b3461b5d273ef 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -71,46 +71,6 @@ static void print_usage(int, char ** argv) { LOG("\n"); } -struct wav_header { - char riff[4] = {'R', 'I', 'F', 'F'}; - uint32_t chunk_size; - char wave[4] = {'W', 'A', 'V', 'E'}; - char fmt[4] = {'f', 'm', 't', ' '}; - uint32_t fmt_chunk_size = 16; - uint16_t audio_format = 1; // PCM - uint16_t num_channels = 1; // Mono - uint32_t sample_rate; - uint32_t byte_rate; - uint16_t block_align; - uint16_t bits_per_sample = 16; - char data[4] = {'d', 'a', 't', 'a'}; - uint32_t data_size; -}; - -static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate) { - std::ofstream file(fname, std::ios::binary); - if (!file) { - LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); - return false; - } - - wav_header header; - header.sample_rate = sample_rate; - header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); - header.block_align = header.num_channels * (header.bits_per_sample / 8); - header.data_size = data.size() * (header.bits_per_sample / 8); - header.chunk_size = 36 + header.data_size; - - file.write(reinterpret_cast(&header), sizeof(header)); - - for (const auto & sample : data) { - int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); - file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); - } - - return file.good(); -} - static void fill_hann_window(int length, bool periodic, float * output) { int offset = -1; if (periodic) { From a98f19918d7e6cff600d1bf0db15ea9cb9bff0da Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Mar 2025 09:51:10 +0100 Subject: [PATCH 03/31] put mimi_model into a shared header --- examples/tts/CMakeLists.txt | 2 +- examples/tts/README-mimi.md | 2 +- examples/tts/mimi-model.cpp | 720 ++++++++++++++++++++++++++++++++++++ examples/tts/mimi-model.h | 32 ++ examples/tts/mimi.cpp | 677 +-------------------------------- 5 files changed, 762 insertions(+), 671 deletions(-) create mode 100644 examples/tts/mimi-model.cpp create mode 100644 examples/tts/mimi-model.h diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index f76d834b18fec..39e0a92c5acb4 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -5,7 +5,7 @@ target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) set(TARGET llama-mimi) -add_executable(${TARGET} mimi.cpp) +add_executable(${TARGET} mimi.cpp mimi-model.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/tts/README-mimi.md b/examples/tts/README-mimi.md index b46f5f77b95d0..6576a118291ad 100644 --- a/examples/tts/README-mimi.md +++ b/examples/tts/README-mimi.md @@ -24,7 +24,7 @@ cmake --build build -j --target llama-mimi # output: output.wav -# alternatively, use "dummy1" to get a "hey hello there" sample output file +# alternatively, use "dummy1" to get a "wah hello there" sample output file ./build/bin/llama-mimi kyutai-mimi.gguf dummy1 ``` diff --git a/examples/tts/mimi-model.cpp b/examples/tts/mimi-model.cpp new file mode 100644 index 0000000000000..31ff86256ae10 --- /dev/null +++ b/examples/tts/mimi-model.cpp @@ -0,0 +1,720 @@ +#include "ggml.h" +#include "ggml-cpp.h" +#include "ggml-cpu.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "gguf.h" + +#include "common.h" +#include "mimi-model.h" + +#include +#include +#include +#include +#include +#include +#include + +/** + * Implementation of Kyutai's Mimi model using GGML. + * Based on this research: https://github.com/ngxson/ggml-easy/blob/master/demo/kyutai-mimi.cpp + * + * NOTE: only decoder is working for now. + * + * Background: + * - The audio codes can be generated using any Mimi-based model, for example: Moshi, Hibiki, Sesame, etc + * - Audio codes must be in the order: (1 semantic component, 31 acoustic components) repeated N times + * + * How it works? + * 1. Audio code passed to RVQ (mimi_residual_vector_quantizer) to get the latent code + * 2. The latent code is passed to a mimi_conv_transpose_1d (depthwise) to upscale + * 3. The upscaled code is passed to transformer, it converts N frames to N frames + * 4. The output embeddings is then passed to SEANet (mimi_encoder_decoder) to get the final waveform + * 5. Waveform is written to a file + */ + +// copied from https://huggingface.co/kyutai/mimi/blob/main/config.json +struct mimi_config_t { + bool causal = true; + int sample_rate = 24000; + int max_position_embeddings = 8000; + int num_hidden_layers = 8; + int n_embd = 512; + int n_ffn = 2048; + int n_head = 8; + int n_head_kv = 8; + int n_rot = 64; + float norm_eps = 1e-5; + float rope_theta = 10000.0f; + int sliding_window = 250; + std::array upsampling_ratio = {8, 6, 5, 4}; + std::array downsampling_ratio = {4, 5, 6, 8}; // reverse of upsampling_ratio + // vector quantizer + float frame_rate = 12.5; + int audio_channels = 1; + int codebook_size = 2048; + int codebook_dim = 256; + int n_semantic_components = 1; + int n_acoustic_components = 31; + // decode + float trim_right_ratio = 1.0f; + int n_codes_per_frame = (sliding_window / 2) * (n_semantic_components + n_acoustic_components); +} mimi_config; + +// Adapted from https://github.com/ngxson/ggml-easy/blob/master/ggml-easy.h +struct mimi_ggml_ctx { + gguf_context * ctx_gguf = nullptr; + ggml_context * ctx_data = nullptr; + ggml_context * ctx_gf = nullptr; + + // CPU-only for now, as many kernels are missing and we actually get less performance with GPU + ggml_backend_t backend = nullptr; + ggml_backend_buffer_t buf = nullptr; + ggml_backend_sched_ptr sched; + + ggml_cgraph * gf = nullptr; + std::vector buf_compute_meta; + int max_nodes = 16 * 1024; + + std::unordered_map tensors; + + mimi_ggml_ctx() { + backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + auto buft = ggml_backend_get_default_buffer_type(backend); + sched.reset( + ggml_backend_sched_new(&backend, &buft, 1, max_nodes, false) + ); + buf_compute_meta.resize(max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); + } + + void load_gguf(const char * fname) { + ggml_context * meta = nullptr; + + gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &meta, + }; + + ctx_gguf = gguf_init_from_file(fname, params); + + // load tensors + const int n_tensors = gguf_get_n_tensors(ctx_gguf); + + std::vector read_buf; + ggml_init_params ggml_params = { + /*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ctx_data = ggml_init(ggml_params); + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + ggml_free(meta); + throw std::runtime_error("cannot open model file for loading tensors"); + } + + // add tensors to context + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + ggml_tensor * t = ggml_get_tensor(meta, name); + ggml_tensor * cur = ggml_dup_tensor(ctx_data, t); + ggml_set_name(cur, name); + tensors.insert({name, cur}); + } + + // alloc memory and offload data + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_data, buft); + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + ggml_tensor * cur = ggml_get_tensor(ctx_data, name); + const size_t offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i); + // printf("%s: Loading tensor \"%s\"\n", __func__, name); + fin.seekg(offset, std::ios::beg); + if (!fin) { + ggml_free(meta); + throw std::runtime_error(string_format("failed to seek for tensor: %s", name)); + } + int num_bytes = ggml_nbytes(cur); + if (ggml_backend_buft_is_host(buft)) { + // for the CPU and Metal backend, we can read directly into the tensor + fin.read(reinterpret_cast(cur->data), num_bytes); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(num_bytes); + fin.read(reinterpret_cast(read_buf.data()), num_bytes); + ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + } + } + printf("%s: Loaded %d tensors from %s\n", __func__, n_tensors, fname); + fin.close(); + + ggml_free(meta); + } + + /** + * Build a cgraph using the given builder function. + * + * The built cgraph will be stored in `ctx.gf` + */ + void build_graph(std::function builder_fn) { + ggml_free(ctx_gf); + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ctx_gf = ggml_init(params); + ggml_backend_sched_reset(sched.get()); + gf = ggml_new_graph_custom(ctx_gf, max_nodes, false); + + builder_fn(ctx_gf, gf); + ggml_backend_sched_alloc_graph(sched.get(), gf); + } + + ggml_status compute() { + ggml_status status = ggml_backend_sched_graph_compute(sched.get(), gf); + return status; + } + + void set_tensor_data(const std::string & name, const void * data) { + ggml_tensor * t = ggml_get_tensor(ctx_gf, name.c_str()); + if (!t) { + throw std::runtime_error(string_format("tensor not found: %s", name.c_str())); + } + ggml_backend_tensor_set(t, data, 0, ggml_nbytes(t)); + } + + std::pair> get_tensor_data(const std::string & name) { + ggml_tensor * t = ggml_get_tensor(ctx_gf, name.c_str()); + if (!t) { + throw std::runtime_error(string_format("tensor not found: %s", name.c_str())); + } + std::vector data(ggml_nbytes(t)); + ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t)); + return std::make_pair(t, data); + } + + ggml_tensor * get_weight(const char *fmt, ...) { + std::vector str(128); + va_list va; + va_start(va, fmt); + vsnprintf(str.data(), 128, fmt, va); + va_end(va); + auto it = tensors.find(str.data()); + if (it == tensors.end()) { + throw std::runtime_error(string_format("weight tensor not found: %s", str.data())); + } + return it->second; + } + + ~mimi_ggml_ctx() { + ggml_free(ctx_data); + gguf_free(ctx_gguf); + ggml_backend_buffer_free(buf); + } +}; + +/////////////////////////////////////////////////////////////////////////// +// extension to ggml.h +// TODO: add these ops to the library (ofc with a more optimized kernel) + + +// mode: (0) constant, (1) reflect, (2) replicate, (3) circular +// value is only used in "constant" +// only "constant" with 0.0f and "replicate" are implemented here +static ggml_tensor * ggml_pad_ext(ggml_context * ctx0, ggml_tensor * x, int mode, + int64_t pad_left, int64_t pad_right, float value = 0.0f) { + GGML_ASSERT(value == 0.0f); // we can technically use ggml_arange, but for simplication we only support 0.0f + GGML_ASSERT(mode == 0 || mode == 2); + if (pad_left > 0) { + ggml_tensor * tmp = ggml_new_tensor_2d(ctx0, x->type, pad_left, x->ne[1]); + if (mode == 0) { + tmp = ggml_scale(ctx0, tmp, value); + } else if (mode == 2) { + ggml_tensor * elem = ggml_view_2d(ctx0, x, 1, x->ne[1], x->nb[1], 0); // get first column + tmp = ggml_repeat(ctx0, elem, tmp); + } + x = ggml_concat(ctx0, tmp, x, 0); + } + if (pad_right > 0) { + ggml_tensor * tmp = ggml_new_tensor_2d(ctx0, x->type, pad_right, x->ne[1]); + if (mode == 0) { + tmp = ggml_scale(ctx0, tmp, value); + } else if (mode == 2) { + int64_t last = x->ne[0] - 1; + ggml_tensor * elem = ggml_view_2d(ctx0, x, 1, x->ne[1], x->nb[1], last * ggml_element_size(x)); // get last column + tmp = ggml_repeat(ctx0, elem, tmp); + } + x = ggml_concat(ctx0, x, tmp, 0); + } + return x; +} + + + + +/////////////////////////////////////////////////////////////////////////// +// MimiConv and MimiConvTranspose + +static int64_t div_ceil(int64_t a, int64_t b) { + return a / b + (a % b ? 1 : 0); +} + +static ggml_tensor * mimi_conv_1d(ggml_context * ctx0, ggml_tensor * x, + ggml_tensor * kernel, ggml_tensor * bias, int stride, int dilation, bool pad_zero = true) { + int64_t kernel_size = (kernel->ne[0] - 1) * dilation + 1; + int64_t p_total = kernel_size - stride; // padding total + int64_t p_half = p_total / 2; + + int64_t n_frames = div_ceil(x->ne[0] - kernel_size + p_total, stride); + int64_t ideal_len = n_frames * stride + kernel_size - p_total; + int64_t p_extra = ideal_len - x->ne[0]; + + int64_t p_right = (mimi_config.causal ? 0 : p_half) + p_extra; + int64_t p_left = p_total - (mimi_config.causal ? 0 : p_half); + + x = ggml_pad_ext(ctx0, x, pad_zero ? 0 : 2, p_left, p_right); + + x = ggml_conv_1d(ctx0, kernel, x, stride, 0, dilation); + if (bias) { + x = ggml_add(ctx0, x, bias); + } + ggml_set_name(x, "mimi_conv_1d"); + return x; +} + +static ggml_tensor * mimi_conv_transpose_1d(ggml_context * ctx0, ggml_tensor * x, + ggml_tensor * kernel, ggml_tensor * bias, int stride, int dilation, bool depthwise) { + GGML_ASSERT(x->ne[1] == kernel->ne[2]); + int64_t n_rows = x->ne[1]; + int64_t kernel_size = kernel->ne[0]; + int64_t p_total = kernel_size - stride; // padding total + + int64_t p_right = mimi_config.causal + ? (float)p_total / mimi_config.trim_right_ratio + : p_total / 2; + int64_t p_left = p_total - p_right; + + ggml_tensor * out = nullptr; + + if (depthwise) { + for (int64_t ir = 0; ir < n_rows; ir++) { + ggml_tensor * row = ggml_view_1d(ctx0, x, + x->ne[0], ir*x->ne[0]*ggml_element_size(x)); + ggml_tensor * krn = ggml_view_1d(ctx0, kernel, + kernel->ne[0], ir*kernel->ne[0]*ggml_element_size(kernel)); + row = ggml_conv_transpose_1d(ctx0, krn, row, stride, 0, dilation); + // unpad (remove p_right and p_left columns) + row = ggml_view_1d(ctx0, row, row->ne[0] - p_total, p_left*ggml_element_size(row)); + + // TODO: concat can be slow, we should use ggml_view_1d/ggml_cpy to avoid realloc + out = out ? ggml_concat(ctx0, out, row, 1) : row; + } + + } else { + out = ggml_conv_transpose_1d(ctx0, kernel, x, stride, 0, dilation); + // unpad + out = ggml_view_2d(ctx0, out, + out->ne[0] - p_total, out->ne[1], + out->nb[1], p_left*ggml_element_size(out)); + } + + if (bias) { + out = ggml_add(ctx0, out, bias); + } + + return out; +} + + + +/////////////////////////////////////////////////////////////////////////// + +// based on MimiEncoder +// SEANet encoder as used by Mimi. +struct mimi_encoder_decoder { + mimi_ggml_ctx & ctx; + struct layer { + bool is_elu = false; + bool is_resnet = false; + bool is_transposed_conv = false; + ggml_tensor * conv_0_w; + ggml_tensor * conv_0_b; + ggml_tensor * conv_1_w; + ggml_tensor * conv_1_b; + int stride = 1; + }; + std::vector layers; + + std::array repeated_pattern = {1, 4, 7, 10}; + + mimi_encoder_decoder(mimi_ggml_ctx & ctx): ctx(ctx) { + layers.push_back({ + .conv_0_w = ctx.get_weight("decoder.layers.0.conv.weight"), + .conv_0_b = ctx.get_weight("decoder.layers.0.conv.bias"), + }); + for (int i = 0; i < (int)repeated_pattern.size(); ++i) { + int i_start = repeated_pattern[i]; + // upsampling layers + layers.push_back({ + .is_elu = true, // layer (i_start) + }); + layers.push_back({ + .conv_0_w = ctx.get_weight("decoder.layers.%d.conv.weight", i_start + 1), + .conv_0_b = ctx.get_weight("decoder.layers.%d.conv.bias", i_start + 1), + .stride = mimi_config.upsampling_ratio[i], + .is_transposed_conv = true, + }); + // residual layers + layers.push_back({ + .is_resnet = true, + .conv_0_w = ctx.get_weight("decoder.layers.%d.block.1.conv.weight", i_start + 2), + .conv_0_b = ctx.get_weight("decoder.layers.%d.block.1.conv.bias", i_start + 2), + .conv_1_w = ctx.get_weight("decoder.layers.%d.block.3.conv.weight", i_start + 2), + .conv_1_b = ctx.get_weight("decoder.layers.%d.block.3.conv.bias", i_start + 2), + }); + } + layers.push_back({ + .is_elu = true, // layer 13 + }); + layers.push_back({ + .conv_0_w = ctx.get_weight("decoder.layers.14.conv.weight"), + .conv_0_b = ctx.get_weight("decoder.layers.14.conv.bias"), + }); + } + + ggml_tensor * forward(ggml_context * ctx0, ggml_tensor * input) { + ggml_tensor * x = input; + + for (auto & layer : layers) { + if (layer.is_elu) { + x = ggml_elu(ctx0, x); + } else if (layer.is_resnet) { + ggml_tensor * residual = x; + x = ggml_elu(ctx0, x); + x = mimi_conv_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, 1, 1); + x = ggml_elu(ctx0, x); + x = mimi_conv_1d(ctx0, x, layer.conv_1_w, layer.conv_1_b, 1, 1); + x = ggml_add(ctx0, x, residual); + } else { + x = layer.is_transposed_conv + ? mimi_conv_transpose_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, layer.stride, 1, false) + : mimi_conv_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, layer.stride, 1); + } + } + + return x; + } +}; + +struct mimi_transformer { + struct layer { + ggml_tensor * inp_norm_w; + ggml_tensor * inp_norm_b; + + ggml_tensor * attn_q; + ggml_tensor * attn_k; + ggml_tensor * attn_v; + ggml_tensor * attn_o; + ggml_tensor * attn_post_norm_w; + ggml_tensor * attn_post_norm_b; + ggml_tensor * attn_layer_scale; + + ggml_tensor * ffn_up; + ggml_tensor * ffn_down; + ggml_tensor * mlp_layer_scale; + }; + std::vector layers; + + mimi_transformer(mimi_ggml_ctx & ctx, const char * prefix, int n_layers) { + for (int il = 0; il < n_layers; il++) { + layers.push_back({ + .inp_norm_w = ctx.get_weight("%s_transformer.layers.%d.input_layernorm.weight", prefix, il), + .inp_norm_b = ctx.get_weight("%s_transformer.layers.%d.input_layernorm.bias", prefix, il), + + .attn_q = ctx.get_weight("%s_transformer.layers.%d.self_attn.q_proj.weight", prefix, il), + .attn_k = ctx.get_weight("%s_transformer.layers.%d.self_attn.k_proj.weight", prefix, il), + .attn_v = ctx.get_weight("%s_transformer.layers.%d.self_attn.v_proj.weight", prefix, il), + .attn_o = ctx.get_weight("%s_transformer.layers.%d.self_attn.o_proj.weight", prefix, il), + .attn_post_norm_w = ctx.get_weight("%s_transformer.layers.%d.post_attention_layernorm.weight", prefix, il), + .attn_post_norm_b = ctx.get_weight("%s_transformer.layers.%d.post_attention_layernorm.bias", prefix, il), + .attn_layer_scale = ctx.get_weight("%s_transformer.layers.%d.self_attn_layer_scale.scale", prefix, il), + + .ffn_up = ctx.get_weight("%s_transformer.layers.%d.mlp.fc1.weight", prefix, il), + .ffn_down = ctx.get_weight("%s_transformer.layers.%d.mlp.fc2.weight", prefix, il), + .mlp_layer_scale = ctx.get_weight("%s_transformer.layers.%d.mlp_layer_scale.scale", prefix, il), + }); + } + } + + ggml_tensor * forward(ggml_context * ctx0, ggml_tensor * input, ggml_tensor * inp_pos) { + int n_tokens = input->ne[1]; + ggml_tensor * x = input; + + auto layer_norm = [&](ggml_tensor * x, ggml_tensor * w, ggml_tensor * b) { + x = ggml_norm(ctx0, x, mimi_config.norm_eps); + x = ggml_mul(ctx0, x, w); + x = ggml_add(ctx0, x, b); + return x; + }; + + ggml_tensor * residual = input; + + for (auto & layer : layers) { + residual = x; + + // input layer norm + x = layer_norm(x, layer.inp_norm_w, layer.inp_norm_b); + + // self attention + { + ggml_tensor * q = ggml_mul_mat(ctx0, layer.attn_q, x); + ggml_tensor * k = ggml_mul_mat(ctx0, layer.attn_k, x); + ggml_tensor * v = ggml_mul_mat(ctx0, layer.attn_v, x); + + int n_embd_head = mimi_config.n_embd / mimi_config.n_head; + q = ggml_reshape_3d(ctx0, q, n_embd_head, mimi_config.n_head, n_tokens); + k = ggml_reshape_3d(ctx0, k, n_embd_head, mimi_config.n_head_kv, n_tokens); + v = ggml_reshape_3d(ctx0, v, n_embd_head, mimi_config.n_head_kv, n_tokens); + + int n_rot = n_embd_head; + q = ggml_rope_inplace(ctx0, q, inp_pos, n_rot, 0); + q = ggml_cont(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3)); + + k = ggml_rope_inplace(ctx0, k, inp_pos, n_rot, 0); + k = ggml_cont(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3)); + + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); // mimic behavior of llama.cpp + kq = ggml_scale_inplace(ctx0, kq, 1.0f / std::sqrt(n_embd_head)); + ggml_tensor * kq_masked = ggml_diag_mask_inf_inplace(ctx0, kq, n_tokens); + kq = ggml_soft_max_inplace(ctx0, kq_masked); + + v = ggml_cont(ctx0, ggml_permute(ctx0, v, 1, 2, 0, 3)); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + kqv = ggml_reshape_3d(ctx0, kqv, n_embd_head, n_tokens, mimi_config.n_head); + kqv = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + kqv = ggml_cont_2d(ctx0, kqv, mimi_config.n_embd, n_tokens); + + x = ggml_mul_mat(ctx0, layer.attn_o, kqv); + } + + // residual + x = ggml_mul(ctx0, x, layer.attn_layer_scale); + x = ggml_add(ctx0, x, residual); + + residual = x; + x = layer_norm(x, layer.attn_post_norm_w, layer.attn_post_norm_b); + + // mlp + { + x = ggml_mul_mat(ctx0, layer.ffn_up, x); + x = ggml_gelu(ctx0, x); + x = ggml_mul_mat(ctx0, layer.ffn_down, x); + } + + // residual + x = ggml_mul(ctx0, x, layer.mlp_layer_scale); + x = ggml_add(ctx0, x, residual); + } + + return x; + } +}; + +struct mimi_residual_vector_quantizer { + struct component { + ggml_tensor * codebook; + }; + + ggml_tensor * semantic_inp_proj; + std::vector semantic_components; + ggml_tensor * semantic_out_proj; + + ggml_tensor * acoustic_inp_proj; + std::vector acoustic_components; + ggml_tensor * acoustic_out_proj; + + mimi_residual_vector_quantizer(mimi_ggml_ctx & ctx) { + semantic_inp_proj = ctx.get_weight("quantizer.semantic_rvq.input_proj.weight"); + semantic_out_proj = ctx.get_weight("quantizer.semantic_rvq.output_proj.weight"); + for (int i = 0; i < mimi_config.n_semantic_components; i++) { + semantic_components.push_back({ + .codebook = ctx.get_weight("quantizer.semantic_rvq.layers.%d.codebook", i), + }); + } + acoustic_inp_proj = ctx.get_weight("quantizer.acoustic_rvq.input_proj.weight"); + acoustic_out_proj = ctx.get_weight("quantizer.acoustic_rvq.output_proj.weight"); + for (int i = 0; i < mimi_config.n_acoustic_components; i++) { + acoustic_components.push_back({ + .codebook = ctx.get_weight("quantizer.acoustic_rvq.layers.%d.codebook", i), + }); + } + } + + // the input has shape [n_codes, n_codes_per_embd] + // first row is semantic, the rest are acoustic + // example: [ [semantic], [acoustic1], [acoustic2], ... ] + ggml_tensor * decode(ggml_context * ctx0, ggml_tensor * input) { + GGML_ASSERT(input->type == GGML_TYPE_I32); + + size_t n_semantic = semantic_components.size(); + int64_t n_codes_per_embd = (n_semantic + acoustic_components.size()); + int64_t n_codes = input->ne[0] / n_codes_per_embd; + + GGML_ASSERT(input->ne[0] % n_codes_per_embd == 0); + + ggml_tensor * out_s = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mimi_config.codebook_dim, n_codes); + ggml_tensor * out_a = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mimi_config.codebook_dim, n_codes); + out_s = ggml_scale(ctx0, out_s, 0.0f); // clear + out_a = ggml_scale(ctx0, out_a, 0.0f); // clear + + for (size_t ir = 0; ir < (size_t)n_codes_per_embd; ir++) { + ggml_tensor * row = ggml_view_1d(ctx0, input, n_codes, ir*n_codes*ggml_element_size(input)); + if (ir < n_semantic) { + // semantic + ggml_tensor * codebook = semantic_components[ir].codebook; + ggml_tensor * embd = ggml_get_rows(ctx0, codebook, row); + out_s = ggml_add(ctx0, out_s, embd); + } else { + // acoustic + ggml_tensor * codebook = acoustic_components[ir-n_semantic].codebook; + ggml_tensor * embd = ggml_get_rows(ctx0, codebook, row); + out_a = ggml_add(ctx0, out_a, embd); + } + } + + out_s = ggml_mul_mat(ctx0, semantic_out_proj, out_s); + out_a = ggml_mul_mat(ctx0, acoustic_out_proj, out_a); + + return ggml_add(ctx0, out_s, out_a); + } +}; + + +mimi_model::mimi_model(const char * fname, bool verbose) : verbose(verbose) { + ctx.reset(new mimi_ggml_ctx()); + ctx->load_gguf(fname); + + // initialize components + seanet_dec .reset(new mimi_encoder_decoder(*ctx)); + transformer_dec.reset(new mimi_transformer(*ctx, "decoder", mimi_config.num_hidden_layers)); + quantizer .reset(new mimi_residual_vector_quantizer(*ctx)); +} + +mimi_model::~mimi_model() { +} + +std::vector mimi_model::decode_frame(const std::vector & codes, int & n_past) { + // build cgraph + int n_pos = -1; + int n_codes = codes.size(); + int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components; + GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiple of n_codes_per_embd"); + + ctx->build_graph([&](ggml_context * ctx_gf, ggml_cgraph * gf) { + ggml_tensor * inp_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_codes); + ggml_set_name(inp_dec, "inp_dec"); + ggml_set_input(inp_dec); + + // RVQ + ggml_tensor * embeddings = quantizer->decode(ctx_gf, inp_dec); + + // upsample + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + embeddings = mimi_conv_transpose_1d(ctx_gf, embeddings, ctx->get_weight("upsample.conv.weight"), nullptr, 2, 1, true); + + // transformer + n_pos = embeddings->ne[0]; + ggml_tensor * pos_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_dec, "pos_dec"); + ggml_set_input(pos_dec); + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + embeddings = transformer_dec->forward(ctx_gf, embeddings, pos_dec); + + // SEANET decoder + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + ggml_tensor * output = seanet_dec->forward(ctx_gf, embeddings); + + ggml_set_name(output, "output"); + ggml_set_output(output); + ggml_build_forward_expand(gf, output); + }); + + // position data + GGML_ASSERT(n_pos <= mimi_config.sliding_window); + std::vector pos_data(n_pos); + for (int i = 0; i < (int)pos_data.size(); i++) { + pos_data[i] = i + n_past; + } + n_past += n_pos; + if (verbose) { + printf("%s: n_pos: %d, n_past: %d\n", __func__, n_pos, n_past); + } + ctx->set_tensor_data("pos_dec", pos_data.data()); + + // code data (need to transpose it) + // code [n_codes, n_codes_per_embd] -> [n_codes_per_embd, n_codes] + std::vector codes_t(n_codes_per_embd * n_codes); + for (int i = 0; i < n_codes / n_codes_per_embd; i++) { + for (int j = 0; j < n_codes_per_embd; j++) { + int src_idx = i * n_codes_per_embd + j; + int dst_idx = j * (n_codes / n_codes_per_embd) + i; + codes_t[dst_idx] = codes[src_idx]; + } + } + ctx->set_tensor_data("inp_dec", codes_t.data()); + + ctx->compute(); + + auto output = ctx->get_tensor_data("output"); + // auto output_tensor = output.first; + auto output_data = output.second; + // printf("Output shape: [%lld, %lld]\n", output_tensor->ne[0], output_tensor->ne[1]); + + std::vector wav_data(output_data.size() / sizeof(float)); + for (size_t i = 0; i < wav_data.size(); i++) { + wav_data[i] = ((float *)output_data.data())[i]; + } + + return wav_data; +} + +std::vector mimi_model::decode(const std::vector & codes) { + std::vector output; + + if (verbose) { + printf("%s: n_codes: %zu\n", __func__, codes.size()); + } + + int64_t t_start = ggml_time_ms(); + int n_frames = 0; + + int n_past = 0; + for (size_t i = 0; i < codes.size(); i += mimi_config.n_codes_per_frame) { + size_t remaining = std::min((size_t)mimi_config.n_codes_per_frame, codes.size() - i); + std::vector frame(codes.begin() + i, codes.begin() + i + remaining); + + auto wav_data = decode_frame(frame, n_past); + output.insert(output.end(), wav_data.begin(), wav_data.end()); + + n_frames++; + } + + int64_t t_end = ggml_time_ms(); + if (verbose) { + printf("%s: n_frames: %d, time: %" PRId64 "ms, per_frame: %" PRId64 "ms\n", __func__, n_frames, t_end - t_start, (t_end - t_start) / n_frames); + } + + return output; +} + +int mimi_model::get_sample_rate() const { + return mimi_config.sample_rate; +} diff --git a/examples/tts/mimi-model.h b/examples/tts/mimi-model.h new file mode 100644 index 0000000000000..d48c19b5476e3 --- /dev/null +++ b/examples/tts/mimi-model.h @@ -0,0 +1,32 @@ +#pragma once + +#include "ggml.h" +#include +#include + +struct mimi_ggml_ctx; +struct mimi_encoder_decoder; +struct mimi_transformer; +struct mimi_residual_vector_quantizer; + +struct mimi_model { + bool verbose = false; + std::unique_ptr ctx; + + std::unique_ptr seanet_dec; + std::unique_ptr transformer_dec; + std::unique_ptr quantizer; + + mimi_model(const char * fname, bool verbose = false); + ~mimi_model(); + + int get_sample_rate() const; + + std::vector decode(const std::vector & codes); + + // TODO: implement encoding pass + // std::vector encode(const std::vector & wav_data); + +private: + std::vector decode_frame(const std::vector & codes, int & n_past); +}; diff --git a/examples/tts/mimi.cpp b/examples/tts/mimi.cpp index 2c5833faa277b..052f546b43a23 100644 --- a/examples/tts/mimi.cpp +++ b/examples/tts/mimi.cpp @@ -1,610 +1,17 @@ -#include "ggml.h" -#include "ggml-cpp.h" -#include "ggml-cpu.h" -#include "ggml-alloc.h" -#include "ggml-backend.h" -#include "gguf.h" - #include "common.h" +#include "mimi-model.h" -#include #include -#include #include -#include -#include - -/** - * Implementation of Kyutai's Mimi model using GGML. - * Based on this research: https://github.com/ngxson/ggml-easy/blob/master/demo/kyutai-mimi.cpp - * - * NOTE: only decoder is working for now. - * - * Background: - * - The audio codes can be generated using any Mimi-based model, for example: Moshi, Hibiki, Sesame, etc - * - Audio codes must be in the order: (1 semantic component, 31 acoustic components) repeated N times - * - * How it works? - * 1. Audio code passed to RVQ (mimi_residual_vector_quantizer) to get the latent code - * 2. The latent code is passed to a mimi_conv_transpose_1d (depthwise) to upscale - * 3. The upscaled code is passed to transformer, it converts N frames to N frames - * 4. The output embeddings is then passed to SEANet (mimi_encoder_decoder) to get the final waveform - * 5. Waveform is written to a file - */ - -// copied from https://huggingface.co/kyutai/mimi/blob/main/config.json -struct mimi_config_t { - bool causal = true; - int max_position_embeddings = 8000; - int num_hidden_layers = 8; - int n_embd = 512; - int n_ffn = 2048; - int n_head = 8; - int n_head_kv = 8; - int n_rot = 64; - float norm_eps = 1e-5; - float rope_theta = 10000.0f; - int sliding_window = 250; - std::array upsampling_ratio = {8, 6, 5, 4}; - std::array downsampling_ratio = {4, 5, 6, 8}; // reverse of upsampling_ratio - // vector quantizer - float frame_rate = 12.5; - int audio_channels = 1; - int codebook_size = 2048; - int codebook_dim = 256; - int n_semantic_components = 1; - int n_acoustic_components = 31; - // decode - float trim_right_ratio = 1.0f; -} mimi_config; - -// Adapted from https://github.com/ngxson/ggml-easy/blob/master/ggml-easy.h -struct mimi_ggml_ctx { - gguf_context * ctx_gguf = nullptr; - ggml_context * ctx_data = nullptr; - ggml_context * ctx_gf = nullptr; - - // CPU-only for now, as many kernels are missing and we actually get less performance with GPU - ggml_backend_t backend = nullptr; - ggml_backend_buffer_t buf = nullptr; - ggml_backend_sched_ptr sched; - - ggml_cgraph * gf = nullptr; - std::vector buf_compute_meta; - int max_nodes = 16 * 1024; - - std::unordered_map tensors; - - mimi_ggml_ctx() { - backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); - auto buft = ggml_backend_get_default_buffer_type(backend); - sched.reset( - ggml_backend_sched_new(&backend, &buft, 1, max_nodes, false) - ); - buf_compute_meta.resize(max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); - } - - void load_gguf(const char * fname) { - ggml_context * meta = nullptr; - - gguf_init_params params = { - /*.no_alloc = */ true, - /*.ctx = */ &meta, - }; - - ctx_gguf = gguf_init_from_file(fname, params); - - // load tensors - const int n_tensors = gguf_get_n_tensors(ctx_gguf); - - std::vector read_buf; - ggml_init_params ggml_params = { - /*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ctx_data = ggml_init(ggml_params); - auto fin = std::ifstream(fname, std::ios::binary); - if (!fin) { - ggml_free(meta); - throw std::runtime_error("cannot open model file for loading tensors"); - } - - // add tensors to context - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx_gguf, i); - ggml_tensor * t = ggml_get_tensor(meta, name); - ggml_tensor * cur = ggml_dup_tensor(ctx_data, t); - ggml_set_name(cur, name); - tensors.insert({name, cur}); - } - - // alloc memory and offload data - ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); - buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_data, buft); - ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - for (int i = 0; i < n_tensors; ++i) { - const char * name = gguf_get_tensor_name(ctx_gguf, i); - ggml_tensor * cur = ggml_get_tensor(ctx_data, name); - const size_t offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i); - // printf("%s: Loading tensor \"%s\"\n", __func__, name); - fin.seekg(offset, std::ios::beg); - if (!fin) { - ggml_free(meta); - throw std::runtime_error(string_format("failed to seek for tensor: %s", name)); - } - int num_bytes = ggml_nbytes(cur); - if (ggml_backend_buft_is_host(buft)) { - // for the CPU and Metal backend, we can read directly into the tensor - fin.read(reinterpret_cast(cur->data), num_bytes); - } else { - // read into a temporary buffer first, then copy to device memory - read_buf.resize(num_bytes); - fin.read(reinterpret_cast(read_buf.data()), num_bytes); - ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); - } - } - printf("%s: Loaded %d tensors from %s\n", __func__, n_tensors, fname); - fin.close(); - - ggml_free(meta); - } - - /** - * Build a cgraph using the given builder function. - * - * The built cgraph will be stored in `ctx.gf` - */ - void build_graph(std::function builder_fn) { - ggml_free(ctx_gf); - struct ggml_init_params params = { - /*.mem_size =*/ buf_compute_meta.size(), - /*.mem_buffer =*/ buf_compute_meta.data(), - /*.no_alloc =*/ true, - }; - - ctx_gf = ggml_init(params); - ggml_backend_sched_reset(sched.get()); - gf = ggml_new_graph_custom(ctx_gf, max_nodes, false); - - builder_fn(ctx_gf, gf); - ggml_backend_sched_alloc_graph(sched.get(), gf); - } - - ggml_status compute() { - ggml_status status = ggml_backend_sched_graph_compute(sched.get(), gf); - return status; - } - - void set_tensor_data(const std::string & name, const void * data) { - ggml_tensor * t = ggml_get_tensor(ctx_gf, name.c_str()); - if (!t) { - throw std::runtime_error(string_format("tensor not found: %s", name.c_str())); - } - ggml_backend_tensor_set(t, data, 0, ggml_nbytes(t)); - } - - std::pair> get_tensor_data(const std::string & name) { - ggml_tensor * t = ggml_get_tensor(ctx_gf, name.c_str()); - if (!t) { - throw std::runtime_error(string_format("tensor not found: %s", name.c_str())); - } - std::vector data(ggml_nbytes(t)); - ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t)); - return std::make_pair(t, data); - } - - ggml_tensor * get_weight(const char *fmt, ...) { - std::vector str(128); - va_list va; - va_start(va, fmt); - vsnprintf(str.data(), 128, fmt, va); - va_end(va); - auto it = tensors.find(str.data()); - if (it == tensors.end()) { - throw std::runtime_error(string_format("weight tensor not found: %s", str.data())); - } - return it->second; - } - - ~mimi_ggml_ctx() { - ggml_free(ctx_data); - gguf_free(ctx_gguf); - ggml_backend_buffer_free(buf); - } -}; - -/////////////////////////////////////////////////////////////////////////// -// extension to ggml.h -// TODO: add these ops to the library (ofc with a more optimized kernel) - - -// mode: (0) constant, (1) reflect, (2) replicate, (3) circular -// value is only used in "constant" -// only "constant" with 0.0f and "replicate" are implemented here -static ggml_tensor * ggml_pad_ext(ggml_context * ctx0, ggml_tensor * x, int mode, - int64_t pad_left, int64_t pad_right, float value = 0.0f) { - GGML_ASSERT(value == 0.0f); // we can technically use ggml_arange, but for simplication we only support 0.0f - GGML_ASSERT(mode == 0 || mode == 2); - if (pad_left > 0) { - ggml_tensor * tmp = ggml_new_tensor_2d(ctx0, x->type, pad_left, x->ne[1]); - if (mode == 0) { - tmp = ggml_scale(ctx0, tmp, value); - } else if (mode == 2) { - ggml_tensor * elem = ggml_view_2d(ctx0, x, 1, x->ne[1], x->nb[1], 0); // get first column - tmp = ggml_repeat(ctx0, elem, tmp); - } - x = ggml_concat(ctx0, tmp, x, 0); - } - if (pad_right > 0) { - ggml_tensor * tmp = ggml_new_tensor_2d(ctx0, x->type, pad_right, x->ne[1]); - if (mode == 0) { - tmp = ggml_scale(ctx0, tmp, value); - } else if (mode == 2) { - int64_t last = x->ne[0] - 1; - ggml_tensor * elem = ggml_view_2d(ctx0, x, 1, x->ne[1], x->nb[1], last * ggml_element_size(x)); // get last column - tmp = ggml_repeat(ctx0, elem, tmp); - } - x = ggml_concat(ctx0, x, tmp, 0); - } - return x; -} - - - - -/////////////////////////////////////////////////////////////////////////// -// MimiConv and MimiConvTranspose - -static int64_t div_ceil(int64_t a, int64_t b) { - return a / b + (a % b ? 1 : 0); -} - -static ggml_tensor * mimi_conv_1d(ggml_context * ctx0, ggml_tensor * x, - ggml_tensor * kernel, ggml_tensor * bias, int stride, int dilation, bool pad_zero = true) { - int64_t kernel_size = (kernel->ne[0] - 1) * dilation + 1; - int64_t p_total = kernel_size - stride; // padding total - int64_t p_half = p_total / 2; - - int64_t n_frames = div_ceil(x->ne[0] - kernel_size + p_total, stride); - int64_t ideal_len = n_frames * stride + kernel_size - p_total; - int64_t p_extra = ideal_len - x->ne[0]; - - int64_t p_right = (mimi_config.causal ? 0 : p_half) + p_extra; - int64_t p_left = p_total - (mimi_config.causal ? 0 : p_half); - - x = ggml_pad_ext(ctx0, x, pad_zero ? 0 : 2, p_left, p_right); - - x = ggml_conv_1d(ctx0, kernel, x, stride, 0, dilation); - if (bias) { - x = ggml_add(ctx0, x, bias); - } - ggml_set_name(x, "mimi_conv_1d"); - return x; -} - -static ggml_tensor * mimi_conv_transpose_1d(ggml_context * ctx0, ggml_tensor * x, - ggml_tensor * kernel, ggml_tensor * bias, int stride, int dilation, bool depthwise) { - GGML_ASSERT(x->ne[1] == kernel->ne[2]); - int64_t n_rows = x->ne[1]; - int64_t kernel_size = kernel->ne[0]; - int64_t p_total = kernel_size - stride; // padding total - - int64_t p_right = mimi_config.causal - ? (float)p_total / mimi_config.trim_right_ratio - : p_total / 2; - int64_t p_left = p_total - p_right; - - ggml_tensor * out = nullptr; - - if (depthwise) { - for (int64_t ir = 0; ir < n_rows; ir++) { - ggml_tensor * row = ggml_view_1d(ctx0, x, - x->ne[0], ir*x->ne[0]*ggml_element_size(x)); - ggml_tensor * krn = ggml_view_1d(ctx0, kernel, - kernel->ne[0], ir*kernel->ne[0]*ggml_element_size(kernel)); - row = ggml_conv_transpose_1d(ctx0, krn, row, stride, 0, dilation); - // unpad (remove p_right and p_left columns) - row = ggml_view_1d(ctx0, row, row->ne[0] - p_total, p_left*ggml_element_size(row)); - - // TODO: concat can be slow, we should use ggml_view_1d/ggml_cpy to avoid realloc - out = out ? ggml_concat(ctx0, out, row, 1) : row; - } - - } else { - out = ggml_conv_transpose_1d(ctx0, kernel, x, stride, 0, dilation); - // unpad - out = ggml_view_2d(ctx0, out, - out->ne[0] - p_total, out->ne[1], - out->nb[1], p_left*ggml_element_size(out)); - } - - if (bias) { - out = ggml_add(ctx0, out, bias); - } - - return out; -} - -/////////////////////////////////////////////////////////////////////////// - -// based on MimiEncoder -// SEANet encoder as used by Mimi. -struct mimi_encoder_decoder { - mimi_ggml_ctx & ctx; - struct layer { - bool is_elu = false; - bool is_resnet = false; - bool is_transposed_conv = false; - ggml_tensor * conv_0_w; - ggml_tensor * conv_0_b; - ggml_tensor * conv_1_w; - ggml_tensor * conv_1_b; - int stride = 1; - }; - std::vector layers; - - std::array repeated_pattern = {1, 4, 7, 10}; - - mimi_encoder_decoder(mimi_ggml_ctx & ctx): ctx(ctx) { - layers.push_back({ - .conv_0_w = ctx.get_weight("decoder.layers.0.conv.weight"), - .conv_0_b = ctx.get_weight("decoder.layers.0.conv.bias"), - }); - for (int i = 0; i < (int)repeated_pattern.size(); ++i) { - int i_start = repeated_pattern[i]; - // upsampling layers - layers.push_back({ - .is_elu = true, // layer (i_start) - }); - layers.push_back({ - .conv_0_w = ctx.get_weight("decoder.layers.%d.conv.weight", i_start + 1), - .conv_0_b = ctx.get_weight("decoder.layers.%d.conv.bias", i_start + 1), - .stride = mimi_config.upsampling_ratio[i], - .is_transposed_conv = true, - }); - // residual layers - layers.push_back({ - .is_resnet = true, - .conv_0_w = ctx.get_weight("decoder.layers.%d.block.1.conv.weight", i_start + 2), - .conv_0_b = ctx.get_weight("decoder.layers.%d.block.1.conv.bias", i_start + 2), - .conv_1_w = ctx.get_weight("decoder.layers.%d.block.3.conv.weight", i_start + 2), - .conv_1_b = ctx.get_weight("decoder.layers.%d.block.3.conv.bias", i_start + 2), - }); - } - layers.push_back({ - .is_elu = true, // layer 13 - }); - layers.push_back({ - .conv_0_w = ctx.get_weight("decoder.layers.14.conv.weight"), - .conv_0_b = ctx.get_weight("decoder.layers.14.conv.bias"), - }); - } - - ggml_tensor * forward(ggml_context * ctx0, ggml_tensor * input) { - ggml_tensor * x = input; - - for (auto & layer : layers) { - if (layer.is_elu) { - x = ggml_elu(ctx0, x); - } else if (layer.is_resnet) { - ggml_tensor * residual = x; - x = ggml_elu(ctx0, x); - x = mimi_conv_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, 1, 1); - x = ggml_elu(ctx0, x); - x = mimi_conv_1d(ctx0, x, layer.conv_1_w, layer.conv_1_b, 1, 1); - x = ggml_add(ctx0, x, residual); - } else { - x = layer.is_transposed_conv - ? mimi_conv_transpose_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, layer.stride, 1, false) - : mimi_conv_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, layer.stride, 1); - } - } - - return x; - } -}; - -struct mimi_transformer { - struct layer { - ggml_tensor * inp_norm_w; - ggml_tensor * inp_norm_b; - - ggml_tensor * attn_q; - ggml_tensor * attn_k; - ggml_tensor * attn_v; - ggml_tensor * attn_o; - ggml_tensor * attn_post_norm_w; - ggml_tensor * attn_post_norm_b; - ggml_tensor * attn_layer_scale; - - ggml_tensor * ffn_up; - ggml_tensor * ffn_down; - ggml_tensor * mlp_layer_scale; - }; - std::vector layers; - - mimi_transformer(mimi_ggml_ctx & ctx, const char * prefix, int n_layers) { - for (int il = 0; il < n_layers; il++) { - layers.push_back({ - .inp_norm_w = ctx.get_weight("%s_transformer.layers.%d.input_layernorm.weight", prefix, il), - .inp_norm_b = ctx.get_weight("%s_transformer.layers.%d.input_layernorm.bias", prefix, il), - - .attn_q = ctx.get_weight("%s_transformer.layers.%d.self_attn.q_proj.weight", prefix, il), - .attn_k = ctx.get_weight("%s_transformer.layers.%d.self_attn.k_proj.weight", prefix, il), - .attn_v = ctx.get_weight("%s_transformer.layers.%d.self_attn.v_proj.weight", prefix, il), - .attn_o = ctx.get_weight("%s_transformer.layers.%d.self_attn.o_proj.weight", prefix, il), - .attn_post_norm_w = ctx.get_weight("%s_transformer.layers.%d.post_attention_layernorm.weight", prefix, il), - .attn_post_norm_b = ctx.get_weight("%s_transformer.layers.%d.post_attention_layernorm.bias", prefix, il), - .attn_layer_scale = ctx.get_weight("%s_transformer.layers.%d.self_attn_layer_scale.scale", prefix, il), - - .ffn_up = ctx.get_weight("%s_transformer.layers.%d.mlp.fc1.weight", prefix, il), - .ffn_down = ctx.get_weight("%s_transformer.layers.%d.mlp.fc2.weight", prefix, il), - .mlp_layer_scale = ctx.get_weight("%s_transformer.layers.%d.mlp_layer_scale.scale", prefix, il), - }); - } - } - - ggml_tensor * forward(ggml_context * ctx0, ggml_tensor * input, ggml_tensor * inp_pos) { - int n_tokens = input->ne[1]; - ggml_tensor * x = input; - - auto layer_norm = [&](ggml_tensor * x, ggml_tensor * w, ggml_tensor * b) { - x = ggml_norm(ctx0, x, mimi_config.norm_eps); - x = ggml_mul(ctx0, x, w); - x = ggml_add(ctx0, x, b); - return x; - }; - - ggml_tensor * residual = input; - - for (auto & layer : layers) { - residual = x; - - // input layer norm - x = layer_norm(x, layer.inp_norm_w, layer.inp_norm_b); - - // self attention - { - ggml_tensor * q = ggml_mul_mat(ctx0, layer.attn_q, x); - ggml_tensor * k = ggml_mul_mat(ctx0, layer.attn_k, x); - ggml_tensor * v = ggml_mul_mat(ctx0, layer.attn_v, x); - - int n_embd_head = mimi_config.n_embd / mimi_config.n_head; - q = ggml_reshape_3d(ctx0, q, n_embd_head, mimi_config.n_head, n_tokens); - k = ggml_reshape_3d(ctx0, k, n_embd_head, mimi_config.n_head_kv, n_tokens); - v = ggml_reshape_3d(ctx0, v, n_embd_head, mimi_config.n_head_kv, n_tokens); - - int n_rot = n_embd_head; - q = ggml_rope_inplace(ctx0, q, inp_pos, n_rot, 0); - q = ggml_cont(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3)); - - k = ggml_rope_inplace(ctx0, k, inp_pos, n_rot, 0); - k = ggml_cont(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3)); - - ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - ggml_mul_mat_set_prec(kq, GGML_PREC_F32); // mimic behavior of llama.cpp - kq = ggml_scale_inplace(ctx0, kq, 1.0f / std::sqrt(n_embd_head)); - ggml_tensor * kq_masked = ggml_diag_mask_inf_inplace(ctx0, kq, n_tokens); - kq = ggml_soft_max_inplace(ctx0, kq_masked); - - v = ggml_cont(ctx0, ggml_permute(ctx0, v, 1, 2, 0, 3)); - - ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); - kqv = ggml_reshape_3d(ctx0, kqv, n_embd_head, n_tokens, mimi_config.n_head); - kqv = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - kqv = ggml_cont_2d(ctx0, kqv, mimi_config.n_embd, n_tokens); - - x = ggml_mul_mat(ctx0, layer.attn_o, kqv); - } - - // residual - x = ggml_mul(ctx0, x, layer.attn_layer_scale); - x = ggml_add(ctx0, x, residual); - - residual = x; - x = layer_norm(x, layer.attn_post_norm_w, layer.attn_post_norm_b); - - // mlp - { - x = ggml_mul_mat(ctx0, layer.ffn_up, x); - x = ggml_gelu(ctx0, x); - x = ggml_mul_mat(ctx0, layer.ffn_down, x); - } - - // residual - x = ggml_mul(ctx0, x, layer.mlp_layer_scale); - x = ggml_add(ctx0, x, residual); - } - - return x; - } -}; - -struct mimi_residual_vector_quantizer { - struct component { - ggml_tensor * codebook; - }; - - ggml_tensor * semantic_inp_proj; - std::vector semantic_components; - ggml_tensor * semantic_out_proj; - - ggml_tensor * acoustic_inp_proj; - std::vector acoustic_components; - ggml_tensor * acoustic_out_proj; - - mimi_residual_vector_quantizer(mimi_ggml_ctx & ctx) { - semantic_inp_proj = ctx.get_weight("quantizer.semantic_rvq.input_proj.weight"); - semantic_out_proj = ctx.get_weight("quantizer.semantic_rvq.output_proj.weight"); - for (int i = 0; i < mimi_config.n_semantic_components; i++) { - semantic_components.push_back({ - .codebook = ctx.get_weight("quantizer.semantic_rvq.layers.%d.codebook", i), - }); - } - acoustic_inp_proj = ctx.get_weight("quantizer.acoustic_rvq.input_proj.weight"); - acoustic_out_proj = ctx.get_weight("quantizer.acoustic_rvq.output_proj.weight"); - for (int i = 0; i < mimi_config.n_acoustic_components; i++) { - acoustic_components.push_back({ - .codebook = ctx.get_weight("quantizer.acoustic_rvq.layers.%d.codebook", i), - }); - } - } - - // the input has shape [n_codes, n_codes_per_embd] - // first row is semantic, the rest are acoustic - // example: [ [semantic], [acoustic1], [acoustic2], ... ] - ggml_tensor * decode(ggml_context * ctx0, ggml_tensor * input) { - GGML_ASSERT(input->type == GGML_TYPE_I32); - - size_t n_semantic = semantic_components.size(); - int64_t n_codes_per_embd = (n_semantic + acoustic_components.size()); - int64_t n_codes = input->ne[0] / n_codes_per_embd; - - GGML_ASSERT(input->ne[0] % n_codes_per_embd == 0); - - ggml_tensor * out_s = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mimi_config.codebook_dim, n_codes); - ggml_tensor * out_a = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mimi_config.codebook_dim, n_codes); - out_s = ggml_scale(ctx0, out_s, 0.0f); // clear - out_a = ggml_scale(ctx0, out_a, 0.0f); // clear - - for (size_t ir = 0; ir < (size_t)n_codes_per_embd; ir++) { - ggml_tensor * row = ggml_view_1d(ctx0, input, n_codes, ir*n_codes*ggml_element_size(input)); - if (ir < n_semantic) { - // semantic - ggml_tensor * codebook = semantic_components[ir].codebook; - ggml_tensor * embd = ggml_get_rows(ctx0, codebook, row); - out_s = ggml_add(ctx0, out_s, embd); - } else { - // acoustic - ggml_tensor * codebook = acoustic_components[ir-n_semantic].codebook; - ggml_tensor * embd = ggml_get_rows(ctx0, codebook, row); - out_a = ggml_add(ctx0, out_a, embd); - } - } - - out_s = ggml_mul_mat(ctx0, semantic_out_proj, out_s); - out_a = ggml_mul_mat(ctx0, acoustic_out_proj, out_a); - - return ggml_add(ctx0, out_s, out_a); - } -}; - - - -/////////////////////////////////////////////////////////////////////////// -// main program - int main(int argc, const char ** argv) { if (argc < 3) { fprintf(stderr, "Usage: %s model.gguf codes.txt [output.wav]\n", argv[0]); fprintf(stderr, " Format of codes.txt file: one code per line\n"); fprintf(stderr, " Replace codes.txt with dummy0 and dummy1 for testing\n"); fprintf(stderr, " dummy0: using code 1, 2, 3,..., 96, used for logits matching\n"); - fprintf(stderr, " dummy1: using code that will outputs 'hey hello there' sound\n"); + fprintf(stderr, " dummy1: using code that will outputs 'wah hello there' sound\n"); return 1; } @@ -612,14 +19,6 @@ int main(int argc, const char ** argv) { const char * codes_path = argv[2]; const char * out_path = argc < 4 ? "output.wav" : argv[3]; - mimi_ggml_ctx ctx; - ctx.load_gguf(model_path); - - // initialize components - mimi_encoder_decoder decoder(ctx); - mimi_transformer transformer(ctx, "decoder", mimi_config.num_hidden_layers); - mimi_residual_vector_quantizer quantizer(ctx); - // load codes std::vector codes; if (strcmp(codes_path, "dummy0") == 0) { @@ -693,78 +92,18 @@ int main(int argc, const char ** argv) { printf("Loaded %d codes from %s\n", (int)codes.size(), codes_path); } - // build cgraph - int n_pos = -1; - int n_codes = codes.size(); - int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components; - GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiple of n_codes_per_embd"); - - ctx.build_graph([&](ggml_context * ctx_gf, ggml_cgraph * gf) { - ggml_tensor * inp_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_codes); - ggml_set_name(inp_dec, "inp_dec"); - ggml_set_input(inp_dec); - - // RVQ - ggml_tensor * embeddings = quantizer.decode(ctx_gf, inp_dec); - - // upsample - embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); - embeddings = mimi_conv_transpose_1d(ctx_gf, embeddings, ctx.get_weight("upsample.conv.weight"), nullptr, 2, 1, true); - - // transformer - n_pos = embeddings->ne[0]; - ggml_tensor * pos_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_pos); - ggml_set_name(pos_dec, "pos_dec"); - ggml_set_input(pos_dec); - embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); - embeddings = transformer.forward(ctx_gf, embeddings, pos_dec); - - // SEANET decoder - embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); - ggml_tensor * output = decoder.forward(ctx_gf, embeddings); - - ggml_set_name(output, "output"); - ggml_set_output(output); - ggml_build_forward_expand(gf, output); - }); - - // position data - std::vector pos_data(1024); - for (int i = 0; i < (int)pos_data.size(); i++) { - pos_data[i] = i; - } - ctx.set_tensor_data("pos_dec", pos_data.data()); - - // code data (need to transpose it) - // code [n_codes, n_codes_per_embd] -> [n_codes_per_embd, n_codes] - std::vector codes_t(n_codes_per_embd * n_codes); - for (int i = 0; i < n_codes / n_codes_per_embd; i++) { - for (int j = 0; j < n_codes_per_embd; j++) { - int src_idx = i * n_codes_per_embd + j; - int dst_idx = j * (n_codes / n_codes_per_embd) + i; - codes_t[dst_idx] = codes[src_idx]; - } - } - ctx.set_tensor_data("inp_dec", codes_t.data()); - - ctx.compute(); - - auto output = ctx.get_tensor_data("output"); - auto output_tensor = output.first; - auto output_data = output.second; - printf("Output shape: [%lld, %lld]\n", output_tensor->ne[0], output_tensor->ne[1]); + mimi_model model(model_path, true); + std::vector wav_data = model.decode(codes); // print first 20 values + printf("Number of output samples: %d\n", (int)wav_data.size()); + printf("First 20 samples:\n"); for (int i = 0; i < 20; i++) { - printf("%2.4f, ", ((float *)output_data.data())[i]); + printf("%2.4f, ", wav_data[i]); } printf("...\n"); // write to wav - std::vector wav_data(output_data.size() / sizeof(float)); - for (size_t i = 0; i < wav_data.size(); i++) { - wav_data[i] = ((float *)output_data.data())[i]; - } printf("Writing to %s\n", out_path); - save_wav16(out_path, wav_data, 24000); + save_wav16(out_path, wav_data, model.get_sample_rate()); } From 891273cf3a678ea4fb4845c35f60af49360b0dbf Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Mar 2025 13:08:42 +0100 Subject: [PATCH 04/31] mimi : non-transposed input codes --- examples/tts/mimi-model.cpp | 14 +++---- examples/tts/mimi-model.h | 1 + examples/tts/mimi.cpp | 78 +++++++++++++++++++------------------ 3 files changed, 48 insertions(+), 45 deletions(-) diff --git a/examples/tts/mimi-model.cpp b/examples/tts/mimi-model.cpp index 31ff86256ae10..92bb47a8365d7 100644 --- a/examples/tts/mimi-model.cpp +++ b/examples/tts/mimi-model.cpp @@ -24,7 +24,8 @@ * * Background: * - The audio codes can be generated using any Mimi-based model, for example: Moshi, Hibiki, Sesame, etc - * - Audio codes must be in the order: (1 semantic component, 31 acoustic components) repeated N times + * - Audio codes must be in the order: N semantic codes followed by (N*31) acoustic codes + * (In other words, input matrix has shape 32 cols x N rows) * * How it works? * 1. Audio code passed to RVQ (mimi_residual_vector_quantizer) to get the latent code @@ -653,23 +654,22 @@ std::vector mimi_model::decode_frame(const std::vector & codes, int for (int i = 0; i < (int)pos_data.size(); i++) { pos_data[i] = i + n_past; } - n_past += n_pos; if (verbose) { printf("%s: n_pos: %d, n_past: %d\n", __func__, n_pos, n_past); } + n_past += n_pos; ctx->set_tensor_data("pos_dec", pos_data.data()); - // code data (need to transpose it) - // code [n_codes, n_codes_per_embd] -> [n_codes_per_embd, n_codes] - std::vector codes_t(n_codes_per_embd * n_codes); + // code data + /*std::vector codes_t(n_codes_per_embd * n_codes); for (int i = 0; i < n_codes / n_codes_per_embd; i++) { for (int j = 0; j < n_codes_per_embd; j++) { int src_idx = i * n_codes_per_embd + j; int dst_idx = j * (n_codes / n_codes_per_embd) + i; codes_t[dst_idx] = codes[src_idx]; } - } - ctx->set_tensor_data("inp_dec", codes_t.data()); + }*/ + ctx->set_tensor_data("inp_dec", codes.data()); ctx->compute(); diff --git a/examples/tts/mimi-model.h b/examples/tts/mimi-model.h index d48c19b5476e3..c26fd3bc08e9f 100644 --- a/examples/tts/mimi-model.h +++ b/examples/tts/mimi-model.h @@ -22,6 +22,7 @@ struct mimi_model { int get_sample_rate() const; + // layout of codes: N semantic codes followed by (N*31) acoustic codes std::vector decode(const std::vector & codes); // TODO: implement encoding pass diff --git a/examples/tts/mimi.cpp b/examples/tts/mimi.cpp index 052f546b43a23..421c9e418ecc6 100644 --- a/examples/tts/mimi.cpp +++ b/examples/tts/mimi.cpp @@ -5,6 +5,11 @@ #include +/** + * This file is used for testing and showcase how to use "mimi_model" class. + * Please keep it simple and easy to understand. + */ + int main(int argc, const char ** argv) { if (argc < 3) { fprintf(stderr, "Usage: %s model.gguf codes.txt [output.wav]\n", argv[0]); @@ -23,48 +28,45 @@ int main(int argc, const char ** argv) { std::vector codes; if (strcmp(codes_path, "dummy0") == 0) { printf("Using dummy0 codes\n"); - codes.resize(32 * 3); // [n_codes = 3, n_codes_per_embd = 32] - int n = 0; - for (int c = 0; c < 32; c++) { - for (int r = 0; r < 3; r++) { - codes[r*32 + c] = n++; - } + codes.resize(32 * 3); // [n_codes_per_embd = 32, n_codes = 3] + for (int i = 0; i < (int)codes.size(); i++) { + codes[i] = i; } } else if (strcmp(codes_path, "dummy1") == 0) { printf("Using dummy1 codes\n"); codes = { - 1263 ,1597 ,1596 ,1477 ,1540 ,1720 ,1433 ,118 ,1066 ,1968 ,1096 ,232 ,418 ,566 ,1653 ,2010 , - 1029 ,1874 ,77 ,1803 ,123 ,908 ,97 ,1616 ,595 ,1170 ,1654 ,1211 ,1967 ,1579 ,1846 ,1462 , - 1962 ,175 ,1539 ,742 ,1065 ,1226 ,19 ,955 ,528 ,1031 ,659 ,1687 ,1173 ,1802 ,1031 ,1714 , - 1986 ,582 ,367 ,112 ,1245 ,1386 ,759 ,532 ,1472 ,1790 ,802 ,1213 ,1543 ,1916 ,1251 ,309 , - 1962 ,1280 ,1943 ,878 ,1588 ,1989 ,568 ,1463 ,1814 ,1095 ,103 ,583 ,976 ,998 ,871 ,587 , - 247 ,1698 ,1817 ,1024 ,268 ,597 ,45 ,1608 ,1880 ,2047 ,759 ,1578 ,1612 ,49 ,1031 ,1076 , - 927 ,1202 ,1601 ,1719 ,1670 ,412 ,568 ,1838 ,341 ,1265 ,1279 ,830 ,1997 ,32 ,1369 ,1686 , - 1307 ,419 ,1143 ,324 ,325 ,572 ,1597 ,1920 ,795 ,915 ,610 ,2000 ,819 ,718 ,1235 ,282 , - 1912 ,1911 ,141 ,1069 ,1485 ,642 ,1370 ,732 ,284 ,1407 ,1591 ,1002 ,939 ,671 ,951 ,1411 , - 1887 ,460 ,1588 ,1636 ,1312 ,232 ,969 ,1513 ,1336 ,1185 ,1660 ,4 ,926 ,1243 ,1077 ,1379 , - 704 ,85 ,257 ,1302 ,1029 ,1717 ,899 ,1345 ,355 ,1915 ,1007 ,315 ,1283 ,779 ,415 ,335 , - 1848 ,1786 ,469 ,295 ,380 ,1736 ,393 ,765 ,1921 ,836 ,374 ,1649 ,52 ,1633 ,759 ,548 , - 1922 ,47 ,564 ,893 ,34 ,131 ,1063 ,1657 ,474 ,1960 ,1255 ,1275 ,92 ,976 ,1217 ,483 , - 105 ,1746 ,1158 ,1557 ,1001 ,512 ,1668 ,1255 ,1045 ,1596 ,613 ,1272 ,1366 ,1147 ,411 ,831 , - 349 ,692 ,1435 ,2005 ,1465 ,37 ,892 ,95 ,460 ,557 ,1315 ,259 ,1978 ,1838 ,1232 ,2003 , - 1197 ,111 ,1953 ,1297 ,1843 ,671 ,1687 ,91 ,1788 ,1138 ,1896 ,399 ,615 ,758 ,1423 ,365 , - 288 ,632 ,876 ,875 ,1156 ,345 ,1189 ,638 ,1527 ,1981 ,1925 ,333 ,1353 ,473 ,1913 ,1443 , - 1634 ,1373 ,803 ,420 ,192 ,1440 ,1593 ,1925 ,784 ,831 ,552 ,807 ,1942 ,1289 ,612 ,511 , - 968 ,1091 ,30 ,828 ,1611 ,1241 ,1985 ,596 ,273 ,529 ,1182 ,302 ,726 ,1942 ,733 ,1590 , - 1564 ,214 ,1156 ,1722 ,1215 ,1837 ,1729 ,1823 ,672 ,116 ,340 ,396 ,721 ,462 ,1615 ,1380 , - 1459 ,1553 ,636 ,586 ,1148 ,1147 ,1941 ,471 ,876 ,127 ,1938 ,2002 ,1563 ,1121 ,857 ,1179 , - 1983 ,1324 ,1726 ,1445 ,295 ,270 ,896 ,1947 ,1740 ,1211 ,128 ,1266 ,734 ,715 ,1562 ,285 , - 1139 ,304 ,526 ,653 ,1270 ,320 ,484 ,22 ,687 ,1065 ,489 ,827 ,993 ,1654 ,431 ,1552 , - 1418 ,1604 ,455 ,841 ,412 ,848 ,475 ,540 ,1903 ,575 ,584 ,300 ,1079 ,189 ,1481 ,893 , - 228 ,1577 ,429 ,635 ,106 ,1536 ,176 ,348 ,1733 ,1570 ,537 ,1840 ,798 ,410 ,1714 ,1318 , - 487 ,332 ,1109 ,1744 ,283 ,692 ,681 ,1744 ,1008 ,1715 ,1956 ,1066 ,1768 ,1645 ,139 ,1967 , - 897 ,132 ,1010 ,1932 ,277 ,1536 ,1541 ,952 ,19 ,88 ,1663 ,1232 ,1681 ,1878 ,1241 ,1805 , - 89 ,1401 ,544 ,1061 ,1166 ,267 ,1351 ,1998 ,1623 ,1898 ,425 ,1320 ,2006 ,865 ,1981 ,823 , - 1243 ,471 ,485 ,1765 ,391 ,1281 ,1607 ,1418 ,116 ,1702 ,1725 ,512 ,1088 ,1375 ,1994 ,1738 , - 725 ,1471 ,811 ,1251 ,1156 ,1664 ,898 ,1511 ,1872 ,1717 ,444 ,1005 ,254 ,103 ,202 ,1769 , - 1511 ,433 ,284 ,721 ,1741 ,56 ,615 ,916 ,887 ,1253 ,916 ,535 ,1666 ,1713 ,741 ,873 , - 447 ,492 ,388 ,321 ,1860 ,1456 ,1658 ,1682 ,848 ,462 ,2034 ,1368 ,1609 ,1887 ,510 ,1516 , + 1049 ,1415 ,1962 ,914 ,1372 ,704 ,1922 ,2036 ,288 ,968 ,193 ,1139 ,897 ,897 ,1243 ,1511 , + 1597 ,175 ,1280 ,1202 ,1911 ,85 ,47 ,692 ,632 ,251 ,1553 ,1735 ,1577 ,132 ,471 ,433 , + 1325 ,1539 ,1943 ,1601 ,141 ,257 ,564 ,1435 ,876 ,1096 ,636 ,61 ,1497 ,1010 ,485 ,284 , + 839 ,776 ,878 ,1719 ,1069 ,1302 ,893 ,2005 ,875 ,908 ,586 ,2001 ,186 ,1932 ,1765 ,721 , + 592 ,1046 ,1588 ,1670 ,1485 ,1141 ,34 ,1465 ,1156 ,1938 ,435 ,753 ,1418 ,277 ,391 ,1741 , + 1440 ,117 ,723 ,412 ,642 ,1717 ,131 ,37 ,345 ,112 ,1979 ,2034 ,1822 ,1536 ,1281 ,56 , + 1341 ,803 ,568 ,568 ,1370 ,1995 ,1063 ,892 ,273 ,895 ,1226 ,354 ,1726 ,1541 ,1607 ,615 , + 985 ,1499 ,1736 ,1838 ,702 ,1345 ,1657 ,511 ,1774 ,1787 ,945 ,1927 ,947 ,952 ,1418 ,916 , + 1239 ,1457 ,1021 ,341 ,284 ,882 ,474 ,1559 ,1923 ,273 ,1330 ,1406 ,1782 ,19 ,116 ,887 , + 1146 ,1307 ,983 ,1237 ,1407 ,1350 ,1960 ,1255 ,878 ,1979 ,1500 ,1939 ,1415 ,88 ,1702 ,1253 , + 1778 ,2 ,10 ,1279 ,999 ,1549 ,1049 ,373 ,1355 ,1200 ,1466 ,1009 ,75 ,2042 ,1725 ,916 , + 1636 ,1135 ,833 ,830 ,1758 ,2015 ,1275 ,1675 ,287 ,744 ,89 ,430 ,1724 ,1232 ,1692 ,535 , + 1485 ,1287 ,973 ,1815 ,314 ,2020 ,424 ,1085 ,982 ,1994 ,1563 ,1269 ,1769 ,1681 ,1082 ,1666 , + 1622 ,1039 ,1209 ,32 ,679 ,732 ,976 ,1462 ,805 ,402 ,1150 ,170 ,1529 ,2013 ,350 ,1175 , + 757 ,1124 ,1091 ,1369 ,1061 ,415 ,1217 ,1135 ,1360 ,1578 ,1205 ,1785 ,1835 ,1241 ,14 ,716 , + 480 ,716 ,681 ,1686 ,1624 ,335 ,865 ,1356 ,1688 ,307 ,366 ,541 ,1262 ,1167 ,59 ,269 , + 1899 ,1798 ,1606 ,1307 ,1549 ,1814 ,114 ,483 ,958 ,1919 ,1179 ,898 ,834 ,1526 ,386 ,447 , + 1481 ,201 ,779 ,419 ,430 ,1451 ,1000 ,156 ,1062 ,615 ,1353 ,414 ,1214 ,1487 ,882 ,32 , + 840 ,1517 ,334 ,1143 ,823 ,454 ,725 ,1298 ,1325 ,649 ,1737 ,913 ,685 ,761 ,2010 ,63 , + 1397 ,1299 ,765 ,1158 ,1809 ,1299 ,1585 ,1776 ,625 ,1539 ,830 ,1563 ,461 ,308 ,1438 ,321 , + 82 ,886 ,1836 ,325 ,1976 ,761 ,359 ,1136 ,1720 ,2036 ,904 ,719 ,526 ,1567 ,145 ,1860 , + 1565 ,1786 ,1400 ,1696 ,232 ,1736 ,512 ,518 ,1895 ,1854 ,1584 ,1393 ,1869 ,1702 ,789 ,1986 , + 116 ,521 ,150 ,1597 ,727 ,1916 ,815 ,1826 ,1382 ,653 ,1596 ,286 ,1373 ,177 ,1397 ,1009 , + 1449 ,353 ,877 ,93 ,266 ,1853 ,1255 ,872 ,1974 ,556 ,1885 ,857 ,992 ,5 ,1921 ,1849 , + 1038 ,1912 ,464 ,795 ,747 ,56 ,124 ,431 ,1868 ,609 ,855 ,1522 ,912 ,1709 ,1507 ,1062 , + 1015 ,1357 ,1487 ,4 ,253 ,1871 ,933 ,215 ,1228 ,633 ,1306 ,2024 ,1453 ,900 ,457 ,471 , + 436 ,1311 ,870 ,1032 ,134 ,984 ,1983 ,1103 ,1627 ,1627 ,414 ,1845 ,583 ,1699 ,1458 ,2018 , + 150 ,450 ,1114 ,369 ,267 ,1273 ,1136 ,1578 ,1063 ,1820 ,120 ,779 ,652 ,1266 ,1929 ,1213 , + 159 ,297 ,1703 ,819 ,93 ,247 ,1366 ,144 ,1617 ,1428 ,812 ,121 ,1637 ,1620 ,289 ,1557 , + 1414 ,971 ,476 ,1685 ,428 ,1802 ,653 ,1290 ,614 ,1663 ,1528 ,1344 ,798 ,1027 ,1305 ,990 , + 1740 ,1154 ,1839 ,912 ,731 ,602 ,1064 ,1508 ,834 ,1387 ,252 ,745 ,1034 ,1102 ,965 ,696 , + 1971 ,1729 ,666 ,282 ,1993 ,1551 ,1703 ,1124 ,1628 ,1725 ,107 ,808 ,1096 ,1753 ,500 ,677 , }; } else { std::ifstream fin(codes_path); From 6dca237b1c12a5732088001dc483e9495ed7eb16 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 29 Mar 2025 18:19:56 +0100 Subject: [PATCH 05/31] tts : add sesame csm --- examples/tts/CMakeLists.txt | 6 + examples/tts/convert_csm_to_gguf.py | 330 ++++++++++++++++++++++++++++ examples/tts/tts-csm.cpp | 120 ++++++++++ src/llama-arch.cpp | 48 ++-- src/llama-arch.h | 3 + src/llama-model.cpp | 24 +- src/llama-model.h | 5 + 7 files changed, 513 insertions(+), 23 deletions(-) create mode 100644 examples/tts/convert_csm_to_gguf.py create mode 100644 examples/tts/tts-csm.cpp diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index c72bd814c3b31..17d2ea08a074e 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -3,3 +3,9 @@ add_executable(${TARGET} tts.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-tts-csm) +add_executable(${TARGET} tts-csm.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/tts/convert_csm_to_gguf.py b/examples/tts/convert_csm_to_gguf.py new file mode 100644 index 0000000000000..ff91098a993ad --- /dev/null +++ b/examples/tts/convert_csm_to_gguf.py @@ -0,0 +1,330 @@ +import os +import sys +import argparse +import logging +import torch +from safetensors.torch import load_file +from typing import Union, Any, Dict +from pathlib import Path +from torch import Tensor +from huggingface_hub import hf_hub_download + +cur_path = sys.path +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent.parent.parent / 'gguf-py')) +import gguf + +sys.path = cur_path + +logger = logging.getLogger("csm") + + +# This converts directly one safetensors file to 2 GGUFs +# It is easier to do this way, rather than convert to 2 smaller HF models and then convert to GGUF +# This is because the Sesame model does not have built-in tokenizer + +def get_field_data(reader: gguf.GGUFReader, key: str) -> Any: + field = reader.get_field(key) + return field.contents() if field else None + +# copied from https://github.com/SesameAILabs/csm/blob/main/models.py +class Llama_3_2_1B: + vocab_size=128_256 + num_layers=16 + num_heads=32 + num_kv_heads=8 + embed_dim=2048 + max_seq_len=2048 + intermediate_dim=8192 + attn_dropout=0.0 + norm_eps=1e-5 + rope_base=500_000 + scale_factor=32 + + def write_gguf_metadata(self, fout: gguf.GGUFWriter, fvocab: gguf.GGUFReader): + arch = get_field_data(fvocab, gguf.Keys.General.ARCHITECTURE) + assert arch == "llama" + fout.add_type("model") + fout.add_block_count(self.num_layers) + fout.add_context_length(self.max_seq_len) + fout.add_feed_forward_length(self.intermediate_dim) + fout.add_embedding_length(self.embed_dim) + # attn + fout.add_head_count(self.num_heads) + fout.add_head_count_kv(self.num_kv_heads) + fout.add_rope_freq_base(self.rope_base) + # fout.add_rope_scaling_factor(self.scale_factor) # breaks if this is added + fout.add_rope_dimension_count(self.embed_dim // self.num_heads) + fout.add_layer_norm_rms_eps(self.norm_eps) + fout.add_key_length(self.embed_dim // self.num_heads) + fout.add_value_length(self.embed_dim // self.num_heads) + # vocab + fout.add_vocab_size(self.vocab_size) + fout.add_tokenizer_model(get_field_data(fvocab, gguf.Keys.Tokenizer.MODEL)) + fout.add_tokenizer_pre(get_field_data(fvocab, gguf.Keys.Tokenizer.PRE)) + fout.add_token_list(get_field_data(fvocab, gguf.Keys.Tokenizer.LIST)[:self.vocab_size]) + fout.add_token_types(get_field_data(fvocab, gguf.Keys.Tokenizer.TOKEN_TYPE)[:self.vocab_size]) + fout.add_token_merges(get_field_data(fvocab, gguf.Keys.Tokenizer.MERGES)) + fout.add_bos_token_id(get_field_data(fvocab, gguf.Keys.Tokenizer.BOS_ID)) + fout.add_eos_token_id(get_field_data(fvocab, gguf.Keys.Tokenizer.EOS_ID)) + +class Llama_3_2_100M(Llama_3_2_1B): + vocab_size=65_632 #128_256 + num_layers=4 + num_heads=8 + num_kv_heads=2 + embed_dim=1024 + max_seq_len=2048 + intermediate_dim=8192 + attn_dropout=0.0 + norm_eps=1e-5 + rope_base=500_000 + scale_factor=32 + +class CSMModelConverter: + state_dict: Dict[str, Tensor] + gguf_writer_backbone: gguf.GGUFWriter + gguf_writer_decoder: gguf.GGUFWriter + gguf_reader_vocab: gguf.GGUFReader + fname_out: Path + ftype: gguf.LlamaFileType + + projection_tensor: Tensor # projecting from n_embd_backbone (2048) to n_embd_decoder (1024) + + def __init__(self, + safetensors_path: Union[Path, str], + path_to_vocab_gguf: Path, + fname_out: Path, + ftype: gguf.LlamaFileType, + is_big_endian: bool,): + + if "" not in fname_out.name: + raise ValueError("Output file name must contain '' placeholder, for example: 'sesame-csm-.gguf'") + + self.state_dict = load_file(safetensors_path, device="cpu") + self.fname_out = fname_out + self.ftype = ftype + self.gguf_reader_vocab = gguf.GGUFReader(path_to_vocab_gguf) + endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + + # backbone + self.gguf_writer_backbone = gguf.GGUFWriter( + path=None, + arch="llama", + endianess=endianess) + + # decoder + self.gguf_writer_decoder = gguf.GGUFWriter( + path=None, + arch="llama", + endianess=endianess) + + Llama_3_2_1B().write_gguf_metadata(self.gguf_writer_backbone, self.gguf_reader_vocab) + Llama_3_2_100M().write_gguf_metadata(self.gguf_writer_decoder, self.gguf_reader_vocab) + + # get projection tensor) + for name, data_torch in self.state_dict.items(): + if name == "projection.weight": + self.projection_tensor = data_torch + break + + # load tensors + for component in ("backbone", "decoder"): + print() + print(f"Converting {component}...") + print() + for name, data_torch in self.state_dict.items(): + # convert any unsupported data types to float32 + old_dtype = data_torch.dtype + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + self.add_tensor(name, data_torch, old_dtype, component) + + def add_tensor(self, name: str, data_torch: Tensor, old_dtype: torch.dtype, component: str): + is_1d = len(data_torch.shape) == 1 + #is_embd = "_embeddings" in name + can_quantize = not is_1d #and not is_embd + data_qtype = gguf.GGMLQuantizationType.F32 + + is_backbone = False + is_decoder = False + + def rename_transformer(name: str) -> str: + # transformer + name = name.replace(".scale", ".weight") + name = name.replace("attn.k_proj", "attn_k") + name = name.replace("attn.q_proj", "attn_q") + name = name.replace("attn.v_proj", "attn_v") + name = name.replace("attn.output_proj", "attn_output") + name = name.replace("sa_norm", "attn_norm") + name = name.replace("mlp.w1", "ffn_gate") + name = name.replace("mlp.w2", "ffn_down") + name = name.replace("mlp.w3", "ffn_up") + name = name.replace("mlp_norm", "ffn_norm") + return name + + if "audio_embeddings." in name: + is_decoder = True + if component == "decoder": + name = name.replace("audio_embeddings.", "token_embd.") + data_torch = torch.mm(data_torch, self.projection_tensor.T) + print("Applied projection to audio_embeddings", data_torch.shape) + + elif "text_embeddings." in name: + is_backbone = True + name = name.replace("text_embeddings.", "token_embd.") + + elif "backbone." in name or "codebook0_head." in name: + is_backbone = True + name = name.replace("backbone.layers.", "blk.") + name = name.replace("backbone.norm.scale", "output_norm.weight") + name = rename_transformer(name) + + elif "decoder." in name: + is_decoder = True + name = name.replace("decoder.layers.", "blk.") + name = name.replace("decoder.norm.scale", "output_norm.weight") + name = rename_transformer(name) + + elif name == "audio_head": + is_decoder = True + name = "audio_head.weight" + + elif name == "projection.weight": + is_decoder = True + name = "inp_proj.weight" + self.projection_tensor = data_torch + + if can_quantize: + if self.ftype == gguf.LlamaFileType.ALL_F32: + data_qtype = gguf.GGMLQuantizationType.F32 + elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_qtype = gguf.GGMLQuantizationType.F16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data_qtype = gguf.GGMLQuantizationType.BF16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: + data_qtype = gguf.GGMLQuantizationType.Q8_0 + else: + raise ValueError(f"Unsupported file type: {self.ftype}") + + data = data_torch.numpy() + + try: + data = gguf.quants.quantize(data, data_qtype) + except Exception as e: + logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) + + if (is_backbone and component == "backbone") or (is_decoder and component == "decoder"): + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}" + logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + if component == "backbone": + self.gguf_writer_backbone.add_tensor(name, data, raw_dtype=data_qtype) + elif component == "decoder": + self.gguf_writer_decoder.add_tensor(name, data, raw_dtype=data_qtype) + + def write(self): + self._write_single(self.gguf_writer_backbone, "backbone") + self._write_single(self.gguf_writer_decoder, "decoder") + + def _write_single(self, gguf_writer: gguf.GGUFWriter, component: str): + output_path = str(self.fname_out).replace("", component) + gguf_writer.write_header_to_file(path=Path(output_path)) + gguf_writer.write_kv_data_to_file() + gguf_writer.write_tensors_to_file(progress=True) + gguf_writer.close() + + @staticmethod + def undo_permute(weights: Tensor, n_head: int, n_head_kv: int): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert Sesame model to GGUFs (multiple files)",) + parser.add_argument( + "--outfile", type=Path, default="sesame-csm-.gguf", + help="path to write to, the '' placeholder is required and will be replaced with 'backbone' and 'decoder'", + ) + parser.add_argument( + "--vocab", type=Path, default="models/ggml-vocab-llama-bpe.gguf", + help="path to vocab GGUF", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", + help="output format", + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "model", type=Path, + help="path to safetensors or model ID containing model file (if model ID is specified, download from Hugging Face hub)", + nargs="?", + default="sesame/csm-1b:model.safetensors", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + + args = parser.parse_args() + if args.model is None: + parser.error("the following arguments are required: model") + return args + + +def main() -> None: + args = parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + dir_model = args.model + path_vocab = args.vocab + + dir_parts = str(dir_model).split(":") + if len(dir_parts) == 2: + try: + dir_model = Path(hf_hub_download(dir_parts[0], dir_parts[1])) + except Exception as e: + print("Error downloading model from Hugging Face hub:", e) + print() + print("Please make sure you have access to the model") + print("Hint: you may need to set HF_TOKEN by running: huggingface-cli login") + + if not path_vocab.exists(): + raise FileNotFoundError(f"Vocab file not found: {path_vocab} ; Hint: download it from https://github.com/ggml-org/llama.cpp/blob/master/models/ggml-vocab-llama-bpe.gguf") + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + } + + logger.info(f"Loading model: {dir_model}") + + with torch.inference_mode(): + converter = CSMModelConverter( + safetensors_path=dir_model, + fname_out=args.outfile, + path_to_vocab_gguf=path_vocab, + ftype=ftype_map[args.outtype], + is_big_endian=args.bigendian, + ) + converter.write() + + +if __name__ == '__main__': + main() + diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp new file mode 100644 index 0000000000000..2835265a59530 --- /dev/null +++ b/examples/tts/tts-csm.cpp @@ -0,0 +1,120 @@ +#include "llama.h" +#include "common.h" +#include "log.h" +#include "arg.h" + +#include +#include +#include + +static void print_usage(int, char ** argv) { + LOG("\nexample usage:\n"); + LOG("\n %s TODO ", argv[0]); + LOG("\n"); +} + +// greedy sampling with custom n_vocab +static llama_token sample_greedy(const float * logits, int n_vocab) { + llama_token max_idx = -1; + float max_val = -FLT_MAX; + for (int i = 0; i < n_vocab; ++i) { + if (logits[i] > max_val) { + max_val = logits[i]; + max_idx = i; + } + } + return max_idx; +} + +// hook to retrieve the embeddings +static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) { + std::vector * embd = (std::vector *) user_data; + + if (t && strcmp(t->name, "result_norm") == 0) { + if (ask) return true; + + auto n_bytes = ggml_nbytes(t); + embd->resize(n_bytes); + ggml_backend_tensor_get(t, embd->data(), 0, n_bytes); + printf("result_norm\n"); + return true; + } + + return false; +} + +int main(int argc, char ** argv) { + common_params params; + + params.model = "sesame-csm-backbone.gguf"; + params.out_file = "output.wav"; + params.prompt = "[0]Hello from Sesame."; + + params.n_predict = 4096; + params.n_batch = 8192; + params.n_ctx = 8192; + + params.sampling.top_k = 4; + params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, }; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) { + return 1; + } + + llama_backend_init(); + llama_numa_init(params.numa); + + common_params params_decoder(params); // duplicate the params + string_replace_all(params_decoder.model, "-backbone", "-decoder"); + + std::vector embd; + params.cb_eval = ggml_callback; + params.cb_eval_user_data = &embd; + common_init_result llama_backbone = common_init_from_params(params); + llama_model * model_bb = llama_backbone.model.get(); + llama_context * ctx_bb = llama_backbone.context.get(); + + //common_init_result llama_decoder = common_init_from_params(params_decoder); + //llama_model * model_dc = llama_decoder.model.get(); + //llama_context * ctx_dc = llama_decoder.context.get(); + + if (model_bb == nullptr || ctx_bb == nullptr) { + return ENOENT; + } + + const llama_vocab * vocab = llama_model_get_vocab(model_bb); + llama_tokens prompt_tokens = common_tokenize(vocab, params.prompt, false, true); + prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab)); + prompt_tokens.insert(prompt_tokens.end(), llama_vocab_eos(vocab)); + + printf("prompt tokens: \n"); + for (size_t i = 0; i < prompt_tokens.size(); ++i) { + printf("%d, ", prompt_tokens[i]); + } + printf("\n"); + + llama_batch batch = llama_batch_init(params.n_batch, 0, 1); + for (size_t i = 0; i < prompt_tokens.size(); ++i) { + common_batch_add(batch, prompt_tokens[i], i, { 0 }, false); + } + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx_bb, batch) != 0) { + LOG_ERR("%s: llama_decode() failed\n", __func__); + return 1; + } + + //auto vocab_dc = llama_model_get_vocab(model_dc); + auto logits = llama_get_logits_ith(ctx_bb, batch.n_tokens - 1); + //printf("next tok: %d\n", sample_greedy(logits, llama_vocab_n_tokens(vocab_dc))); + for (size_t i = 0; i < 10; ++i) { + printf("%4.2f, ", logits[i]); + } + printf("next tok: %d\n", sample_greedy(logits, 65632)); + + for (size_t i = 0; i < 10; ++i) { + printf("%4.2f, ", embd[i]); + } + + return 0; +} diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 9e443d83029f5..268ef3c4f3e58 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -208,27 +208,30 @@ static const std::map> LLM_TENSOR_N { LLM_ARCH_LLAMA, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, - { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, - { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_CSM_CBOOK_OUTPUT, "codebook0_head" }, + { LLM_TENSOR_CSM_AUDIO_OUTPUT, "audio_head" }, + { LLM_TENSOR_CSM_INP_PROJ, "inp_proj" }, }, }, { @@ -1570,6 +1573,9 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CSM_CBOOK_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CSM_AUDIO_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CSM_INP_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index 39e3a2ce0565c..c43e8e7c62c3c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -347,6 +347,9 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + LLM_TENSOR_CSM_CBOOK_OUTPUT, + LLM_TENSOR_CSM_AUDIO_OUTPUT, + LLM_TENSOR_CSM_INP_PROJ, }; enum llm_tensor_layer { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a4f06112d2842..845b901b3e8a1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1622,6 +1622,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // csm sesame model + { + // TODO: maybe store these in gguf metadata + int64_t csm_audio_cbook_size = 2051; // audio codebook size + int64_t csm_acoustic_tokens = 31; // == number of acoutic tokens for Mimi + int64_t csm_backbone_n_embd = 2048; // used by decoder (n_embd_decoder != n_embd_backbone) + csm_output_cbook = create_tensor(tn(LLM_TENSOR_CSM_CBOOK_OUTPUT, "weight"), {n_embd, csm_audio_cbook_size}, TENSOR_NOT_REQUIRED); + csm_output_audio = create_tensor(tn(LLM_TENSOR_CSM_AUDIO_OUTPUT, "weight"), {csm_audio_cbook_size, n_embd, csm_acoustic_tokens}, TENSOR_NOT_REQUIRED); + csm_input_proj = create_tensor(tn(LLM_TENSOR_CSM_INP_PROJ, "weight"), {csm_backbone_n_embd, n_embd}, TENSOR_NOT_REQUIRED); + } + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4265,8 +4276,17 @@ struct llm_build_llama : public llm_graph_context { cb(cur, "result_norm", -1); res->t_embd = cur; - // lm_head - cur = build_lora_mm(model.output, cur); + if (model.csm_output_cbook) { + // Sesame csm backbone + // hack: because n_cbook < n_vocab, we use the first logits for the codebook output + int64_t n_vocab = model.tok_embd->ne[1]; + int64_t n_codes = model.csm_output_cbook->ne[1]; + cur = build_lora_mm(model.csm_output_cbook, cur); + cur = ggml_pad(ctx0, cur, n_vocab - n_codes, 0, 0, 0); + } else { + // lm_head (normal case) + cur = build_lora_mm(model.output, cur); + } // For Granite architecture if (hparams.f_logit_scale) { diff --git a/src/llama-model.h b/src/llama-model.h index 0064d597a9613..6c368d691c0d0 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -337,6 +337,11 @@ struct llama_model { struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; + // sesame csm + struct ggml_tensor * csm_output_cbook = nullptr; // backbone codebook + struct ggml_tensor * csm_output_audio = nullptr; // audio decoder output + struct ggml_tensor * csm_input_proj = nullptr; // audio decoder input projection + std::vector layers; llama_model_params params; From 2d743b6758c37b68543ebc24ba77248e5912a129 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 00:17:44 +0100 Subject: [PATCH 06/31] wip --- examples/tts/convert_csm_to_gguf.py | 32 ++-- examples/tts/tts-csm.cpp | 123 ++++++++++---- src/llama-arch.cpp | 33 +++- src/llama-arch.h | 4 +- src/llama-model.cpp | 250 +++++++++++++++++++++++++--- src/llama-model.h | 4 +- 6 files changed, 376 insertions(+), 70 deletions(-) diff --git a/examples/tts/convert_csm_to_gguf.py b/examples/tts/convert_csm_to_gguf.py index ff91098a993ad..183ea98b7076d 100644 --- a/examples/tts/convert_csm_to_gguf.py +++ b/examples/tts/convert_csm_to_gguf.py @@ -89,8 +89,6 @@ class CSMModelConverter: fname_out: Path ftype: gguf.LlamaFileType - projection_tensor: Tensor # projecting from n_embd_backbone (2048) to n_embd_decoder (1024) - def __init__(self, safetensors_path: Union[Path, str], path_to_vocab_gguf: Path, @@ -110,24 +108,18 @@ def __init__(self, # backbone self.gguf_writer_backbone = gguf.GGUFWriter( path=None, - arch="llama", + arch="llama-csm", endianess=endianess) # decoder self.gguf_writer_decoder = gguf.GGUFWriter( path=None, - arch="llama", + arch="llama-csm", endianess=endianess) Llama_3_2_1B().write_gguf_metadata(self.gguf_writer_backbone, self.gguf_reader_vocab) Llama_3_2_100M().write_gguf_metadata(self.gguf_writer_decoder, self.gguf_reader_vocab) - # get projection tensor) - for name, data_torch in self.state_dict.items(): - if name == "projection.weight": - self.projection_tensor = data_torch - break - # load tensors for component in ("backbone", "decoder"): print() @@ -165,10 +157,7 @@ def rename_transformer(name: str) -> str: if "audio_embeddings." in name: is_decoder = True - if component == "decoder": - name = name.replace("audio_embeddings.", "token_embd.") - data_torch = torch.mm(data_torch, self.projection_tensor.T) - print("Applied projection to audio_embeddings", data_torch.shape) + name = name.replace("audio_embeddings.", "audio_embd.") elif "text_embeddings." in name: is_backbone = True @@ -189,11 +178,18 @@ def rename_transformer(name: str) -> str: elif name == "audio_head": is_decoder = True name = "audio_head.weight" + if component == "decoder": + # add padding at the beginning so that build_lora_mm_id can be used + zero_tensor = torch.zeros(1, 1024, 2051) + data_torch = torch.cat([zero_tensor, data_torch], dim=0) + assert data_torch.shape == (32, 1024, 2051) + # then, transpose it + data_torch = data_torch.transpose(1, 2) elif name == "projection.weight": is_decoder = True - name = "inp_proj.weight" - self.projection_tensor = data_torch + is_backbone = True + name = "csm_proj.weight" if can_quantize: if self.ftype == gguf.LlamaFileType.ALL_F32: @@ -203,7 +199,9 @@ def rename_transformer(name: str) -> str: elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: data_qtype = gguf.GGMLQuantizationType.BF16 elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: - data_qtype = gguf.GGMLQuantizationType.Q8_0 + # decoder is very sensitive to quantization, do not quantize it lower than F16 + data_qtype = gguf.GGMLQuantizationType.Q8_0 if component != "decoder" \ + else gguf.GGMLQuantizationType.F16 else: raise ValueError(f"Unsupported file type: {self.ftype}") diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index 2835265a59530..34eeb4b4db4d9 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -30,13 +30,12 @@ static llama_token sample_greedy(const float * logits, int n_vocab) { static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) { std::vector * embd = (std::vector *) user_data; - if (t && strcmp(t->name, "result_norm") == 0) { + if (t && (strcmp(t->name, "output_csm_proj") == 0 || strcmp(t->name, "output_audio_embd") == 0)) { if (ask) return true; - auto n_bytes = ggml_nbytes(t); - embd->resize(n_bytes); - ggml_backend_tensor_get(t, embd->data(), 0, n_bytes); - printf("result_norm\n"); + embd->resize(ggml_nelements(t)); + ggml_backend_tensor_get(t, embd->data(), 0, ggml_nbytes(t)); + // printf("%s tensor size: %lld, %lld\n", t->name, t->ne[0], t->ne[1]); return true; } @@ -54,9 +53,6 @@ int main(int argc, char ** argv) { params.n_batch = 8192; params.n_ctx = 8192; - params.sampling.top_k = 4; - params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K, }; - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) { return 1; } @@ -64,24 +60,30 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); - common_params params_decoder(params); // duplicate the params - string_replace_all(params_decoder.model, "-backbone", "-decoder"); - std::vector embd; params.cb_eval = ggml_callback; params.cb_eval_user_data = &embd; + params.warmup = false; + + common_params params_decoder(params); // duplicate the params + string_replace_all(params_decoder.model, "-backbone", "-decoder"); + common_init_result llama_backbone = common_init_from_params(params); llama_model * model_bb = llama_backbone.model.get(); llama_context * ctx_bb = llama_backbone.context.get(); - //common_init_result llama_decoder = common_init_from_params(params_decoder); - //llama_model * model_dc = llama_decoder.model.get(); - //llama_context * ctx_dc = llama_decoder.context.get(); + common_init_result llama_decoder = common_init_from_params(params_decoder); + llama_model * model_dc = llama_decoder.model.get(); + llama_context * ctx_dc = llama_decoder.context.get(); if (model_bb == nullptr || ctx_bb == nullptr) { return ENOENT; } + if (model_dc == nullptr || ctx_dc == nullptr) { + return ENOENT; + } + const llama_vocab * vocab = llama_model_get_vocab(model_bb); llama_tokens prompt_tokens = common_tokenize(vocab, params.prompt, false, true); prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab)); @@ -93,27 +95,92 @@ int main(int argc, char ** argv) { } printf("\n"); + llama_pos n_past_bb = 0; llama_batch batch = llama_batch_init(params.n_batch, 0, 1); + common_batch_clear(batch); for (size_t i = 0; i < prompt_tokens.size(); ++i) { - common_batch_add(batch, prompt_tokens[i], i, { 0 }, false); + common_batch_add(batch, prompt_tokens[i], n_past_bb++, { 0 }, false); } batch.logits[batch.n_tokens - 1] = true; - if (llama_decode(ctx_bb, batch) != 0) { - LOG_ERR("%s: llama_decode() failed\n", __func__); - return 1; - } + std::vector inp_past_embd(2048, 0.0f); + llama_batch batch_past_embd = llama_batch_init(1, inp_past_embd.size(), 1); - //auto vocab_dc = llama_model_get_vocab(model_dc); - auto logits = llama_get_logits_ith(ctx_bb, batch.n_tokens - 1); - //printf("next tok: %d\n", sample_greedy(logits, llama_vocab_n_tokens(vocab_dc))); - for (size_t i = 0; i < 10; ++i) { - printf("%4.2f, ", logits[i]); - } - printf("next tok: %d\n", sample_greedy(logits, 65632)); + for (int k = 0; k < 4; ++k) { + if (llama_decode(ctx_bb, k == 0 ? batch : batch_past_embd) != 0) { + LOG_ERR("%s: llama_decode() failed\n", __func__); + return 1; + } + + auto vocab_dc = llama_model_get_vocab(model_dc); + auto logits = llama_get_logits_ith(ctx_bb, k == 0 ? (batch.n_tokens - 1) : 0); + // for (size_t i = 0; i < 10; ++i) { + // printf("%4.2f, ", logits[i]); + // } + // printf("\n"); + + llama_token latent_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); + // printf("latent_token: %d\n", latent_token); + printf("%5d, ", latent_token); + + // for (size_t i = 0; i < 10; ++i) { + // printf("%4.2f, ", embd[i]); + // } + // printf("\n"); + + + + // decode + prompt_tokens.clear(); + prompt_tokens.push_back(latent_token); + inp_past_embd = std::vector(inp_past_embd.size(), 0.0f); + { + llama_kv_self_clear(ctx_dc); + llama_batch batch_embd = llama_batch_init(1, embd.size(), 1); + llama_batch batch_token = llama_batch_init(1, 0, 1); + { + batch_embd.n_tokens = 1; + batch_embd.pos[0] = 0; + batch_embd.seq_id[0][0] = 0; + batch_embd.n_seq_id[0] = 1; + batch_embd.logits[0] = false; + memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float)); + } + llama_decode(ctx_dc, batch_embd); + + llama_token audio_token = latent_token; + for (int i = 0; i < 31; ++i) { + common_batch_clear(batch_token); + // encoder vocab is further divided into 32 codebooks, each with 2051 entries + llama_token inp_tok = audio_token + 2051*i; + common_batch_add(batch_token, inp_tok, i+1, { 0 }, true); + llama_decode(ctx_dc, batch_token); + auto logits = llama_get_logits_ith(ctx_dc, 0); + audio_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); + printf("%d,", audio_token); + prompt_tokens.push_back(audio_token); + + GGML_ASSERT(inp_past_embd.size() == embd.size()); + for (size_t i = 0; i < inp_past_embd.size(); ++i) { + inp_past_embd[i] += embd[i]; + } + } + printf("\n"); + + llama_batch_free(batch_embd); + llama_batch_free(batch_token); + } - for (size_t i = 0; i < 10; ++i) { - printf("%4.2f, ", embd[i]); + // prepare for the next iteration + { + batch_past_embd.n_tokens = 1; + batch_past_embd.pos[0] = n_past_bb; + batch_past_embd.seq_id[0][0] = 0; + batch_past_embd.n_seq_id[0] = 1; + batch_past_embd.logits[0] = true; + memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float)); + } + n_past_bb++; } return 0; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 268ef3c4f3e58..fcdff0da60562 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -6,6 +6,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_LLAMA_CSM, "llama-csm" }, { LLM_ARCH_DECI, "deci" }, { LLM_ARCH_FALCON, "falcon" }, { LLM_ARCH_GROK, "grok" }, @@ -229,9 +230,36 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_LLAMA_CSM, // like LLM_ARCH_LLAMA, but with extra tensors for Sesame CSM + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_CSM_AUDIO_EMBD, "audio_embd" }, { LLM_TENSOR_CSM_CBOOK_OUTPUT, "codebook0_head" }, { LLM_TENSOR_CSM_AUDIO_OUTPUT, "audio_head" }, - { LLM_TENSOR_CSM_INP_PROJ, "inp_proj" }, + { LLM_TENSOR_CSM_PROJ, "csm_proj" }, }, }, { @@ -1573,9 +1601,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CSM_AUDIO_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, {LLM_TENSOR_CSM_CBOOK_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CSM_AUDIO_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CSM_INP_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CSM_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index c43e8e7c62c3c..4d39e88f0885b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -10,6 +10,7 @@ enum llm_arch { LLM_ARCH_LLAMA, + LLM_ARCH_LLAMA_CSM, LLM_ARCH_DECI, LLM_ARCH_FALCON, LLM_ARCH_BAICHUAN, @@ -347,9 +348,10 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + LLM_TENSOR_CSM_AUDIO_EMBD, LLM_TENSOR_CSM_CBOOK_OUTPUT, LLM_TENSOR_CSM_AUDIO_OUTPUT, - LLM_TENSOR_CSM_INP_PROJ, + LLM_TENSOR_CSM_PROJ, }; enum llm_tensor_layer { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 845b901b3e8a1..73df168c31aaf 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -513,7 +513,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) { + if (arch == LLM_ARCH_LLAMA + || arch == LLM_ARCH_LLAMA_CSM + || arch == LLM_ARCH_DECI + || arch == LLM_ARCH_FALCON + ) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -531,6 +535,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA_CSM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1622,17 +1627,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } - // csm sesame model - { - // TODO: maybe store these in gguf metadata - int64_t csm_audio_cbook_size = 2051; // audio codebook size - int64_t csm_acoustic_tokens = 31; // == number of acoutic tokens for Mimi - int64_t csm_backbone_n_embd = 2048; // used by decoder (n_embd_decoder != n_embd_backbone) - csm_output_cbook = create_tensor(tn(LLM_TENSOR_CSM_CBOOK_OUTPUT, "weight"), {n_embd, csm_audio_cbook_size}, TENSOR_NOT_REQUIRED); - csm_output_audio = create_tensor(tn(LLM_TENSOR_CSM_AUDIO_OUTPUT, "weight"), {csm_audio_cbook_size, n_embd, csm_acoustic_tokens}, TENSOR_NOT_REQUIRED); - csm_input_proj = create_tensor(tn(LLM_TENSOR_CSM_INP_PROJ, "weight"), {csm_backbone_n_embd, n_embd}, TENSOR_NOT_REQUIRED); - } - for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -1676,6 +1670,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_LLAMA_CSM: + { + // TODO: maybe store these in gguf metadata + int64_t csm_audio_cbook_size = 2051; // audio codebook size + int64_t csm_acoustic_tokens = 32; // equal to number of acoutic tokens for Mimi + //int64_t csm_n_audio_vocab = csm_audio_cbook_size*csm_acoustic_tokens; + + csm_output_cbook = create_tensor(tn(LLM_TENSOR_CSM_CBOOK_OUTPUT, "weight"), {n_embd, csm_audio_cbook_size}, TENSOR_NOT_REQUIRED); + + bool is_backbone = csm_output_cbook != nullptr; + + csm_output_audio = is_backbone ? nullptr + : create_tensor(tn(LLM_TENSOR_CSM_AUDIO_OUTPUT, "weight"), {n_embd, csm_audio_cbook_size, csm_acoustic_tokens}, 0); + + tok_embd = is_backbone + ? create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0) + : create_tensor(tn(LLM_TENSOR_CSM_AUDIO_EMBD, "weight"), {n_embd*2, n_vocab}, 0); + + csm_proj = is_backbone + ? create_tensor(tn(LLM_TENSOR_CSM_PROJ, "weight"), {n_embd, n_embd/2}, 0) + : create_tensor(tn(LLM_TENSOR_CSM_PROJ, "weight"), {n_embd*2, n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // output tensor is either audio or code depends on backbone / decoder + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_DECI: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4276,21 +4312,190 @@ struct llm_build_llama : public llm_graph_context { cb(cur, "result_norm", -1); res->t_embd = cur; + // lm_head (normal case) + cur = build_lora_mm(model.output, cur); + + // For Granite architecture + if (hparams.f_logit_scale) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +// llama used by Sesame CSM +struct llm_build_llama_csm : public llm_graph_context { + llm_build_llama_csm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + bool is_backbone = model.csm_output_cbook; + bool is_decoder = !is_backbone; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * input_embd = inpL; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + if (is_decoder && inpL->ne[0] != hparams.n_embd) { + inpL = build_lora_mm(model.csm_proj, inpL); + } + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architecture + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + if (model.csm_output_cbook) { // Sesame csm backbone - // hack: because n_cbook < n_vocab, we use the first logits for the codebook output + // hack: because n_cbook < n_vocab, we use the first logits for the output int64_t n_vocab = model.tok_embd->ne[1]; int64_t n_codes = model.csm_output_cbook->ne[1]; + ggml_tensor * last_h = cur; cur = build_lora_mm(model.csm_output_cbook, cur); cur = ggml_pad(ctx0, cur, n_vocab - n_codes, 0, 0, 0); - } else { - // lm_head (normal case) - cur = build_lora_mm(model.output, cur); - } - // For Granite architecture - if (hparams.f_logit_scale) { - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); + // project to csm decoder dim + last_h = build_lora_mm(model.csm_proj, last_h); + cb(last_h, "output_csm_proj", -1); // use callback to retrieve the result + ggml_build_forward_expand(gf, last_h); + + } else if (model.csm_output_audio && ggml_nelements(cur)) { + // Sesame csm decoder + // hack: because n_audio < n_vocab, we use the first logits for the output + cur = build_lora_mm_id(model.csm_output_audio, cur, inp_pos); + int64_t n_vocab = model.tok_embd->ne[1]; + int64_t n_codes = cur->ne[0]; + cur = ggml_pad(ctx0, cur, n_vocab - n_codes, cur->ne[1], 0, 0); + + // also get audio embeddings, which will be passed back to backbone to keep track of generation progress + if (ubatch.token) { + cb(input_embd, "output_audio_embd", -1); + ggml_build_forward_expand(gf, input_embd); + } + + } else { + // otherwise, dummy output } cb(cur, "result_output", -1); @@ -11896,6 +12101,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_LLAMA_CSM: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_DECI: { llm = std::make_unique(*this, params, gf); @@ -12234,6 +12443,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA_CSM: case LLM_ARCH_DECI: case LLM_ARCH_BAICHUAN: case LLM_ARCH_STARCODER: diff --git a/src/llama-model.h b/src/llama-model.h index 6c368d691c0d0..296f8f16e4712 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -338,9 +338,9 @@ struct llama_model { struct ggml_tensor * conv1d_b = nullptr; // sesame csm - struct ggml_tensor * csm_output_cbook = nullptr; // backbone codebook + struct ggml_tensor * csm_output_cbook = nullptr; // backbone output codebook struct ggml_tensor * csm_output_audio = nullptr; // audio decoder output - struct ggml_tensor * csm_input_proj = nullptr; // audio decoder input projection + struct ggml_tensor * csm_proj = nullptr; // to convert backbone dim to decoder dim std::vector layers; From f9162e7005469fefe9f4a73c1373003cd29aa61f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 01:03:04 +0100 Subject: [PATCH 07/31] wip --- examples/tts/tts-csm.cpp | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index 34eeb4b4db4d9..5b0a23b2141ad 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -106,7 +106,7 @@ int main(int argc, char ** argv) { std::vector inp_past_embd(2048, 0.0f); llama_batch batch_past_embd = llama_batch_init(1, inp_past_embd.size(), 1); - for (int k = 0; k < 4; ++k) { + for (int k = 0; k < 32; ++k) { if (llama_decode(ctx_bb, k == 0 ? batch : batch_past_embd) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; @@ -121,7 +121,7 @@ int main(int argc, char ** argv) { llama_token latent_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); // printf("latent_token: %d\n", latent_token); - printf("%5d, ", latent_token); + printf("%d,", latent_token); // for (size_t i = 0; i < 10; ++i) { // printf("%4.2f, ", embd[i]); @@ -149,7 +149,9 @@ int main(int argc, char ** argv) { llama_decode(ctx_dc, batch_embd); llama_token audio_token = latent_token; - for (int i = 0; i < 31; ++i) { + int n_codes = 32; + int sum_codes = 0; + for (int i = 0; i < n_codes; ++i) { common_batch_clear(batch_token); // encoder vocab is further divided into 32 codebooks, each with 2051 entries llama_token inp_tok = audio_token + 2051*i; @@ -157,8 +159,13 @@ int main(int argc, char ** argv) { llama_decode(ctx_dc, batch_token); auto logits = llama_get_logits_ith(ctx_dc, 0); audio_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); - printf("%d,", audio_token); - prompt_tokens.push_back(audio_token); + + // discard last code + if (i < n_codes - 1) { + printf("%d,", audio_token); + prompt_tokens.push_back(audio_token); + sum_codes += audio_token; + } GGML_ASSERT(inp_past_embd.size() == embd.size()); for (size_t i = 0; i < inp_past_embd.size(); ++i) { @@ -169,8 +176,22 @@ int main(int argc, char ** argv) { llama_batch_free(batch_embd); llama_batch_free(batch_token); + + if (sum_codes == 0) { + return 0; // done + } } + // printf("inp_past_embd, n_past_bb = %d\n", n_past_bb); + // for (size_t i = 0; i < inp_past_embd.size(); ++i) { + // printf("%4.4f, ", inp_past_embd[i]); + // if (i == 2) { + // printf("... "); + // i = inp_past_embd.size() - 4; + // } + // } + // printf("\n"); + // prepare for the next iteration { batch_past_embd.n_tokens = 1; From eae5f0e1ced91eaffc5f148147c982b3d38877e2 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 10:50:35 +0200 Subject: [PATCH 08/31] add mimi_model::transpose_input --- examples/tts/mimi-model.cpp | 27 ++++++++++++++++++--------- examples/tts/mimi-model.h | 5 +++++ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/examples/tts/mimi-model.cpp b/examples/tts/mimi-model.cpp index 92bb47a8365d7..ded56ff317d63 100644 --- a/examples/tts/mimi-model.cpp +++ b/examples/tts/mimi-model.cpp @@ -617,7 +617,7 @@ std::vector mimi_model::decode_frame(const std::vector & codes, int int n_pos = -1; int n_codes = codes.size(); int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components; - GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiple of n_codes_per_embd"); + GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiply of n_codes_per_embd"); ctx->build_graph([&](ggml_context * ctx_gf, ggml_cgraph * gf) { ggml_tensor * inp_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_codes); @@ -661,14 +661,6 @@ std::vector mimi_model::decode_frame(const std::vector & codes, int ctx->set_tensor_data("pos_dec", pos_data.data()); // code data - /*std::vector codes_t(n_codes_per_embd * n_codes); - for (int i = 0; i < n_codes / n_codes_per_embd; i++) { - for (int j = 0; j < n_codes_per_embd; j++) { - int src_idx = i * n_codes_per_embd + j; - int dst_idx = j * (n_codes / n_codes_per_embd) + i; - codes_t[dst_idx] = codes[src_idx]; - } - }*/ ctx->set_tensor_data("inp_dec", codes.data()); ctx->compute(); @@ -715,6 +707,23 @@ std::vector mimi_model::decode(const std::vector & codes) { return output; } +std::vector mimi_model::transpose_input(const std::vector & codes) { + int n_codes = codes.size(); + int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components; + GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiply of n_codes_per_embd"); + + std::vector codes_T(n_codes_per_embd * n_codes); + for (int i = 0; i < n_codes / n_codes_per_embd; i++) { + for (int j = 0; j < n_codes_per_embd; j++) { + int src_idx = i * n_codes_per_embd + j; + int dst_idx = j * (n_codes / n_codes_per_embd) + i; + codes_T[dst_idx] = codes[src_idx]; + } + } + + return codes_T; +} + int mimi_model::get_sample_rate() const { return mimi_config.sample_rate; } diff --git a/examples/tts/mimi-model.h b/examples/tts/mimi-model.h index c26fd3bc08e9f..96945981513c0 100644 --- a/examples/tts/mimi-model.h +++ b/examples/tts/mimi-model.h @@ -22,6 +22,11 @@ struct mimi_model { int get_sample_rate() const; + // transpose layout: + // - from: (1 semantic code followed by 31 acoustic codes) repeast N times + // - to: N semantic codes followed by (N*31) acoustic codes + std::vector transpose_input(const std::vector & codes); + // layout of codes: N semantic codes followed by (N*31) acoustic codes std::vector decode(const std::vector & codes); From 43bf237e3975a80fe0a52204ad08c7d43999c594 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 10:52:18 +0200 Subject: [PATCH 09/31] fix build --- examples/tts/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index 39e0a92c5acb4..371c3bbf7434d 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -8,4 +8,5 @@ set(TARGET llama-mimi) add_executable(${TARGET} mimi.cpp mimi-model.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) +# for using C++ designated initializers, TODO: can be changed back to C++17 in the future +target_compile_features(${TARGET} PRIVATE cxx_std_20) From e618405d4b9040f9536e2acc6761eac004146969 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 11:07:18 +0200 Subject: [PATCH 10/31] fix build (2) --- examples/tts/mimi-model.cpp | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/examples/tts/mimi-model.cpp b/examples/tts/mimi-model.cpp index ded56ff317d63..141dd1043923b 100644 --- a/examples/tts/mimi-model.cpp +++ b/examples/tts/mimi-model.cpp @@ -15,6 +15,7 @@ #include #include #include +#include /** * Implementation of Kyutai's Mimi model using GGML. @@ -344,10 +345,10 @@ struct mimi_encoder_decoder { bool is_elu = false; bool is_resnet = false; bool is_transposed_conv = false; - ggml_tensor * conv_0_w; - ggml_tensor * conv_0_b; - ggml_tensor * conv_1_w; - ggml_tensor * conv_1_b; + ggml_tensor * conv_0_w = nullptr; + ggml_tensor * conv_0_b = nullptr; + ggml_tensor * conv_1_w = nullptr; + ggml_tensor * conv_1_b = nullptr; int stride = 1; }; std::vector layers; @@ -415,20 +416,20 @@ struct mimi_encoder_decoder { struct mimi_transformer { struct layer { - ggml_tensor * inp_norm_w; - ggml_tensor * inp_norm_b; - - ggml_tensor * attn_q; - ggml_tensor * attn_k; - ggml_tensor * attn_v; - ggml_tensor * attn_o; - ggml_tensor * attn_post_norm_w; - ggml_tensor * attn_post_norm_b; - ggml_tensor * attn_layer_scale; - - ggml_tensor * ffn_up; - ggml_tensor * ffn_down; - ggml_tensor * mlp_layer_scale; + ggml_tensor * inp_norm_w = nullptr; + ggml_tensor * inp_norm_b = nullptr; + + ggml_tensor * attn_q = nullptr; + ggml_tensor * attn_k = nullptr; + ggml_tensor * attn_v = nullptr; + ggml_tensor * attn_o = nullptr; + ggml_tensor * attn_post_norm_w = nullptr; + ggml_tensor * attn_post_norm_b = nullptr; + ggml_tensor * attn_layer_scale = nullptr; + + ggml_tensor * ffn_up = nullptr; + ggml_tensor * ffn_down = nullptr; + ggml_tensor * mlp_layer_scale = nullptr; }; std::vector layers; From e185e0ac7fb10b2b933caae2823a04f44d703967 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 11:44:34 +0200 Subject: [PATCH 11/31] fix build (3) --- examples/tts/mimi-model.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/tts/mimi-model.cpp b/examples/tts/mimi-model.cpp index 141dd1043923b..0b1fabe86088e 100644 --- a/examples/tts/mimi-model.cpp +++ b/examples/tts/mimi-model.cpp @@ -16,6 +16,7 @@ #include #include #include +#include /** * Implementation of Kyutai's Mimi model using GGML. @@ -367,10 +368,10 @@ struct mimi_encoder_decoder { .is_elu = true, // layer (i_start) }); layers.push_back({ + .is_transposed_conv = true, .conv_0_w = ctx.get_weight("decoder.layers.%d.conv.weight", i_start + 1), .conv_0_b = ctx.get_weight("decoder.layers.%d.conv.bias", i_start + 1), .stride = mimi_config.upsampling_ratio[i], - .is_transposed_conv = true, }); // residual layers layers.push_back({ From ce83041ec3205b2586fca7d52ac9cef5c0ddc446 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 11:45:36 +0200 Subject: [PATCH 12/31] fix strcmp --- examples/tts/mimi.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/tts/mimi.cpp b/examples/tts/mimi.cpp index 421c9e418ecc6..502e0150634b7 100644 --- a/examples/tts/mimi.cpp +++ b/examples/tts/mimi.cpp @@ -3,6 +3,7 @@ #include #include +#include // strcmp /** From 61d8ad6aef03879ca7193a302a0f549a40d761cb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 12:04:33 +0200 Subject: [PATCH 13/31] fix compilation on linux --- examples/tts/convert_mimi_to_gguf.py | 4 ++-- examples/tts/mimi-model.cpp | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/tts/convert_mimi_to_gguf.py b/examples/tts/convert_mimi_to_gguf.py index 5b44ef62103ba..5dce72a398a91 100644 --- a/examples/tts/convert_mimi_to_gguf.py +++ b/examples/tts/convert_mimi_to_gguf.py @@ -5,13 +5,13 @@ from typing import Union from pathlib import Path from torch import Tensor -from transformers import MimiModel +from transformers import MimiModel, PreTrainedModel logger = logging.getLogger("mimi") class MimiModelConverter: - mimi_model: MimiModel + mimi_model: PreTrainedModel gguf_writer: gguf.GGUFWriter fname_out: Path ftype: gguf.LlamaFileType diff --git a/examples/tts/mimi-model.cpp b/examples/tts/mimi-model.cpp index 0b1fabe86088e..427aeff8658bf 100644 --- a/examples/tts/mimi-model.cpp +++ b/examples/tts/mimi-model.cpp @@ -17,6 +17,8 @@ #include #include #include +#include +#include /** * Implementation of Kyutai's Mimi model using GGML. From 40120540afccd23bcf31c95d720f3292097815e0 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 13:02:07 +0200 Subject: [PATCH 14/31] clean up --- examples/tts/README-csm.md | 31 +++++++ examples/tts/convert_csm_to_gguf.py | 2 +- examples/tts/tts-csm.cpp | 134 ++++++++++++++++++---------- 3 files changed, 120 insertions(+), 47 deletions(-) create mode 100644 examples/tts/README-csm.md diff --git a/examples/tts/README-csm.md b/examples/tts/README-csm.md new file mode 100644 index 0000000000000..f660d8965ecbe --- /dev/null +++ b/examples/tts/README-csm.md @@ -0,0 +1,31 @@ +# Sesame CSM + +To get the GGUF: + +```sh +python examples/tts/convert_csm_to_gguf.py + +# default output files: +# sesame-csm-backbone.gguf +# sesame-csm-decoder.gguf + +# optionally, quantize it +# (lowest scheme is q8_0, it does not make sense to quantize further, quality degrades too much) +python examples/tts/convert_csm_to_gguf.py --outtype q8_0 +``` + +Compile the example: + +```sh +cmake --build build -j --target llama-tts-csm +``` + +Run the example: + +```sh +./build/bin/llama-tts-csm -m sesame-csm-backbone.gguf -p "[0]Hello world." +# sesame-csm-backbone.gguf will automatically be loaded +# make sure the place these 2 GGUF files in the same directory + +# output file: output.wav +``` diff --git a/examples/tts/convert_csm_to_gguf.py b/examples/tts/convert_csm_to_gguf.py index 183ea98b7076d..09b7748c2a63d 100644 --- a/examples/tts/convert_csm_to_gguf.py +++ b/examples/tts/convert_csm_to_gguf.py @@ -95,7 +95,7 @@ def __init__(self, fname_out: Path, ftype: gguf.LlamaFileType, is_big_endian: bool,): - + if "" not in fname_out.name: raise ValueError("Output file name must contain '' placeholder, for example: 'sesame-csm-.gguf'") diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index 5b0a23b2141ad..b4b01331d2d22 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -6,6 +6,10 @@ #include #include #include +#include // memcpy and strcmp +#include + +// For more details on how this works, see: https://github.com/ggml-org/llama.cpp/pull/12648 static void print_usage(int, char ** argv) { LOG("\nexample usage:\n"); @@ -30,6 +34,8 @@ static llama_token sample_greedy(const float * logits, int n_vocab) { static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) { std::vector * embd = (std::vector *) user_data; + // output_csm_proj is the embeddings output from backbone + // output_audio_embd is the embeddings output from decoder if (t && (strcmp(t->name, "output_csm_proj") == 0 || strcmp(t->name, "output_audio_embd") == 0)) { if (ask) return true; @@ -45,13 +51,10 @@ static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) { int main(int argc, char ** argv) { common_params params; - params.model = "sesame-csm-backbone.gguf"; - params.out_file = "output.wav"; - params.prompt = "[0]Hello from Sesame."; - - params.n_predict = 4096; - params.n_batch = 8192; - params.n_ctx = 8192; + params.model = "sesame-csm-backbone.gguf"; + params.out_file = "output.wav"; + params.prompt = "[0]Hello from Sesame."; + params.n_predict = 2048; // CSM's max trained seq length if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) { return 1; @@ -66,6 +69,7 @@ int main(int argc, char ** argv) { params.warmup = false; common_params params_decoder(params); // duplicate the params + params_decoder.n_ctx = 64; // we never use more than this string_replace_all(params_decoder.model, "-backbone", "-decoder"); common_init_result llama_backbone = common_init_from_params(params); @@ -96,77 +100,114 @@ int main(int argc, char ** argv) { printf("\n"); llama_pos n_past_bb = 0; - llama_batch batch = llama_batch_init(params.n_batch, 0, 1); - common_batch_clear(batch); + llama_batch batch_prompt = llama_batch_init(params.n_batch, 0, 1); + common_batch_clear(batch_prompt); for (size_t i = 0; i < prompt_tokens.size(); ++i) { - common_batch_add(batch, prompt_tokens[i], n_past_bb++, { 0 }, false); + common_batch_add(batch_prompt, prompt_tokens[i], n_past_bb++, { 0 }, false); } - batch.logits[batch.n_tokens - 1] = true; + batch_prompt.logits[batch_prompt.n_tokens - 1] = true; + // inp_past_embd is the "squashed" embeddings from the decoder std::vector inp_past_embd(2048, 0.0f); llama_batch batch_past_embd = llama_batch_init(1, inp_past_embd.size(), 1); - for (int k = 0; k < 32; ++k) { - if (llama_decode(ctx_bb, k == 0 ? batch : batch_past_embd) != 0) { - LOG_ERR("%s: llama_decode() failed\n", __func__); + int64_t t_gb_start = ggml_time_ms(); // global start time + int64_t t_bb = 0; // backbone time + int64_t n_bb_gen = 0; // backbone generation count + int64_t t_dc = 0; // decoder time + int64_t n_dc_gen = 0; // decoder generation count + + bool is_stop = false; + + // backbone generation loop + for (int k = 0; k < params.n_predict; ++k) { + bool is_prompt_processing = k == 0; + + if (!is_prompt_processing) { + // generate the next RVQ semantic token + batch_past_embd.n_tokens = 1; + batch_past_embd.pos[0] = n_past_bb++; + batch_past_embd.seq_id[0][0] = 0; + batch_past_embd.n_seq_id[0] = 1; + batch_past_embd.logits[0] = true; + std::memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float)); + } + + int64_t t_bb_start = ggml_time_ms(); + if (llama_decode(ctx_bb, is_prompt_processing ? batch_prompt : batch_past_embd) != 0) { + LOG_ERR("%s: backbone llama_decode() failed\n", __func__); return 1; } + n_bb_gen++; + t_bb += ggml_time_ms() - t_bb_start; auto vocab_dc = llama_model_get_vocab(model_dc); - auto logits = llama_get_logits_ith(ctx_bb, k == 0 ? (batch.n_tokens - 1) : 0); + auto logits = llama_get_logits_ith(ctx_bb, is_prompt_processing ? (batch_prompt.n_tokens - 1) : 0); // for (size_t i = 0; i < 10; ++i) { // printf("%4.2f, ", logits[i]); // } // printf("\n"); - llama_token latent_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); - // printf("latent_token: %d\n", latent_token); - printf("%d,", latent_token); + llama_token semantic_tok = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); + printf("%d,", semantic_tok); // for (size_t i = 0; i < 10; ++i) { // printf("%4.2f, ", embd[i]); // } // printf("\n"); - - // decode - prompt_tokens.clear(); - prompt_tokens.push_back(latent_token); + // decoder generation loop inp_past_embd = std::vector(inp_past_embd.size(), 0.0f); { llama_kv_self_clear(ctx_dc); llama_batch batch_embd = llama_batch_init(1, embd.size(), 1); llama_batch batch_token = llama_batch_init(1, 0, 1); + + // first "token" is the latent embeddings from backbone { batch_embd.n_tokens = 1; batch_embd.pos[0] = 0; batch_embd.seq_id[0][0] = 0; batch_embd.n_seq_id[0] = 1; batch_embd.logits[0] = false; - memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float)); + std::memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float)); + } + if (llama_decode(ctx_dc, batch_embd) != 0) { + LOG_ERR("%s: decoder llama_decode(embd) failed\n", __func__); + return 1; } - llama_decode(ctx_dc, batch_embd); - - llama_token audio_token = latent_token; + + // then, decode the semantic_tok to generate acoustic tokens + llama_token tok = semantic_tok; int n_codes = 32; - int sum_codes = 0; + int sum_codes = 0; // to check if all codes are 0 for (int i = 0; i < n_codes; ++i) { common_batch_clear(batch_token); // encoder vocab is further divided into 32 codebooks, each with 2051 entries - llama_token inp_tok = audio_token + 2051*i; + llama_token inp_tok = tok + 2051*i; common_batch_add(batch_token, inp_tok, i+1, { 0 }, true); - llama_decode(ctx_dc, batch_token); + + int64_t t_bb_start = ggml_time_ms(); + if (llama_decode(ctx_dc, batch_token) != 0) { + LOG_ERR("%s: decoder llama_decode(token) failed\n", __func__); + return 1; + } + n_dc_gen++; + t_dc += ggml_time_ms() - t_bb_start; + + // sample the acoustic token auto logits = llama_get_logits_ith(ctx_dc, 0); - audio_token = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); + llama_token acoustic_tok = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); - // discard last code + // discard last code (only for embeddings) if (i < n_codes - 1) { - printf("%d,", audio_token); - prompt_tokens.push_back(audio_token); - sum_codes += audio_token; + printf("%d,", acoustic_tok); + tok = acoustic_tok; // next input token + sum_codes += acoustic_tok; } + // do progressive hsum of embeddings GGML_ASSERT(inp_past_embd.size() == embd.size()); for (size_t i = 0; i < inp_past_embd.size(); ++i) { inp_past_embd[i] += embd[i]; @@ -177,9 +218,8 @@ int main(int argc, char ** argv) { llama_batch_free(batch_embd); llama_batch_free(batch_token); - if (sum_codes == 0) { - return 0; // done - } + // if all codes are 0, then we are done + is_stop = sum_codes == 0; } // printf("inp_past_embd, n_past_bb = %d\n", n_past_bb); @@ -192,17 +232,19 @@ int main(int argc, char ** argv) { // } // printf("\n"); - // prepare for the next iteration - { - batch_past_embd.n_tokens = 1; - batch_past_embd.pos[0] = n_past_bb; - batch_past_embd.seq_id[0][0] = 0; - batch_past_embd.n_seq_id[0] = 1; - batch_past_embd.logits[0] = true; - memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float)); + if (is_stop) { + break; } - n_past_bb++; } + // print timing info + printf("\ntimings:\n"); + printf(" backbone: %" PRId64 " ms, %" PRId64 " generated token (%.2f tok/s)\n", t_bb, n_bb_gen, (float)n_bb_gen*1000/(float)t_bb); + printf(" decoder: %" PRId64 " ms, %" PRId64 " generated token (%.2f tok/s)\n", t_dc, n_dc_gen, (float)n_dc_gen*1000/(float)t_dc); + printf(" total: %" PRId64 " ms\n\n", ggml_time_ms() - t_gb_start); + + llama_batch_free(batch_prompt); + llama_batch_free(batch_past_embd); + return 0; } From 7ecce7645576854d095726810ff9787df5380f65 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 14:21:55 +0200 Subject: [PATCH 15/31] working now --- examples/tts/CMakeLists.txt | 20 ++++++++------ examples/tts/convert_mimi_to_gguf.py | 2 +- examples/tts/mimi-model.cpp | 2 +- examples/tts/tts-csm.cpp | 41 ++++++++++++++++++++++++---- 4 files changed, 50 insertions(+), 15 deletions(-) diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index 58a8599148bab..ab184a85ba17b 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -4,15 +4,19 @@ install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) -set(TARGET llama-tts-csm) -add_executable(${TARGET} tts-csm.cpp) +add_library(mimi-model mimi-model.h mimi-model.cpp) +target_link_libraries(mimi-model PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +# for using C++ designated initializers, TODO: can be changed back to C++17 in the future +target_compile_features(mimi-model PRIVATE cxx_std_20) + +set(TARGET llama-mimi) +add_executable(${TARGET} mimi.cpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(${TARGET} PRIVATE llama common mimi-model ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) -set(TARGET llama-mimi) -add_executable(${TARGET} mimi.cpp mimi-model.cpp) +set(TARGET llama-tts-csm) +add_executable(${TARGET} tts-csm.cpp) install(TARGETS ${TARGET} RUNTIME) -target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) -# for using C++ designated initializers, TODO: can be changed back to C++17 in the future -target_compile_features(${TARGET} PRIVATE cxx_std_20) +target_link_libraries(${TARGET} PRIVATE llama common mimi-model ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/tts/convert_mimi_to_gguf.py b/examples/tts/convert_mimi_to_gguf.py index 5dce72a398a91..81cb8f48cc25e 100644 --- a/examples/tts/convert_mimi_to_gguf.py +++ b/examples/tts/convert_mimi_to_gguf.py @@ -27,7 +27,7 @@ def __init__(self, endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE self.gguf_writer = gguf.GGUFWriter( path=None, - arch="if you see this, you are using the wrong file", + arch="this model cannot be used as LLM, use it via --model-vocoder in TTS examples", endianess=endianess) assert self.mimi_model.config.architectures[0] == "MimiModel" diff --git a/examples/tts/mimi-model.cpp b/examples/tts/mimi-model.cpp index 427aeff8658bf..3663201dc5971 100644 --- a/examples/tts/mimi-model.cpp +++ b/examples/tts/mimi-model.cpp @@ -716,7 +716,7 @@ std::vector mimi_model::transpose_input(const std::vector & codes) { int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components; GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiply of n_codes_per_embd"); - std::vector codes_T(n_codes_per_embd * n_codes); + std::vector codes_T(n_codes); for (int i = 0; i < n_codes / n_codes_per_embd; i++) { for (int j = 0; j < n_codes_per_embd; j++) { int src_idx = i * n_codes_per_embd + j; diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index b4b01331d2d22..843d7f6b79196 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "arg.h" +#include "mimi-model.h" #include #include @@ -13,7 +14,13 @@ static void print_usage(int, char ** argv) { LOG("\nexample usage:\n"); - LOG("\n %s TODO ", argv[0]); + LOG("\n By default, model will be downloaded from https://huggingface.co/ggml-org/sesame-csm-1b-GGUF"); + LOG("\n %s -p \"[0]I have a dream that one day every valley shall be exalted\" -o output.wav", argv[0]); + LOG("\n"); + LOG("\n To use a local model, specify the path to the model file:"); + LOG("\n %s -p ... -m sesame-csm-backbone.gguf -mv kyutai-mimi.gguf -o output.wav", argv[0]); + LOG("\n"); + LOG("\n Note: the model need 2 files to run, one ends with '-backbone-.gguf' and the other ends with '-decoder.gguf'"); LOG("\n"); } @@ -51,10 +58,15 @@ static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) { int main(int argc, char ** argv) { common_params params; - params.model = "sesame-csm-backbone.gguf"; - params.out_file = "output.wav"; - params.prompt = "[0]Hello from Sesame."; - params.n_predict = 2048; // CSM's max trained seq length + params.model = "sesame-csm-backbone.gguf"; + params.vocoder.model = "kyutai-mimi.gguf"; + params.out_file = "output.wav"; + params.prompt = "[0]Hello from Sesame."; + params.n_predict = 2048; // CSM's max trained seq length + + // HF model + params.model_url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf"; + params.vocoder.model_url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/kyutai-mimi.gguf"; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) { return 1; @@ -71,6 +83,9 @@ int main(int argc, char ** argv) { common_params params_decoder(params); // duplicate the params params_decoder.n_ctx = 64; // we never use more than this string_replace_all(params_decoder.model, "-backbone", "-decoder"); + if (!params_decoder.model_url.empty()) { + string_replace_all(params_decoder.model_url, "-backbone", "-decoder"); + } common_init_result llama_backbone = common_init_from_params(params); llama_model * model_bb = llama_backbone.model.get(); @@ -88,6 +103,8 @@ int main(int argc, char ** argv) { return ENOENT; } + mimi_model mimi(params.vocoder.model.c_str(), true); + const llama_vocab * vocab = llama_model_get_vocab(model_bb); llama_tokens prompt_tokens = common_tokenize(vocab, params.prompt, false, true); prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab)); @@ -118,6 +135,7 @@ int main(int argc, char ** argv) { int64_t n_dc_gen = 0; // decoder generation count bool is_stop = false; + std::vector generated_codes; // backbone generation loop for (int k = 0; k < params.n_predict; ++k) { @@ -150,6 +168,7 @@ int main(int argc, char ** argv) { llama_token semantic_tok = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); printf("%d,", semantic_tok); + generated_codes.push_back(semantic_tok); // for (size_t i = 0; i < 10; ++i) { // printf("%4.2f, ", embd[i]); @@ -205,6 +224,7 @@ int main(int argc, char ** argv) { printf("%d,", acoustic_tok); tok = acoustic_tok; // next input token sum_codes += acoustic_tok; + generated_codes.push_back(acoustic_tok); } // do progressive hsum of embeddings @@ -246,5 +266,16 @@ int main(int argc, char ** argv) { llama_batch_free(batch_prompt); llama_batch_free(batch_past_embd); + printf("decode %zu RVQ tokens into wav...\n", generated_codes.size()); + generated_codes = mimi.transpose_input(generated_codes); + std::vector wav_data = mimi.decode(generated_codes); + + if (!save_wav16(params.out_file.c_str(), wav_data, mimi.get_sample_rate())) { + LOG_ERR("Failed to save wav file\n"); + return 1; + } + + printf("\n"); + return 0; } From 6976682fbc2226675610fedfdfd2ea9b8ce231a4 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 14:28:13 +0200 Subject: [PATCH 16/31] update readme --- examples/tts/README-csm.md | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/examples/tts/README-csm.md b/examples/tts/README-csm.md index f660d8965ecbe..676b9889e157d 100644 --- a/examples/tts/README-csm.md +++ b/examples/tts/README-csm.md @@ -1,5 +1,27 @@ # Sesame CSM +This demo shows running inference of [Sesame CSM](https://github.com/SesameAILabs/csm) using llama.cpp / GGML + +It contains 3 components (each has its own GGUF file): +1. Backbone LLM +2. Decoder LLM +3. Mimi decoder + +## Quick start + +By default, all GGUF files are downloaded from [ggml-org Hugging Face's account](https://huggingface.co/ggml-org/sesame-csm-1b-GGUF) + +```sh +# build (make sure to have LLAMA_CURL enabled) +cmake -B build -DLLAMA_CURL=ON +cmake --build build -j --target llama-tts-csm + +# run it +./build/bin/llama-tts-csm -p "[0]Hi, my name is Xuan Son. I am software engineer at Hugging Face." +``` + +## Convert the model yourself + To get the GGUF: ```sh @@ -14,16 +36,10 @@ python examples/tts/convert_csm_to_gguf.py python examples/tts/convert_csm_to_gguf.py --outtype q8_0 ``` -Compile the example: - -```sh -cmake --build build -j --target llama-tts-csm -``` - -Run the example: +Run the example using local file: ```sh -./build/bin/llama-tts-csm -m sesame-csm-backbone.gguf -p "[0]Hello world." +./build/bin/llama-tts-csm -m sesame-csm-backbone.gguf -mv kyutai-mimi.gguf -p "[0]Hello world." # sesame-csm-backbone.gguf will automatically be loaded # make sure the place these 2 GGUF files in the same directory From 1e9afd9d816881d7a244c7639eda2eea85da3ad3 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 14:45:04 +0200 Subject: [PATCH 17/31] nits --- examples/tts/mimi.cpp | 1 + examples/tts/tts-csm.cpp | 2 ++ 2 files changed, 3 insertions(+) diff --git a/examples/tts/mimi.cpp b/examples/tts/mimi.cpp index 502e0150634b7..17047c9b4f1ce 100644 --- a/examples/tts/mimi.cpp +++ b/examples/tts/mimi.cpp @@ -79,6 +79,7 @@ int main(int argc, const char ** argv) { while (std::getline(fin, line)) { // Skip empty lines if (line.empty()) continue; + // TODO: support both comma (with spaces) and new line try { int code = std::stoi(line); codes.push_back(code); diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index 843d7f6b79196..4d77e5f6d3169 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -270,6 +270,8 @@ int main(int argc, char ** argv) { generated_codes = mimi.transpose_input(generated_codes); std::vector wav_data = mimi.decode(generated_codes); + printf("output wav file: %s\n", params.out_file.c_str()); + if (!save_wav16(params.out_file.c_str(), wav_data, mimi.get_sample_rate())) { LOG_ERR("Failed to save wav file\n"); return 1; From 40ab1ab30755d78f59c1e8ce16cc249a4e0c5954 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 14:51:29 +0200 Subject: [PATCH 18/31] fix mul_mat_id read out-of-bound --- examples/tts/convert_csm_to_gguf.py | 6 +++--- src/llama-model.cpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/tts/convert_csm_to_gguf.py b/examples/tts/convert_csm_to_gguf.py index 09b7748c2a63d..53f586f19962d 100644 --- a/examples/tts/convert_csm_to_gguf.py +++ b/examples/tts/convert_csm_to_gguf.py @@ -179,10 +179,10 @@ def rename_transformer(name: str) -> str: is_decoder = True name = "audio_head.weight" if component == "decoder": - # add padding at the beginning so that build_lora_mm_id can be used + # add padding at the beginning and the end so that build_lora_mm_id can be used zero_tensor = torch.zeros(1, 1024, 2051) - data_torch = torch.cat([zero_tensor, data_torch], dim=0) - assert data_torch.shape == (32, 1024, 2051) + data_torch = torch.cat([zero_tensor, data_torch, zero_tensor], dim=0) + assert data_torch.shape == (33, 1024, 2051) # then, transpose it data_torch = data_torch.transpose(1, 2) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 11c34b4deaa52..64c6978e98b31 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1662,7 +1662,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { { // TODO: maybe store these in gguf metadata int64_t csm_audio_cbook_size = 2051; // audio codebook size - int64_t csm_acoustic_tokens = 32; // equal to number of acoutic tokens for Mimi + int64_t csm_audio_tokens = 32; // equal to number of audio tokens for Mimi //int64_t csm_n_audio_vocab = csm_audio_cbook_size*csm_acoustic_tokens; csm_output_cbook = create_tensor(tn(LLM_TENSOR_CSM_CBOOK_OUTPUT, "weight"), {n_embd, csm_audio_cbook_size}, TENSOR_NOT_REQUIRED); @@ -1670,7 +1670,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { bool is_backbone = csm_output_cbook != nullptr; csm_output_audio = is_backbone ? nullptr - : create_tensor(tn(LLM_TENSOR_CSM_AUDIO_OUTPUT, "weight"), {n_embd, csm_audio_cbook_size, csm_acoustic_tokens}, 0); + : create_tensor(tn(LLM_TENSOR_CSM_AUDIO_OUTPUT, "weight"), {n_embd, csm_audio_cbook_size, csm_audio_tokens+1}, 0); tok_embd = is_backbone ? create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0) From eaba2bfbcf7af45a140c6dad3803084df29cb922 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 15:05:32 +0200 Subject: [PATCH 19/31] will this fix windows build? --- examples/tts/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index ab184a85ba17b..e66c298db461a 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -4,7 +4,7 @@ install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) -add_library(mimi-model mimi-model.h mimi-model.cpp) +add_library(mimi-model STATIC mimi-model.h mimi-model.cpp) target_link_libraries(mimi-model PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) # for using C++ designated initializers, TODO: can be changed back to C++17 in the future target_compile_features(mimi-model PRIVATE cxx_std_20) From 5fe27efcebc0f4d01cc0a52c763db5365ec0634c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 30 Mar 2025 23:49:49 +0200 Subject: [PATCH 20/31] (try) fixing problem with long text --- examples/tts/tts-csm.cpp | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index 4d77e5f6d3169..fb5146a3bbcc9 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -13,7 +13,7 @@ // For more details on how this works, see: https://github.com/ggml-org/llama.cpp/pull/12648 static void print_usage(int, char ** argv) { - LOG("\nexample usage:\n"); + LOG("\nExample usage:\n"); LOG("\n By default, model will be downloaded from https://huggingface.co/ggml-org/sesame-csm-1b-GGUF"); LOG("\n %s -p \"[0]I have a dream that one day every valley shall be exalted\" -o output.wav", argv[0]); LOG("\n"); @@ -22,6 +22,11 @@ static void print_usage(int, char ** argv) { LOG("\n"); LOG("\n Note: the model need 2 files to run, one ends with '-backbone-.gguf' and the other ends with '-decoder.gguf'"); LOG("\n"); + LOG("\nPrompt format:"); + LOG("\n Each line must start with speaker ID in square brackets, followed by the text. A full stop is recommended at the end of each turn"); + LOG("\n Example: [0]Hello world."); + LOG("\n If you want to enter long text, use -f file.txt to read from file"); + LOG("\n"); } // greedy sampling with custom n_vocab @@ -61,7 +66,7 @@ int main(int argc, char ** argv) { params.model = "sesame-csm-backbone.gguf"; params.vocoder.model = "kyutai-mimi.gguf"; params.out_file = "output.wav"; - params.prompt = "[0]Hello from Sesame."; + params.prompt = ""; params.n_predict = 2048; // CSM's max trained seq length // HF model @@ -75,6 +80,11 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); + if (params.prompt.empty()) { + LOG_ERR("prompt is empty\n"); + return 1; + } + std::vector embd; params.cb_eval = ggml_callback; params.cb_eval_user_data = &embd; @@ -167,7 +177,7 @@ int main(int argc, char ** argv) { // printf("\n"); llama_token semantic_tok = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); - printf("%d,", semantic_tok); + printf("Sem token %5d : %d,", 1+(int)generated_codes.size()/32, semantic_tok); generated_codes.push_back(semantic_tok); // for (size_t i = 0; i < 10; ++i) { @@ -200,7 +210,7 @@ int main(int argc, char ** argv) { // then, decode the semantic_tok to generate acoustic tokens llama_token tok = semantic_tok; int n_codes = 32; - int sum_codes = 0; // to check if all codes are 0 + int sum_codes = semantic_tok; // to check if all codes are 0 for (int i = 0; i < n_codes; ++i) { common_batch_clear(batch_token); // encoder vocab is further divided into 32 codebooks, each with 2051 entries @@ -228,9 +238,12 @@ int main(int argc, char ** argv) { } // do progressive hsum of embeddings - GGML_ASSERT(inp_past_embd.size() == embd.size()); - for (size_t i = 0; i < inp_past_embd.size(); ++i) { - inp_past_embd[i] += embd[i]; + // skip first semantic code + if (i > 0) { + GGML_ASSERT(inp_past_embd.size() == embd.size()); + for (size_t i = 0; i < inp_past_embd.size(); ++i) { + inp_past_embd[i] += embd[i]; + } } } printf("\n"); @@ -253,6 +266,8 @@ int main(int argc, char ** argv) { // printf("\n"); if (is_stop) { + // remove last 32 codes since they will be all zeros + generated_codes.resize(generated_codes.size() - 32); break; } } From c796ee0f6620ea049e9e8fe63310e1078b3d8cae Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 31 Mar 2025 00:02:25 +0200 Subject: [PATCH 21/31] mimi: fix frame splitting --- examples/tts/mimi-model.cpp | 3 ++- examples/tts/mimi-model.h | 4 ++-- examples/tts/mimi.cpp | 2 ++ examples/tts/tts-csm.cpp | 1 - 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/tts/mimi-model.cpp b/examples/tts/mimi-model.cpp index 3663201dc5971..fee88c679e1f3 100644 --- a/examples/tts/mimi-model.cpp +++ b/examples/tts/mimi-model.cpp @@ -665,7 +665,8 @@ std::vector mimi_model::decode_frame(const std::vector & codes, int ctx->set_tensor_data("pos_dec", pos_data.data()); // code data - ctx->set_tensor_data("inp_dec", codes.data()); + auto codes_T = mimi_model::transpose_input(codes); + ctx->set_tensor_data("inp_dec", codes_T.data()); ctx->compute(); diff --git a/examples/tts/mimi-model.h b/examples/tts/mimi-model.h index 96945981513c0..1ded07e875e7d 100644 --- a/examples/tts/mimi-model.h +++ b/examples/tts/mimi-model.h @@ -25,9 +25,9 @@ struct mimi_model { // transpose layout: // - from: (1 semantic code followed by 31 acoustic codes) repeast N times // - to: N semantic codes followed by (N*31) acoustic codes - std::vector transpose_input(const std::vector & codes); + static std::vector transpose_input(const std::vector & codes); - // layout of codes: N semantic codes followed by (N*31) acoustic codes + // layout of codes: (1 semantic code followed by 31 acoustic codes) repeast N times std::vector decode(const std::vector & codes); // TODO: implement encoding pass diff --git a/examples/tts/mimi.cpp b/examples/tts/mimi.cpp index 17047c9b4f1ce..293f2fb775c3b 100644 --- a/examples/tts/mimi.cpp +++ b/examples/tts/mimi.cpp @@ -69,6 +69,8 @@ int main(int argc, const char ** argv) { 1740 ,1154 ,1839 ,912 ,731 ,602 ,1064 ,1508 ,834 ,1387 ,252 ,745 ,1034 ,1102 ,965 ,696 , 1971 ,1729 ,666 ,282 ,1993 ,1551 ,1703 ,1124 ,1628 ,1725 ,107 ,808 ,1096 ,1753 ,500 ,677 , }; + // this particular example is pre-transposed, we need to undo that + codes = mimi_model::transpose_input(codes); } else { std::ifstream fin(codes_path); if (!fin) { diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index fb5146a3bbcc9..2dda42198637d 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -282,7 +282,6 @@ int main(int argc, char ** argv) { llama_batch_free(batch_past_embd); printf("decode %zu RVQ tokens into wav...\n", generated_codes.size()); - generated_codes = mimi.transpose_input(generated_codes); std::vector wav_data = mimi.decode(generated_codes); printf("output wav file: %s\n", params.out_file.c_str()); From e31a75c209af16376fb53ea829def34ca4fbad9c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 31 Mar 2025 11:53:24 +0200 Subject: [PATCH 22/31] fix mimi example dummy1 --- examples/tts/mimi-model.h | 11 ++++--- examples/tts/mimi.cpp | 66 +++++++++++++++++++-------------------- 2 files changed, 38 insertions(+), 39 deletions(-) diff --git a/examples/tts/mimi-model.h b/examples/tts/mimi-model.h index 1ded07e875e7d..eb5eb46c22807 100644 --- a/examples/tts/mimi-model.h +++ b/examples/tts/mimi-model.h @@ -22,11 +22,6 @@ struct mimi_model { int get_sample_rate() const; - // transpose layout: - // - from: (1 semantic code followed by 31 acoustic codes) repeast N times - // - to: N semantic codes followed by (N*31) acoustic codes - static std::vector transpose_input(const std::vector & codes); - // layout of codes: (1 semantic code followed by 31 acoustic codes) repeast N times std::vector decode(const std::vector & codes); @@ -35,4 +30,10 @@ struct mimi_model { private: std::vector decode_frame(const std::vector & codes, int & n_past); + + // transpose layout (from streaming layout to non-streaming): + // - from: (1 semantic code followed by 31 acoustic codes) repeast N times + // - to: N semantic codes followed by (N*31) acoustic codes + // streaming layout is 1-31, 1-31, 1-31, ..., used for real-time processing + static std::vector transpose_input(const std::vector & codes); }; diff --git a/examples/tts/mimi.cpp b/examples/tts/mimi.cpp index 293f2fb775c3b..a50bd44f599a9 100644 --- a/examples/tts/mimi.cpp +++ b/examples/tts/mimi.cpp @@ -36,41 +36,39 @@ int main(int argc, const char ** argv) { } else if (strcmp(codes_path, "dummy1") == 0) { printf("Using dummy1 codes\n"); codes = { - 1049 ,1415 ,1962 ,914 ,1372 ,704 ,1922 ,2036 ,288 ,968 ,193 ,1139 ,897 ,897 ,1243 ,1511 , - 1597 ,175 ,1280 ,1202 ,1911 ,85 ,47 ,692 ,632 ,251 ,1553 ,1735 ,1577 ,132 ,471 ,433 , - 1325 ,1539 ,1943 ,1601 ,141 ,257 ,564 ,1435 ,876 ,1096 ,636 ,61 ,1497 ,1010 ,485 ,284 , - 839 ,776 ,878 ,1719 ,1069 ,1302 ,893 ,2005 ,875 ,908 ,586 ,2001 ,186 ,1932 ,1765 ,721 , - 592 ,1046 ,1588 ,1670 ,1485 ,1141 ,34 ,1465 ,1156 ,1938 ,435 ,753 ,1418 ,277 ,391 ,1741 , - 1440 ,117 ,723 ,412 ,642 ,1717 ,131 ,37 ,345 ,112 ,1979 ,2034 ,1822 ,1536 ,1281 ,56 , - 1341 ,803 ,568 ,568 ,1370 ,1995 ,1063 ,892 ,273 ,895 ,1226 ,354 ,1726 ,1541 ,1607 ,615 , - 985 ,1499 ,1736 ,1838 ,702 ,1345 ,1657 ,511 ,1774 ,1787 ,945 ,1927 ,947 ,952 ,1418 ,916 , - 1239 ,1457 ,1021 ,341 ,284 ,882 ,474 ,1559 ,1923 ,273 ,1330 ,1406 ,1782 ,19 ,116 ,887 , - 1146 ,1307 ,983 ,1237 ,1407 ,1350 ,1960 ,1255 ,878 ,1979 ,1500 ,1939 ,1415 ,88 ,1702 ,1253 , - 1778 ,2 ,10 ,1279 ,999 ,1549 ,1049 ,373 ,1355 ,1200 ,1466 ,1009 ,75 ,2042 ,1725 ,916 , - 1636 ,1135 ,833 ,830 ,1758 ,2015 ,1275 ,1675 ,287 ,744 ,89 ,430 ,1724 ,1232 ,1692 ,535 , - 1485 ,1287 ,973 ,1815 ,314 ,2020 ,424 ,1085 ,982 ,1994 ,1563 ,1269 ,1769 ,1681 ,1082 ,1666 , - 1622 ,1039 ,1209 ,32 ,679 ,732 ,976 ,1462 ,805 ,402 ,1150 ,170 ,1529 ,2013 ,350 ,1175 , - 757 ,1124 ,1091 ,1369 ,1061 ,415 ,1217 ,1135 ,1360 ,1578 ,1205 ,1785 ,1835 ,1241 ,14 ,716 , - 480 ,716 ,681 ,1686 ,1624 ,335 ,865 ,1356 ,1688 ,307 ,366 ,541 ,1262 ,1167 ,59 ,269 , - 1899 ,1798 ,1606 ,1307 ,1549 ,1814 ,114 ,483 ,958 ,1919 ,1179 ,898 ,834 ,1526 ,386 ,447 , - 1481 ,201 ,779 ,419 ,430 ,1451 ,1000 ,156 ,1062 ,615 ,1353 ,414 ,1214 ,1487 ,882 ,32 , - 840 ,1517 ,334 ,1143 ,823 ,454 ,725 ,1298 ,1325 ,649 ,1737 ,913 ,685 ,761 ,2010 ,63 , - 1397 ,1299 ,765 ,1158 ,1809 ,1299 ,1585 ,1776 ,625 ,1539 ,830 ,1563 ,461 ,308 ,1438 ,321 , - 82 ,886 ,1836 ,325 ,1976 ,761 ,359 ,1136 ,1720 ,2036 ,904 ,719 ,526 ,1567 ,145 ,1860 , - 1565 ,1786 ,1400 ,1696 ,232 ,1736 ,512 ,518 ,1895 ,1854 ,1584 ,1393 ,1869 ,1702 ,789 ,1986 , - 116 ,521 ,150 ,1597 ,727 ,1916 ,815 ,1826 ,1382 ,653 ,1596 ,286 ,1373 ,177 ,1397 ,1009 , - 1449 ,353 ,877 ,93 ,266 ,1853 ,1255 ,872 ,1974 ,556 ,1885 ,857 ,992 ,5 ,1921 ,1849 , - 1038 ,1912 ,464 ,795 ,747 ,56 ,124 ,431 ,1868 ,609 ,855 ,1522 ,912 ,1709 ,1507 ,1062 , - 1015 ,1357 ,1487 ,4 ,253 ,1871 ,933 ,215 ,1228 ,633 ,1306 ,2024 ,1453 ,900 ,457 ,471 , - 436 ,1311 ,870 ,1032 ,134 ,984 ,1983 ,1103 ,1627 ,1627 ,414 ,1845 ,583 ,1699 ,1458 ,2018 , - 150 ,450 ,1114 ,369 ,267 ,1273 ,1136 ,1578 ,1063 ,1820 ,120 ,779 ,652 ,1266 ,1929 ,1213 , - 159 ,297 ,1703 ,819 ,93 ,247 ,1366 ,144 ,1617 ,1428 ,812 ,121 ,1637 ,1620 ,289 ,1557 , - 1414 ,971 ,476 ,1685 ,428 ,1802 ,653 ,1290 ,614 ,1663 ,1528 ,1344 ,798 ,1027 ,1305 ,990 , - 1740 ,1154 ,1839 ,912 ,731 ,602 ,1064 ,1508 ,834 ,1387 ,252 ,745 ,1034 ,1102 ,965 ,696 , - 1971 ,1729 ,666 ,282 ,1993 ,1551 ,1703 ,1124 ,1628 ,1725 ,107 ,808 ,1096 ,1753 ,500 ,677 , + 1049 ,1597 ,1325 ,839 ,592 ,1440 ,1341 ,985 ,1239 ,1146 ,1778 ,1636 ,1485 ,1622 ,757 ,480 , + 1899 ,1481 ,840 ,1397 ,82 ,1565 ,116 ,1449 ,1038 ,1015 ,436 ,150 ,159 ,1414 ,1740 ,1971 , + 1415 ,175 ,1539 ,776 ,1046 ,117 ,803 ,1499 ,1457 ,1307 ,2 ,1135 ,1287 ,1039 ,1124 ,716 , + 1798 ,201 ,1517 ,1299 ,886 ,1786 ,521 ,353 ,1912 ,1357 ,1311 ,450 ,297 ,971 ,1154 ,1729 , + 1962 ,1280 ,1943 ,878 ,1588 ,723 ,568 ,1736 ,1021 ,983 ,10 ,833 ,973 ,1209 ,1091 ,681 , + 1606 ,779 ,334 ,765 ,1836 ,1400 ,150 ,877 ,464 ,1487 ,870 ,1114 ,1703 ,476 ,1839 ,666 , + 914 ,1202 ,1601 ,1719 ,1670 ,412 ,568 ,1838 ,341 ,1237 ,1279 ,830 ,1815 ,32 ,1369 ,1686 , + 1307 ,419 ,1143 ,1158 ,325 ,1696 ,1597 ,93 ,795 ,4 ,1032 ,369 ,819 ,1685 ,912 ,282 , + 1372 ,1911 ,141 ,1069 ,1485 ,642 ,1370 ,702 ,284 ,1407 ,999 ,1758 ,314 ,679 ,1061 ,1624 , + 1549 ,430 ,823 ,1809 ,1976 ,232 ,727 ,266 ,747 ,253 ,134 ,267 ,93 ,428 ,731 ,1993 , + 704 ,85 ,257 ,1302 ,1141 ,1717 ,1995 ,1345 ,882 ,1350 ,1549 ,2015 ,2020 ,732 ,415 ,335 , + 1814 ,1451 ,454 ,1299 ,761 ,1736 ,1916 ,1853 ,56 ,1871 ,984 ,1273 ,247 ,1802 ,602 ,1551 , + 1922 ,47 ,564 ,893 ,34 ,131 ,1063 ,1657 ,474 ,1960 ,1049 ,1275 ,424 ,976 ,1217 ,865 , + 114 ,1000 ,725 ,1585 ,359 ,512 ,815 ,1255 ,124 ,933 ,1983 ,1136 ,1366 ,653 ,1064 ,1703 , + 2036 ,692 ,1435 ,2005 ,1465 ,37 ,892 ,511 ,1559 ,1255 ,373 ,1675 ,1085 ,1462 ,1135 ,1356 , + 483 ,156 ,1298 ,1776 ,1136 ,518 ,1826 ,872 ,431 ,215 ,1103 ,1578 ,144 ,1290 ,1508 ,1124 , + 288 ,632 ,876 ,875 ,1156 ,345 ,273 ,1774 ,1923 ,878 ,1355 ,287 ,982 ,805 ,1360 ,1688 , + 958 ,1062 ,1325 ,625 ,1720 ,1895 ,1382 ,1974 ,1868 ,1228 ,1627 ,1063 ,1617 ,614 ,834 ,1628 , + 968 ,251 ,1096 ,908 ,1938 ,112 ,895 ,1787 ,273 ,1979 ,1200 ,744 ,1994 ,402 ,1578 ,307 , + 1919 ,615 ,649 ,1539 ,2036 ,1854 ,653 ,556 ,609 ,633 ,1627 ,1820 ,1428 ,1663 ,1387 ,1725 , + 193 ,1553 ,636 ,586 ,435 ,1979 ,1226 ,945 ,1330 ,1500 ,1466 ,89 ,1563 ,1150 ,1205 ,366 , + 1179 ,1353 ,1737 ,830 ,904 ,1584 ,1596 ,1885 ,855 ,1306 ,414 ,120 ,812 ,1528 ,252 ,107 , + 1139 ,1735 ,61 ,2001 ,753 ,2034 ,354 ,1927 ,1406 ,1939 ,1009 ,430 ,1269 ,170 ,1785 ,541 , + 898 ,414 ,913 ,1563 ,719 ,1393 ,286 ,857 ,1522 ,2024 ,1845 ,779 ,121 ,1344 ,745 ,808 , + 897 ,1577 ,1497 ,186 ,1418 ,1822 ,1726 ,947 ,1782 ,1415 ,75 ,1724 ,1769 ,1529 ,1835 ,1262 , + 834 ,1214 ,685 ,461 ,526 ,1869 ,1373 ,992 ,912 ,1453 ,583 ,652 ,1637 ,798 ,1034 ,1096 , + 897 ,132 ,1010 ,1932 ,277 ,1536 ,1541 ,952 ,19 ,88 ,2042 ,1232 ,1681 ,2013 ,1241 ,1167 , + 1526 ,1487 ,761 ,308 ,1567 ,1702 ,177 ,5 ,1709 ,900 ,1699 ,1266 ,1620 ,1027 ,1102 ,1753 , + 1243 ,471 ,485 ,1765 ,391 ,1281 ,1607 ,1418 ,116 ,1702 ,1725 ,1692 ,1082 ,350 ,14 ,59 , + 386 ,882 ,2010 ,1438 ,145 ,789 ,1397 ,1921 ,1507 ,457 ,1458 ,1929 ,289 ,1305 ,965 ,500 , + 1511 ,433 ,284 ,721 ,1741 ,56 ,615 ,916 ,887 ,1253 ,916 ,535 ,1666 ,1175 ,716 ,269 , + 447 ,32 ,63 ,321 ,1860 ,1986 ,1009 ,1849 ,1062 ,471 ,2018 ,1213 ,1557 ,990 ,696 ,677 , }; - // this particular example is pre-transposed, we need to undo that - codes = mimi_model::transpose_input(codes); } else { std::ifstream fin(codes_path); if (!fin) { From 5be8e7d64a5ff2c47b503a96d9e4730718b84b5b Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 31 Mar 2025 12:41:12 +0200 Subject: [PATCH 23/31] add top-k and temp sampling --- examples/tts/tts-csm.cpp | 54 +++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index 2dda42198637d..19c6d46a834a2 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -29,17 +29,27 @@ static void print_usage(int, char ** argv) { LOG("\n"); } -// greedy sampling with custom n_vocab -static llama_token sample_greedy(const float * logits, int n_vocab) { - llama_token max_idx = -1; - float max_val = -FLT_MAX; - for (int i = 0; i < n_vocab; ++i) { - if (logits[i] > max_val) { - max_val = logits[i]; - max_idx = i; - } +// sampling with custom n_vocab +// modified version of llama_sampler_sample() +static llama_token sample_token(struct llama_sampler * smpl, const float * logits, int n_vocab) { + std::vector cur; + cur.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } - return max_idx; + + llama_token_data_array cur_p = { + /* .data = */ cur.data(), + /* .size = */ cur.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; + + llama_sampler_apply(smpl, &cur_p); + GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + auto token = cur_p.data[cur_p.selected].id; + llama_sampler_accept(smpl, token); + return token; } // hook to retrieve the embeddings @@ -63,11 +73,13 @@ static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) { int main(int argc, char ** argv) { common_params params; - params.model = "sesame-csm-backbone.gguf"; - params.vocoder.model = "kyutai-mimi.gguf"; - params.out_file = "output.wav"; - params.prompt = ""; - params.n_predict = 2048; // CSM's max trained seq length + params.model = "sesame-csm-backbone.gguf"; + params.vocoder.model = "kyutai-mimi.gguf"; + params.out_file = "output.wav"; + params.prompt = ""; + params.n_predict = 2048; // CSM's max trained seq length + params.sampling.top_k = 50; // default param from CSM python code + params.sampling.temp = 0.9; // default param from CSM python code // HF model params.model_url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf"; @@ -115,11 +127,19 @@ int main(int argc, char ** argv) { mimi_model mimi(params.vocoder.model.c_str(), true); + // tokenize the prompt const llama_vocab * vocab = llama_model_get_vocab(model_bb); llama_tokens prompt_tokens = common_tokenize(vocab, params.prompt, false, true); prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab)); prompt_tokens.insert(prompt_tokens.end(), llama_vocab_eos(vocab)); + // init sampler + // the python implementation only has top-k and temperature sampling, so we'll use just that + llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params())); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(params.sampling.temp)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(params.sampling.seed)); + printf("prompt tokens: \n"); for (size_t i = 0; i < prompt_tokens.size(); ++i) { printf("%d, ", prompt_tokens[i]); @@ -176,7 +196,7 @@ int main(int argc, char ** argv) { // } // printf("\n"); - llama_token semantic_tok = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); + llama_token semantic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc)); printf("Sem token %5d : %d,", 1+(int)generated_codes.size()/32, semantic_tok); generated_codes.push_back(semantic_tok); @@ -227,7 +247,7 @@ int main(int argc, char ** argv) { // sample the acoustic token auto logits = llama_get_logits_ith(ctx_dc, 0); - llama_token acoustic_tok = sample_greedy(logits, llama_vocab_n_tokens(vocab_dc)); + llama_token acoustic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc)); // discard last code (only for embeddings) if (i < n_codes - 1) { From 90231cc2514907ae2779ad85cdb88af3fe21b49a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 1 Apr 2025 09:51:27 +0200 Subject: [PATCH 24/31] much better on long generation --- examples/tts/tts-csm.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index 19c6d46a834a2..915653c518a73 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -258,12 +258,9 @@ int main(int argc, char ** argv) { } // do progressive hsum of embeddings - // skip first semantic code - if (i > 0) { - GGML_ASSERT(inp_past_embd.size() == embd.size()); - for (size_t i = 0; i < inp_past_embd.size(); ++i) { - inp_past_embd[i] += embd[i]; - } + GGML_ASSERT(inp_past_embd.size() == embd.size()); + for (size_t i = 0; i < inp_past_embd.size(); ++i) { + inp_past_embd[i] += embd[i]; } } printf("\n"); From e9dc47687c8a92135f8d15db8f7e2057598995bb Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 2 Apr 2025 16:41:28 +0200 Subject: [PATCH 25/31] fix tts-csm --- examples/tts/tts-csm.cpp | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index 915653c518a73..3cb844615681e 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -73,17 +73,17 @@ static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) { int main(int argc, char ** argv) { common_params params; - params.model = "sesame-csm-backbone.gguf"; - params.vocoder.model = "kyutai-mimi.gguf"; - params.out_file = "output.wav"; - params.prompt = ""; - params.n_predict = 2048; // CSM's max trained seq length - params.sampling.top_k = 50; // default param from CSM python code - params.sampling.temp = 0.9; // default param from CSM python code + params.model.path = "sesame-csm-backbone.gguf"; + params.vocoder.model.path = "kyutai-mimi.gguf"; + params.out_file = "output.wav"; + params.prompt = ""; + params.n_predict = 2048; // CSM's max trained seq length + params.sampling.top_k = 50; // default param from CSM python code + params.sampling.temp = 0.9; // default param from CSM python code // HF model - params.model_url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf"; - params.vocoder.model_url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/kyutai-mimi.gguf"; + params.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf"; + params.vocoder.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/kyutai-mimi.gguf"; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) { return 1; @@ -104,10 +104,8 @@ int main(int argc, char ** argv) { common_params params_decoder(params); // duplicate the params params_decoder.n_ctx = 64; // we never use more than this - string_replace_all(params_decoder.model, "-backbone", "-decoder"); - if (!params_decoder.model_url.empty()) { - string_replace_all(params_decoder.model_url, "-backbone", "-decoder"); - } + string_replace_all(params_decoder.model.path, "-backbone", "-decoder"); + string_replace_all(params_decoder.model.url, "-backbone", "-decoder"); common_init_result llama_backbone = common_init_from_params(params); llama_model * model_bb = llama_backbone.model.get(); @@ -125,7 +123,7 @@ int main(int argc, char ** argv) { return ENOENT; } - mimi_model mimi(params.vocoder.model.c_str(), true); + mimi_model mimi(params.vocoder.model.path.c_str(), true); // tokenize the prompt const llama_vocab * vocab = llama_model_get_vocab(model_bb); From c681257e58d171c8a33fa867f555539a1db8ff0a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 2 Apr 2025 17:31:29 +0200 Subject: [PATCH 26/31] ability to do multi-turns --- examples/tts/csm-demo.txt | 5 + examples/tts/tts-csm.cpp | 293 +++++++++++++++++++++----------------- 2 files changed, 171 insertions(+), 127 deletions(-) create mode 100644 examples/tts/csm-demo.txt diff --git a/examples/tts/csm-demo.txt b/examples/tts/csm-demo.txt new file mode 100644 index 0000000000000..1c913388bfb3d --- /dev/null +++ b/examples/tts/csm-demo.txt @@ -0,0 +1,5 @@ +[0]Hey how are you doing. +[1]Pretty good, pretty good. +[0]I'm great, so happy to be speaking to you. +What about you? +[1]Me too, this is some cool stuff huh? diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index 3cb844615681e..a8a9bd22d955b 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -5,6 +5,7 @@ #include "mimi-model.h" #include +#include #include #include #include // memcpy and strcmp @@ -23,12 +24,39 @@ static void print_usage(int, char ** argv) { LOG("\n Note: the model need 2 files to run, one ends with '-backbone-.gguf' and the other ends with '-decoder.gguf'"); LOG("\n"); LOG("\nPrompt format:"); - LOG("\n Each line must start with speaker ID in square brackets, followed by the text. A full stop is recommended at the end of each turn"); - LOG("\n Example: [0]Hello world."); + LOG("\n Each line must start with speaker ID in square brackets, followed by the text. One turn per line. A full stop is recommended at the end of each turn"); + LOG("\n Example:"); + LOG("\n [0]Hey how are you doing."); + LOG("\n [1]Pretty good, pretty good."); LOG("\n If you want to enter long text, use -f file.txt to read from file"); LOG("\n"); } +// split text containing "[N]..." into speaker turns +static std::vector get_speaker_turns(const std::string & input) { + if (input.empty()) { + LOG_ERR("Empty input\n"); + return {}; + } + if (input[0] != '[') { + LOG_ERR("Invalid input format: missing speaker ID\n"); + return {}; + } + std::regex re(R"((\[\d+\][\s\S]*?)(?=\[\d+\]|$))"); + std::smatch match; + std::vector turns; + std::string::const_iterator searchStart(input.cbegin()); + while (std::regex_search(searchStart, input.cend(), match, re)) { + std::string turn = match[1].str(); + if (turn.empty()) { + continue; + } + turns.push_back(turn); + searchStart = match.suffix().first; + } + return turns; +} + // sampling with custom n_vocab // modified version of llama_sampler_sample() static llama_token sample_token(struct llama_sampler * smpl, const float * logits, int n_vocab) { @@ -81,9 +109,11 @@ int main(int argc, char ** argv) { params.sampling.top_k = 50; // default param from CSM python code params.sampling.temp = 0.9; // default param from CSM python code - // HF model - params.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf"; - params.vocoder.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/kyutai-mimi.gguf"; + // HF model (hack: we temporary reuse speculative.model as the decoder model, only to get it downloaded) + params.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf"; + params.speculative.model.path = "sesame-csm-decoder.gguf"; + params.speculative.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-decoder.gguf"; + params.vocoder.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/kyutai-mimi.gguf"; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) { return 1; @@ -125,12 +155,6 @@ int main(int argc, char ** argv) { mimi_model mimi(params.vocoder.model.path.c_str(), true); - // tokenize the prompt - const llama_vocab * vocab = llama_model_get_vocab(model_bb); - llama_tokens prompt_tokens = common_tokenize(vocab, params.prompt, false, true); - prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab)); - prompt_tokens.insert(prompt_tokens.end(), llama_vocab_eos(vocab)); - // init sampler // the python implementation only has top-k and temperature sampling, so we'll use just that llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params())); @@ -138,19 +162,8 @@ int main(int argc, char ** argv) { llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(params.sampling.temp)); llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(params.sampling.seed)); - printf("prompt tokens: \n"); - for (size_t i = 0; i < prompt_tokens.size(); ++i) { - printf("%d, ", prompt_tokens[i]); - } - printf("\n"); - - llama_pos n_past_bb = 0; llama_batch batch_prompt = llama_batch_init(params.n_batch, 0, 1); - common_batch_clear(batch_prompt); - for (size_t i = 0; i < prompt_tokens.size(); ++i) { - common_batch_add(batch_prompt, prompt_tokens[i], n_past_bb++, { 0 }, false); - } - batch_prompt.logits[batch_prompt.n_tokens - 1] = true; + llama_pos n_past_bb = 0; // inp_past_embd is the "squashed" embeddings from the decoder std::vector inp_past_embd(2048, 0.0f); @@ -162,128 +175,154 @@ int main(int argc, char ** argv) { int64_t t_dc = 0; // decoder time int64_t n_dc_gen = 0; // decoder generation count - bool is_stop = false; std::vector generated_codes; - // backbone generation loop - for (int k = 0; k < params.n_predict; ++k) { - bool is_prompt_processing = k == 0; - - if (!is_prompt_processing) { - // generate the next RVQ semantic token - batch_past_embd.n_tokens = 1; - batch_past_embd.pos[0] = n_past_bb++; - batch_past_embd.seq_id[0][0] = 0; - batch_past_embd.n_seq_id[0] = 1; - batch_past_embd.logits[0] = true; - std::memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float)); - } + auto turns = get_speaker_turns(params.prompt); + + for (const std::string & turn : turns) { + // tokenize the turn + llama_tokens prompt_tokens; + { + printf("\n---\nturn: %s\n\n", turn.c_str()); + const llama_vocab * vocab = llama_model_get_vocab(model_bb); + prompt_tokens = common_tokenize(vocab, turn, false, true); + prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab)); + prompt_tokens.insert(prompt_tokens.end(), llama_vocab_eos(vocab)); + + printf("prompt (%zu tokens): \n", prompt_tokens.size()); + for (size_t i = 0; i < prompt_tokens.size(); ++i) { + printf("%d, ", prompt_tokens[i]); + } + printf("\n"); - int64_t t_bb_start = ggml_time_ms(); - if (llama_decode(ctx_bb, is_prompt_processing ? batch_prompt : batch_past_embd) != 0) { - LOG_ERR("%s: backbone llama_decode() failed\n", __func__); - return 1; + common_batch_clear(batch_prompt); + for (size_t i = 0; i < prompt_tokens.size(); ++i) { + common_batch_add(batch_prompt, prompt_tokens[i], n_past_bb++, { 0 }, false); + } + batch_prompt.logits[batch_prompt.n_tokens - 1] = true; } - n_bb_gen++; - t_bb += ggml_time_ms() - t_bb_start; - auto vocab_dc = llama_model_get_vocab(model_dc); - auto logits = llama_get_logits_ith(ctx_bb, is_prompt_processing ? (batch_prompt.n_tokens - 1) : 0); - // for (size_t i = 0; i < 10; ++i) { - // printf("%4.2f, ", logits[i]); - // } - // printf("\n"); + // backbone generation loop + bool is_end_of_turn = false; + for (int k = 0; k < params.n_predict; ++k) { + bool is_prompt_processing = k == 0; + + if (!is_prompt_processing) { + // generate the next RVQ semantic token + batch_past_embd.n_tokens = 1; + batch_past_embd.pos[0] = n_past_bb++; + batch_past_embd.seq_id[0][0] = 0; + batch_past_embd.n_seq_id[0] = 1; + batch_past_embd.logits[0] = true; + std::memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float)); + } - llama_token semantic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc)); - printf("Sem token %5d : %d,", 1+(int)generated_codes.size()/32, semantic_tok); - generated_codes.push_back(semantic_tok); + int64_t t_bb_start = ggml_time_ms(); + if (llama_decode(ctx_bb, is_prompt_processing ? batch_prompt : batch_past_embd) != 0) { + LOG_ERR("%s: backbone llama_decode() failed\n", __func__); + return 1; + } + n_bb_gen++; + t_bb += ggml_time_ms() - t_bb_start; - // for (size_t i = 0; i < 10; ++i) { - // printf("%4.2f, ", embd[i]); - // } - // printf("\n"); + auto vocab_dc = llama_model_get_vocab(model_dc); + auto logits = llama_get_logits_ith(ctx_bb, is_prompt_processing ? (batch_prompt.n_tokens - 1) : 0); + // for (size_t i = 0; i < 10; ++i) { + // printf("%4.2f, ", logits[i]); + // } + // printf("\n"); + llama_token semantic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc)); + printf("Sem token %5d : %d,", 1+(int)generated_codes.size()/32, semantic_tok); + generated_codes.push_back(semantic_tok); - // decoder generation loop - inp_past_embd = std::vector(inp_past_embd.size(), 0.0f); - { - llama_kv_self_clear(ctx_dc); - llama_batch batch_embd = llama_batch_init(1, embd.size(), 1); - llama_batch batch_token = llama_batch_init(1, 0, 1); + // for (size_t i = 0; i < 10; ++i) { + // printf("%4.2f, ", embd[i]); + // } + // printf("\n"); - // first "token" is the latent embeddings from backbone - { - batch_embd.n_tokens = 1; - batch_embd.pos[0] = 0; - batch_embd.seq_id[0][0] = 0; - batch_embd.n_seq_id[0] = 1; - batch_embd.logits[0] = false; - std::memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float)); - } - if (llama_decode(ctx_dc, batch_embd) != 0) { - LOG_ERR("%s: decoder llama_decode(embd) failed\n", __func__); - return 1; - } - // then, decode the semantic_tok to generate acoustic tokens - llama_token tok = semantic_tok; - int n_codes = 32; - int sum_codes = semantic_tok; // to check if all codes are 0 - for (int i = 0; i < n_codes; ++i) { - common_batch_clear(batch_token); - // encoder vocab is further divided into 32 codebooks, each with 2051 entries - llama_token inp_tok = tok + 2051*i; - common_batch_add(batch_token, inp_tok, i+1, { 0 }, true); - - int64_t t_bb_start = ggml_time_ms(); - if (llama_decode(ctx_dc, batch_token) != 0) { - LOG_ERR("%s: decoder llama_decode(token) failed\n", __func__); - return 1; + // decoder generation loop + inp_past_embd = std::vector(inp_past_embd.size(), 0.0f); + { + llama_kv_self_clear(ctx_dc); + llama_batch batch_embd = llama_batch_init(1, embd.size(), 1); + llama_batch batch_token = llama_batch_init(1, 0, 1); + + // first "token" is the latent embeddings from backbone + { + batch_embd.n_tokens = 1; + batch_embd.pos[0] = 0; + batch_embd.seq_id[0][0] = 0; + batch_embd.n_seq_id[0] = 1; + batch_embd.logits[0] = false; + std::memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float)); } - n_dc_gen++; - t_dc += ggml_time_ms() - t_bb_start; - - // sample the acoustic token - auto logits = llama_get_logits_ith(ctx_dc, 0); - llama_token acoustic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc)); - - // discard last code (only for embeddings) - if (i < n_codes - 1) { - printf("%d,", acoustic_tok); - tok = acoustic_tok; // next input token - sum_codes += acoustic_tok; - generated_codes.push_back(acoustic_tok); + if (llama_decode(ctx_dc, batch_embd) != 0) { + LOG_ERR("%s: decoder llama_decode(embd) failed\n", __func__); + return 1; } - // do progressive hsum of embeddings - GGML_ASSERT(inp_past_embd.size() == embd.size()); - for (size_t i = 0; i < inp_past_embd.size(); ++i) { - inp_past_embd[i] += embd[i]; + // then, decode the semantic_tok to generate acoustic tokens + llama_token tok = semantic_tok; + int n_codes = 32; + int sum_codes = semantic_tok; // to check if all codes are 0 + for (int i = 0; i < n_codes; ++i) { + common_batch_clear(batch_token); + // encoder vocab is further divided into 32 codebooks, each with 2051 entries + llama_token inp_tok = tok + 2051*i; + common_batch_add(batch_token, inp_tok, i+1, { 0 }, true); + + int64_t t_bb_start = ggml_time_ms(); + if (llama_decode(ctx_dc, batch_token) != 0) { + LOG_ERR("%s: decoder llama_decode(token) failed\n", __func__); + return 1; + } + n_dc_gen++; + t_dc += ggml_time_ms() - t_bb_start; + + // sample the acoustic token + auto logits = llama_get_logits_ith(ctx_dc, 0); + llama_token acoustic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc)); + + // discard last code (only for embeddings) + if (i < n_codes - 1) { + printf("%d,", acoustic_tok); + tok = acoustic_tok; // next input token + sum_codes += acoustic_tok; + generated_codes.push_back(acoustic_tok); + } + + // do progressive hsum of embeddings + GGML_ASSERT(inp_past_embd.size() == embd.size()); + for (size_t i = 0; i < inp_past_embd.size(); ++i) { + inp_past_embd[i] += embd[i]; + } } - } - printf("\n"); + printf("\n"); - llama_batch_free(batch_embd); - llama_batch_free(batch_token); + llama_batch_free(batch_embd); + llama_batch_free(batch_token); - // if all codes are 0, then we are done - is_stop = sum_codes == 0; - } + // if all codes are 0, then we are done + is_end_of_turn = sum_codes == 0; + } - // printf("inp_past_embd, n_past_bb = %d\n", n_past_bb); - // for (size_t i = 0; i < inp_past_embd.size(); ++i) { - // printf("%4.4f, ", inp_past_embd[i]); - // if (i == 2) { - // printf("... "); - // i = inp_past_embd.size() - 4; - // } - // } - // printf("\n"); - - if (is_stop) { - // remove last 32 codes since they will be all zeros - generated_codes.resize(generated_codes.size() - 32); - break; + // printf("inp_past_embd, n_past_bb = %d\n", n_past_bb); + // for (size_t i = 0; i < inp_past_embd.size(); ++i) { + // printf("%4.4f, ", inp_past_embd[i]); + // if (i == 2) { + // printf("... "); + // i = inp_past_embd.size() - 4; + // } + // } + // printf("\n"); + + if (is_end_of_turn) { + // remove last 32 codes since they will be all zeros + generated_codes.resize(generated_codes.size() - 32); + break; + } } } From d17809999de22b03abb451bd4530a7ccb4ae98ba Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 3 Apr 2025 14:34:39 +0200 Subject: [PATCH 27/31] add audio EOS token --- examples/tts/tts-csm.cpp | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index a8a9bd22d955b..937948425cdc4 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -225,6 +225,11 @@ int main(int argc, char ** argv) { n_bb_gen++; t_bb += ggml_time_ms() - t_bb_start; + if (is_end_of_turn) { + // done decoding audio's EOS token + break; + } + auto vocab_dc = llama_model_get_vocab(model_dc); auto logits = llama_get_logits_ith(ctx_bb, is_prompt_processing ? (batch_prompt.n_tokens - 1) : 0); // for (size_t i = 0; i < 10; ++i) { @@ -304,8 +309,13 @@ int main(int argc, char ** argv) { llama_batch_free(batch_embd); llama_batch_free(batch_token); - // if all codes are 0, then we are done + // if all codes are 0, then we are done (got audio EOS token) + // note: we still need to run backbone decode one more time to decode the audio's EOS token is_end_of_turn = sum_codes == 0; + if (is_end_of_turn) { + // remove last 32 codes since they will be all zeros + generated_codes.resize(generated_codes.size() - 32); + } } // printf("inp_past_embd, n_past_bb = %d\n", n_past_bb); @@ -317,12 +327,6 @@ int main(int argc, char ** argv) { // } // } // printf("\n"); - - if (is_end_of_turn) { - // remove last 32 codes since they will be all zeros - generated_codes.resize(generated_codes.size() - 32); - break; - } } } From d1de6cc5ee7ea9c7db3856930c495e1e2ed4418d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 9 Apr 2025 17:32:05 +0200 Subject: [PATCH 28/31] add speaker reference --- examples/tts/csm_generate_speaker.py | 79 ++ examples/tts/tts-csm-data.h | 1513 ++++++++++++++++++++++++++ examples/tts/tts-csm.cpp | 182 +++- src/llama-model.cpp | 8 + 4 files changed, 1753 insertions(+), 29 deletions(-) create mode 100644 examples/tts/csm_generate_speaker.py create mode 100644 examples/tts/tts-csm-data.h diff --git a/examples/tts/csm_generate_speaker.py b/examples/tts/csm_generate_speaker.py new file mode 100644 index 0000000000000..0dc6929a23d4c --- /dev/null +++ b/examples/tts/csm_generate_speaker.py @@ -0,0 +1,79 @@ +import argparse +from pathlib import Path +from transformers import MimiModel, AutoFeatureExtractor +from transformers.models.mimi.modeling_mimi import MimiEncoderOutput + +from scipy.io.wavfile import read +from scipy.signal import resample +import numpy as np + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Generate speaker reference file, used by llama-tts-csm example",) + parser.add_argument( + "--model-path", type=Path, + help="custom Mimi model path (safetensors model). If not specified, will use the default model from Hugging Face hub", + ) + parser.add_argument( + "infile", type=Path, + help="the wav input file to use for generating the speaker reference file", + nargs="?", + ) + # parser.add_argument( + # "outfile", type=Path, + # help="the output file, defaults to the input file with .codes suffix", + # nargs="?", + # ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + if args.infile is None: + raise ValueError("Input file is required") + + if not args.infile.exists(): + raise FileNotFoundError(f"Input file {args.infile} not found") + + # if args.outfile is None: + # args.outfile = args.infile.with_suffix(".codes") + + model = MimiModel.from_pretrained(args.model_path or "kyutai/mimi") + feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_path or "kyutai/mimi") + + inp_audio = read(args.infile) + original_sample_rate = inp_audio[0] + audio_data = inp_audio[1] + + # If stereo, get only the first channel + if len(audio_data.shape) > 1 and audio_data.shape[1] >= 2: + audio_data = audio_data[:, 0] + + # resample + target_sample_rate = 24000 + number_of_samples = round(len(audio_data) * float(target_sample_rate) / original_sample_rate) + resampled_audio = resample(audio_data, number_of_samples) + resampled_audio = resampled_audio / max(np.max(np.abs(resampled_audio)), 1e-10) + + # pre-process the inputs + audio_sample = np.array(resampled_audio, dtype=float) + inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") + print('inputs', inputs["input_values"], inputs["input_values"].shape) + + # encode + encoder_outputs = model.encode(inputs["input_values"]) + assert isinstance(encoder_outputs, MimiEncoderOutput), "encoder_outputs should be of type MimiEncoderOutput" + + # output + flattened_audio_codes = encoder_outputs.audio_codes.transpose(-1, -2).flatten() + for i in range(0, len(flattened_audio_codes), 16): + for code in flattened_audio_codes[i:i+16].tolist(): + print(f"{code:<5}", end=",") + print() + + +if __name__ == '__main__': + main() diff --git a/examples/tts/tts-csm-data.h b/examples/tts/tts-csm-data.h new file mode 100644 index 0000000000000..c3c47ca7ac3a2 --- /dev/null +++ b/examples/tts/tts-csm-data.h @@ -0,0 +1,1513 @@ +#pragma once + +#include + +// https://huggingface.co/spaces/sesame/csm-1b/blob/main/prompts/conversational_a.wav +const char * default_speaker_a_text = "[0]like revising for an exam I'd have to try and like keep up the momentum because I'd start really early I'd be like okay I'm gonna start revising now and then like you're revising for ages and then I just like start losing steam I didn't do that for the exam we had recently to be fair that was a more of a last minute scenario but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I sort of start the day with this not like a panic but like a"; +std::initializer_list default_speaker_a_codes = { + 1952 ,425 ,59 ,331 ,2022 ,592 ,648 ,917 ,849 ,1427 ,133 ,1238 ,1045 ,897 ,303 ,413 , + 890 ,171 ,1726 ,1991 ,439 ,743 ,1129 ,1343 ,1406 ,493 ,2003 ,1541 ,401 ,662 ,325 ,879 , + 809 ,1597 ,420 ,193 ,780 ,618 ,1643 ,178 ,1151 ,927 ,1995 ,1857 ,1947 ,353 ,577 ,1493 , + 1598 ,407 ,814 ,5 ,1768 ,170 ,294 ,1386 ,972 ,1158 ,336 ,1300 ,1126 ,1687 ,1088 ,330 , + 526 ,1423 ,1324 ,1444 ,2032 ,1765 ,1127 ,736 ,536 ,307 ,572 ,1136 ,672 ,1324 ,750 ,17 , + 1334 ,129 ,1322 ,362 ,76 ,561 ,1773 ,827 ,1861 ,1876 ,1455 ,245 ,2045 ,1872 ,1660 ,1505 , + 1365 ,338 ,1205 ,1503 ,1177 ,1064 ,203 ,1684 ,805 ,1944 ,1661 ,1128 ,1135 ,504 ,133 ,652 , + 120 ,901 ,1821 ,1828 ,1248 ,1131 ,157 ,604 ,938 ,1520 ,884 ,963 ,1306 ,421 ,1214 ,912 , + 1417 ,10 ,1713 ,1128 ,1158 ,360 ,958 ,1912 ,68 ,1677 ,1496 ,1945 ,1596 ,1641 ,1385 ,1097 , + 1961 ,1096 ,421 ,894 ,883 ,1804 ,252 ,1662 ,1180 ,919 ,1706 ,777 ,1562 ,158 ,1638 ,483 , + 371 ,588 ,1890 ,683 ,1573 ,645 ,1331 ,213 ,1822 ,1458 ,27 ,85 ,174 ,250 ,1881 ,255 , + 186 ,1592 ,1951 ,777 ,1466 ,1542 ,183 ,431 ,1173 ,744 ,526 ,1814 ,98 ,997 ,1376 ,1009 , + 728 ,1206 ,762 ,776 ,791 ,487 ,45 ,993 ,2002 ,249 ,544 ,1845 ,662 ,357 ,1760 ,1896 , + 1582 ,1822 ,760 ,1586 ,173 ,163 ,1541 ,1443 ,697 ,975 ,1775 ,1759 ,768 ,61 ,251 ,1620 , + 819 ,852 ,1539 ,691 ,1655 ,420 ,1158 ,1890 ,728 ,569 ,925 ,1092 ,1550 ,1502 ,194 ,1728 , + 1180 ,1393 ,1021 ,1896 ,529 ,408 ,1816 ,1537 ,647 ,1701 ,766 ,1099 ,1442 ,1481 ,1026 ,1770 , + 994 ,520 ,852 ,464 ,44 ,1739 ,1285 ,1143 ,1466 ,1637 ,1980 ,553 ,2037 ,329 ,1464 ,1938 , + 519 ,590 ,1175 ,157 ,398 ,806 ,12 ,1488 ,1565 ,1534 ,1484 ,1712 ,170 ,431 ,1166 ,555 , + 313 ,1423 ,1867 ,76 ,239 ,469 ,159 ,2014 ,323 ,1254 ,601 ,451 ,1014 ,176 ,970 ,1048 , + 229 ,1322 ,536 ,1979 ,376 ,283 ,618 ,2019 ,1702 ,1272 ,1968 ,75 ,1943 ,462 ,251 ,686 , + 1791 ,1005 ,779 ,815 ,1075 ,932 ,1956 ,1206 ,1853 ,1639 ,1568 ,1794 ,274 ,622 ,1633 ,867 , + 21 ,515 ,2041 ,845 ,879 ,198 ,442 ,579 ,1326 ,1734 ,523 ,531 ,197 ,1806 ,821 ,901 , + 2038 ,194 ,424 ,1942 ,625 ,1186 ,139 ,1654 ,1647 ,699 ,1996 ,1992 ,1917 ,1503 ,1818 ,297 , + 1190 ,694 ,638 ,1001 ,1918 ,707 ,291 ,911 ,36 ,501 ,1976 ,761 ,592 ,1994 ,1587 ,672 , + 93 ,322 ,747 ,1016 ,920 ,959 ,529 ,567 ,109 ,69 ,953 ,1381 ,1258 ,2020 ,441 ,38 , + 620 ,194 ,1230 ,1806 ,1737 ,1550 ,2029 ,1518 ,875 ,976 ,952 ,542 ,2040 ,577 ,1946 ,625 , + 82 ,1581 ,167 ,810 ,1380 ,1095 ,1784 ,97 ,1122 ,1335 ,185 ,428 ,83 ,1399 ,1610 ,854 , + 1714 ,1003 ,197 ,2034 ,80 ,1392 ,575 ,1955 ,340 ,604 ,827 ,443 ,1549 ,792 ,1593 ,1750 , + 429 ,1702 ,288 ,1370 ,925 ,1276 ,1954 ,734 ,371 ,1657 ,1707 ,1945 ,1855 ,145 ,1045 ,312 , + 590 ,1189 ,1542 ,1255 ,457 ,1484 ,738 ,731 ,1667 ,1033 ,1058 ,47 ,1061 ,1315 ,866 ,2008 , + 704 ,183 ,201 ,238 ,128 ,1736 ,926 ,1210 ,479 ,1873 ,698 ,1092 ,197 ,1081 ,1837 ,1883 , + 1721 ,806 ,730 ,531 ,1049 ,1428 ,266 ,894 ,499 ,1525 ,1283 ,1520 ,4 ,1291 ,870 ,1674 , + 235 ,301 ,213 ,286 ,1414 ,1570 ,914 ,410 ,55 ,1037 ,1631 ,1689 ,313 ,1012 ,1241 ,1951 , + 1932 ,1531 ,752 ,1727 ,1667 ,694 ,1754 ,2011 ,1645 ,428 ,387 ,291 ,327 ,1961 ,1666 ,418 , + 1339 ,901 ,1147 ,1894 ,811 ,242 ,1302 ,546 ,721 ,62 ,680 ,1439 ,140 ,258 ,1846 ,411 , + 747 ,1981 ,1665 ,58 ,1411 ,1116 ,340 ,874 ,1498 ,1470 ,794 ,741 ,131 ,938 ,783 ,736 , + 2030 ,1947 ,750 ,130 ,744 ,84 ,864 ,1264 ,1114 ,1275 ,244 ,54 ,1003 ,97 ,1002 ,1608 , + 1617 ,1260 ,945 ,894 ,524 ,664 ,59 ,810 ,235 ,1839 ,141 ,1430 ,2018 ,1385 ,220 ,51 , + 646 ,1638 ,505 ,825 ,1177 ,1445 ,1291 ,293 ,779 ,1023 ,337 ,1155 ,1171 ,1379 ,1205 ,214 , + 1557 ,1312 ,684 ,2039 ,1925 ,39 ,1242 ,1928 ,222 ,1987 ,938 ,509 ,1093 ,1172 ,663 ,922 , + 1468 ,266 ,551 ,54 ,212 ,1058 ,389 ,294 ,1396 ,771 ,360 ,1415 ,209 ,11 ,208 ,818 , + 1841 ,1828 ,1293 ,409 ,1058 ,1503 ,1208 ,1593 ,993 ,330 ,1527 ,713 ,1925 ,382 ,780 ,149 , + 75 ,538 ,1999 ,1932 ,800 ,1486 ,1692 ,470 ,2000 ,1661 ,404 ,1638 ,225 ,1780 ,256 ,384 , + 189 ,1987 ,456 ,2034 ,1056 ,1890 ,827 ,406 ,748 ,978 ,1202 ,727 ,227 ,1310 ,1101 ,1045 , + 918 ,1628 ,1599 ,544 ,2000 ,95 ,96 ,1302 ,712 ,257 ,1806 ,1293 ,17 ,1579 ,426 ,432 , + 1832 ,1987 ,1032 ,739 ,613 ,44 ,1881 ,1361 ,1113 ,1700 ,790 ,1582 ,335 ,1837 ,273 ,755 , + 877 ,133 ,984 ,1698 ,361 ,764 ,353 ,1574 ,498 ,791 ,67 ,1572 ,804 ,1875 ,1102 ,91 , + 955 ,773 ,2008 ,693 ,129 ,1523 ,290 ,862 ,1752 ,552 ,1732 ,632 ,1407 ,1230 ,1013 ,2025 , + 854 ,1044 ,1764 ,409 ,190 ,1485 ,125 ,1134 ,538 ,2034 ,1456 ,577 ,990 ,1493 ,1587 ,526 , + 1320 ,480 ,827 ,290 ,1837 ,679 ,99 ,1852 ,866 ,1798 ,163 ,943 ,1806 ,1979 ,31 ,1999 , + 702 ,31 ,1852 ,1072 ,63 ,1550 ,1440 ,999 ,530 ,1493 ,1 ,405 ,1877 ,136 ,1413 ,1525 , + 402 ,8 ,250 ,786 ,304 ,1426 ,1600 ,1852 ,1063 ,215 ,313 ,1269 ,1875 ,490 ,383 ,1117 , + 769 ,1515 ,1535 ,164 ,1019 ,102 ,326 ,1255 ,120 ,1542 ,1996 ,1027 ,1731 ,1430 ,802 ,485 , + 210 ,646 ,1758 ,443 ,1270 ,1953 ,771 ,643 ,699 ,393 ,47 ,1314 ,941 ,1218 ,481 ,764 , + 666 ,243 ,783 ,546 ,267 ,555 ,825 ,2008 ,1210 ,1542 ,1165 ,439 ,1736 ,1204 ,166 ,1942 , + 32 ,646 ,1490 ,1402 ,423 ,1953 ,1353 ,717 ,724 ,847 ,115 ,951 ,1995 ,1688 ,1047 ,1752 , + 448 ,243 ,783 ,290 ,1736 ,1443 ,666 ,1744 ,1210 ,1992 ,1165 ,253 ,1123 ,113 ,166 ,1684 , + 1978 ,1829 ,618 ,1853 ,1255 ,1067 ,1353 ,717 ,724 ,591 ,569 ,1124 ,35 ,97 ,332 ,1752 , + 1850 ,243 ,783 ,290 ,1736 ,1443 ,666 ,1744 ,1210 ,1370 ,1165 ,436 ,1908 ,113 ,644 ,851 , + 1978 ,251 ,1736 ,1853 ,1255 ,1626 ,377 ,1586 ,204 ,591 ,1538 ,951 ,1995 ,1688 ,427 ,1833 , + 1850 ,243 ,783 ,290 ,1736 ,1030 ,666 ,1744 ,739 ,1370 ,1165 ,436 ,1123 ,113 ,144 ,851 , + 1978 ,1829 ,1736 ,1406 ,1255 ,1626 ,1332 ,1586 ,204 ,847 ,1538 ,483 ,35 ,1688 ,1047 ,1833 , + 1850 ,243 ,783 ,290 ,1736 ,1030 ,976 ,2008 ,739 ,1992 ,1165 ,436 ,1908 ,1430 ,644 ,1684 , + 32 ,1419 ,1736 ,1402 ,692 ,1953 ,377 ,717 ,204 ,847 ,1538 ,1388 ,1995 ,1440 ,1047 ,1833 , + 1850 ,243 ,783 ,142 ,481 ,1030 ,666 ,1744 ,739 ,1370 ,1165 ,436 ,1908 ,650 ,644 ,851 , + 1978 ,251 ,1736 ,1853 ,1255 ,1626 ,377 ,1586 ,204 ,591 ,569 ,951 ,1995 ,1688 ,1140 ,1833 , + 1850 ,243 ,783 ,142 ,481 ,1030 ,666 ,1744 ,739 ,1370 ,1165 ,253 ,1908 ,650 ,166 ,1684 , + 1978 ,1419 ,290 ,1402 ,1255 ,1626 ,1353 ,717 ,724 ,591 ,569 ,1388 ,1995 ,97 ,332 ,1833 , + 481 ,243 ,783 ,142 ,481 ,1030 ,666 ,1744 ,739 ,1370 ,1165 ,253 ,1908 ,650 ,802 ,1684 , + 1978 ,1419 ,290 ,1402 ,1255 ,1626 ,1353 ,717 ,204 ,591 ,1538 ,1388 ,1995 ,1440 ,332 ,1752 , + 481 ,243 ,783 ,142 ,481 ,1030 ,976 ,2008 ,739 ,374 ,1165 ,436 ,1736 ,1112 ,644 ,1684 , + 1978 ,251 ,1490 ,1402 ,610 ,1953 ,1353 ,717 ,204 ,591 ,1538 ,951 ,35 ,1440 ,1047 ,1752 , + 481 ,243 ,1178 ,546 ,267 ,555 ,976 ,1648 ,739 ,374 ,1165 ,253 ,1908 ,113 ,166 ,851 , + 1978 ,1829 ,1736 ,1853 ,1255 ,1067 ,1353 ,1774 ,724 ,591 ,569 ,1124 ,35 ,1688 ,332 ,1562 , + 481 ,243 ,1178 ,546 ,267 ,555 ,976 ,1648 ,1210 ,374 ,1165 ,439 ,1912 ,1204 ,144 ,851 , + 32 ,646 ,1490 ,1428 ,692 ,1626 ,1332 ,1774 ,724 ,847 ,1538 ,1124 ,1995 ,1688 ,427 ,1833 , + 481 ,243 ,1178 ,546 ,267 ,1030 ,976 ,2008 ,739 ,374 ,1165 ,436 ,194 ,1430 ,644 ,1684 , + 32 ,1642 ,1736 ,1402 ,1908 ,1626 ,377 ,717 ,204 ,591 ,1538 ,1388 ,422 ,1440 ,427 ,1833 , + 481 ,243 ,1178 ,546 ,267 ,1030 ,976 ,2008 ,739 ,374 ,1165 ,436 ,194 ,1430 ,644 ,1684 , + 1978 ,1642 ,1736 ,1402 ,1908 ,1626 ,377 ,717 ,204 ,591 ,1538 ,483 ,422 ,1440 ,1140 ,1752 , + 481 ,243 ,1178 ,546 ,267 ,1030 ,825 ,2008 ,739 ,1370 ,1165 ,436 ,1101 ,650 ,853 ,610 , + 1978 ,251 ,290 ,1406 ,692 ,1626 ,1497 ,1774 ,724 ,591 ,1538 ,951 ,1995 ,1440 ,332 ,1562 , + 481 ,243 ,1178 ,546 ,267 ,1030 ,825 ,2008 ,739 ,1370 ,1165 ,436 ,1101 ,650 ,853 ,851 , + 1978 ,251 ,290 ,1406 ,692 ,1626 ,1497 ,1586 ,204 ,591 ,1538 ,951 ,422 ,1440 ,332 ,1562 , + 384 ,1211 ,1456 ,417 ,267 ,347 ,666 ,1744 ,1210 ,1370 ,1165 ,760 ,1123 ,1492 ,853 ,851 , + 210 ,1829 ,1490 ,1406 ,1133 ,1067 ,1095 ,1586 ,1423 ,973 ,1841 ,254 ,1995 ,1688 ,1047 ,1562 , + 1826 ,1052 ,658 ,1507 ,73 ,2010 ,1666 ,1273 ,306 ,1500 ,2040 ,730 ,1395 ,1907 ,570 ,1218 , + 1816 ,1681 ,1615 ,909 ,1860 ,1490 ,526 ,1998 ,2029 ,17 ,209 ,912 ,1919 ,2020 ,155 ,1806 , + 481 ,243 ,1697 ,546 ,481 ,555 ,1871 ,1648 ,1972 ,1992 ,1028 ,253 ,1123 ,1430 ,644 ,851 , + 200 ,1829 ,1490 ,1397 ,692 ,1067 ,377 ,717 ,204 ,591 ,1538 ,1388 ,1995 ,1688 ,332 ,1562 , + 481 ,243 ,1178 ,1348 ,1335 ,1572 ,976 ,2008 ,739 ,1542 ,1165 ,436 ,1908 ,1112 ,644 ,851 , + 1978 ,1419 ,1736 ,1402 ,692 ,1897 ,377 ,1586 ,1423 ,847 ,115 ,1388 ,35 ,1688 ,1140 ,1752 , + 481 ,243 ,1178 ,1348 ,1335 ,1443 ,976 ,2008 ,1210 ,1992 ,1165 ,436 ,194 ,1112 ,644 ,851 , + 1978 ,1829 ,1490 ,1402 ,610 ,1897 ,377 ,1586 ,724 ,591 ,1538 ,1124 ,1995 ,1688 ,1140 ,1833 , + 384 ,991 ,1686 ,1709 ,568 ,1356 ,1871 ,1868 ,322 ,1546 ,675 ,1439 ,1700 ,839 ,148 ,465 , + 435 ,271 ,63 ,1314 ,65 ,992 ,1201 ,641 ,1033 ,1325 ,7 ,1792 ,369 ,473 ,271 ,1549 , + 1738 ,1521 ,146 ,1846 ,56 ,457 ,1658 ,1739 ,1379 ,2028 ,937 ,1457 ,712 ,345 ,1877 ,5 , + 386 ,613 ,1007 ,686 ,2030 ,1093 ,107 ,722 ,1476 ,125 ,1068 ,201 ,207 ,1234 ,159 ,128 , + 522 ,1511 ,742 ,405 ,547 ,1176 ,546 ,1078 ,464 ,1834 ,1400 ,487 ,1703 ,921 ,148 ,1587 , + 382 ,166 ,1972 ,1540 ,1375 ,1785 ,789 ,83 ,983 ,1138 ,1484 ,1347 ,437 ,367 ,744 ,1370 , + 785 ,1190 ,1614 ,1453 ,1715 ,1975 ,1246 ,1068 ,990 ,1216 ,1669 ,1892 ,117 ,491 ,938 ,542 , + 1969 ,148 ,0 ,704 ,1035 ,790 ,1274 ,1828 ,445 ,1530 ,703 ,1656 ,530 ,1749 ,1322 ,1485 , + 354 ,1854 ,110 ,1445 ,1526 ,1262 ,64 ,278 ,1474 ,1239 ,1986 ,1345 ,1177 ,286 ,382 ,171 , + 464 ,1428 ,1722 ,347 ,864 ,61 ,602 ,2033 ,1684 ,561 ,348 ,1535 ,1728 ,1179 ,416 ,1411 , + 521 ,344 ,62 ,1606 ,1473 ,1163 ,29 ,885 ,906 ,573 ,1032 ,1870 ,300 ,924 ,852 ,55 , + 587 ,1673 ,904 ,495 ,1585 ,1804 ,1294 ,1133 ,561 ,1089 ,1175 ,1075 ,1117 ,1365 ,137 ,1124 , + 521 ,749 ,1590 ,1947 ,1602 ,302 ,1109 ,1610 ,441 ,613 ,680 ,213 ,1584 ,1909 ,1520 ,276 , + 461 ,493 ,1934 ,346 ,780 ,201 ,564 ,1350 ,1494 ,892 ,616 ,975 ,585 ,802 ,1508 ,1302 , + 1686 ,1976 ,349 ,1393 ,825 ,368 ,99 ,798 ,1384 ,472 ,546 ,442 ,1709 ,1021 ,418 ,932 , + 1264 ,541 ,1769 ,1987 ,1229 ,1007 ,896 ,1120 ,327 ,544 ,579 ,1758 ,1150 ,1103 ,329 ,1955 , + 1548 ,578 ,1879 ,862 ,509 ,1158 ,1278 ,1200 ,937 ,145 ,766 ,1907 ,83 ,1903 ,1683 ,691 , + 65 ,1096 ,769 ,737 ,1146 ,819 ,1617 ,1650 ,636 ,1535 ,707 ,419 ,214 ,661 ,1215 ,808 , + 1548 ,1351 ,769 ,1461 ,1823 ,156 ,890 ,526 ,1694 ,392 ,36 ,845 ,658 ,1336 ,597 ,1807 , + 1597 ,1173 ,1225 ,1225 ,274 ,2035 ,1087 ,2039 ,896 ,846 ,592 ,415 ,688 ,1522 ,1222 ,1728 , + 1109 ,1398 ,1764 ,1826 ,1034 ,2023 ,914 ,1239 ,1534 ,1342 ,197 ,830 ,723 ,854 ,2011 ,1132 , + 272 ,315 ,744 ,145 ,1838 ,791 ,162 ,757 ,1749 ,1110 ,267 ,781 ,532 ,1187 ,869 ,192 , + 29 ,740 ,1051 ,1626 ,432 ,1966 ,725 ,396 ,1048 ,512 ,418 ,1787 ,1838 ,990 ,1205 ,1464 , + 947 ,525 ,1303 ,1325 ,624 ,1697 ,438 ,951 ,757 ,1125 ,390 ,177 ,1343 ,1273 ,746 ,834 , + 268 ,1190 ,1562 ,284 ,473 ,955 ,895 ,1553 ,747 ,1339 ,890 ,1804 ,1300 ,1537 ,201 ,166 , + 1040 ,84 ,1872 ,631 ,331 ,353 ,22 ,1982 ,576 ,162 ,84 ,1097 ,1067 ,752 ,463 ,1609 , + 1558 ,740 ,1916 ,2015 ,1906 ,201 ,1110 ,1708 ,853 ,675 ,357 ,1727 ,938 ,986 ,2016 ,509 , + 1385 ,1985 ,1948 ,1347 ,1297 ,390 ,1344 ,1199 ,1208 ,566 ,258 ,450 ,1599 ,53 ,100 ,806 , + 199 ,1054 ,1544 ,1716 ,696 ,1983 ,835 ,1281 ,538 ,1199 ,203 ,765 ,1961 ,611 ,546 ,396 , + 256 ,382 ,647 ,419 ,1370 ,800 ,1614 ,825 ,1040 ,264 ,514 ,1901 ,1713 ,1273 ,860 ,1656 , + 1912 ,1879 ,1037 ,1604 ,577 ,1507 ,1170 ,1010 ,1375 ,892 ,1242 ,1843 ,1286 ,1041 ,1503 ,1215 , + 1395 ,648 ,2044 ,995 ,1372 ,474 ,310 ,517 ,1278 ,743 ,1903 ,469 ,1985 ,1855 ,9 ,2015 , + 533 ,497 ,1455 ,46 ,1568 ,432 ,1524 ,1735 ,1274 ,349 ,1250 ,73 ,405 ,1600 ,783 ,509 , + 1385 ,228 ,793 ,768 ,827 ,39 ,442 ,310 ,2044 ,1561 ,1861 ,88 ,1598 ,1385 ,1949 ,1337 , + 1756 ,1727 ,1501 ,985 ,647 ,2044 ,1974 ,195 ,853 ,1731 ,681 ,1854 ,556 ,775 ,613 ,1765 , + 649 ,1481 ,1288 ,1858 ,1442 ,1623 ,785 ,270 ,579 ,1325 ,420 ,1564 ,20 ,1643 ,822 ,639 , + 833 ,1202 ,1645 ,519 ,1386 ,1247 ,909 ,644 ,871 ,1193 ,1692 ,542 ,1131 ,507 ,1301 ,1654 , + 612 ,887 ,1246 ,1246 ,1937 ,1365 ,168 ,913 ,1788 ,1473 ,1986 ,1357 ,736 ,220 ,1946 ,1171 , + 929 ,1636 ,448 ,1565 ,1333 ,1593 ,647 ,43 ,1099 ,1679 ,1065 ,632 ,652 ,993 ,1342 ,1186 , + 785 ,1992 ,260 ,1311 ,662 ,1490 ,1879 ,1475 ,1661 ,1946 ,1880 ,372 ,790 ,446 ,1367 ,989 , + 1141 ,185 ,277 ,698 ,476 ,1177 ,1597 ,1519 ,1553 ,1254 ,1975 ,374 ,1943 ,606 ,2046 ,930 , + 1566 ,1510 ,1451 ,512 ,740 ,1829 ,1114 ,1968 ,1644 ,1150 ,1827 ,910 ,1448 ,1339 ,381 ,422 , + 215 ,1603 ,344 ,1162 ,294 ,1511 ,316 ,671 ,531 ,827 ,1211 ,1217 ,1684 ,1161 ,1370 ,111 , + 139 ,101 ,521 ,1984 ,1714 ,452 ,1177 ,634 ,319 ,122 ,618 ,2030 ,2041 ,769 ,862 ,1237 , + 1929 ,1867 ,1878 ,1754 ,1686 ,1239 ,1529 ,663 ,1061 ,1095 ,633 ,1998 ,157 ,1838 ,297 ,1001 , + 1887 ,1890 ,591 ,1110 ,754 ,1273 ,481 ,245 ,1587 ,1087 ,1964 ,1011 ,615 ,148 ,967 ,1451 , + 1384 ,870 ,76 ,874 ,1455 ,1587 ,190 ,782 ,801 ,164 ,696 ,1228 ,990 ,862 ,1825 ,1928 , + 1956 ,182 ,137 ,1592 ,1108 ,442 ,513 ,2027 ,1828 ,1302 ,1066 ,1626 ,557 ,1655 ,604 ,2041 , + 1924 ,504 ,78 ,1597 ,956 ,1019 ,1744 ,1340 ,738 ,1903 ,1582 ,1002 ,534 ,413 ,1397 ,655 , + 294 ,728 ,1240 ,1992 ,1557 ,769 ,178 ,1518 ,680 ,232 ,850 ,1483 ,340 ,875 ,73 ,216 , + 1915 ,332 ,1280 ,1530 ,920 ,146 ,1895 ,18 ,81 ,1895 ,779 ,1564 ,953 ,399 ,1627 ,291 , + 109 ,453 ,101 ,611 ,613 ,1660 ,952 ,1386 ,1926 ,623 ,270 ,242 ,506 ,892 ,391 ,712 , + 1384 ,172 ,912 ,916 ,921 ,1077 ,1528 ,379 ,960 ,293 ,330 ,1805 ,451 ,1362 ,1596 ,1033 , + 427 ,1210 ,637 ,1788 ,1426 ,1896 ,2015 ,693 ,544 ,1538 ,416 ,1137 ,668 ,1310 ,1456 ,1092 , + 964 ,846 ,556 ,828 ,573 ,1096 ,761 ,1075 ,19 ,2025 ,1598 ,791 ,1725 ,234 ,1204 ,680 , + 657 ,1615 ,523 ,1362 ,1299 ,1405 ,217 ,575 ,994 ,1090 ,195 ,1537 ,1234 ,1880 ,172 ,1574 , + 1559 ,1440 ,574 ,607 ,1574 ,1894 ,1998 ,1508 ,39 ,577 ,1388 ,838 ,1074 ,1493 ,627 ,1742 , + 854 ,1142 ,267 ,130 ,45 ,169 ,1036 ,818 ,875 ,1157 ,1701 ,1420 ,455 ,283 ,1937 ,1722 , + 547 ,1312 ,370 ,917 ,1441 ,607 ,125 ,1828 ,1106 ,391 ,356 ,1233 ,1507 ,1084 ,1019 ,659 , + 1324 ,1706 ,1749 ,1767 ,73 ,1006 ,1293 ,627 ,590 ,30 ,1363 ,764 ,630 ,583 ,1484 ,1418 , + 1862 ,2019 ,1481 ,90 ,1822 ,1623 ,1836 ,311 ,506 ,1204 ,1973 ,1280 ,1057 ,557 ,1743 ,1994 , + 44 ,1818 ,1313 ,885 ,862 ,1200 ,887 ,1641 ,1921 ,277 ,1347 ,521 ,1269 ,166 ,388 ,993 , + 1221 ,752 ,963 ,2015 ,1529 ,691 ,783 ,1125 ,55 ,1257 ,190 ,1968 ,1962 ,1225 ,1593 ,335 , + 301 ,362 ,1102 ,112 ,48 ,1359 ,1437 ,924 ,1210 ,1581 ,1147 ,717 ,206 ,655 ,1247 ,1352 , + 496 ,1527 ,1037 ,1258 ,1296 ,1999 ,1840 ,1352 ,578 ,484 ,1736 ,1105 ,914 ,781 ,934 ,7 , + 1894 ,804 ,1197 ,1321 ,1546 ,180 ,1713 ,871 ,1467 ,698 ,1142 ,1179 ,1174 ,1812 ,942 ,1277 , + 1030 ,200 ,856 ,941 ,169 ,1680 ,969 ,227 ,229 ,831 ,1665 ,175 ,992 ,2020 ,754 ,1541 , + 275 ,1187 ,1155 ,237 ,580 ,2008 ,304 ,784 ,890 ,1243 ,1498 ,583 ,1694 ,1205 ,772 ,265 , + 1225 ,380 ,1464 ,1249 ,1779 ,308 ,567 ,1364 ,397 ,252 ,197 ,1787 ,468 ,460 ,1781 ,386 , + 1024 ,926 ,1262 ,1108 ,618 ,839 ,839 ,1234 ,1257 ,1669 ,392 ,965 ,1161 ,810 ,832 ,803 , + 93 ,386 ,1252 ,1260 ,1866 ,1975 ,517 ,171 ,1144 ,1570 ,1158 ,1590 ,1761 ,544 ,839 ,1626 , + 1839 ,1232 ,616 ,2 ,743 ,1646 ,698 ,852 ,953 ,88 ,1712 ,295 ,257 ,1832 ,1863 ,2008 , + 1765 ,1729 ,214 ,112 ,1012 ,589 ,815 ,141 ,1683 ,256 ,1647 ,1952 ,364 ,1243 ,1571 ,1208 , + 1353 ,1485 ,1199 ,1896 ,1676 ,1931 ,1720 ,1340 ,7 ,910 ,1686 ,467 ,90 ,1837 ,1015 ,1858 , + 1127 ,559 ,1604 ,726 ,1465 ,1543 ,1861 ,1644 ,382 ,1641 ,1130 ,1451 ,173 ,474 ,1628 ,1415 , + 1128 ,912 ,1167 ,1433 ,2033 ,511 ,1410 ,571 ,171 ,315 ,1533 ,769 ,262 ,1544 ,630 ,244 , + 632 ,501 ,910 ,1315 ,913 ,1150 ,719 ,237 ,1678 ,282 ,320 ,245 ,1557 ,1053 ,831 ,1366 , + 2008 ,488 ,1343 ,191 ,2029 ,193 ,1358 ,248 ,1699 ,637 ,1034 ,196 ,347 ,688 ,1502 ,380 , + 728 ,872 ,713 ,1871 ,1165 ,1017 ,397 ,1567 ,332 ,616 ,19 ,1792 ,978 ,1123 ,1397 ,537 , + 1172 ,694 ,1705 ,1723 ,1046 ,593 ,780 ,2002 ,725 ,115 ,1419 ,730 ,485 ,678 ,57 ,938 , + 389 ,1287 ,1313 ,1918 ,43 ,668 ,1878 ,1728 ,1786 ,1987 ,1874 ,1863 ,1236 ,1124 ,1726 ,337 , + 1596 ,1870 ,1547 ,1780 ,151 ,185 ,1456 ,1093 ,1603 ,1534 ,1096 ,1317 ,1206 ,1081 ,1300 ,315 , + 103 ,110 ,1042 ,79 ,1822 ,285 ,633 ,1763 ,875 ,172 ,1604 ,1013 ,1829 ,1551 ,314 ,750 , + 1352 ,1139 ,202 ,1432 ,1649 ,938 ,1037 ,906 ,1252 ,1359 ,586 ,1861 ,1295 ,1376 ,1904 ,1164 , + 524 ,1398 ,469 ,194 ,2019 ,811 ,1221 ,1520 ,815 ,1369 ,1099 ,1285 ,492 ,152 ,1289 ,1742 , + 533 ,1029 ,1592 ,560 ,116 ,852 ,268 ,2029 ,1932 ,423 ,1277 ,721 ,544 ,347 ,1534 ,933 , + 1222 ,1983 ,170 ,1511 ,1239 ,1792 ,846 ,1854 ,1876 ,1410 ,1989 ,1884 ,1629 ,894 ,1185 ,1567 , + 252 ,773 ,632 ,1794 ,109 ,1804 ,976 ,758 ,417 ,1529 ,676 ,203 ,1522 ,771 ,1777 ,131 , + 495 ,1373 ,1645 ,2016 ,543 ,1695 ,1171 ,1895 ,994 ,1987 ,296 ,418 ,1194 ,1189 ,1595 ,1801 , + 1334 ,773 ,762 ,434 ,1368 ,1249 ,1738 ,1546 ,1939 ,1019 ,550 ,531 ,1552 ,1362 ,323 ,316 , + 400 ,1961 ,766 ,1201 ,875 ,2028 ,211 ,111 ,508 ,758 ,598 ,906 ,63 ,681 ,42 ,1988 , + 1732 ,1184 ,1270 ,1490 ,202 ,692 ,1961 ,1057 ,852 ,978 ,894 ,1082 ,1048 ,888 ,889 ,1047 , + 860 ,254 ,1833 ,19 ,38 ,896 ,14 ,1245 ,2028 ,416 ,886 ,213 ,1617 ,807 ,442 ,1422 , + 1899 ,667 ,595 ,111 ,79 ,1161 ,938 ,1020 ,603 ,1527 ,1402 ,1747 ,2022 ,1376 ,735 ,418 , + 1140 ,1785 ,338 ,1633 ,1881 ,1556 ,916 ,84 ,1378 ,1147 ,1462 ,1415 ,1829 ,726 ,1436 ,645 , + 1552 ,1459 ,1719 ,1535 ,892 ,1933 ,1163 ,672 ,1203 ,1231 ,1503 ,772 ,1272 ,1918 ,107 ,2036 , + 1367 ,968 ,1989 ,888 ,2019 ,1376 ,1767 ,2025 ,368 ,29 ,1358 ,952 ,1348 ,116 ,1002 ,65 , + 1970 ,1522 ,1784 ,523 ,173 ,1765 ,904 ,1572 ,432 ,71 ,1460 ,1278 ,347 ,300 ,502 ,136 , + 317 ,902 ,1669 ,1738 ,777 ,1076 ,1441 ,553 ,949 ,1906 ,622 ,1409 ,285 ,1081 ,1125 ,256 , + 1467 ,1165 ,390 ,171 ,109 ,1342 ,421 ,856 ,1616 ,597 ,787 ,1375 ,1070 ,903 ,1264 ,230 , + 317 ,856 ,130 ,677 ,216 ,212 ,211 ,49 ,732 ,1883 ,2015 ,1564 ,1278 ,1340 ,621 ,79 , + 624 ,1117 ,1087 ,1876 ,1489 ,711 ,1089 ,1912 ,191 ,1510 ,171 ,526 ,1420 ,136 ,848 ,1586 , + 877 ,376 ,1865 ,1875 ,1401 ,1032 ,973 ,736 ,1559 ,1067 ,2026 ,347 ,1074 ,143 ,656 ,1912 , + 100 ,25 ,959 ,813 ,1115 ,1534 ,986 ,1154 ,426 ,1305 ,1600 ,1228 ,416 ,763 ,534 ,2004 , + 854 ,55 ,1523 ,1290 ,311 ,1032 ,542 ,1398 ,1660 ,1427 ,2043 ,815 ,118 ,1515 ,163 ,907 , + 1511 ,439 ,224 ,1569 ,327 ,370 ,1662 ,454 ,155 ,234 ,1153 ,1461 ,1599 ,1905 ,1922 ,1973 , + 702 ,1540 ,183 ,1071 ,291 ,1431 ,1506 ,1567 ,1214 ,883 ,1991 ,1544 ,234 ,1657 ,885 ,1211 , + 1471 ,763 ,418 ,1021 ,1928 ,745 ,1507 ,507 ,1826 ,858 ,650 ,1589 ,459 ,221 ,1168 ,1879 , + 34 ,1700 ,1178 ,97 ,1019 ,555 ,666 ,1744 ,1210 ,1542 ,415 ,436 ,1101 ,1430 ,853 ,1942 , + 200 ,251 ,1490 ,1402 ,1908 ,1626 ,1353 ,717 ,204 ,591 ,47 ,1388 ,687 ,1440 ,1140 ,1833 , + 666 ,243 ,783 ,142 ,481 ,555 ,666 ,1648 ,1210 ,1542 ,1165 ,253 ,1912 ,650 ,166 ,851 , + 1978 ,1419 ,290 ,1853 ,1255 ,1626 ,1353 ,1586 ,724 ,847 ,1538 ,951 ,1995 ,97 ,332 ,1752 , + 448 ,243 ,783 ,142 ,481 ,1030 ,666 ,2008 ,739 ,1370 ,1165 ,1383 ,1908 ,650 ,853 ,851 , + 32 ,646 ,290 ,1428 ,692 ,1897 ,1497 ,1586 ,204 ,973 ,1538 ,951 ,1995 ,14 ,1047 ,1752 , + 84 ,243 ,783 ,142 ,481 ,1030 ,666 ,2008 ,739 ,1370 ,1165 ,436 ,194 ,650 ,144 ,1684 , + 1978 ,646 ,290 ,1402 ,1908 ,1626 ,1332 ,717 ,724 ,591 ,47 ,483 ,422 ,1440 ,1047 ,1833 , + 84 ,243 ,783 ,142 ,481 ,1030 ,976 ,2008 ,739 ,1992 ,1165 ,436 ,1101 ,650 ,644 ,610 , + 1978 ,251 ,290 ,1853 ,692 ,1897 ,1353 ,1774 ,724 ,591 ,1538 ,1388 ,35 ,1440 ,332 ,1833 , + 7 ,203 ,265 ,290 ,306 ,93 ,104 ,583 ,1938 ,278 ,618 ,1040 ,321 ,1213 ,166 ,1732 , + 959 ,271 ,1531 ,172 ,1133 ,1680 ,359 ,1509 ,1110 ,1591 ,260 ,254 ,1334 ,2023 ,911 ,1752 , + 739 ,1068 ,811 ,1473 ,1141 ,301 ,1784 ,1374 ,791 ,1505 ,402 ,1444 ,1321 ,1625 ,397 ,1711 , + 653 ,514 ,1779 ,1949 ,1648 ,998 ,289 ,1555 ,1342 ,1723 ,54 ,1238 ,1654 ,1538 ,798 ,823 , + 739 ,1274 ,682 ,460 ,1631 ,120 ,411 ,1277 ,761 ,1117 ,2030 ,1587 ,1961 ,1468 ,1538 ,772 , + 1306 ,1725 ,828 ,419 ,362 ,981 ,1583 ,1843 ,2024 ,1650 ,306 ,1062 ,1913 ,650 ,1441 ,1040 , + 1417 ,774 ,1530 ,1086 ,1018 ,1496 ,1015 ,885 ,142 ,870 ,1121 ,1829 ,1907 ,1089 ,403 ,1411 , + 23 ,770 ,1480 ,553 ,1711 ,530 ,1905 ,860 ,1972 ,736 ,582 ,119 ,1965 ,1941 ,1724 ,1425 , + 728 ,580 ,1209 ,454 ,990 ,1507 ,1411 ,824 ,1306 ,407 ,1630 ,1968 ,735 ,848 ,574 ,1851 , + 2002 ,1186 ,1661 ,132 ,1082 ,217 ,1619 ,761 ,1465 ,1416 ,1146 ,88 ,1191 ,1555 ,236 ,1506 , + 1353 ,1748 ,1434 ,563 ,837 ,1612 ,1514 ,481 ,1272 ,770 ,99 ,988 ,1413 ,1560 ,273 ,1656 , + 642 ,1759 ,30 ,1163 ,629 ,1705 ,297 ,1732 ,1467 ,802 ,1138 ,701 ,570 ,1466 ,330 ,1435 , + 1608 ,1945 ,407 ,1259 ,1545 ,1828 ,486 ,1851 ,675 ,1515 ,1664 ,1395 ,700 ,1054 ,938 ,1903 , + 1566 ,668 ,1663 ,70 ,409 ,1363 ,108 ,525 ,1986 ,1474 ,1211 ,1952 ,1175 ,1419 ,1710 ,574 , + 1516 ,1527 ,95 ,664 ,1029 ,439 ,1716 ,1333 ,815 ,26 ,867 ,1269 ,730 ,429 ,509 ,1977 , + 1618 ,1651 ,328 ,1499 ,1037 ,618 ,202 ,979 ,1952 ,536 ,1322 ,1041 ,1649 ,1279 ,2011 ,290 , + 1230 ,392 ,936 ,598 ,597 ,1628 ,1904 ,603 ,761 ,804 ,839 ,461 ,1729 ,781 ,1938 ,2017 , + 141 ,370 ,827 ,1623 ,545 ,266 ,484 ,1926 ,352 ,493 ,70 ,847 ,1864 ,707 ,1430 ,1552 , + 1178 ,1503 ,1090 ,1938 ,862 ,1763 ,224 ,1012 ,1167 ,1395 ,877 ,688 ,837 ,1044 ,601 ,1031 , + 1542 ,665 ,859 ,1707 ,113 ,1694 ,2021 ,575 ,1217 ,112 ,483 ,52 ,1861 ,2036 ,744 ,97 , + 1451 ,867 ,647 ,454 ,1480 ,1956 ,981 ,1288 ,996 ,1393 ,595 ,1575 ,1870 ,891 ,673 ,385 , + 1411 ,756 ,929 ,765 ,1897 ,1085 ,1124 ,1363 ,1561 ,1627 ,474 ,875 ,1925 ,422 ,741 ,1119 , + 819 ,1354 ,1492 ,921 ,1041 ,469 ,641 ,1532 ,180 ,1157 ,1381 ,1620 ,2024 ,895 ,495 ,1820 , + 1903 ,780 ,1415 ,1646 ,71 ,1933 ,967 ,1773 ,253 ,1305 ,1042 ,1342 ,1521 ,1392 ,1045 ,649 , + 1497 ,710 ,1169 ,1064 ,1509 ,1987 ,468 ,1292 ,664 ,773 ,78 ,578 ,2029 ,497 ,53 ,394 , + 1992 ,1709 ,767 ,1202 ,1054 ,388 ,2007 ,1772 ,815 ,1081 ,1141 ,30 ,1641 ,1316 ,1647 ,311 , + 576 ,694 ,1578 ,1418 ,1323 ,706 ,2013 ,663 ,83 ,268 ,1359 ,1912 ,1004 ,235 ,345 ,420 , + 900 ,429 ,1301 ,1615 ,1812 ,1187 ,1625 ,1571 ,105 ,1466 ,765 ,2013 ,1506 ,1295 ,1171 ,730 , + 872 ,1446 ,1076 ,1145 ,528 ,480 ,736 ,1663 ,1649 ,1419 ,1808 ,851 ,1075 ,1931 ,392 ,1646 , + 1570 ,736 ,122 ,1580 ,702 ,2014 ,382 ,1434 ,974 ,1679 ,876 ,167 ,338 ,334 ,594 ,1614 , + 872 ,20 ,302 ,2044 ,1376 ,1213 ,1698 ,278 ,1035 ,128 ,669 ,1123 ,479 ,282 ,512 ,530 , + 1260 ,1469 ,1804 ,228 ,751 ,1773 ,1677 ,498 ,567 ,1510 ,468 ,1820 ,1041 ,707 ,1683 ,784 , + 1678 ,1453 ,2026 ,1451 ,972 ,755 ,1569 ,1559 ,1864 ,973 ,823 ,405 ,901 ,874 ,1689 ,770 , + 1855 ,1120 ,1148 ,321 ,701 ,1488 ,801 ,1365 ,1108 ,241 ,761 ,1985 ,34 ,479 ,252 ,1008 , + 1149 ,148 ,1025 ,529 ,616 ,1007 ,1589 ,1200 ,1676 ,1678 ,146 ,931 ,353 ,346 ,1642 ,185 , + 1985 ,1232 ,1969 ,1091 ,16 ,1097 ,526 ,1054 ,1387 ,1317 ,1385 ,95 ,1467 ,2043 ,421 ,1218 , + 1149 ,2010 ,794 ,67 ,811 ,1644 ,1735 ,1834 ,1151 ,1839 ,487 ,520 ,298 ,329 ,617 ,1728 , + 823 ,150 ,1012 ,1749 ,691 ,422 ,1914 ,240 ,1692 ,1792 ,742 ,634 ,1977 ,1804 ,1973 ,851 , + 390 ,1945 ,228 ,871 ,595 ,964 ,796 ,206 ,829 ,1145 ,973 ,1777 ,1556 ,1082 ,1282 ,1296 , + 1031 ,441 ,751 ,2004 ,1176 ,800 ,1411 ,906 ,2 ,1755 ,1381 ,282 ,97 ,1981 ,458 ,1495 , + 802 ,440 ,642 ,1586 ,573 ,116 ,1324 ,612 ,1029 ,1266 ,460 ,489 ,901 ,79 ,1563 ,758 , + 1639 ,1009 ,1293 ,1894 ,1643 ,1608 ,34 ,438 ,640 ,1629 ,766 ,1189 ,693 ,1647 ,1222 ,1864 , + 93 ,629 ,2021 ,370 ,1423 ,363 ,343 ,1294 ,570 ,258 ,823 ,1404 ,1937 ,232 ,477 ,715 , + 1429 ,287 ,584 ,592 ,274 ,1949 ,1420 ,501 ,1308 ,261 ,1778 ,49 ,94 ,709 ,1965 ,1581 , + 1960 ,1541 ,1068 ,188 ,1387 ,362 ,1892 ,1778 ,38 ,1007 ,31 ,151 ,355 ,1823 ,693 ,1917 , + 364 ,945 ,1886 ,37 ,1377 ,995 ,54 ,237 ,787 ,277 ,840 ,1526 ,1560 ,1744 ,395 ,754 , + 1338 ,788 ,1158 ,629 ,2038 ,865 ,667 ,234 ,687 ,1739 ,1811 ,1406 ,1252 ,688 ,1642 ,1457 , + 214 ,1151 ,1916 ,1581 ,1221 ,311 ,1347 ,152 ,1303 ,1815 ,705 ,16 ,1274 ,241 ,153 ,1048 , + 1794 ,1908 ,256 ,1942 ,893 ,11 ,271 ,1115 ,1106 ,554 ,316 ,990 ,1081 ,411 ,95 ,1407 , + 758 ,1523 ,77 ,1962 ,281 ,1871 ,1945 ,1929 ,81 ,797 ,1076 ,1467 ,37 ,790 ,1412 ,1442 , + 740 ,1153 ,533 ,1029 ,1453 ,1697 ,202 ,1052 ,1447 ,2028 ,1040 ,1372 ,1149 ,565 ,1551 ,1511 , + 1300 ,1292 ,292 ,333 ,893 ,1869 ,1761 ,2022 ,2017 ,1501 ,693 ,1647 ,110 ,1241 ,135 ,425 , + 1453 ,416 ,225 ,563 ,171 ,1386 ,1518 ,1330 ,759 ,1170 ,651 ,1037 ,20 ,288 ,843 ,472 , + 1378 ,1067 ,1466 ,1303 ,357 ,1011 ,222 ,1620 ,1913 ,1962 ,1684 ,10 ,1870 ,1703 ,949 ,1571 , + 1274 ,70 ,1313 ,93 ,534 ,436 ,1214 ,855 ,1375 ,835 ,592 ,1919 ,942 ,953 ,1034 ,837 , + 1612 ,1838 ,445 ,1717 ,1225 ,210 ,1612 ,237 ,700 ,766 ,415 ,237 ,1788 ,1593 ,75 ,869 , + 1790 ,539 ,1677 ,653 ,1735 ,343 ,1686 ,1001 ,1073 ,1587 ,509 ,49 ,1770 ,444 ,1429 ,1183 , + 1935 ,473 ,947 ,1890 ,1364 ,43 ,1344 ,31 ,1255 ,271 ,336 ,2010 ,733 ,764 ,1065 ,1688 , + 389 ,785 ,50 ,1205 ,1269 ,804 ,1728 ,671 ,1390 ,152 ,946 ,51 ,1400 ,622 ,1425 ,1612 , + 1346 ,1842 ,997 ,1636 ,959 ,1989 ,1288 ,877 ,704 ,762 ,1265 ,353 ,884 ,1413 ,1947 ,1118 , + 186 ,287 ,1220 ,236 ,38 ,1069 ,327 ,948 ,767 ,2000 ,1023 ,1281 ,1014 ,591 ,1254 ,986 , + 196 ,1598 ,1121 ,1710 ,910 ,414 ,1627 ,1794 ,1819 ,1543 ,594 ,1588 ,496 ,1311 ,1649 ,1228 , + 1307 ,520 ,157 ,828 ,264 ,1069 ,837 ,568 ,887 ,1318 ,1704 ,141 ,791 ,376 ,1149 ,1032 , + 175 ,1658 ,1288 ,1047 ,1133 ,39 ,687 ,1066 ,18 ,17 ,883 ,1667 ,171 ,1983 ,1327 ,54 , + 2042 ,1700 ,1029 ,164 ,915 ,347 ,976 ,754 ,1972 ,1992 ,1458 ,253 ,1123 ,1430 ,144 ,1942 , + 1531 ,251 ,618 ,1428 ,1255 ,1626 ,1332 ,74 ,1423 ,973 ,115 ,1845 ,422 ,97 ,1047 ,1752 , + 541 ,243 ,1697 ,164 ,1736 ,1030 ,976 ,2008 ,1210 ,91 ,1165 ,436 ,1912 ,113 ,144 ,851 , + 1978 ,1419 ,1736 ,1853 ,692 ,1953 ,1332 ,1586 ,724 ,847 ,47 ,1388 ,35 ,14 ,1047 ,1752 , + 48 ,355 ,962 ,523 ,1514 ,20 ,1505 ,2015 ,435 ,954 ,583 ,1916 ,1883 ,1427 ,716 ,1091 , + 1663 ,797 ,1529 ,1861 ,897 ,219 ,357 ,643 ,948 ,543 ,1582 ,1543 ,687 ,419 ,1556 ,1470 , + 1945 ,1974 ,1323 ,1156 ,420 ,54 ,1607 ,583 ,435 ,954 ,1012 ,436 ,1001 ,1571 ,603 ,1279 , + 821 ,2002 ,723 ,1347 ,1405 ,424 ,1301 ,709 ,684 ,52 ,429 ,1168 ,687 ,1464 ,1342 ,1823 , + 201 ,438 ,246 ,751 ,636 ,960 ,1714 ,1408 ,161 ,1852 ,1111 ,1416 ,969 ,1105 ,1237 ,1591 , + 376 ,139 ,1733 ,705 ,780 ,286 ,1508 ,1104 ,163 ,1981 ,1824 ,507 ,1869 ,1003 ,1452 ,371 , + 1359 ,1285 ,1984 ,1299 ,371 ,148 ,727 ,255 ,1744 ,1424 ,708 ,1988 ,188 ,680 ,533 ,656 , + 1667 ,1459 ,1066 ,1678 ,229 ,1727 ,130 ,2045 ,519 ,413 ,1825 ,586 ,688 ,297 ,134 ,598 , + 1661 ,122 ,1095 ,309 ,40 ,1896 ,457 ,36 ,589 ,1170 ,1701 ,1088 ,1738 ,1601 ,931 ,275 , + 1441 ,792 ,607 ,842 ,41 ,95 ,1470 ,636 ,12 ,1645 ,170 ,1827 ,1553 ,1168 ,1452 ,652 , + 1242 ,1715 ,1865 ,443 ,1318 ,844 ,1045 ,1668 ,1540 ,550 ,1344 ,298 ,623 ,1175 ,1270 ,1535 , + 1156 ,547 ,926 ,1415 ,1775 ,486 ,163 ,524 ,255 ,1717 ,711 ,527 ,1984 ,961 ,992 ,1413 , + 627 ,1450 ,971 ,448 ,621 ,926 ,839 ,1628 ,1059 ,158 ,147 ,1073 ,1884 ,935 ,481 ,1622 , + 1478 ,498 ,900 ,1294 ,507 ,560 ,1505 ,1862 ,1461 ,1604 ,601 ,1472 ,890 ,758 ,1339 ,60 , + 867 ,252 ,1560 ,154 ,400 ,1688 ,887 ,1090 ,1256 ,2020 ,1466 ,1293 ,1349 ,1166 ,791 ,679 , + 723 ,607 ,2 ,47 ,893 ,580 ,337 ,1981 ,364 ,1704 ,113 ,451 ,1100 ,172 ,1076 ,1277 , + 709 ,1515 ,1626 ,164 ,306 ,1030 ,428 ,1648 ,1972 ,357 ,1458 ,760 ,1912 ,1714 ,1821 ,1942 , + 1029 ,646 ,1490 ,1402 ,1255 ,1953 ,1945 ,1774 ,699 ,847 ,1538 ,1124 ,35 ,1924 ,850 ,573 , + 768 ,243 ,783 ,142 ,481 ,555 ,666 ,1744 ,1210 ,1542 ,1165 ,253 ,1101 ,113 ,166 ,851 , + 32 ,1419 ,618 ,1853 ,1255 ,1953 ,1332 ,1774 ,724 ,847 ,569 ,1388 ,1930 ,1688 ,427 ,1752 , + 768 ,243 ,783 ,142 ,481 ,555 ,976 ,1648 ,1210 ,374 ,1165 ,436 ,194 ,1112 ,853 ,851 , + 1978 ,1829 ,1736 ,1406 ,610 ,1953 ,1497 ,1586 ,1423 ,847 ,1538 ,1388 ,1930 ,14 ,332 ,1833 , + 768 ,243 ,783 ,142 ,481 ,555 ,976 ,1648 ,1210 ,374 ,1165 ,436 ,194 ,1112 ,644 ,851 , + 1978 ,646 ,1736 ,1402 ,1908 ,1067 ,377 ,1586 ,1423 ,591 ,1538 ,483 ,1930 ,14 ,427 ,1833 , + 1850 ,243 ,783 ,142 ,481 ,1030 ,976 ,1744 ,739 ,1992 ,662 ,436 ,1908 ,650 ,166 ,667 , + 1978 ,251 ,1736 ,1402 ,1255 ,1067 ,1353 ,717 ,604 ,847 ,569 ,1124 ,35 ,14 ,1047 ,1833 , + 1850 ,243 ,783 ,1348 ,1335 ,1030 ,976 ,1744 ,1210 ,1992 ,662 ,253 ,1908 ,650 ,802 ,667 , + 1978 ,1642 ,290 ,1402 ,1255 ,1067 ,1353 ,1774 ,604 ,591 ,1538 ,483 ,1930 ,1688 ,1047 ,1833 , + 758 ,427 ,199 ,1697 ,839 ,1167 ,121 ,1630 ,1833 ,1546 ,963 ,291 ,814 ,1094 ,496 ,478 , + 200 ,901 ,1100 ,808 ,1802 ,352 ,796 ,619 ,1350 ,777 ,1847 ,1314 ,936 ,943 ,1448 ,573 , + 1268 ,1920 ,1504 ,1179 ,1626 ,1724 ,1856 ,2004 ,1349 ,959 ,542 ,210 ,1973 ,1517 ,210 ,1395 , + 1522 ,1216 ,76 ,994 ,491 ,489 ,1139 ,1287 ,1375 ,1151 ,403 ,1740 ,1072 ,979 ,1389 ,777 , + 1221 ,1192 ,1021 ,705 ,731 ,573 ,818 ,328 ,853 ,1037 ,1976 ,563 ,1934 ,175 ,1303 ,320 , + 1975 ,1835 ,960 ,1246 ,827 ,93 ,385 ,782 ,1482 ,217 ,387 ,672 ,1003 ,1001 ,428 ,44 , + 1906 ,1209 ,474 ,1084 ,917 ,1621 ,1590 ,1750 ,1854 ,705 ,1129 ,648 ,1770 ,761 ,41 ,164 , + 1569 ,1044 ,912 ,1346 ,580 ,1636 ,290 ,683 ,1004 ,986 ,1762 ,1535 ,1275 ,820 ,853 ,785 , + 1906 ,50 ,513 ,1311 ,1858 ,1413 ,895 ,120 ,1970 ,770 ,1606 ,910 ,1294 ,614 ,593 ,1796 , + 1541 ,1039 ,970 ,1797 ,1311 ,1343 ,1250 ,793 ,117 ,637 ,408 ,1860 ,1274 ,650 ,1707 ,1062 , + 1346 ,1038 ,17 ,454 ,1513 ,1700 ,886 ,483 ,1415 ,1138 ,1690 ,826 ,1132 ,1481 ,1599 ,15 , + 1882 ,1607 ,1412 ,944 ,784 ,659 ,1330 ,278 ,1464 ,1895 ,1287 ,657 ,273 ,602 ,1837 ,405 , + 726 ,1406 ,1077 ,941 ,359 ,1272 ,916 ,255 ,1129 ,1277 ,1762 ,182 ,69 ,622 ,2036 ,761 , + 1014 ,1255 ,1406 ,733 ,1162 ,660 ,1526 ,436 ,1579 ,1392 ,103 ,441 ,1198 ,1079 ,232 ,355 , + 290 ,661 ,873 ,166 ,1619 ,700 ,753 ,1513 ,2027 ,2035 ,1750 ,1956 ,787 ,44 ,471 ,963 , + 1761 ,940 ,1581 ,498 ,1264 ,843 ,1955 ,1258 ,1689 ,1032 ,562 ,1500 ,1903 ,767 ,1229 ,111 , + 1918 ,1540 ,401 ,383 ,1930 ,1453 ,77 ,842 ,152 ,1706 ,1561 ,1133 ,1857 ,561 ,993 ,194 , + 192 ,1480 ,1304 ,1028 ,1197 ,219 ,843 ,1440 ,1762 ,1390 ,313 ,1561 ,687 ,133 ,772 ,1424 , + 1430 ,310 ,146 ,713 ,1338 ,747 ,580 ,978 ,1301 ,208 ,277 ,1385 ,1367 ,230 ,907 ,1790 , + 601 ,905 ,786 ,553 ,1413 ,357 ,88 ,203 ,352 ,1886 ,1225 ,1980 ,664 ,2047 ,939 ,100 , + 1378 ,715 ,2006 ,1606 ,691 ,531 ,403 ,133 ,1301 ,717 ,1054 ,21 ,1525 ,1715 ,634 ,368 , + 1938 ,4 ,878 ,1440 ,796 ,1399 ,1980 ,537 ,1777 ,715 ,747 ,395 ,827 ,562 ,1463 ,1301 , + 1204 ,580 ,232 ,1682 ,393 ,453 ,1170 ,599 ,1518 ,983 ,1680 ,763 ,1988 ,1896 ,274 ,382 , + 1320 ,547 ,155 ,1422 ,251 ,114 ,1357 ,1078 ,602 ,689 ,907 ,1078 ,1848 ,290 ,887 ,575 , + 409 ,96 ,159 ,1437 ,960 ,58 ,518 ,1090 ,1100 ,916 ,802 ,1217 ,188 ,1421 ,560 ,2039 , + 1848 ,1941 ,1928 ,1292 ,1359 ,883 ,701 ,189 ,1248 ,785 ,280 ,763 ,360 ,73 ,1987 ,372 , + 1174 ,136 ,436 ,162 ,1333 ,1706 ,1255 ,323 ,664 ,557 ,226 ,1642 ,1382 ,691 ,89 ,1582 , + 1886 ,2025 ,1051 ,1809 ,1679 ,1404 ,518 ,1491 ,671 ,377 ,997 ,908 ,402 ,234 ,921 ,1631 , + 1682 ,1318 ,820 ,726 ,875 ,492 ,1644 ,36 ,117 ,1208 ,1088 ,666 ,1955 ,1524 ,1716 ,1487 , + 1193 ,1953 ,1982 ,1783 ,1949 ,1405 ,780 ,800 ,334 ,203 ,1495 ,1615 ,1121 ,1918 ,514 ,1871 , + 1033 ,365 ,502 ,1066 ,1657 ,1581 ,232 ,1442 ,1227 ,297 ,1946 ,1792 ,996 ,669 ,725 ,2023 , + 975 ,1976 ,1591 ,1168 ,925 ,238 ,802 ,1116 ,1941 ,1629 ,1231 ,1395 ,1847 ,244 ,174 ,1554 , + 1010 ,199 ,61 ,1196 ,977 ,891 ,1963 ,845 ,1250 ,562 ,2047 ,1945 ,601 ,691 ,1140 ,1559 , + 765 ,920 ,1149 ,671 ,1409 ,1299 ,1249 ,461 ,1227 ,1678 ,1269 ,1697 ,636 ,784 ,443 ,239 , + 1010 ,618 ,126 ,920 ,287 ,1669 ,39 ,727 ,60 ,795 ,1310 ,311 ,1944 ,382 ,1216 ,955 , + 288 ,1295 ,31 ,576 ,339 ,2029 ,1111 ,1567 ,5 ,1449 ,1506 ,449 ,992 ,705 ,1107 ,274 , + 450 ,1885 ,1355 ,723 ,1572 ,379 ,190 ,1922 ,700 ,1917 ,1071 ,510 ,963 ,555 ,1686 ,216 , + 1552 ,1613 ,593 ,407 ,360 ,1217 ,598 ,1176 ,184 ,485 ,765 ,1989 ,327 ,198 ,4 ,139 , + 1095 ,98 ,12 ,1034 ,1164 ,1367 ,518 ,727 ,871 ,1689 ,320 ,133 ,218 ,1841 ,672 ,1175 , + 158 ,1015 ,837 ,1714 ,1045 ,1820 ,1744 ,999 ,2028 ,1239 ,1503 ,728 ,1472 ,243 ,713 ,1832 , + 1110 ,446 ,412 ,1293 ,40 ,583 ,557 ,1017 ,1106 ,1805 ,1176 ,1190 ,582 ,1943 ,983 ,923 , + 1599 ,20 ,1531 ,1377 ,1870 ,1621 ,1658 ,1480 ,1260 ,1986 ,688 ,775 ,1930 ,963 ,1448 ,1269 , + 2005 ,1363 ,996 ,386 ,1135 ,89 ,1531 ,1808 ,767 ,1314 ,486 ,1055 ,1760 ,222 ,1224 ,1189 , + 1568 ,1173 ,1652 ,734 ,131 ,895 ,560 ,133 ,1618 ,1569 ,543 ,368 ,1201 ,1657 ,552 ,1258 , + 2005 ,1078 ,674 ,2021 ,2047 ,1920 ,37 ,1757 ,19 ,1955 ,1376 ,575 ,1160 ,1345 ,180 ,1019 , + 125 ,902 ,438 ,1471 ,291 ,1903 ,1113 ,561 ,645 ,1174 ,286 ,1934 ,194 ,1998 ,1300 ,160 , + 40 ,81 ,77 ,1342 ,555 ,955 ,377 ,1804 ,1976 ,1505 ,253 ,37 ,8 ,216 ,197 ,445 , + 66 ,425 ,458 ,1747 ,1396 ,210 ,437 ,1585 ,1228 ,1105 ,215 ,309 ,1746 ,1547 ,1062 ,52 , + 1489 ,1744 ,374 ,1797 ,819 ,903 ,513 ,1454 ,1338 ,10 ,9 ,1407 ,1820 ,561 ,383 ,1057 , + 1857 ,1950 ,568 ,1927 ,469 ,1373 ,1199 ,753 ,1586 ,1291 ,1887 ,906 ,1904 ,195 ,1079 ,1341 , + 1621 ,1597 ,480 ,1225 ,1677 ,716 ,1603 ,1628 ,245 ,158 ,34 ,619 ,202 ,1702 ,1594 ,1555 , + 306 ,339 ,352 ,725 ,407 ,1491 ,2008 ,74 ,765 ,49 ,573 ,191 ,78 ,1260 ,2043 ,1282 , + 302 ,81 ,1223 ,521 ,1749 ,1571 ,1461 ,1302 ,1548 ,1867 ,147 ,1091 ,1231 ,1900 ,1165 ,828 , + 663 ,136 ,1160 ,1765 ,616 ,1070 ,922 ,1522 ,775 ,1292 ,1248 ,1462 ,1057 ,1818 ,1627 ,563 , + 302 ,991 ,627 ,1527 ,947 ,1567 ,1440 ,5 ,1960 ,864 ,468 ,1810 ,528 ,206 ,110 ,58 , + 835 ,513 ,805 ,57 ,1218 ,1680 ,1458 ,885 ,1603 ,1292 ,260 ,1431 ,58 ,388 ,1625 ,1585 , + 1364 ,387 ,25 ,1667 ,427 ,1199 ,1819 ,1266 ,1655 ,1633 ,823 ,1920 ,273 ,804 ,966 ,311 , + 1475 ,1357 ,397 ,105 ,1982 ,1036 ,1596 ,1260 ,1408 ,958 ,731 ,1700 ,2033 ,858 ,1425 ,1361 , + 476 ,1961 ,1045 ,334 ,1848 ,554 ,537 ,1707 ,1857 ,145 ,1447 ,730 ,1323 ,1195 ,336 ,529 , + 131 ,1769 ,452 ,813 ,509 ,629 ,886 ,832 ,289 ,898 ,889 ,691 ,384 ,701 ,948 ,392 , + 4 ,1540 ,667 ,1034 ,1142 ,1040 ,607 ,1739 ,809 ,1769 ,1516 ,1552 ,676 ,381 ,1996 ,880 , + 835 ,590 ,1138 ,1861 ,131 ,1845 ,1940 ,249 ,1565 ,771 ,1076 ,690 ,1427 ,1553 ,826 ,195 , + 769 ,1700 ,1178 ,164 ,267 ,1443 ,1238 ,2008 ,1866 ,374 ,415 ,436 ,1123 ,1430 ,853 ,667 , + 210 ,1642 ,290 ,1428 ,1255 ,1067 ,1353 ,963 ,604 ,973 ,1538 ,483 ,422 ,97 ,1140 ,1833 , + 666 ,243 ,783 ,142 ,481 ,1030 ,976 ,2008 ,739 ,1992 ,1165 ,436 ,194 ,1112 ,144 ,1448 , + 1978 ,251 ,1490 ,1853 ,610 ,1897 ,1332 ,1774 ,724 ,591 ,1538 ,951 ,1930 ,1440 ,1047 ,1752 , + 768 ,243 ,783 ,142 ,481 ,1030 ,976 ,1744 ,1972 ,91 ,1165 ,253 ,1908 ,1430 ,644 ,851 , + 1978 ,1642 ,1490 ,1402 ,692 ,1953 ,377 ,74 ,204 ,591 ,1538 ,1388 ,1160 ,1688 ,332 ,1956 , + 768 ,1864 ,265 ,164 ,89 ,555 ,976 ,1744 ,739 ,59 ,1778 ,253 ,1912 ,591 ,921 ,480 , + 260 ,1642 ,2015 ,82 ,423 ,1444 ,1821 ,963 ,765 ,1792 ,1538 ,1543 ,987 ,1171 ,427 ,904 , + 787 ,212 ,1884 ,194 ,1601 ,765 ,742 ,1850 ,967 ,1317 ,1310 ,605 ,1208 ,882 ,943 ,701 , + 29 ,58 ,479 ,490 ,943 ,1601 ,1685 ,961 ,36 ,147 ,75 ,874 ,1612 ,632 ,754 ,32 , + 1783 ,400 ,409 ,872 ,920 ,1212 ,613 ,1669 ,1704 ,480 ,1527 ,1430 ,241 ,1809 ,404 ,666 , + 1413 ,1308 ,1018 ,1381 ,1906 ,828 ,305 ,212 ,779 ,535 ,1225 ,1748 ,109 ,1319 ,478 ,776 , + 208 ,51 ,409 ,1811 ,2009 ,1060 ,216 ,1084 ,1225 ,1366 ,723 ,1902 ,1304 ,216 ,433 ,1866 , + 983 ,1158 ,157 ,1766 ,449 ,400 ,1405 ,1676 ,796 ,305 ,319 ,1890 ,1003 ,1335 ,1457 ,89 , + 208 ,232 ,1419 ,1850 ,1419 ,1060 ,137 ,1512 ,631 ,1830 ,279 ,1808 ,1994 ,1872 ,402 ,986 , + 1808 ,85 ,21 ,1279 ,567 ,4 ,544 ,1151 ,1379 ,295 ,682 ,1113 ,1953 ,757 ,1180 ,1068 , + 208 ,449 ,759 ,1768 ,1300 ,567 ,1102 ,183 ,648 ,1885 ,645 ,1225 ,1440 ,214 ,938 ,1818 , + 113 ,1370 ,746 ,681 ,588 ,1972 ,386 ,926 ,1581 ,1971 ,286 ,776 ,1673 ,1017 ,1125 ,1855 , + 2005 ,655 ,126 ,1533 ,1799 ,1851 ,934 ,1628 ,693 ,1487 ,90 ,753 ,1956 ,1427 ,734 ,1205 , + 259 ,1693 ,729 ,737 ,650 ,23 ,772 ,327 ,1930 ,61 ,488 ,763 ,599 ,797 ,40 ,1254 , + 1285 ,1540 ,328 ,1562 ,279 ,726 ,160 ,1529 ,292 ,624 ,1165 ,963 ,1979 ,543 ,1552 ,262 , + 1993 ,1364 ,1665 ,654 ,123 ,1092 ,1199 ,1022 ,711 ,607 ,1405 ,1589 ,687 ,436 ,1349 ,805 , + 192 ,1700 ,1212 ,1620 ,267 ,1443 ,326 ,1648 ,1866 ,357 ,662 ,253 ,1908 ,1714 ,853 ,667 , + 32 ,1829 ,1490 ,1402 ,692 ,1626 ,377 ,1774 ,604 ,591 ,133 ,1388 ,1995 ,1688 ,332 ,1562 , + 1419 ,145 ,1539 ,1384 ,491 ,474 ,183 ,807 ,1214 ,939 ,1017 ,1054 ,1698 ,1660 ,35 ,513 , + 835 ,165 ,407 ,461 ,398 ,870 ,950 ,304 ,1881 ,1099 ,669 ,65 ,346 ,1134 ,901 ,111 , + 642 ,1851 ,910 ,1278 ,417 ,1737 ,1130 ,609 ,779 ,379 ,1617 ,488 ,1449 ,1969 ,973 ,508 , + 295 ,1762 ,207 ,1038 ,595 ,1662 ,107 ,2008 ,1673 ,1158 ,436 ,1559 ,1252 ,122 ,1216 ,761 , + 1621 ,1783 ,1502 ,350 ,200 ,553 ,712 ,88 ,767 ,899 ,143 ,1548 ,814 ,900 ,851 ,1031 , + 734 ,1218 ,1779 ,440 ,1558 ,1656 ,1455 ,1029 ,1181 ,2042 ,1591 ,1916 ,1052 ,1659 ,1008 ,278 , + 1174 ,373 ,1605 ,1634 ,323 ,1286 ,1645 ,490 ,994 ,1598 ,784 ,5 ,1973 ,1064 ,1132 ,104 , + 612 ,967 ,1071 ,1898 ,253 ,1032 ,2021 ,1180 ,485 ,1176 ,1114 ,1907 ,1290 ,486 ,143 ,1567 , + 1174 ,1944 ,1461 ,364 ,349 ,586 ,774 ,1064 ,1983 ,631 ,1914 ,906 ,928 ,546 ,1736 ,467 , + 933 ,1269 ,882 ,404 ,675 ,296 ,1096 ,472 ,1321 ,314 ,1326 ,1490 ,997 ,1744 ,1191 ,1928 , + 1308 ,1799 ,124 ,1944 ,1511 ,1244 ,1508 ,1173 ,1346 ,1127 ,1607 ,1836 ,951 ,821 ,842 ,625 , + 63 ,621 ,1225 ,388 ,308 ,650 ,1031 ,1447 ,1382 ,1986 ,1381 ,942 ,913 ,1352 ,1692 ,722 , + 523 ,1284 ,774 ,1815 ,1895 ,1187 ,246 ,1062 ,1250 ,1405 ,55 ,340 ,1741 ,859 ,1292 ,1690 , + 63 ,1096 ,817 ,318 ,1120 ,97 ,1040 ,383 ,1572 ,1214 ,351 ,1168 ,1382 ,195 ,1647 ,83 , + 1688 ,1198 ,739 ,1319 ,507 ,351 ,886 ,1803 ,157 ,1100 ,864 ,310 ,1299 ,623 ,426 ,390 , + 1620 ,113 ,1252 ,1212 ,1294 ,116 ,1557 ,694 ,829 ,552 ,644 ,1870 ,950 ,10 ,910 ,290 , + 1907 ,1665 ,307 ,2032 ,1944 ,588 ,1505 ,1888 ,347 ,1225 ,1528 ,337 ,797 ,983 ,274 ,965 , + 1937 ,1812 ,1956 ,1822 ,513 ,839 ,640 ,115 ,621 ,1649 ,2041 ,1079 ,1109 ,28 ,1561 ,1879 , + 898 ,498 ,1012 ,1133 ,1044 ,846 ,202 ,533 ,1748 ,2023 ,1954 ,1522 ,412 ,1200 ,1768 ,1360 , + 750 ,1409 ,404 ,1881 ,138 ,904 ,1265 ,870 ,121 ,638 ,1756 ,1793 ,1009 ,1104 ,313 ,62 , + 497 ,994 ,61 ,1785 ,580 ,87 ,1324 ,1190 ,369 ,846 ,1607 ,1704 ,676 ,1422 ,1339 ,537 , + 1757 ,1982 ,438 ,1849 ,150 ,1884 ,882 ,568 ,781 ,1446 ,1137 ,1260 ,1678 ,1834 ,765 ,1489 , + 1797 ,1515 ,1626 ,164 ,306 ,1443 ,183 ,662 ,1561 ,91 ,675 ,1625 ,1101 ,1430 ,144 ,1942 , + 1458 ,963 ,1736 ,1047 ,1908 ,1626 ,655 ,643 ,1252 ,591 ,47 ,1388 ,987 ,1820 ,632 ,1752 , + 448 ,243 ,783 ,142 ,481 ,1030 ,976 ,2008 ,739 ,1992 ,1165 ,436 ,194 ,1714 ,644 ,1684 , + 1978 ,1642 ,1490 ,1428 ,692 ,1897 ,377 ,717 ,1423 ,591 ,1538 ,1388 ,1995 ,97 ,332 ,1752 , + 7 ,1597 ,265 ,1712 ,306 ,906 ,825 ,1744 ,1338 ,374 ,1165 ,253 ,1449 ,1204 ,144 ,1942 , + 210 ,251 ,75 ,1402 ,1768 ,1491 ,1353 ,717 ,11 ,973 ,47 ,597 ,1065 ,1688 ,1140 ,1956 , + 1883 ,590 ,958 ,1287 ,1457 ,1040 ,225 ,1197 ,2001 ,1963 ,1306 ,355 ,1291 ,555 ,810 ,854 , + 489 ,1072 ,1344 ,1283 ,906 ,113 ,466 ,370 ,1226 ,1706 ,901 ,1611 ,1970 ,182 ,1330 ,1887 , + 1956 ,132 ,1509 ,710 ,973 ,464 ,1248 ,1120 ,1599 ,1248 ,1944 ,1452 ,743 ,1711 ,1776 ,408 , + 1151 ,513 ,1005 ,1548 ,1929 ,1123 ,272 ,654 ,534 ,353 ,892 ,203 ,372 ,920 ,91 ,1349 , + 1688 ,815 ,173 ,4 ,771 ,1993 ,395 ,1447 ,1273 ,1553 ,1387 ,2032 ,498 ,968 ,1667 ,627 , + 520 ,1941 ,1323 ,956 ,1696 ,674 ,1402 ,322 ,1188 ,766 ,861 ,585 ,1346 ,1247 ,1407 ,2021 , + 964 ,585 ,1092 ,1847 ,1720 ,430 ,79 ,1940 ,489 ,1519 ,866 ,848 ,2011 ,886 ,2042 ,338 , + 1668 ,909 ,684 ,1771 ,581 ,677 ,39 ,1984 ,428 ,1127 ,1525 ,551 ,1925 ,308 ,1054 ,19 , + 1914 ,1251 ,1338 ,2016 ,1484 ,1040 ,1168 ,1615 ,1687 ,832 ,1895 ,1563 ,1962 ,2029 ,1012 ,471 , + 118 ,571 ,1292 ,579 ,480 ,929 ,1111 ,1051 ,542 ,1602 ,1871 ,1803 ,1943 ,870 ,589 ,1668 , + 1004 ,1541 ,992 ,466 ,1040 ,786 ,1065 ,881 ,622 ,481 ,122 ,1093 ,641 ,267 ,961 ,386 , + 298 ,1604 ,1789 ,758 ,65 ,65 ,989 ,1691 ,955 ,1876 ,440 ,1987 ,2047 ,735 ,1975 ,345 , + 1227 ,440 ,881 ,533 ,770 ,1870 ,137 ,108 ,357 ,1149 ,1536 ,698 ,1585 ,906 ,741 ,2015 , + 1966 ,1253 ,512 ,768 ,1579 ,1070 ,1971 ,1414 ,1717 ,1892 ,1944 ,1109 ,360 ,755 ,654 ,1673 , + 1424 ,1727 ,320 ,1354 ,711 ,1191 ,611 ,1329 ,809 ,1416 ,1262 ,153 ,1192 ,1863 ,650 ,1511 , + 508 ,403 ,716 ,984 ,1399 ,818 ,213 ,601 ,172 ,89 ,1323 ,1543 ,168 ,1149 ,970 ,1780 , + 2007 ,175 ,170 ,785 ,1322 ,1167 ,1706 ,463 ,1801 ,982 ,428 ,1484 ,1327 ,94 ,1075 ,1800 , + 1305 ,1502 ,686 ,681 ,935 ,698 ,1791 ,1535 ,170 ,701 ,1856 ,1598 ,71 ,1637 ,1824 ,36 , + 1999 ,1028 ,1326 ,403 ,1956 ,184 ,576 ,531 ,767 ,238 ,654 ,502 ,1005 ,1072 ,1047 ,1128 , + 976 ,1409 ,1972 ,1596 ,148 ,1978 ,1105 ,1740 ,1775 ,159 ,150 ,285 ,11 ,1101 ,88 ,1895 , + 1432 ,971 ,1767 ,1990 ,1849 ,1081 ,328 ,735 ,366 ,2017 ,914 ,1975 ,121 ,255 ,1232 ,66 , + 1665 ,1232 ,1424 ,1525 ,533 ,641 ,1341 ,1074 ,1855 ,435 ,404 ,197 ,797 ,1589 ,1279 ,110 , + 13 ,156 ,351 ,46 ,2015 ,1041 ,765 ,1416 ,1633 ,1723 ,1164 ,7 ,1697 ,912 ,1976 ,175 , + 1894 ,1259 ,25 ,1213 ,481 ,1342 ,918 ,1297 ,527 ,956 ,789 ,760 ,573 ,1374 ,1315 ,1395 , + 1147 ,1354 ,642 ,1196 ,50 ,1127 ,1313 ,914 ,1582 ,81 ,554 ,917 ,632 ,1268 ,520 ,714 , + 542 ,1006 ,253 ,1527 ,1182 ,261 ,101 ,1014 ,1324 ,1658 ,1728 ,1036 ,349 ,378 ,1644 ,116 , + 1678 ,1981 ,955 ,1219 ,1247 ,1709 ,1387 ,336 ,456 ,989 ,819 ,2023 ,1129 ,1946 ,1740 ,597 , + 528 ,1280 ,1723 ,1385 ,1961 ,430 ,1477 ,730 ,1382 ,891 ,267 ,1046 ,736 ,194 ,2010 ,1513 , + 1381 ,387 ,448 ,1561 ,820 ,1065 ,801 ,267 ,1393 ,418 ,981 ,1042 ,985 ,2041 ,1787 ,1591 , + 1423 ,1019 ,39 ,2021 ,1628 ,829 ,1787 ,1404 ,657 ,978 ,1859 ,296 ,689 ,1377 ,696 ,1660 , + 1287 ,1024 ,44 ,1537 ,848 ,1014 ,1495 ,1779 ,1135 ,78 ,969 ,1899 ,1151 ,41 ,1257 ,1679 , + 213 ,6 ,1319 ,1058 ,1818 ,1637 ,956 ,1696 ,338 ,1163 ,1183 ,1719 ,500 ,1997 ,170 ,830 , + 35 ,1333 ,1840 ,1224 ,959 ,502 ,1910 ,1738 ,1127 ,1373 ,1706 ,1611 ,1577 ,1822 ,409 ,1913 , + 582 ,598 ,118 ,545 ,1355 ,69 ,984 ,1385 ,824 ,647 ,1497 ,1603 ,1159 ,1504 ,935 ,884 , + 1082 ,1623 ,61 ,566 ,515 ,1220 ,998 ,394 ,137 ,1102 ,1415 ,412 ,1274 ,988 ,168 ,1418 , + 2042 ,314 ,1892 ,382 ,163 ,1161 ,1775 ,1957 ,1884 ,1914 ,284 ,836 ,1253 ,894 ,994 ,1875 , + 1909 ,290 ,990 ,1932 ,1994 ,1102 ,1970 ,21 ,1585 ,1071 ,2017 ,115 ,1671 ,653 ,999 ,1041 , + 116 ,1188 ,1890 ,1070 ,1750 ,694 ,450 ,944 ,1332 ,1135 ,270 ,622 ,1959 ,49 ,311 ,186 , + 497 ,1104 ,924 ,1964 ,420 ,716 ,574 ,763 ,1041 ,820 ,1012 ,36 ,1704 ,836 ,745 ,1585 , + 1352 ,84 ,1587 ,1279 ,1768 ,787 ,340 ,1929 ,524 ,2036 ,1175 ,1336 ,612 ,534 ,798 ,1516 , + 1263 ,1515 ,783 ,142 ,1736 ,2010 ,976 ,1744 ,1561 ,1542 ,1048 ,1383 ,1343 ,1421 ,644 ,610 , + 1993 ,1829 ,1923 ,1853 ,610 ,1067 ,377 ,717 ,18 ,973 ,1538 ,1124 ,346 ,97 ,427 ,1833 , + 541 ,243 ,783 ,1348 ,267 ,1443 ,825 ,2008 ,1210 ,1992 ,1165 ,1383 ,1101 ,1714 ,802 ,1448 , + 32 ,1642 ,1490 ,1406 ,692 ,1897 ,1353 ,1774 ,1423 ,591 ,1538 ,951 ,35 ,97 ,1047 ,1752 , + 1771 ,243 ,1559 ,1348 ,1736 ,1443 ,666 ,1648 ,1972 ,1370 ,415 ,143 ,1343 ,1112 ,144 ,1448 , + 200 ,646 ,1490 ,1406 ,1255 ,1626 ,1497 ,1774 ,711 ,591 ,47 ,1388 ,1930 ,97 ,1216 ,1752 , + 448 ,243 ,1697 ,546 ,1736 ,1443 ,976 ,1744 ,1210 ,1370 ,1165 ,253 ,1908 ,650 ,644 ,851 , + 1978 ,1419 ,290 ,1853 ,692 ,1067 ,377 ,1774 ,724 ,847 ,47 ,951 ,1995 ,97 ,332 ,1833 , + 448 ,243 ,1697 ,1348 ,1736 ,1443 ,1978 ,1744 ,1210 ,1992 ,1165 ,253 ,1912 ,1430 ,802 ,851 , + 32 ,251 ,1736 ,1428 ,692 ,1067 ,1332 ,1586 ,724 ,973 ,1538 ,951 ,422 ,1688 ,1047 ,1956 , + 448 ,243 ,1697 ,164 ,1736 ,1572 ,976 ,1744 ,739 ,1542 ,1165 ,253 ,1908 ,1714 ,644 ,1684 , + 32 ,1419 ,1736 ,1406 ,1908 ,1067 ,1497 ,717 ,1423 ,847 ,47 ,1388 ,422 ,1688 ,332 ,1562 , + 448 ,243 ,1697 ,164 ,1736 ,1572 ,976 ,1744 ,1210 ,1542 ,1165 ,253 ,1908 ,1714 ,644 ,851 , + 1978 ,1419 ,1736 ,1406 ,1908 ,1953 ,377 ,1586 ,1423 ,591 ,1538 ,1388 ,1995 ,1688 ,332 ,1752 , + 1850 ,243 ,1697 ,1348 ,1736 ,1443 ,1978 ,1744 ,1210 ,1370 ,1165 ,1190 ,1912 ,1430 ,144 ,851 , + 1978 ,251 ,1736 ,1406 ,692 ,1067 ,1497 ,1586 ,724 ,591 ,47 ,951 ,1930 ,1688 ,1047 ,1833 , + 384 ,1211 ,1622 ,1562 ,1836 ,555 ,477 ,1648 ,1938 ,278 ,675 ,253 ,194 ,650 ,1336 ,1511 , + 32 ,1829 ,618 ,82 ,123 ,154 ,1332 ,1277 ,172 ,1040 ,1538 ,1124 ,2008 ,1688 ,332 ,1209 , + 420 ,1260 ,573 ,183 ,877 ,27 ,1797 ,1879 ,1200 ,546 ,100 ,1093 ,1300 ,80 ,1060 ,1640 , + 1789 ,168 ,1725 ,1579 ,2046 ,1469 ,1888 ,1990 ,76 ,1692 ,839 ,1116 ,692 ,551 ,686 ,1650 , + 1738 ,693 ,311 ,1527 ,623 ,1339 ,755 ,335 ,753 ,357 ,1456 ,1304 ,761 ,1770 ,1377 ,695 , + 447 ,1584 ,1529 ,1330 ,942 ,776 ,1362 ,986 ,1611 ,1332 ,429 ,44 ,1704 ,1718 ,1190 ,1004 , + 293 ,1700 ,1178 ,97 ,1736 ,1572 ,666 ,2008 ,1972 ,1992 ,415 ,436 ,1123 ,113 ,644 ,851 , + 1978 ,646 ,290 ,1428 ,423 ,1067 ,377 ,963 ,1423 ,1343 ,47 ,483 ,1995 ,97 ,1047 ,1562 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,2008 ,739 ,374 ,1165 ,436 ,1101 ,113 ,644 ,1448 , + 1978 ,646 ,1736 ,1406 ,692 ,1897 ,377 ,1774 ,204 ,591 ,1538 ,1124 ,422 ,1688 ,1047 ,1956 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,2008 ,1210 ,374 ,1165 ,436 ,1912 ,1714 ,644 ,1448 , + 1978 ,251 ,1736 ,1428 ,692 ,1897 ,377 ,1774 ,204 ,591 ,1538 ,1124 ,422 ,1688 ,1047 ,1833 , + 384 ,243 ,1559 ,1348 ,1736 ,1443 ,1978 ,1744 ,1210 ,1992 ,1165 ,1190 ,1912 ,650 ,144 ,851 , + 32 ,251 ,290 ,1428 ,423 ,1953 ,1497 ,1586 ,724 ,591 ,47 ,1388 ,422 ,1688 ,1047 ,1752 , + 1208 ,1172 ,595 ,321 ,1594 ,614 ,387 ,1684 ,401 ,1656 ,1055 ,1638 ,2 ,1161 ,988 ,734 , + 1243 ,569 ,515 ,1627 ,1985 ,1226 ,742 ,1862 ,1994 ,1461 ,913 ,1032 ,419 ,1784 ,1220 ,1771 , + 1382 ,1949 ,2008 ,89 ,1099 ,1158 ,185 ,1016 ,298 ,874 ,1630 ,1599 ,968 ,1069 ,380 ,957 , + 1583 ,856 ,1522 ,1681 ,1855 ,945 ,1388 ,1974 ,1848 ,1825 ,658 ,866 ,1248 ,741 ,1999 ,628 , + 982 ,1879 ,1663 ,1021 ,1638 ,261 ,1718 ,834 ,1809 ,449 ,1462 ,1438 ,968 ,1144 ,405 ,1910 , + 1138 ,1395 ,498 ,224 ,1334 ,1143 ,169 ,1911 ,1876 ,396 ,1367 ,1522 ,1794 ,273 ,553 ,557 , + 1635 ,520 ,328 ,1034 ,610 ,657 ,2031 ,663 ,1594 ,47 ,358 ,1109 ,1735 ,1722 ,1806 ,1868 , + 844 ,974 ,438 ,493 ,144 ,1997 ,1612 ,1668 ,1141 ,822 ,487 ,420 ,769 ,1529 ,1963 ,1117 , + 1263 ,1515 ,1626 ,164 ,1736 ,347 ,666 ,2008 ,739 ,1370 ,1165 ,436 ,1908 ,1430 ,144 ,667 , + 748 ,1642 ,1736 ,1402 ,1908 ,1897 ,377 ,1396 ,604 ,973 ,693 ,951 ,422 ,1688 ,427 ,1417 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,1744 ,739 ,1542 ,1165 ,253 ,1123 ,650 ,644 ,851 , + 32 ,1419 ,290 ,1406 ,692 ,1897 ,377 ,1586 ,204 ,591 ,1538 ,951 ,422 ,1688 ,332 ,1833 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,2008 ,1210 ,1542 ,1165 ,436 ,1736 ,1112 ,144 ,1448 , + 1978 ,1642 ,1736 ,1853 ,610 ,1897 ,1332 ,1774 ,724 ,591 ,1538 ,951 ,422 ,1688 ,1047 ,1752 , + 1850 ,243 ,1697 ,164 ,1736 ,1030 ,976 ,1744 ,739 ,1992 ,1165 ,253 ,1101 ,1714 ,644 ,1684 , + 32 ,1419 ,1490 ,1406 ,692 ,1067 ,377 ,717 ,204 ,847 ,1538 ,951 ,1930 ,97 ,332 ,1752 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,1744 ,1210 ,374 ,1165 ,253 ,1912 ,1714 ,644 ,1684 , + 1978 ,251 ,1490 ,1428 ,1255 ,1953 ,377 ,717 ,724 ,591 ,1538 ,1124 ,1995 ,1688 ,332 ,1752 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,1744 ,1210 ,374 ,1165 ,253 ,1912 ,1714 ,644 ,1684 , + 1978 ,251 ,1490 ,1428 ,1255 ,1953 ,377 ,717 ,724 ,591 ,1538 ,951 ,1995 ,1688 ,332 ,1752 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,2008 ,1210 ,374 ,1165 ,436 ,1912 ,1714 ,644 ,1448 , + 1978 ,251 ,1736 ,1428 ,692 ,1897 ,377 ,1774 ,204 ,591 ,1538 ,1124 ,422 ,1688 ,1047 ,1833 , + 1850 ,243 ,783 ,142 ,481 ,1030 ,666 ,1744 ,1210 ,1370 ,1165 ,253 ,1912 ,1714 ,644 ,667 , + 1978 ,251 ,618 ,1428 ,1255 ,1626 ,377 ,1774 ,604 ,973 ,1538 ,1388 ,35 ,1688 ,332 ,1833 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,1744 ,1210 ,374 ,1165 ,253 ,1912 ,1714 ,644 ,1684 , + 1978 ,251 ,1490 ,1428 ,1255 ,1626 ,377 ,717 ,724 ,591 ,1538 ,1124 ,422 ,1688 ,332 ,1752 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,1744 ,1210 ,374 ,1165 ,253 ,1101 ,113 ,644 ,851 , + 1978 ,1642 ,618 ,1406 ,1255 ,1067 ,377 ,1586 ,724 ,591 ,47 ,1388 ,422 ,1688 ,1047 ,1956 , + 1850 ,243 ,1178 ,1348 ,267 ,1443 ,666 ,1744 ,1210 ,1370 ,1165 ,439 ,1912 ,1430 ,144 ,851 , + 1978 ,646 ,290 ,1428 ,1255 ,1067 ,1332 ,717 ,724 ,847 ,899 ,1124 ,422 ,97 ,332 ,1956 , + 481 ,243 ,783 ,142 ,481 ,1030 ,976 ,1744 ,1210 ,1992 ,1165 ,253 ,1912 ,1714 ,644 ,1684 , + 1978 ,251 ,1490 ,1428 ,1255 ,1953 ,377 ,717 ,724 ,973 ,1538 ,1388 ,1995 ,97 ,1047 ,1752 , + 1850 ,243 ,1178 ,546 ,267 ,1030 ,976 ,2008 ,739 ,1370 ,1165 ,1383 ,1908 ,1714 ,144 ,851 , + 32 ,1419 ,1736 ,1428 ,692 ,1897 ,1332 ,1586 ,724 ,847 ,569 ,951 ,422 ,1688 ,1047 ,1752 , + 384 ,243 ,1559 ,1348 ,1736 ,555 ,825 ,1648 ,739 ,91 ,1165 ,253 ,1908 ,1112 ,144 ,1684 , + 1978 ,1642 ,1490 ,1402 ,610 ,1626 ,1332 ,1774 ,724 ,591 ,47 ,1388 ,35 ,97 ,427 ,904 , + 1885 ,1216 ,1203 ,1196 ,1317 ,64 ,1893 ,524 ,384 ,240 ,1361 ,1850 ,690 ,1967 ,934 ,1357 , + 172 ,1201 ,583 ,539 ,76 ,964 ,615 ,1090 ,413 ,1384 ,193 ,782 ,1428 ,1873 ,455 ,1827 , + 314 ,143 ,1286 ,1368 ,64 ,1219 ,1012 ,2025 ,1708 ,782 ,1899 ,1802 ,1109 ,107 ,159 ,1807 , + 1425 ,1415 ,1281 ,1949 ,2029 ,158 ,110 ,265 ,1961 ,422 ,1269 ,1214 ,409 ,484 ,1831 ,1972 , + 569 ,1493 ,1743 ,1290 ,791 ,99 ,2009 ,577 ,1766 ,969 ,1919 ,1524 ,364 ,132 ,1865 ,208 , + 1671 ,807 ,1600 ,1765 ,758 ,401 ,565 ,1975 ,621 ,1608 ,119 ,1310 ,823 ,5 ,1143 ,242 , + 963 ,400 ,1743 ,1907 ,1610 ,1972 ,1713 ,749 ,1806 ,1272 ,985 ,1699 ,1321 ,1992 ,507 ,590 , + 1018 ,1615 ,179 ,867 ,1067 ,637 ,1646 ,1077 ,800 ,30 ,1084 ,1429 ,823 ,918 ,461 ,321 , + 476 ,971 ,984 ,1290 ,1902 ,1504 ,194 ,1107 ,1322 ,1648 ,2041 ,1683 ,1048 ,105 ,1901 ,1627 , + 151 ,1287 ,731 ,1440 ,687 ,17 ,545 ,762 ,1170 ,1867 ,389 ,544 ,489 ,1455 ,800 ,350 , + 1035 ,669 ,1149 ,417 ,1067 ,716 ,1752 ,479 ,1571 ,1316 ,1162 ,271 ,666 ,1783 ,170 ,966 , + 518 ,1764 ,52 ,1139 ,35 ,670 ,1064 ,874 ,1668 ,169 ,1272 ,957 ,644 ,237 ,107 ,963 , + 1696 ,348 ,420 ,1057 ,411 ,718 ,164 ,318 ,755 ,551 ,1756 ,497 ,739 ,807 ,1524 ,176 , + 82 ,197 ,1179 ,1389 ,151 ,293 ,176 ,1089 ,1757 ,976 ,1175 ,268 ,442 ,1234 ,1518 ,542 , + 1268 ,1172 ,1362 ,1857 ,691 ,817 ,1138 ,515 ,1566 ,1071 ,1989 ,1732 ,1419 ,2033 ,1210 ,1457 , + 1937 ,1322 ,315 ,473 ,1527 ,1141 ,1766 ,957 ,236 ,652 ,1382 ,2021 ,1231 ,1221 ,1224 ,471 , + 443 ,160 ,762 ,981 ,735 ,957 ,12 ,1342 ,1053 ,1380 ,602 ,784 ,1434 ,1472 ,1665 ,1469 , + 1682 ,1008 ,311 ,1184 ,1292 ,1189 ,1219 ,1425 ,1214 ,978 ,337 ,89 ,1542 ,1360 ,1443 ,1736 , + 707 ,774 ,1870 ,33 ,1110 ,490 ,1826 ,1419 ,394 ,1172 ,96 ,743 ,1169 ,227 ,822 ,247 , + 1779 ,280 ,1580 ,993 ,1933 ,1639 ,33 ,923 ,1036 ,851 ,1254 ,1576 ,1882 ,243 ,1369 ,354 , + 1038 ,1398 ,984 ,1648 ,1765 ,1932 ,1995 ,744 ,600 ,462 ,588 ,478 ,637 ,1041 ,1633 ,474 , + 740 ,1839 ,44 ,665 ,1661 ,1774 ,1306 ,912 ,552 ,689 ,828 ,926 ,1729 ,1943 ,665 ,1743 , + 1401 ,453 ,1647 ,1796 ,357 ,1109 ,1566 ,1755 ,1296 ,1266 ,313 ,1419 ,702 ,1571 ,27 ,1184 , + 1842 ,620 ,757 ,1734 ,1870 ,1124 ,1117 ,1133 ,862 ,219 ,1690 ,1316 ,1231 ,848 ,1177 ,653 , + 409 ,1256 ,1915 ,1716 ,1244 ,75 ,1514 ,1677 ,1333 ,290 ,111 ,378 ,1112 ,437 ,1848 ,754 , + 1028 ,208 ,727 ,1368 ,1453 ,759 ,1230 ,1326 ,1344 ,557 ,404 ,40 ,817 ,1531 ,1681 ,80 , + 390 ,134 ,906 ,1749 ,554 ,695 ,719 ,514 ,478 ,1593 ,1955 ,1189 ,1348 ,1494 ,1503 ,1513 , + 1579 ,1104 ,1660 ,1362 ,1985 ,1814 ,1579 ,793 ,62 ,1979 ,843 ,1868 ,514 ,1919 ,1396 ,785 , + 1635 ,618 ,124 ,21 ,791 ,225 ,895 ,774 ,1167 ,1658 ,1421 ,1494 ,793 ,1582 ,368 ,1755 , + 16 ,1571 ,824 ,1604 ,1201 ,585 ,1867 ,514 ,314 ,1097 ,1667 ,1703 ,2047 ,851 ,539 ,139 , + 609 ,1663 ,1534 ,1846 ,701 ,54 ,1067 ,243 ,1860 ,1500 ,1179 ,1618 ,1700 ,1722 ,788 ,700 , + 684 ,1829 ,1441 ,241 ,1405 ,1410 ,1353 ,1086 ,848 ,2 ,1778 ,49 ,1431 ,1076 ,1477 ,1420 , + 1345 ,1216 ,1235 ,1453 ,978 ,737 ,1613 ,1134 ,1657 ,965 ,33 ,1694 ,1560 ,1483 ,228 ,1341 , + 275 ,660 ,266 ,956 ,1279 ,1086 ,486 ,1858 ,1690 ,500 ,696 ,237 ,1920 ,322 ,625 ,1798 , + 1384 ,1558 ,78 ,406 ,1050 ,208 ,548 ,1462 ,1016 ,19 ,179 ,1099 ,152 ,267 ,1883 ,1041 , + 1457 ,735 ,1541 ,311 ,2010 ,1566 ,1500 ,1743 ,1507 ,1045 ,37 ,1470 ,679 ,589 ,410 ,268 , + 1382 ,538 ,1300 ,218 ,1446 ,109 ,524 ,1327 ,591 ,496 ,215 ,678 ,862 ,1521 ,1394 ,1350 , + 43 ,772 ,1721 ,129 ,1756 ,1485 ,372 ,932 ,1556 ,1537 ,1598 ,243 ,409 ,1619 ,397 ,1814 , + 1401 ,1976 ,876 ,1796 ,1735 ,706 ,515 ,1530 ,138 ,988 ,760 ,1217 ,552 ,123 ,975 ,45 , + 798 ,1644 ,484 ,1166 ,576 ,1772 ,1438 ,250 ,822 ,1761 ,2006 ,351 ,1013 ,584 ,414 ,4 , + 1332 ,359 ,1311 ,48 ,192 ,1853 ,567 ,544 ,1071 ,1110 ,24 ,967 ,1366 ,1022 ,420 ,1086 , + 566 ,1598 ,1723 ,240 ,327 ,554 ,1111 ,1288 ,1827 ,784 ,1661 ,1102 ,962 ,1183 ,1218 ,1497 , + 1086 ,388 ,1906 ,523 ,400 ,641 ,1106 ,1930 ,761 ,1702 ,763 ,1054 ,510 ,570 ,1781 ,1344 , + 1154 ,1867 ,378 ,663 ,1801 ,1213 ,971 ,795 ,786 ,631 ,1488 ,248 ,271 ,198 ,852 ,1373 , + 1498 ,1935 ,1280 ,1924 ,1886 ,475 ,423 ,1300 ,1103 ,940 ,1410 ,484 ,1464 ,1732 ,1438 ,188 , + 1913 ,292 ,67 ,708 ,670 ,484 ,832 ,166 ,2036 ,1945 ,2012 ,151 ,171 ,117 ,1614 ,729 , + 1498 ,964 ,393 ,1464 ,557 ,693 ,552 ,1804 ,1161 ,2032 ,713 ,243 ,1671 ,342 ,172 ,2036 , + 1652 ,1796 ,1746 ,5 ,1270 ,542 ,112 ,1559 ,810 ,1325 ,2020 ,1177 ,624 ,983 ,1318 ,225 , + 670 ,127 ,774 ,410 ,114 ,531 ,1068 ,927 ,600 ,1738 ,1604 ,619 ,180 ,83 ,1443 ,1662 , + 542 ,980 ,641 ,385 ,919 ,712 ,357 ,672 ,886 ,1900 ,476 ,1268 ,369 ,1687 ,1603 ,807 , + 1346 ,411 ,180 ,1893 ,668 ,161 ,1692 ,394 ,1454 ,661 ,787 ,1310 ,71 ,509 ,822 ,532 , + 1886 ,158 ,1301 ,1260 ,964 ,1602 ,1195 ,1711 ,694 ,1728 ,1193 ,1989 ,1036 ,254 ,1952 ,920 , + 409 ,1366 ,1993 ,617 ,1279 ,1244 ,722 ,950 ,185 ,304 ,395 ,661 ,1976 ,1176 ,379 ,1622 , + 648 ,1873 ,1152 ,463 ,1565 ,1265 ,1626 ,984 ,1892 ,501 ,154 ,1139 ,1765 ,1775 ,1441 ,1704 , + 1750 ,1257 ,1595 ,1181 ,234 ,942 ,1209 ,512 ,1046 ,7 ,360 ,1809 ,1592 ,294 ,417 ,1428 , + 1617 ,781 ,1717 ,1929 ,1317 ,261 ,1161 ,609 ,370 ,895 ,1289 ,1908 ,970 ,577 ,800 ,1235 , + 1434 ,728 ,1681 ,166 ,131 ,1354 ,723 ,1516 ,68 ,729 ,634 ,512 ,1020 ,583 ,707 ,1306 , + 394 ,1539 ,913 ,619 ,1905 ,391 ,2034 ,1664 ,1702 ,1785 ,1858 ,291 ,1945 ,237 ,1068 ,1117 , + 13 ,808 ,1544 ,183 ,45 ,1522 ,577 ,479 ,1347 ,1477 ,1938 ,393 ,882 ,993 ,172 ,980 , + 268 ,415 ,1138 ,472 ,38 ,1213 ,328 ,509 ,1461 ,503 ,1995 ,744 ,2008 ,1707 ,204 ,633 , + 1378 ,756 ,661 ,26 ,2036 ,333 ,337 ,992 ,1942 ,527 ,1294 ,1854 ,1765 ,1291 ,1711 ,800 , + 1503 ,1752 ,1538 ,101 ,1838 ,715 ,1298 ,487 ,1641 ,1464 ,1671 ,1874 ,1499 ,1878 ,372 ,1974 , + 1139 ,1221 ,1343 ,1498 ,639 ,382 ,901 ,1531 ,258 ,1991 ,703 ,1020 ,1994 ,2018 ,1603 ,81 , + 66 ,99 ,1852 ,1012 ,327 ,566 ,1414 ,456 ,2021 ,1063 ,239 ,2045 ,42 ,445 ,1225 ,368 , + 1346 ,924 ,684 ,519 ,521 ,679 ,1700 ,2012 ,978 ,82 ,970 ,691 ,900 ,1539 ,258 ,1428 , + 1612 ,1882 ,1432 ,520 ,429 ,27 ,230 ,806 ,1062 ,696 ,1877 ,433 ,1314 ,341 ,252 ,474 , + 871 ,2043 ,689 ,465 ,361 ,1801 ,938 ,312 ,24 ,64 ,460 ,206 ,1442 ,1128 ,904 ,1107 , + 141 ,382 ,1492 ,35 ,1428 ,751 ,525 ,1694 ,1332 ,1790 ,1665 ,722 ,313 ,685 ,147 ,72 , + 1315 ,889 ,893 ,1431 ,1401 ,1260 ,18 ,1172 ,1111 ,1110 ,423 ,1855 ,575 ,1117 ,141 ,871 , + 1877 ,1754 ,1657 ,938 ,1991 ,1633 ,308 ,1444 ,1360 ,359 ,1493 ,299 ,1223 ,1888 ,1414 ,1933 , + 599 ,593 ,167 ,1352 ,1953 ,1914 ,24 ,1703 ,838 ,139 ,650 ,1239 ,63 ,203 ,1654 ,1959 , + 1486 ,980 ,837 ,1697 ,570 ,626 ,1630 ,1590 ,1247 ,1950 ,1147 ,1841 ,327 ,1203 ,1703 ,983 , + 1797 ,1367 ,1535 ,1348 ,1532 ,716 ,1002 ,1180 ,1548 ,345 ,524 ,1716 ,1256 ,1977 ,802 ,667 , + 168 ,1874 ,1288 ,1047 ,1780 ,531 ,1056 ,1849 ,711 ,2002 ,133 ,906 ,254 ,1447 ,989 ,1718 , + 541 ,1700 ,1178 ,142 ,481 ,1443 ,976 ,1744 ,1437 ,374 ,415 ,1190 ,1123 ,113 ,144 ,851 , + 1978 ,1642 ,1736 ,1402 ,1255 ,1897 ,1332 ,717 ,1423 ,973 ,47 ,1124 ,422 ,1688 ,1047 ,1562 , + 192 ,203 ,265 ,556 ,1736 ,395 ,666 ,13 ,739 ,1338 ,1704 ,515 ,86 ,1492 ,2003 ,1555 , + 1316 ,86 ,1923 ,440 ,143 ,219 ,116 ,977 ,1673 ,1913 ,1847 ,929 ,987 ,957 ,1216 ,1718 , + 1883 ,710 ,186 ,1055 ,544 ,1275 ,1707 ,236 ,1534 ,543 ,247 ,1950 ,1540 ,886 ,464 ,1666 , + 577 ,2029 ,419 ,152 ,1791 ,1154 ,859 ,677 ,1814 ,1380 ,951 ,471 ,136 ,1361 ,341 ,1464 , + 772 ,1879 ,1922 ,1826 ,832 ,1027 ,229 ,1295 ,515 ,1392 ,60 ,310 ,1205 ,55 ,1306 ,1503 , + 580 ,1322 ,934 ,555 ,1807 ,317 ,32 ,358 ,575 ,1494 ,731 ,1526 ,691 ,1637 ,1086 ,1781 , + 68 ,1661 ,1360 ,698 ,1508 ,852 ,1970 ,1429 ,1500 ,979 ,1843 ,444 ,1526 ,678 ,1390 ,1503 , + 1583 ,1693 ,605 ,1117 ,1760 ,1846 ,777 ,731 ,1808 ,1860 ,557 ,378 ,1261 ,971 ,341 ,1590 , + 1024 ,62 ,1611 ,454 ,865 ,1642 ,685 ,1175 ,1882 ,1035 ,1859 ,1589 ,250 ,1101 ,671 ,1608 , + 205 ,1812 ,1548 ,1784 ,898 ,168 ,603 ,54 ,1288 ,1957 ,1645 ,36 ,1017 ,840 ,1683 ,448 , + 1734 ,291 ,1914 ,1804 ,976 ,449 ,319 ,1940 ,2019 ,1632 ,674 ,567 ,788 ,1646 ,347 ,963 , + 1963 ,227 ,1618 ,130 ,1104 ,1888 ,1884 ,973 ,1576 ,1465 ,1066 ,741 ,884 ,837 ,1338 ,1343 , + 409 ,1662 ,1421 ,780 ,578 ,24 ,290 ,1691 ,616 ,240 ,1929 ,497 ,1391 ,1517 ,1455 ,596 , + 601 ,969 ,1713 ,1291 ,543 ,1673 ,291 ,785 ,1386 ,1707 ,520 ,1320 ,1179 ,984 ,742 ,441 , + 1507 ,1671 ,1605 ,312 ,635 ,685 ,1776 ,220 ,1528 ,378 ,744 ,969 ,1544 ,949 ,1614 ,1413 , + 1905 ,227 ,328 ,522 ,160 ,33 ,418 ,736 ,533 ,269 ,1797 ,109 ,218 ,1075 ,884 ,1468 , + 718 ,1578 ,213 ,477 ,1187 ,1871 ,399 ,927 ,639 ,1921 ,348 ,1890 ,1246 ,1229 ,501 ,709 , + 1963 ,658 ,1305 ,1398 ,602 ,163 ,1762 ,539 ,34 ,1540 ,2013 ,134 ,293 ,223 ,1317 ,1442 , + 706 ,670 ,327 ,18 ,372 ,1426 ,274 ,439 ,1371 ,308 ,1331 ,1606 ,1647 ,1656 ,1549 ,1950 , + 288 ,1033 ,1483 ,1959 ,200 ,935 ,725 ,465 ,1213 ,321 ,1786 ,1762 ,2025 ,1151 ,970 ,853 , + 231 ,1433 ,19 ,1458 ,35 ,1938 ,1677 ,738 ,859 ,1157 ,1602 ,1501 ,1172 ,1834 ,643 ,1085 , + 1376 ,1570 ,1317 ,1162 ,612 ,1275 ,637 ,302 ,156 ,1300 ,896 ,1663 ,284 ,753 ,1739 ,638 , + 1817 ,1515 ,1325 ,291 ,1642 ,1981 ,477 ,1551 ,1639 ,376 ,2040 ,1259 ,650 ,355 ,1691 ,1938 , + 530 ,1692 ,858 ,1139 ,1870 ,402 ,1928 ,804 ,1192 ,1179 ,133 ,1139 ,2047 ,357 ,127 ,310 , + 1697 ,138 ,291 ,1176 ,1595 ,1524 ,1495 ,433 ,1757 ,84 ,10 ,972 ,1556 ,962 ,279 ,1325 , + 1505 ,1308 ,1993 ,290 ,930 ,1975 ,242 ,782 ,987 ,601 ,312 ,457 ,471 ,1528 ,40 ,107 , + 802 ,936 ,597 ,1398 ,144 ,30 ,189 ,487 ,1003 ,1256 ,252 ,1286 ,934 ,1020 ,1242 ,1741 , + 506 ,1976 ,1550 ,422 ,508 ,319 ,2041 ,1126 ,2021 ,1284 ,1762 ,898 ,1948 ,1380 ,1776 ,1800 , + 1312 ,9 ,1825 ,921 ,459 ,553 ,422 ,630 ,435 ,1023 ,1024 ,520 ,1704 ,1631 ,198 ,213 , + 1852 ,177 ,1647 ,1084 ,1433 ,989 ,116 ,1704 ,1088 ,1608 ,1041 ,1820 ,228 ,1244 ,383 ,1199 , + 1046 ,494 ,1175 ,1536 ,799 ,5 ,170 ,364 ,1357 ,97 ,1394 ,2038 ,461 ,1581 ,1086 ,805 , + 1252 ,191 ,1826 ,594 ,1636 ,1189 ,674 ,295 ,1544 ,520 ,1449 ,1065 ,30 ,1402 ,509 ,619 , + 1650 ,656 ,1369 ,812 ,1380 ,39 ,1452 ,1457 ,637 ,1600 ,455 ,1931 ,1464 ,231 ,965 ,1547 , + 1627 ,1654 ,245 ,1383 ,129 ,1596 ,1918 ,1069 ,71 ,496 ,1054 ,798 ,490 ,1592 ,472 ,3 , + 1751 ,90 ,1323 ,1057 ,604 ,1644 ,271 ,507 ,926 ,723 ,314 ,1915 ,970 ,627 ,330 ,1319 , + 1389 ,934 ,1304 ,1375 ,407 ,1771 ,882 ,1555 ,1356 ,2033 ,785 ,909 ,1364 ,1939 ,1474 ,2025 , + 504 ,10 ,678 ,1891 ,1292 ,1001 ,1173 ,1117 ,1661 ,134 ,593 ,536 ,2026 ,34 ,1316 ,489 , + 277 ,1768 ,590 ,1319 ,180 ,1940 ,675 ,218 ,1832 ,457 ,203 ,444 ,1958 ,1932 ,1139 ,479 , + 1199 ,364 ,1344 ,479 ,1390 ,413 ,1074 ,41 ,32 ,1335 ,1646 ,775 ,395 ,1106 ,160 ,980 , + 398 ,1802 ,1127 ,217 ,1406 ,338 ,185 ,1683 ,1465 ,260 ,806 ,1443 ,2023 ,1278 ,1677 ,1239 , + 415 ,1425 ,382 ,1632 ,749 ,201 ,1592 ,2038 ,1296 ,1080 ,1060 ,1306 ,1208 ,307 ,192 ,1801 , + 540 ,1414 ,1010 ,984 ,1897 ,1362 ,9 ,2023 ,1814 ,376 ,477 ,903 ,571 ,1821 ,248 ,139 , + 1378 ,1603 ,1427 ,1335 ,132 ,1086 ,1838 ,1986 ,1172 ,748 ,1000 ,481 ,276 ,1827 ,1309 ,1064 , + 1507 ,904 ,1213 ,196 ,1019 ,1189 ,1619 ,574 ,1222 ,1750 ,493 ,1786 ,985 ,1866 ,276 ,1598 , + 454 ,464 ,1235 ,1452 ,196 ,1454 ,1237 ,1152 ,1463 ,1973 ,569 ,1041 ,740 ,1829 ,804 ,295 , + 1739 ,1123 ,1248 ,556 ,1777 ,1453 ,1350 ,2047 ,149 ,1211 ,575 ,410 ,152 ,1836 ,1010 ,913 , + 706 ,670 ,732 ,1385 ,1344 ,489 ,73 ,1590 ,1438 ,1663 ,1020 ,887 ,193 ,117 ,1268 ,1730 , + 186 ,1611 ,721 ,1897 ,1594 ,338 ,448 ,463 ,1083 ,1187 ,618 ,1651 ,1218 ,825 ,814 ,242 , + 1290 ,1157 ,1836 ,656 ,714 ,1525 ,829 ,946 ,1346 ,95 ,688 ,181 ,993 ,778 ,780 ,84 , + 139 ,1688 ,452 ,1383 ,1326 ,18 ,1086 ,1443 ,1761 ,1860 ,851 ,1835 ,1850 ,1094 ,317 ,595 , + 1280 ,1196 ,1490 ,237 ,1231 ,1026 ,1846 ,1817 ,347 ,1011 ,1609 ,1382 ,203 ,1724 ,965 ,1683 , + 1653 ,223 ,1726 ,1520 ,223 ,443 ,1868 ,791 ,1703 ,759 ,1755 ,1529 ,1078 ,1175 ,1150 ,856 , + 1280 ,1885 ,1288 ,55 ,1778 ,1520 ,824 ,1945 ,671 ,590 ,1720 ,93 ,1888 ,6 ,1311 ,1795 , + 1218 ,825 ,465 ,1163 ,247 ,301 ,1192 ,968 ,414 ,482 ,712 ,1799 ,744 ,793 ,1291 ,1170 , + 1228 ,1185 ,1024 ,250 ,1097 ,837 ,115 ,1178 ,1113 ,976 ,420 ,311 ,1391 ,1793 ,900 ,1848 , + 947 ,1739 ,1038 ,1005 ,1364 ,171 ,1612 ,127 ,1938 ,1891 ,682 ,993 ,196 ,290 ,330 ,294 , + 13 ,1974 ,1726 ,989 ,1745 ,1652 ,1607 ,865 ,858 ,336 ,534 ,1665 ,438 ,224 ,608 ,1591 , + 200 ,1644 ,1373 ,5 ,291 ,407 ,687 ,1849 ,1939 ,1332 ,342 ,57 ,1520 ,1820 ,2043 ,1855 , + 1331 ,1159 ,1501 ,1334 ,711 ,1926 ,850 ,1439 ,386 ,374 ,1325 ,1757 ,1912 ,1369 ,1707 ,425 , + 2040 ,473 ,1544 ,725 ,728 ,670 ,357 ,1004 ,249 ,1340 ,421 ,1033 ,1609 ,14 ,1706 ,805 , + 541 ,243 ,783 ,546 ,267 ,1030 ,976 ,2008 ,1437 ,1542 ,662 ,436 ,194 ,1430 ,644 ,851 , + 1978 ,1642 ,1923 ,1853 ,423 ,1626 ,377 ,1586 ,1423 ,591 ,47 ,1388 ,35 ,1688 ,332 ,1562 , + 666 ,243 ,783 ,142 ,481 ,1030 ,666 ,1744 ,1437 ,1370 ,662 ,1190 ,1908 ,1714 ,802 ,1448 , + 1978 ,1642 ,1490 ,1402 ,692 ,1067 ,1353 ,1774 ,1423 ,847 ,1538 ,1388 ,35 ,1440 ,1047 ,1752 , + 448 ,243 ,1697 ,164 ,1736 ,1030 ,666 ,1744 ,1437 ,1370 ,1165 ,1190 ,1912 ,1714 ,644 ,1448 , + 1978 ,646 ,1736 ,1428 ,1255 ,1897 ,377 ,1774 ,1423 ,847 ,1538 ,1124 ,35 ,97 ,332 ,1562 , + 192 ,243 ,1697 ,1348 ,1736 ,1030 ,1978 ,1744 ,739 ,1992 ,1165 ,439 ,1101 ,1430 ,802 ,1448 , + 32 ,646 ,1736 ,1406 ,1255 ,1067 ,1353 ,717 ,604 ,591 ,47 ,1388 ,422 ,97 ,427 ,1956 , + 675 ,153 ,1546 ,818 ,1052 ,948 ,1790 ,462 ,477 ,64 ,807 ,1863 ,1936 ,872 ,384 ,615 , + 74 ,1996 ,1935 ,1445 ,166 ,1798 ,1344 ,569 ,286 ,58 ,1716 ,506 ,357 ,13 ,381 ,974 , + 780 ,1949 ,1620 ,810 ,153 ,697 ,650 ,1851 ,199 ,69 ,1434 ,1458 ,1402 ,1265 ,89 ,1720 , + 58 ,1167 ,1433 ,883 ,1086 ,1253 ,629 ,1613 ,1573 ,1653 ,178 ,19 ,713 ,1079 ,1321 ,363 , + 1315 ,1697 ,1547 ,1696 ,139 ,814 ,878 ,855 ,256 ,1826 ,948 ,1838 ,1928 ,727 ,1600 ,1022 , + 333 ,918 ,1712 ,1508 ,498 ,1577 ,877 ,1159 ,492 ,1208 ,529 ,279 ,1300 ,1796 ,287 ,1329 , + 976 ,419 ,756 ,67 ,1742 ,2029 ,449 ,1617 ,520 ,1256 ,922 ,1234 ,1490 ,476 ,1983 ,697 , + 497 ,1570 ,794 ,1888 ,1307 ,10 ,23 ,1313 ,1799 ,684 ,157 ,1036 ,1419 ,1377 ,129 ,958 , + 297 ,106 ,1944 ,500 ,1734 ,247 ,934 ,472 ,1357 ,940 ,1344 ,1016 ,1161 ,133 ,86 ,627 , + 1940 ,1460 ,1500 ,1827 ,1936 ,468 ,1340 ,538 ,909 ,1958 ,1765 ,1518 ,1405 ,250 ,1200 ,992 , + 846 ,596 ,1819 ,1450 ,2005 ,1569 ,733 ,1190 ,469 ,1992 ,1048 ,605 ,1912 ,837 ,853 ,1938 , + 1050 ,1331 ,77 ,1858 ,1169 ,511 ,1093 ,1774 ,699 ,1438 ,569 ,559 ,207 ,369 ,1783 ,1709 , + 420 ,1828 ,1206 ,1543 ,18 ,1006 ,93 ,101 ,28 ,103 ,7 ,1029 ,978 ,472 ,1353 ,2024 , + 282 ,1410 ,67 ,1973 ,1751 ,676 ,1271 ,1922 ,897 ,1130 ,704 ,941 ,1438 ,788 ,1897 ,871 , + 235 ,199 ,1592 ,1796 ,1802 ,511 ,1317 ,1832 ,754 ,1543 ,1517 ,970 ,1869 ,1570 ,1319 ,541 , + 862 ,1639 ,1973 ,442 ,333 ,1903 ,889 ,221 ,1351 ,25 ,1367 ,1020 ,1936 ,1567 ,902 ,734 , + 1382 ,364 ,1257 ,676 ,1967 ,99 ,829 ,1440 ,600 ,584 ,936 ,592 ,1304 ,1011 ,1864 ,412 , + 471 ,1517 ,958 ,650 ,300 ,1269 ,1246 ,1198 ,1451 ,1497 ,273 ,1828 ,1553 ,615 ,688 ,649 , + 1150 ,1499 ,602 ,538 ,173 ,1370 ,1054 ,322 ,1332 ,327 ,1446 ,1622 ,876 ,1780 ,1471 ,706 , + 672 ,1170 ,1150 ,1301 ,1162 ,2023 ,1810 ,1504 ,865 ,1088 ,1185 ,1900 ,719 ,821 ,418 ,630 , + 1220 ,1478 ,1902 ,940 ,139 ,546 ,642 ,400 ,998 ,272 ,614 ,1283 ,342 ,470 ,432 ,1000 , + 826 ,1772 ,1857 ,2 ,1177 ,294 ,358 ,1815 ,63 ,1130 ,830 ,538 ,2046 ,1477 ,1260 ,1725 , + 1760 ,1532 ,1379 ,1072 ,1010 ,1794 ,1324 ,1488 ,1663 ,1856 ,49 ,1084 ,83 ,1969 ,259 ,1292 , + 431 ,952 ,740 ,1700 ,234 ,827 ,722 ,1112 ,444 ,481 ,1446 ,415 ,1074 ,379 ,1992 ,388 , + 1327 ,1234 ,676 ,780 ,1538 ,1033 ,1941 ,1630 ,303 ,879 ,430 ,25 ,2037 ,1839 ,173 ,206 , + 1161 ,1346 ,793 ,1260 ,628 ,1884 ,1470 ,803 ,69 ,471 ,1431 ,1848 ,519 ,1906 ,1852 ,699 , + 1928 ,1700 ,1559 ,1562 ,340 ,1443 ,976 ,1744 ,1210 ,1542 ,662 ,1757 ,1912 ,1714 ,853 ,241 , + 1978 ,251 ,290 ,1402 ,610 ,1897 ,1497 ,1774 ,204 ,591 ,899 ,1124 ,35 ,1440 ,481 ,970 , + 666 ,243 ,783 ,142 ,267 ,1030 ,976 ,2008 ,739 ,374 ,1165 ,143 ,1912 ,1714 ,644 ,1448 , + 32 ,1642 ,1490 ,1428 ,692 ,1067 ,1497 ,1774 ,204 ,591 ,1538 ,1124 ,422 ,97 ,1047 ,1752 , + 384 ,243 ,1697 ,1348 ,1736 ,1030 ,1978 ,1744 ,739 ,1992 ,1165 ,1190 ,1101 ,1430 ,144 ,1448 , + 1978 ,1642 ,1490 ,1406 ,1255 ,1897 ,1332 ,1774 ,724 ,591 ,47 ,1388 ,422 ,97 ,1140 ,1956 , + 1913 ,1123 ,1568 ,48 ,1380 ,1374 ,1027 ,99 ,947 ,74 ,1780 ,874 ,1170 ,828 ,1792 ,882 , + 1431 ,557 ,1477 ,1049 ,132 ,846 ,895 ,573 ,552 ,759 ,1064 ,1987 ,1541 ,671 ,1844 ,993 , + 854 ,1855 ,54 ,1418 ,830 ,1315 ,988 ,1612 ,1923 ,713 ,1902 ,1564 ,1535 ,1076 ,1781 ,1474 , + 287 ,856 ,1841 ,1246 ,1014 ,812 ,325 ,1397 ,431 ,568 ,1932 ,1376 ,331 ,1743 ,1098 ,50 , + 974 ,1652 ,108 ,719 ,1196 ,306 ,1677 ,728 ,1498 ,832 ,773 ,1973 ,1140 ,1870 ,1939 ,1522 , + 203 ,1590 ,448 ,1014 ,1915 ,470 ,1124 ,1472 ,1369 ,1870 ,1595 ,717 ,994 ,1498 ,619 ,2024 , + 1173 ,1154 ,1076 ,331 ,950 ,1869 ,1603 ,302 ,1381 ,130 ,239 ,1174 ,1852 ,758 ,48 ,648 , + 1496 ,676 ,733 ,1431 ,1240 ,81 ,1873 ,1051 ,278 ,442 ,1282 ,1175 ,1055 ,1474 ,1548 ,62 , + 1797 ,1515 ,1535 ,1450 ,1019 ,1030 ,666 ,1744 ,1106 ,605 ,1165 ,436 ,1101 ,1430 ,144 ,851 , + 1978 ,646 ,618 ,1853 ,610 ,1953 ,377 ,717 ,1423 ,17 ,1538 ,513 ,1930 ,1640 ,1140 ,1752 , + 666 ,243 ,783 ,142 ,267 ,1030 ,976 ,2008 ,739 ,374 ,1165 ,439 ,1101 ,650 ,644 ,1942 , + 1978 ,646 ,618 ,1406 ,692 ,1953 ,377 ,717 ,204 ,591 ,115 ,1124 ,422 ,1688 ,1140 ,1562 , + 448 ,243 ,1697 ,164 ,1736 ,1443 ,666 ,1868 ,739 ,1370 ,1165 ,143 ,1101 ,650 ,644 ,1448 , + 32 ,1829 ,290 ,1406 ,1255 ,1067 ,377 ,1774 ,204 ,591 ,1538 ,951 ,35 ,97 ,427 ,1562 , + 503 ,113 ,963 ,880 ,1070 ,187 ,93 ,344 ,500 ,514 ,1271 ,852 ,1858 ,670 ,1202 ,1010 , + 297 ,1772 ,793 ,1197 ,296 ,1477 ,569 ,999 ,1171 ,1855 ,283 ,1090 ,1479 ,1792 ,1279 ,842 , + 1554 ,643 ,66 ,53 ,1760 ,887 ,1905 ,534 ,1398 ,654 ,1564 ,246 ,1609 ,420 ,1622 ,199 , + 484 ,556 ,924 ,889 ,1432 ,1132 ,129 ,437 ,709 ,469 ,332 ,855 ,1676 ,1738 ,279 ,150 , + 207 ,1845 ,1327 ,584 ,1347 ,249 ,436 ,1111 ,219 ,309 ,267 ,50 ,1647 ,1274 ,1686 ,183 , + 484 ,368 ,103 ,448 ,857 ,955 ,499 ,41 ,1121 ,1181 ,1134 ,878 ,491 ,1619 ,1190 ,1705 , + 1078 ,1290 ,1234 ,1462 ,2 ,1088 ,2012 ,956 ,1749 ,761 ,1664 ,806 ,829 ,646 ,1745 ,1362 , + 1117 ,1545 ,1712 ,298 ,505 ,1921 ,772 ,1431 ,2016 ,1903 ,1300 ,447 ,2000 ,1869 ,358 ,1019 , + 955 ,1516 ,555 ,629 ,796 ,1931 ,1855 ,181 ,1245 ,2020 ,998 ,1157 ,727 ,390 ,263 ,1369 , + 1490 ,746 ,1830 ,1951 ,820 ,1401 ,400 ,1505 ,1715 ,1349 ,627 ,303 ,284 ,894 ,442 ,2043 , + 1789 ,545 ,496 ,1025 ,832 ,1973 ,670 ,158 ,1603 ,672 ,15 ,1183 ,1848 ,204 ,2044 ,1194 , + 604 ,498 ,1454 ,1786 ,1952 ,560 ,400 ,1355 ,641 ,778 ,631 ,1596 ,888 ,392 ,135 ,599 , + 853 ,393 ,159 ,1085 ,696 ,865 ,1492 ,1915 ,944 ,2035 ,1951 ,1775 ,1526 ,1376 ,526 ,1483 , + 579 ,754 ,590 ,1520 ,680 ,1881 ,1501 ,1111 ,1611 ,1395 ,744 ,1826 ,1229 ,1587 ,1770 ,1071 , + 1676 ,136 ,307 ,519 ,1355 ,69 ,452 ,1546 ,871 ,1396 ,433 ,292 ,1895 ,1382 ,1193 ,1004 , + 397 ,900 ,1363 ,1791 ,1384 ,508 ,1597 ,1708 ,719 ,756 ,51 ,1127 ,497 ,1124 ,1465 ,1251 , + 138 ,65 ,424 ,636 ,635 ,874 ,1818 ,1399 ,983 ,736 ,1358 ,975 ,1818 ,250 ,751 ,1094 , + 1151 ,1751 ,1157 ,1091 ,1007 ,500 ,1387 ,1331 ,322 ,1998 ,631 ,1743 ,324 ,1696 ,1591 ,985 , + 1401 ,1355 ,1669 ,677 ,1364 ,1346 ,911 ,1340 ,967 ,1979 ,1139 ,162 ,1221 ,1601 ,316 ,273 , + 1857 ,708 ,777 ,766 ,1279 ,1408 ,1787 ,1137 ,1472 ,866 ,1790 ,1608 ,618 ,461 ,1353 ,1704 , + 1308 ,1031 ,1302 ,570 ,278 ,1981 ,229 ,1520 ,1226 ,708 ,754 ,1874 ,550 ,360 ,3 ,355 , + 1304 ,1609 ,122 ,1161 ,1236 ,28 ,1732 ,1053 ,1786 ,429 ,1454 ,1439 ,1358 ,7 ,645 ,1050 , + 409 ,653 ,244 ,632 ,1699 ,1644 ,1242 ,944 ,386 ,337 ,2028 ,1731 ,1252 ,636 ,788 ,1765 , + 1844 ,1616 ,641 ,1373 ,2036 ,185 ,1832 ,1183 ,73 ,267 ,1886 ,844 ,781 ,1586 ,606 ,1871 , + 2007 ,983 ,1899 ,547 ,1073 ,1592 ,2014 ,1529 ,1031 ,1251 ,1805 ,1040 ,719 ,349 ,1079 ,1943 , + 940 ,1903 ,1028 ,615 ,446 ,1409 ,1778 ,638 ,431 ,341 ,1186 ,792 ,1585 ,1670 ,1557 ,1879 , + 998 ,839 ,1284 ,696 ,213 ,339 ,1564 ,689 ,2003 ,1299 ,685 ,1573 ,888 ,293 ,1715 ,948 , + 1378 ,903 ,1837 ,1025 ,1264 ,877 ,230 ,899 ,4 ,612 ,38 ,1579 ,1977 ,593 ,241 ,260 , + 450 ,929 ,321 ,1387 ,1427 ,360 ,1711 ,667 ,451 ,109 ,1162 ,1704 ,1874 ,1358 ,837 ,1862 , + 372 ,714 ,207 ,440 ,1535 ,591 ,1056 ,1582 ,1667 ,1354 ,1405 ,279 ,1904 ,1712 ,1382 ,540 , + 1608 ,383 ,973 ,320 ,597 ,1584 ,1419 ,1703 ,269 ,1842 ,544 ,1059 ,1090 ,1348 ,297 ,1088 , + 1676 ,305 ,1960 ,1542 ,888 ,1808 ,1830 ,760 ,1906 ,491 ,870 ,21 ,1061 ,267 ,278 ,845 , +}; + +// https://huggingface.co/spaces/sesame/csm-1b/blob/main/prompts/conversational_b.wav +const char * default_speaker_b_text = "[1]like a super Mario level. Like it's very like high detail. And like, once you get into the park, it just like, everything looks like a computer game and they have all these, like, you know, if, if there's like a, you know, like in a Mario game, they will have like a question block. And if you like, you know, punch it, a coin will come out. So like everyone, when they come into the park, they get like this little bracelet and then you can go punching question blocks around."; +std::initializer_list default_speaker_b_codes = { + 1049 ,1864 ,658 ,896 ,819 ,515 ,641 ,1248 ,53 ,278 ,1037 ,141 ,1423 ,565 ,828 ,986 , + 1993 ,1692 ,170 ,1357 ,1780 ,1845 ,967 ,1253 ,1587 ,1854 ,1778 ,1165 ,58 ,575 ,1499 ,491 , + 919 ,934 ,1446 ,392 ,328 ,2020 ,1418 ,1652 ,1117 ,291 ,488 ,1168 ,1989 ,931 ,894 ,140 , + 1820 ,1666 ,1655 ,2038 ,1092 ,1370 ,826 ,1499 ,176 ,554 ,188 ,708 ,1548 ,224 ,437 ,1884 , + 599 ,1960 ,976 ,1150 ,826 ,860 ,287 ,723 ,1818 ,533 ,1790 ,1859 ,1919 ,393 ,652 ,1792 , + 783 ,539 ,252 ,1414 ,1035 ,340 ,1448 ,1506 ,194 ,132 ,109 ,1425 ,741 ,366 ,1157 ,1659 , + 510 ,32 ,759 ,458 ,1226 ,359 ,484 ,540 ,68 ,592 ,975 ,789 ,905 ,1556 ,1323 ,1601 , + 1079 ,1516 ,858 ,170 ,1971 ,1674 ,376 ,1190 ,1346 ,1617 ,368 ,1488 ,308 ,1712 ,423 ,1834 , + 1273 ,1719 ,1334 ,1293 ,1420 ,430 ,1324 ,427 ,1081 ,927 ,1214 ,191 ,1400 ,1292 ,777 ,622 , + 895 ,956 ,2012 ,1040 ,509 ,198 ,821 ,1365 ,1597 ,978 ,548 ,1608 ,1341 ,1148 ,380 ,1511 , + 887 ,879 ,1368 ,305 ,161 ,1121 ,1191 ,1839 ,1986 ,507 ,1540 ,1206 ,1511 ,1948 ,1549 ,1340 , + 583 ,1124 ,515 ,1691 ,1224 ,1357 ,446 ,1070 ,2011 ,1301 ,971 ,789 ,2002 ,1502 ,851 ,193 , + 1295 ,1132 ,2041 ,1522 ,1753 ,869 ,588 ,555 ,1012 ,492 ,48 ,1274 ,1701 ,1733 ,1185 ,635 , + 1881 ,1916 ,1964 ,1907 ,1296 ,467 ,94 ,1245 ,350 ,293 ,476 ,1537 ,689 ,2028 ,1684 ,819 , + 1764 ,1684 ,2002 ,1017 ,1485 ,633 ,1064 ,626 ,1287 ,499 ,131 ,470 ,581 ,1930 ,1585 ,1957 , + 1078 ,830 ,1664 ,1405 ,1471 ,1697 ,942 ,599 ,510 ,75 ,1118 ,992 ,1435 ,756 ,1021 ,1048 , + 1407 ,1158 ,534 ,1168 ,1501 ,1105 ,697 ,602 ,1626 ,1479 ,1187 ,361 ,1651 ,1426 ,557 ,334 , + 1157 ,76 ,877 ,1501 ,321 ,1122 ,597 ,1359 ,1507 ,1344 ,1894 ,984 ,209 ,2043 ,821 ,1230 , + 1610 ,135 ,877 ,662 ,956 ,1905 ,746 ,1324 ,1610 ,1476 ,1091 ,781 ,1734 ,216 ,1595 ,1619 , + 1242 ,1857 ,1331 ,122 ,1415 ,679 ,1437 ,502 ,899 ,546 ,1377 ,353 ,1835 ,1312 ,1333 ,1798 , + 1141 ,795 ,288 ,2020 ,884 ,1734 ,407 ,1357 ,2035 ,1645 ,1609 ,1224 ,1651 ,2025 ,1874 ,1776 , + 494 ,553 ,1859 ,297 ,1451 ,199 ,944 ,391 ,1481 ,621 ,108 ,1837 ,1079 ,845 ,1964 ,1153 , + 719 ,611 ,941 ,1020 ,476 ,1582 ,1413 ,979 ,1224 ,170 ,1747 ,1550 ,530 ,80 ,1982 ,230 , + 1715 ,732 ,1806 ,755 ,1844 ,114 ,476 ,247 ,1772 ,838 ,445 ,1916 ,564 ,263 ,1367 ,938 , + 1914 ,1090 ,1334 ,920 ,1072 ,810 ,176 ,1539 ,1385 ,877 ,1750 ,1422 ,1431 ,1806 ,1950 ,445 , + 430 ,495 ,1691 ,1634 ,1505 ,1201 ,1014 ,72 ,203 ,478 ,593 ,1895 ,657 ,1343 ,1432 ,967 , + 1005 ,448 ,318 ,1583 ,376 ,1303 ,1009 ,1238 ,1130 ,447 ,1604 ,553 ,107 ,142 ,795 ,277 , + 1109 ,1718 ,389 ,1012 ,1475 ,1054 ,1741 ,1366 ,1140 ,1851 ,527 ,1929 ,1186 ,1544 ,792 ,1870 , + 1473 ,1745 ,1309 ,859 ,1138 ,1582 ,177 ,1518 ,260 ,1483 ,1866 ,1873 ,491 ,780 ,1015 ,1967 , + 624 ,2004 ,530 ,2015 ,75 ,313 ,223 ,1627 ,1635 ,693 ,322 ,1843 ,474 ,1114 ,1613 ,1561 , + 1358 ,975 ,68 ,20 ,1056 ,975 ,13 ,1095 ,1754 ,949 ,58 ,1791 ,1560 ,1116 ,668 ,1398 , + 886 ,403 ,441 ,1945 ,2002 ,564 ,1671 ,591 ,1913 ,1076 ,687 ,1789 ,1235 ,684 ,1914 ,170 , + 126 ,960 ,323 ,390 ,1200 ,1069 ,1710 ,169 ,421 ,1008 ,615 ,1322 ,115 ,1973 ,474 ,1099 , + 712 ,1658 ,1344 ,1333 ,1850 ,745 ,1112 ,231 ,1905 ,59 ,1227 ,1834 ,612 ,558 ,492 ,555 , + 1895 ,397 ,156 ,316 ,592 ,1652 ,1334 ,1538 ,1936 ,1521 ,1709 ,705 ,645 ,226 ,851 ,715 , + 63 ,272 ,749 ,282 ,908 ,1950 ,1154 ,696 ,699 ,270 ,1351 ,41 ,1934 ,1431 ,994 ,272 , + 1557 ,1168 ,373 ,386 ,340 ,1707 ,845 ,1665 ,1353 ,1416 ,1867 ,439 ,442 ,1705 ,272 ,458 , + 210 ,1419 ,258 ,786 ,469 ,507 ,78 ,753 ,604 ,531 ,902 ,1388 ,170 ,1030 ,489 ,492 , + 1743 ,477 ,1178 ,1348 ,481 ,347 ,825 ,1665 ,1353 ,91 ,1165 ,439 ,1123 ,1204 ,853 ,1942 , + 200 ,1829 ,1736 ,1668 ,692 ,1897 ,1497 ,1396 ,204 ,847 ,1538 ,483 ,1995 ,14 ,271 ,1752 , + 716 ,1056 ,1029 ,546 ,340 ,347 ,976 ,1665 ,1210 ,605 ,415 ,439 ,1908 ,1204 ,220 ,1684 , + 32 ,251 ,1736 ,1853 ,692 ,1067 ,1353 ,1774 ,724 ,847 ,569 ,1388 ,422 ,97 ,427 ,1833 , + 1850 ,1056 ,1029 ,142 ,481 ,1443 ,666 ,1665 ,1210 ,1370 ,1165 ,439 ,1123 ,113 ,144 ,851 , + 1978 ,1829 ,618 ,1428 ,1908 ,1897 ,1332 ,1586 ,1423 ,847 ,1538 ,1388 ,1995 ,14 ,427 ,1833 , + 752 ,1056 ,1029 ,142 ,1736 ,1443 ,976 ,1665 ,739 ,91 ,415 ,143 ,1908 ,113 ,220 ,1684 , + 32 ,1642 ,290 ,1853 ,1908 ,1897 ,1353 ,1774 ,724 ,847 ,569 ,483 ,1995 ,1440 ,1047 ,1833 , + 919 ,184 ,503 ,2040 ,1509 ,1253 ,1209 ,7 ,484 ,354 ,872 ,792 ,1345 ,351 ,1874 ,139 , + 894 ,1177 ,203 ,2045 ,1663 ,587 ,1735 ,1451 ,1285 ,1283 ,633 ,1487 ,395 ,1255 ,1978 ,1546 , + 854 ,737 ,2002 ,343 ,235 ,985 ,1636 ,1391 ,515 ,1192 ,1290 ,16 ,1114 ,331 ,1475 ,1679 , + 1255 ,816 ,1872 ,512 ,1931 ,1124 ,479 ,863 ,414 ,1401 ,42 ,1938 ,95 ,238 ,455 ,875 , + 979 ,1538 ,319 ,1950 ,1107 ,488 ,750 ,1691 ,1611 ,1273 ,724 ,930 ,1816 ,331 ,1081 ,796 , + 510 ,937 ,943 ,1607 ,323 ,214 ,568 ,458 ,826 ,799 ,1833 ,1843 ,1008 ,1525 ,1183 ,11 , + 946 ,836 ,1539 ,847 ,820 ,1902 ,1728 ,634 ,1150 ,644 ,1376 ,400 ,120 ,1304 ,1891 ,1963 , + 1509 ,1081 ,1361 ,1246 ,1178 ,887 ,401 ,1190 ,1471 ,358 ,206 ,960 ,1569 ,520 ,1761 ,1353 , + 561 ,817 ,274 ,1883 ,1420 ,430 ,82 ,212 ,1379 ,2009 ,1472 ,1441 ,1481 ,1222 ,1501 ,1215 , + 1298 ,11 ,1023 ,605 ,1674 ,1660 ,1519 ,584 ,1587 ,175 ,436 ,388 ,95 ,99 ,1795 ,1677 , + 655 ,543 ,1761 ,790 ,1983 ,961 ,662 ,129 ,1458 ,523 ,1838 ,599 ,1902 ,1010 ,1598 ,128 , + 877 ,527 ,1077 ,1228 ,545 ,1338 ,1980 ,792 ,530 ,987 ,1444 ,595 ,1369 ,1601 ,1425 ,496 , + 1917 ,695 ,1192 ,954 ,1419 ,118 ,567 ,1334 ,142 ,372 ,1200 ,1715 ,1607 ,606 ,1277 ,749 , + 1570 ,399 ,422 ,962 ,2009 ,772 ,46 ,1583 ,685 ,80 ,1578 ,1123 ,342 ,476 ,1491 ,993 , + 1460 ,657 ,557 ,262 ,583 ,1090 ,1418 ,355 ,1275 ,395 ,1074 ,498 ,468 ,1173 ,314 ,1411 , + 947 ,300 ,1935 ,1587 ,1608 ,207 ,725 ,333 ,587 ,1927 ,1256 ,180 ,1534 ,24 ,1904 ,385 , + 1460 ,1249 ,443 ,178 ,433 ,1132 ,382 ,990 ,866 ,1703 ,1092 ,2013 ,2021 ,480 ,727 ,694 , + 510 ,1733 ,1022 ,1706 ,1437 ,1024 ,997 ,283 ,326 ,1694 ,347 ,708 ,1428 ,83 ,1461 ,46 , + 862 ,1739 ,508 ,1976 ,1506 ,168 ,83 ,1854 ,270 ,1110 ,612 ,873 ,339 ,1211 ,1709 ,799 , + 1860 ,1158 ,1307 ,813 ,989 ,1278 ,434 ,766 ,1334 ,1506 ,1726 ,405 ,143 ,806 ,713 ,49 , + 854 ,291 ,1922 ,1982 ,328 ,151 ,724 ,774 ,1345 ,1986 ,13 ,1445 ,310 ,280 ,1123 ,1913 , + 894 ,1131 ,1241 ,386 ,361 ,1332 ,709 ,143 ,1085 ,1645 ,98 ,266 ,1406 ,305 ,1158 ,1978 , + 398 ,986 ,881 ,1523 ,1338 ,1060 ,1138 ,1748 ,1239 ,842 ,198 ,1155 ,1791 ,555 ,1746 ,1706 , + 450 ,493 ,568 ,781 ,2044 ,159 ,1487 ,1135 ,743 ,282 ,1857 ,965 ,1203 ,1452 ,838 ,604 , + 977 ,1224 ,343 ,1393 ,1770 ,430 ,1713 ,567 ,1998 ,627 ,11 ,348 ,1167 ,291 ,1577 ,337 , + 728 ,589 ,971 ,173 ,159 ,245 ,471 ,1982 ,393 ,1219 ,2039 ,421 ,215 ,1350 ,29 ,805 , + 1792 ,307 ,1530 ,303 ,898 ,560 ,463 ,488 ,375 ,217 ,1705 ,569 ,114 ,1895 ,654 ,32 , + 232 ,1259 ,492 ,1980 ,449 ,1940 ,1330 ,1462 ,1627 ,993 ,1782 ,1013 ,791 ,1734 ,1446 ,250 , + 1322 ,1058 ,1334 ,1615 ,1183 ,1850 ,1858 ,862 ,1687 ,760 ,1241 ,1520 ,779 ,1096 ,276 ,175 , + 398 ,1069 ,333 ,1857 ,646 ,1521 ,984 ,115 ,1655 ,122 ,810 ,170 ,52 ,1720 ,127 ,1062 , + 1412 ,26 ,464 ,2002 ,420 ,1097 ,1860 ,1031 ,496 ,153 ,1770 ,1178 ,542 ,65 ,1001 ,1923 , + 1434 ,365 ,443 ,1280 ,893 ,1591 ,1941 ,119 ,1775 ,1278 ,620 ,1829 ,469 ,46 ,952 ,1716 , + 1158 ,700 ,1881 ,1098 ,503 ,922 ,1472 ,1382 ,1088 ,1965 ,1127 ,1093 ,632 ,1128 ,1787 ,1267 , + 1875 ,675 ,2043 ,236 ,1433 ,543 ,1609 ,1061 ,1598 ,887 ,1212 ,425 ,393 ,1775 ,1552 ,1384 , + 1623 ,1941 ,1264 ,1223 ,2045 ,851 ,1495 ,109 ,496 ,582 ,1959 ,1460 ,355 ,1343 ,1442 ,620 , + 1125 ,1984 ,1385 ,352 ,1443 ,1030 ,11 ,454 ,135 ,309 ,1085 ,1259 ,1118 ,1159 ,28 ,646 , + 336 ,1465 ,659 ,1321 ,726 ,851 ,422 ,1226 ,633 ,127 ,1748 ,1704 ,169 ,1980 ,1631 ,1889 , + 990 ,1867 ,793 ,1803 ,1251 ,449 ,1470 ,336 ,513 ,1066 ,438 ,1429 ,211 ,941 ,1154 ,1075 , + 847 ,662 ,465 ,1643 ,1181 ,1478 ,204 ,773 ,113 ,179 ,1914 ,274 ,486 ,843 ,636 ,260 , + 995 ,1464 ,646 ,1953 ,162 ,1003 ,244 ,1952 ,610 ,1813 ,228 ,135 ,889 ,1136 ,1198 ,868 , + 1938 ,1838 ,622 ,458 ,64 ,440 ,1309 ,276 ,1331 ,580 ,1839 ,217 ,70 ,994 ,340 ,2020 , + 1772 ,1761 ,759 ,1712 ,1220 ,1303 ,1876 ,1916 ,589 ,65 ,1054 ,1661 ,266 ,1997 ,1949 ,2017 , + 1938 ,135 ,784 ,1694 ,296 ,358 ,1310 ,1148 ,253 ,1958 ,1779 ,1663 ,849 ,349 ,871 ,317 , + 484 ,214 ,1226 ,203 ,2006 ,955 ,824 ,867 ,1130 ,1243 ,1537 ,1619 ,1905 ,107 ,1867 ,1163 , + 1590 ,1944 ,105 ,1774 ,1696 ,440 ,1952 ,15 ,734 ,700 ,1695 ,982 ,341 ,1712 ,1258 ,1596 , + 1538 ,1276 ,1794 ,1785 ,1192 ,1235 ,1433 ,1241 ,1526 ,496 ,821 ,878 ,551 ,1699 ,1912 ,648 , + 1718 ,1735 ,1446 ,1419 ,700 ,33 ,3 ,404 ,516 ,687 ,1118 ,756 ,741 ,1347 ,1600 ,892 , + 1661 ,877 ,1683 ,673 ,1632 ,489 ,1284 ,1281 ,656 ,263 ,1123 ,1413 ,1219 ,372 ,706 ,1413 , + 860 ,1768 ,307 ,1709 ,1239 ,1234 ,90 ,944 ,432 ,860 ,1220 ,1637 ,1939 ,131 ,1606 ,648 , + 1805 ,755 ,1291 ,1512 ,509 ,1472 ,699 ,616 ,977 ,51 ,1311 ,284 ,1782 ,1018 ,519 ,205 , + 1537 ,2031 ,0 ,966 ,251 ,1340 ,1254 ,1458 ,781 ,927 ,541 ,1896 ,526 ,736 ,513 ,446 , + 1141 ,921 ,819 ,1654 ,546 ,924 ,1307 ,237 ,4 ,1446 ,847 ,1974 ,1947 ,619 ,1245 ,1623 , + 2042 ,477 ,1697 ,142 ,340 ,478 ,2015 ,1665 ,1866 ,1370 ,2040 ,1107 ,1783 ,113 ,1821 ,1189 , + 937 ,1796 ,77 ,1047 ,1218 ,154 ,1453 ,1480 ,765 ,777 ,569 ,1388 ,285 ,1218 ,1249 ,764 , + 666 ,316 ,1559 ,164 ,267 ,1443 ,1978 ,1665 ,1210 ,91 ,1165 ,1190 ,1123 ,1430 ,144 ,1684 , + 32 ,1829 ,1736 ,1402 ,692 ,1953 ,1332 ,1774 ,1423 ,847 ,1538 ,951 ,35 ,1440 ,427 ,1833 , + 835 ,366 ,214 ,2002 ,1383 ,1571 ,280 ,1534 ,1539 ,1058 ,1871 ,1836 ,1242 ,790 ,1923 ,339 , + 312 ,394 ,1304 ,1289 ,1817 ,900 ,1585 ,1400 ,941 ,625 ,393 ,1645 ,346 ,264 ,1353 ,830 , + 685 ,1178 ,1529 ,1623 ,1045 ,926 ,1688 ,184 ,1558 ,1366 ,259 ,631 ,489 ,994 ,263 ,1857 , + 1494 ,488 ,1453 ,844 ,1511 ,636 ,1308 ,271 ,1436 ,751 ,1131 ,1813 ,1281 ,2020 ,66 ,48 , + 873 ,472 ,353 ,253 ,1915 ,1554 ,401 ,1076 ,1954 ,932 ,293 ,1251 ,1266 ,264 ,1875 ,265 , + 1674 ,506 ,1291 ,1275 ,1690 ,1038 ,826 ,1390 ,1235 ,702 ,839 ,1080 ,816 ,868 ,340 ,1975 , + 993 ,1607 ,1046 ,1902 ,1238 ,1756 ,971 ,2010 ,920 ,1137 ,39 ,694 ,1903 ,1056 ,125 ,984 , + 144 ,883 ,1899 ,1025 ,1540 ,1356 ,866 ,1427 ,799 ,1072 ,488 ,1546 ,1194 ,1156 ,1587 ,1008 , + 2000 ,617 ,1212 ,1471 ,1906 ,1237 ,1196 ,516 ,1733 ,849 ,1467 ,1451 ,556 ,1379 ,1229 ,1150 , + 602 ,1431 ,23 ,979 ,1702 ,291 ,1095 ,1549 ,1402 ,1153 ,786 ,1093 ,469 ,967 ,1758 ,720 , + 1039 ,156 ,659 ,1817 ,381 ,1197 ,1046 ,981 ,1770 ,1769 ,1017 ,286 ,304 ,644 ,418 ,44 , + 704 ,1696 ,178 ,1832 ,1786 ,989 ,359 ,638 ,1067 ,501 ,229 ,909 ,1869 ,1200 ,1614 ,176 , + 1320 ,1168 ,1315 ,1620 ,1175 ,1273 ,1254 ,594 ,1599 ,774 ,1246 ,350 ,582 ,1977 ,369 ,447 , + 1184 ,97 ,1012 ,5 ,1764 ,1456 ,526 ,179 ,1816 ,1218 ,2020 ,761 ,171 ,834 ,17 ,324 , + 1331 ,477 ,1178 ,691 ,481 ,1273 ,183 ,1744 ,1106 ,1601 ,415 ,436 ,1912 ,1024 ,144 ,851 , + 210 ,1642 ,618 ,1428 ,692 ,1953 ,1332 ,1774 ,1423 ,1591 ,569 ,483 ,1609 ,1658 ,1047 ,1086 , + 1928 ,1056 ,1029 ,142 ,1736 ,1443 ,976 ,2008 ,739 ,91 ,1028 ,143 ,194 ,1112 ,853 ,851 , + 571 ,473 ,290 ,1406 ,1255 ,1623 ,377 ,1774 ,724 ,847 ,1887 ,1388 ,254 ,436 ,332 ,1780 , + 1771 ,1056 ,480 ,546 ,581 ,11 ,326 ,101 ,1743 ,1521 ,1791 ,229 ,1485 ,650 ,853 ,1860 , + 1531 ,1829 ,618 ,1428 ,1003 ,251 ,377 ,717 ,604 ,847 ,883 ,760 ,1930 ,1924 ,905 ,1833 , + 657 ,1549 ,627 ,1264 ,503 ,30 ,397 ,1671 ,228 ,1160 ,1115 ,1471 ,618 ,361 ,2044 ,631 , + 1652 ,1735 ,679 ,1953 ,251 ,253 ,867 ,83 ,273 ,1876 ,335 ,1335 ,680 ,614 ,1585 ,1875 , + 919 ,296 ,1277 ,753 ,617 ,164 ,1213 ,905 ,590 ,628 ,548 ,45 ,1132 ,1585 ,1350 ,1160 , + 84 ,1724 ,1223 ,851 ,704 ,1060 ,87 ,35 ,924 ,1554 ,1258 ,278 ,992 ,1240 ,158 ,1821 , + 1187 ,1626 ,1103 ,1473 ,1489 ,1600 ,1504 ,1011 ,72 ,296 ,908 ,1271 ,707 ,1774 ,1755 ,932 , + 1694 ,1513 ,1829 ,48 ,989 ,835 ,640 ,1925 ,121 ,1905 ,1320 ,1158 ,577 ,1060 ,441 ,258 , + 392 ,1955 ,497 ,727 ,1216 ,629 ,1139 ,1171 ,400 ,2013 ,1689 ,983 ,1027 ,273 ,1189 ,852 , + 1763 ,550 ,1405 ,1787 ,594 ,1323 ,1121 ,1190 ,1825 ,1190 ,1806 ,1082 ,2004 ,813 ,1768 ,1591 , + 1941 ,496 ,1274 ,1821 ,494 ,127 ,1304 ,1244 ,1113 ,1283 ,1135 ,1932 ,683 ,906 ,1671 ,767 , + 502 ,2041 ,450 ,977 ,1772 ,929 ,1747 ,440 ,669 ,1581 ,1677 ,1877 ,341 ,1730 ,842 ,975 , + 713 ,1145 ,1487 ,1875 ,689 ,549 ,1182 ,17 ,1744 ,499 ,453 ,789 ,573 ,1867 ,1728 ,575 , + 1818 ,1538 ,40 ,1288 ,1011 ,1061 ,685 ,241 ,1589 ,1928 ,701 ,1211 ,835 ,409 ,1109 ,755 , + 783 ,1135 ,1485 ,776 ,1840 ,908 ,202 ,1364 ,653 ,1040 ,451 ,1834 ,867 ,331 ,613 ,172 , + 350 ,1076 ,1647 ,876 ,644 ,1968 ,1863 ,1262 ,1080 ,1101 ,1435 ,166 ,467 ,1620 ,2010 ,972 , + 368 ,777 ,911 ,1293 ,678 ,498 ,1412 ,132 ,1032 ,216 ,672 ,1446 ,79 ,1930 ,1062 ,1110 , + 1179 ,1822 ,108 ,224 ,682 ,433 ,1729 ,1062 ,1469 ,1263 ,824 ,1673 ,971 ,1183 ,1224 ,727 , + 70 ,657 ,1030 ,507 ,511 ,1410 ,1775 ,1100 ,518 ,544 ,957 ,565 ,464 ,973 ,1127 ,1806 , + 1005 ,213 ,437 ,475 ,1248 ,1158 ,600 ,121 ,1290 ,915 ,685 ,265 ,1976 ,1376 ,721 ,551 , + 1164 ,676 ,1046 ,1832 ,33 ,327 ,567 ,1754 ,1073 ,1815 ,326 ,1717 ,1883 ,1034 ,1759 ,1191 , + 209 ,1788 ,1255 ,647 ,1577 ,548 ,569 ,189 ,1997 ,159 ,672 ,865 ,1014 ,1855 ,813 ,1390 , + 1689 ,1383 ,1248 ,1207 ,869 ,985 ,1749 ,1762 ,906 ,32 ,261 ,869 ,1347 ,313 ,1438 ,754 , + 552 ,753 ,497 ,1655 ,1269 ,693 ,1714 ,1250 ,689 ,1924 ,263 ,641 ,1346 ,1632 ,1721 ,1724 , + 577 ,1783 ,498 ,1709 ,985 ,182 ,574 ,1481 ,952 ,497 ,1769 ,252 ,1188 ,1848 ,1447 ,868 , + 607 ,8 ,1975 ,1471 ,487 ,1466 ,369 ,1365 ,1273 ,462 ,776 ,318 ,1208 ,868 ,1416 ,14 , + 577 ,397 ,1819 ,959 ,823 ,1867 ,837 ,1893 ,2044 ,345 ,427 ,1040 ,1428 ,1745 ,513 ,1937 , + 200 ,1881 ,890 ,768 ,334 ,507 ,655 ,1918 ,972 ,667 ,1041 ,1954 ,1312 ,1881 ,1440 ,523 , + 754 ,363 ,328 ,1969 ,1797 ,1033 ,904 ,582 ,904 ,2023 ,118 ,295 ,1495 ,1977 ,1014 ,411 , + 1993 ,1106 ,77 ,493 ,1513 ,843 ,526 ,509 ,844 ,164 ,1841 ,69 ,58 ,1585 ,850 ,94 , + 206 ,1821 ,981 ,487 ,19 ,1744 ,352 ,187 ,123 ,511 ,250 ,462 ,1629 ,282 ,1421 ,1084 , + 121 ,873 ,739 ,68 ,691 ,378 ,221 ,1754 ,1557 ,1041 ,1120 ,228 ,295 ,1725 ,591 ,2008 , + 1235 ,737 ,315 ,153 ,1729 ,1381 ,42 ,1968 ,298 ,864 ,194 ,78 ,1603 ,1464 ,1140 ,588 , + 166 ,31 ,264 ,347 ,395 ,1619 ,1417 ,1064 ,8 ,489 ,1255 ,307 ,567 ,1222 ,752 ,1739 , + 1300 ,1600 ,981 ,571 ,2026 ,1420 ,1439 ,408 ,412 ,1279 ,435 ,942 ,512 ,1304 ,1312 ,406 , + 2019 ,1914 ,1455 ,1460 ,1930 ,1271 ,1926 ,215 ,0 ,608 ,1880 ,1226 ,1556 ,142 ,808 ,500 , + 1479 ,482 ,2 ,1399 ,842 ,233 ,1564 ,698 ,206 ,1930 ,1768 ,1349 ,1740 ,320 ,1651 ,182 , + 1967 ,871 ,249 ,676 ,1026 ,770 ,1500 ,1046 ,1695 ,614 ,1829 ,341 ,1564 ,1399 ,1138 ,1142 , + 926 ,198 ,149 ,435 ,402 ,682 ,1622 ,1015 ,2018 ,1681 ,616 ,642 ,1330 ,1198 ,1745 ,810 , + 33 ,600 ,1430 ,264 ,611 ,797 ,120 ,1311 ,573 ,1344 ,1196 ,1083 ,2046 ,1655 ,229 ,1635 , + 1143 ,1014 ,1037 ,465 ,1336 ,951 ,20 ,441 ,1892 ,587 ,744 ,290 ,1978 ,499 ,1987 ,526 , + 169 ,461 ,1471 ,22 ,632 ,819 ,346 ,2040 ,422 ,29 ,933 ,1191 ,972 ,223 ,1568 ,875 , + 1143 ,1223 ,1883 ,421 ,1235 ,1470 ,798 ,118 ,48 ,557 ,256 ,737 ,1205 ,377 ,690 ,494 , + 1928 ,1848 ,1854 ,1447 ,759 ,718 ,1781 ,1243 ,1759 ,1415 ,1899 ,73 ,836 ,648 ,729 ,264 , + 1066 ,1706 ,1672 ,936 ,403 ,335 ,135 ,1077 ,440 ,1681 ,596 ,1565 ,2002 ,191 ,878 ,1212 , + 69 ,886 ,137 ,1495 ,1684 ,761 ,478 ,1271 ,1 ,326 ,119 ,1746 ,1095 ,484 ,1178 ,1786 , + 272 ,184 ,941 ,493 ,820 ,351 ,654 ,435 ,1445 ,1175 ,195 ,330 ,1272 ,1401 ,1330 ,672 , + 562 ,1482 ,7 ,1968 ,1903 ,1927 ,1606 ,759 ,1425 ,537 ,1735 ,116 ,1674 ,1446 ,1769 ,1492 , + 366 ,820 ,593 ,715 ,393 ,1499 ,927 ,963 ,1424 ,1416 ,1768 ,788 ,1900 ,1721 ,1760 ,1036 , + 775 ,317 ,1062 ,35 ,124 ,1938 ,971 ,859 ,656 ,2028 ,970 ,644 ,2039 ,1529 ,1290 ,1584 , + 1056 ,1500 ,1234 ,467 ,1116 ,671 ,1481 ,372 ,1672 ,2046 ,2039 ,1188 ,169 ,1947 ,9 ,1840 , + 503 ,642 ,919 ,687 ,1952 ,530 ,445 ,1364 ,253 ,1127 ,790 ,1756 ,1422 ,1557 ,716 ,686 , + 114 ,429 ,1902 ,712 ,1921 ,1534 ,273 ,925 ,2002 ,709 ,191 ,255 ,410 ,613 ,1554 ,1798 , + 1665 ,645 ,887 ,218 ,1892 ,16 ,1665 ,447 ,1725 ,475 ,803 ,1274 ,581 ,32 ,398 ,1068 , + 929 ,1461 ,408 ,458 ,406 ,172 ,1420 ,305 ,1 ,884 ,1399 ,1544 ,221 ,557 ,540 ,1055 , + 1351 ,1706 ,1174 ,313 ,367 ,341 ,1955 ,164 ,87 ,1429 ,1434 ,1156 ,1117 ,393 ,629 ,679 , + 1894 ,2026 ,1269 ,1420 ,137 ,24 ,494 ,1970 ,404 ,525 ,862 ,598 ,1615 ,1172 ,318 ,159 , + 955 ,127 ,149 ,1383 ,1806 ,371 ,1490 ,1986 ,644 ,158 ,882 ,1308 ,1547 ,1183 ,957 ,595 , + 1020 ,869 ,485 ,2021 ,1517 ,265 ,168 ,1967 ,892 ,1741 ,1881 ,998 ,1362 ,1271 ,149 ,986 , + 1966 ,1204 ,1304 ,1188 ,473 ,1703 ,1761 ,798 ,1332 ,574 ,631 ,1593 ,19 ,1245 ,1097 ,1977 , + 2042 ,1625 ,1029 ,97 ,1542 ,1335 ,976 ,972 ,1972 ,1187 ,1176 ,760 ,108 ,1722 ,1653 ,1218 , + 32 ,1410 ,581 ,82 ,1908 ,866 ,526 ,601 ,73 ,503 ,1984 ,895 ,987 ,907 ,1706 ,1517 , + 1327 ,1370 ,953 ,870 ,670 ,1858 ,1608 ,890 ,1794 ,1955 ,1082 ,89 ,1362 ,1600 ,708 ,369 , + 1685 ,201 ,438 ,1263 ,1764 ,524 ,1504 ,1022 ,805 ,1919 ,747 ,76 ,1031 ,1820 ,482 ,1877 , + 857 ,1744 ,666 ,1450 ,475 ,701 ,666 ,1190 ,1938 ,1187 ,112 ,1912 ,1123 ,1430 ,1821 ,750 , + 748 ,153 ,1219 ,973 ,407 ,348 ,2024 ,963 ,482 ,2002 ,1538 ,1954 ,389 ,14 ,1496 ,1833 , + 716 ,316 ,1559 ,290 ,1736 ,555 ,819 ,1665 ,739 ,91 ,1165 ,36 ,812 ,1430 ,220 ,1448 , + 200 ,646 ,1490 ,1402 ,1908 ,1897 ,1353 ,643 ,1423 ,1291 ,1263 ,951 ,35 ,1440 ,1140 ,1256 , + 84 ,1056 ,783 ,290 ,267 ,1443 ,976 ,1744 ,1210 ,1992 ,1165 ,439 ,1101 ,1430 ,644 ,1684 , + 32 ,646 ,1736 ,1406 ,1908 ,1626 ,377 ,717 ,204 ,973 ,1538 ,951 ,35 ,97 ,1047 ,1752 , + 84 ,243 ,1697 ,290 ,1736 ,1572 ,976 ,1648 ,1210 ,91 ,415 ,143 ,812 ,650 ,144 ,851 , + 1978 ,646 ,290 ,1406 ,1255 ,1067 ,1497 ,1586 ,724 ,973 ,569 ,951 ,35 ,1440 ,332 ,1562 , + 752 ,243 ,1697 ,1348 ,1335 ,1572 ,976 ,2008 ,739 ,91 ,415 ,436 ,1912 ,1430 ,144 ,241 , + 32 ,251 ,1736 ,1402 ,1908 ,1953 ,1332 ,1774 ,724 ,973 ,569 ,951 ,35 ,1688 ,427 ,1752 , + 322 ,522 ,844 ,1274 ,464 ,1280 ,806 ,460 ,1961 ,1242 ,898 ,916 ,1536 ,134 ,910 ,753 , + 1422 ,1795 ,1809 ,1175 ,1095 ,104 ,1746 ,1602 ,402 ,1328 ,938 ,280 ,1414 ,1062 ,33 ,97 , + 801 ,368 ,33 ,634 ,1692 ,1922 ,289 ,1915 ,1547 ,1629 ,440 ,176 ,1011 ,705 ,189 ,45 , + 151 ,191 ,630 ,34 ,1732 ,577 ,115 ,2021 ,606 ,1286 ,522 ,5 ,1005 ,1013 ,204 ,615 , + 856 ,383 ,1749 ,805 ,1778 ,1094 ,603 ,53 ,1218 ,1411 ,159 ,162 ,667 ,1345 ,586 ,837 , + 1933 ,889 ,122 ,1210 ,1296 ,1486 ,1648 ,496 ,1842 ,856 ,1531 ,1706 ,74 ,1588 ,1036 ,191 , + 256 ,1939 ,1544 ,1746 ,194 ,2008 ,1505 ,1427 ,1281 ,414 ,1276 ,1487 ,1303 ,1633 ,471 ,428 , + 254 ,1187 ,1530 ,516 ,1063 ,771 ,1435 ,1156 ,303 ,1935 ,1864 ,237 ,1336 ,1600 ,238 ,1026 , + 1599 ,1526 ,726 ,121 ,1020 ,507 ,87 ,1235 ,1494 ,1182 ,901 ,479 ,2024 ,33 ,678 ,1457 , + 807 ,714 ,325 ,1238 ,704 ,876 ,1656 ,1440 ,904 ,1791 ,46 ,329 ,1390 ,1036 ,1995 ,342 , + 390 ,739 ,1428 ,1364 ,463 ,776 ,83 ,1766 ,357 ,834 ,2019 ,1970 ,944 ,318 ,895 ,457 , + 1538 ,470 ,1656 ,1356 ,345 ,953 ,494 ,1380 ,299 ,1682 ,1733 ,533 ,1239 ,1265 ,363 ,1974 , + 599 ,1552 ,713 ,343 ,1835 ,1425 ,35 ,574 ,1203 ,904 ,703 ,1776 ,1683 ,1907 ,1567 ,418 , + 1103 ,1541 ,456 ,1938 ,1675 ,1651 ,1119 ,271 ,61 ,1905 ,519 ,278 ,1462 ,975 ,350 ,1074 , + 1508 ,99 ,784 ,1958 ,2007 ,1704 ,1375 ,1831 ,1215 ,1608 ,1960 ,1496 ,1274 ,1594 ,1333 ,1291 , + 617 ,257 ,215 ,2 ,466 ,1166 ,544 ,1430 ,1859 ,1903 ,1361 ,1428 ,1872 ,1379 ,1949 ,503 , + 552 ,1657 ,1384 ,1592 ,1643 ,678 ,1017 ,1480 ,1148 ,294 ,1149 ,1971 ,475 ,1317 ,110 ,501 , + 1905 ,1093 ,1599 ,1848 ,1489 ,1260 ,1475 ,65 ,1849 ,1687 ,1068 ,704 ,1385 ,351 ,1616 ,269 , + 1600 ,1081 ,1602 ,1222 ,215 ,477 ,35 ,135 ,612 ,1462 ,1754 ,1493 ,853 ,386 ,347 ,61 , + 1112 ,1289 ,926 ,1011 ,1468 ,1408 ,1105 ,1390 ,779 ,1289 ,455 ,424 ,11 ,1298 ,1818 ,1097 , + 2001 ,1802 ,939 ,1850 ,1629 ,911 ,76 ,801 ,691 ,655 ,255 ,264 ,1020 ,86 ,1329 ,1223 , + 1733 ,1292 ,1193 ,743 ,820 ,373 ,830 ,1956 ,1894 ,113 ,840 ,728 ,705 ,1981 ,974 ,457 , + 1029 ,472 ,202 ,1009 ,852 ,1391 ,1377 ,1058 ,655 ,1089 ,1333 ,459 ,1098 ,1996 ,37 ,594 , + 849 ,466 ,26 ,1746 ,187 ,1112 ,325 ,1888 ,1560 ,742 ,139 ,1373 ,893 ,547 ,1440 ,333 , + 877 ,1624 ,1606 ,1965 ,1572 ,477 ,456 ,226 ,1042 ,956 ,711 ,1341 ,497 ,167 ,215 ,1995 , + 1350 ,1975 ,1810 ,558 ,1897 ,3 ,602 ,2035 ,1294 ,755 ,1196 ,869 ,472 ,614 ,403 ,1756 , + 1015 ,1244 ,1583 ,1336 ,1708 ,1399 ,914 ,782 ,1152 ,18 ,1895 ,1869 ,1389 ,1605 ,1618 ,1973 , + 156 ,1068 ,2016 ,828 ,1285 ,1970 ,1503 ,561 ,1506 ,501 ,1684 ,581 ,759 ,394 ,2002 ,989 , + 984 ,1442 ,1529 ,1944 ,589 ,601 ,2015 ,1840 ,1051 ,568 ,1965 ,1633 ,2006 ,338 ,530 ,1694 , + 144 ,127 ,1893 ,1003 ,276 ,990 ,2033 ,2045 ,1635 ,1099 ,694 ,246 ,1434 ,606 ,736 ,2047 , + 1428 ,1599 ,615 ,1294 ,281 ,1894 ,1639 ,409 ,443 ,218 ,2046 ,679 ,1673 ,1274 ,139 ,1986 , + 1968 ,1649 ,1542 ,354 ,66 ,584 ,645 ,1558 ,1116 ,775 ,680 ,1557 ,1254 ,256 ,1037 ,961 , + 854 ,1841 ,643 ,1874 ,1897 ,1363 ,687 ,1747 ,1460 ,329 ,595 ,1371 ,880 ,55 ,1889 ,1184 , + 1120 ,1808 ,1700 ,1396 ,1750 ,1208 ,1416 ,204 ,1900 ,426 ,1785 ,770 ,1052 ,173 ,1256 ,2030 , + 527 ,963 ,1273 ,499 ,1983 ,844 ,667 ,1127 ,1079 ,168 ,726 ,1487 ,1772 ,91 ,1571 ,1453 , + 1691 ,1926 ,1561 ,895 ,1869 ,809 ,1782 ,770 ,1265 ,820 ,889 ,755 ,833 ,901 ,1494 ,985 , + 1108 ,286 ,1212 ,1034 ,1837 ,1335 ,410 ,1602 ,1770 ,122 ,1422 ,240 ,1875 ,1600 ,1121 ,583 , + 959 ,510 ,715 ,150 ,1951 ,918 ,1357 ,1574 ,273 ,130 ,1886 ,732 ,1521 ,883 ,1275 ,643 , + 1327 ,725 ,710 ,744 ,1994 ,1773 ,540 ,172 ,156 ,1534 ,1406 ,183 ,1352 ,665 ,131 ,641 , + 769 ,392 ,388 ,1626 ,425 ,652 ,1105 ,582 ,1339 ,32 ,979 ,2004 ,1744 ,1054 ,1761 ,710 , + 1640 ,669 ,1121 ,1287 ,1837 ,317 ,2041 ,1744 ,1186 ,605 ,1223 ,1107 ,812 ,1722 ,1844 ,1785 , + 1385 ,1670 ,75 ,973 ,555 ,1067 ,526 ,659 ,1415 ,49 ,794 ,483 ,285 ,2015 ,989 ,1199 , + 666 ,1056 ,1029 ,546 ,267 ,1030 ,825 ,2008 ,1437 ,91 ,662 ,436 ,1912 ,1430 ,644 ,241 , + 1978 ,251 ,1736 ,1406 ,610 ,1626 ,1353 ,1586 ,1423 ,847 ,1538 ,951 ,35 ,1440 ,332 ,1562 , + 1140 ,255 ,376 ,596 ,278 ,529 ,945 ,1142 ,637 ,950 ,1461 ,741 ,495 ,1965 ,128 ,1190 , + 531 ,529 ,712 ,152 ,877 ,1056 ,500 ,501 ,473 ,1963 ,1910 ,601 ,1616 ,1229 ,130 ,438 , + 510 ,54 ,402 ,1289 ,1696 ,823 ,894 ,275 ,1195 ,1943 ,1220 ,904 ,1933 ,1290 ,876 ,1556 , + 1985 ,2016 ,1176 ,877 ,23 ,1774 ,1257 ,536 ,118 ,1615 ,297 ,1890 ,83 ,888 ,513 ,1894 , + 1697 ,1582 ,375 ,1654 ,1363 ,821 ,1647 ,309 ,1362 ,1927 ,1741 ,1777 ,902 ,2036 ,846 ,1867 , + 1862 ,498 ,1431 ,1028 ,1612 ,1629 ,1405 ,1125 ,874 ,265 ,139 ,314 ,1693 ,1313 ,359 ,682 , + 985 ,1425 ,1278 ,1822 ,388 ,430 ,943 ,1064 ,1670 ,199 ,1247 ,461 ,865 ,366 ,439 ,1024 , + 1117 ,510 ,950 ,1333 ,1241 ,1902 ,1290 ,1412 ,1145 ,1166 ,2012 ,1839 ,1149 ,1100 ,51 ,1877 , + 1540 ,665 ,1220 ,853 ,8 ,1143 ,1868 ,1553 ,1724 ,1781 ,1982 ,1104 ,495 ,832 ,1510 ,1145 , + 1229 ,1549 ,730 ,1438 ,1453 ,1734 ,841 ,1378 ,203 ,1720 ,286 ,1204 ,1680 ,1266 ,584 ,994 , + 1495 ,580 ,228 ,1984 ,1947 ,813 ,1795 ,1936 ,1201 ,1106 ,1762 ,524 ,883 ,1179 ,1223 ,1657 , + 1018 ,404 ,1619 ,937 ,394 ,1267 ,1759 ,679 ,1997 ,279 ,456 ,609 ,1895 ,157 ,267 ,1354 , + 1373 ,903 ,1399 ,1966 ,298 ,256 ,1773 ,1467 ,1485 ,1352 ,1882 ,1140 ,523 ,1827 ,504 ,878 , + 1467 ,567 ,2028 ,96 ,830 ,1679 ,144 ,281 ,1720 ,195 ,1886 ,391 ,136 ,1216 ,1542 ,1786 , + 1929 ,124 ,1766 ,1623 ,765 ,671 ,372 ,1304 ,549 ,626 ,1534 ,1384 ,42 ,1656 ,1714 ,1171 , + 684 ,1422 ,209 ,1767 ,1862 ,684 ,1989 ,961 ,993 ,1869 ,728 ,873 ,1413 ,1502 ,1545 ,581 , + 1575 ,1905 ,1329 ,1381 ,290 ,1305 ,1236 ,735 ,312 ,1128 ,1058 ,1435 ,789 ,137 ,444 ,1444 , + 1729 ,64 ,1185 ,1745 ,355 ,1057 ,44 ,2025 ,814 ,19 ,1118 ,141 ,892 ,874 ,1391 ,422 , + 535 ,1632 ,497 ,1070 ,1403 ,1548 ,42 ,250 ,635 ,956 ,43 ,113 ,334 ,332 ,1949 ,59 , + 1678 ,1916 ,1535 ,108 ,373 ,1659 ,608 ,1908 ,934 ,1400 ,561 ,21 ,958 ,1138 ,1720 ,763 , + 1120 ,1383 ,728 ,1110 ,1044 ,1330 ,1646 ,1172 ,1765 ,1223 ,626 ,1094 ,1195 ,1731 ,1512 ,1093 , + 1196 ,173 ,447 ,271 ,1433 ,92 ,1976 ,1907 ,1157 ,151 ,479 ,1936 ,1960 ,1643 ,1698 ,963 , + 1307 ,397 ,886 ,1287 ,1168 ,906 ,837 ,1084 ,1858 ,671 ,1867 ,889 ,681 ,1094 ,1821 ,1196 , + 1999 ,562 ,1012 ,45 ,281 ,111 ,1093 ,582 ,306 ,253 ,1272 ,107 ,1745 ,1511 ,1194 ,1211 , + 1557 ,1520 ,919 ,309 ,1168 ,680 ,2041 ,156 ,739 ,671 ,1704 ,253 ,1101 ,1024 ,1630 ,851 , + 445 ,1642 ,766 ,1402 ,1908 ,1491 ,1453 ,717 ,2044 ,1772 ,47 ,597 ,285 ,97 ,489 ,1562 , + 222 ,1168 ,1456 ,1034 ,1873 ,347 ,712 ,1648 ,1866 ,345 ,1458 ,240 ,1908 ,650 ,144 ,241 , + 306 ,1821 ,655 ,1406 ,811 ,39 ,1154 ,1774 ,204 ,973 ,1016 ,929 ,987 ,97 ,1477 ,492 , + 1737 ,901 ,1531 ,691 ,1019 ,964 ,1359 ,549 ,1407 ,605 ,516 ,240 ,1824 ,1527 ,496 ,44 , + 1663 ,1829 ,1928 ,1148 ,11 ,348 ,101 ,804 ,1088 ,459 ,2034 ,1680 ,239 ,1864 ,171 ,1082 , + 1737 ,222 ,573 ,421 ,490 ,91 ,183 ,382 ,1423 ,1868 ,680 ,497 ,1867 ,1872 ,792 ,1698 , + 968 ,1043 ,1096 ,223 ,461 ,1503 ,297 ,1567 ,1739 ,517 ,542 ,1752 ,57 ,1240 ,261 ,1163 , + 1863 ,1245 ,552 ,217 ,1763 ,2044 ,523 ,1245 ,1975 ,269 ,819 ,25 ,1921 ,1102 ,1224 ,1424 , + 908 ,1436 ,943 ,526 ,1327 ,1781 ,596 ,1427 ,725 ,1616 ,1335 ,1982 ,1109 ,1468 ,1060 ,1477 , + 45 ,750 ,920 ,1964 ,81 ,757 ,866 ,754 ,1476 ,1779 ,1995 ,1964 ,1362 ,136 ,167 ,721 , + 669 ,1730 ,568 ,1678 ,551 ,2018 ,653 ,1450 ,570 ,1471 ,19 ,354 ,1043 ,1234 ,929 ,19 , + 411 ,1851 ,1626 ,921 ,932 ,1540 ,1607 ,101 ,1629 ,1439 ,9 ,497 ,1717 ,1076 ,381 ,1848 , + 960 ,2029 ,902 ,493 ,533 ,2030 ,624 ,516 ,880 ,215 ,29 ,845 ,500 ,1377 ,1335 ,1126 , + 639 ,1295 ,1586 ,596 ,382 ,744 ,840 ,1204 ,747 ,1239 ,1846 ,1118 ,1143 ,996 ,510 ,991 , + 464 ,1072 ,1514 ,893 ,656 ,1512 ,1473 ,1691 ,312 ,830 ,703 ,482 ,815 ,801 ,1074 ,741 , + 337 ,856 ,509 ,284 ,1609 ,853 ,377 ,1986 ,1350 ,530 ,1138 ,1663 ,788 ,1792 ,706 ,812 , + 364 ,387 ,50 ,1009 ,969 ,50 ,1292 ,770 ,583 ,432 ,846 ,1383 ,1699 ,752 ,1624 ,2010 , + 1539 ,531 ,1649 ,1294 ,665 ,241 ,2007 ,1268 ,1843 ,363 ,1581 ,1764 ,1891 ,1774 ,1913 ,1536 , + 788 ,817 ,469 ,138 ,343 ,1752 ,201 ,1855 ,791 ,975 ,1380 ,550 ,1727 ,336 ,999 ,298 , + 1144 ,468 ,435 ,1385 ,311 ,664 ,856 ,440 ,682 ,890 ,1463 ,935 ,1919 ,999 ,1382 ,1408 , + 1053 ,1705 ,1789 ,1360 ,1563 ,601 ,523 ,973 ,987 ,1244 ,775 ,1519 ,1697 ,1764 ,1896 ,624 , + 376 ,1829 ,348 ,734 ,1062 ,486 ,1358 ,1370 ,1729 ,724 ,1877 ,1087 ,946 ,325 ,1359 ,1847 , + 868 ,1682 ,1259 ,1502 ,1470 ,1718 ,1927 ,1123 ,1131 ,279 ,388 ,1576 ,1575 ,1910 ,119 ,784 , + 1335 ,661 ,277 ,1504 ,2018 ,1846 ,811 ,470 ,231 ,300 ,1868 ,924 ,1189 ,1613 ,102 ,138 , + 1925 ,611 ,1611 ,1588 ,868 ,239 ,422 ,1123 ,871 ,270 ,560 ,1016 ,215 ,541 ,305 ,126 , + 1849 ,1985 ,130 ,1528 ,1625 ,1968 ,987 ,1999 ,1714 ,1703 ,1822 ,1922 ,1522 ,519 ,374 ,1648 , + 1732 ,1102 ,138 ,1097 ,658 ,70 ,1518 ,256 ,914 ,1886 ,147 ,1926 ,99 ,2034 ,2023 ,1857 , + 1543 ,707 ,1889 ,727 ,66 ,567 ,1446 ,1318 ,1885 ,389 ,442 ,702 ,1882 ,298 ,822 ,803 , + 1132 ,1076 ,6 ,2002 ,643 ,1394 ,757 ,1252 ,1454 ,491 ,203 ,109 ,923 ,1818 ,1295 ,1507 , + 871 ,306 ,640 ,1975 ,1918 ,1842 ,959 ,1035 ,1495 ,411 ,1640 ,1083 ,904 ,1583 ,675 ,1096 , + 59 ,1541 ,1316 ,1262 ,292 ,1142 ,2037 ,482 ,1625 ,931 ,296 ,175 ,1721 ,461 ,1156 ,810 , + 1240 ,1462 ,402 ,482 ,1623 ,1468 ,960 ,356 ,148 ,233 ,1333 ,1528 ,1986 ,602 ,50 ,1663 , + 242 ,31 ,1162 ,1706 ,1591 ,706 ,1619 ,1921 ,297 ,2037 ,143 ,852 ,1695 ,516 ,716 ,802 , + 1513 ,46 ,1360 ,1873 ,1330 ,815 ,1066 ,1988 ,74 ,443 ,129 ,1764 ,1923 ,621 ,334 ,504 , + 1252 ,66 ,1495 ,146 ,250 ,854 ,532 ,869 ,1082 ,1019 ,927 ,1544 ,1284 ,104 ,374 ,746 , + 475 ,431 ,1237 ,1318 ,1625 ,207 ,856 ,773 ,1374 ,807 ,549 ,163 ,355 ,605 ,1855 ,33 , + 1716 ,33 ,1748 ,1032 ,1992 ,1614 ,1905 ,931 ,594 ,1745 ,141 ,2019 ,1472 ,583 ,708 ,1892 , + 110 ,1878 ,1062 ,1796 ,1824 ,1998 ,64 ,887 ,234 ,1770 ,321 ,1004 ,503 ,1066 ,1926 ,1220 , + 31 ,206 ,1350 ,2039 ,1748 ,1258 ,808 ,1185 ,1346 ,1294 ,280 ,1828 ,1216 ,680 ,834 ,598 , + 1428 ,1780 ,1159 ,1295 ,503 ,1691 ,1867 ,680 ,114 ,543 ,803 ,1699 ,678 ,1786 ,1267 ,1740 , + 1933 ,1982 ,1775 ,1150 ,266 ,1826 ,1549 ,1886 ,690 ,617 ,1503 ,1175 ,10 ,12 ,273 ,950 , + 1136 ,905 ,825 ,1640 ,271 ,959 ,1080 ,1973 ,1956 ,1038 ,1735 ,1479 ,1170 ,475 ,1716 ,960 , + 1009 ,1168 ,1541 ,223 ,178 ,1928 ,321 ,1094 ,1796 ,1824 ,1136 ,1932 ,680 ,1868 ,878 ,1333 , + 166 ,1381 ,798 ,205 ,1081 ,1253 ,2017 ,837 ,1092 ,397 ,1205 ,1813 ,1642 ,935 ,698 ,1086 , + 1753 ,1341 ,1064 ,1735 ,147 ,419 ,999 ,525 ,1091 ,423 ,897 ,1442 ,1867 ,662 ,962 ,535 , + 527 ,1433 ,302 ,299 ,1161 ,202 ,13 ,1926 ,1300 ,712 ,2018 ,762 ,583 ,1016 ,1221 ,719 , + 1001 ,824 ,826 ,789 ,1396 ,1063 ,1646 ,427 ,987 ,1738 ,527 ,912 ,1857 ,920 ,647 ,954 , + 702 ,54 ,1276 ,746 ,1114 ,882 ,41 ,1739 ,959 ,861 ,1588 ,1555 ,1544 ,1999 ,1399 ,45 , + 683 ,634 ,1072 ,835 ,1634 ,396 ,1100 ,1363 ,1709 ,1600 ,1117 ,280 ,1837 ,521 ,394 ,1221 , + 1327 ,1966 ,530 ,1918 ,1142 ,436 ,1000 ,1771 ,453 ,1825 ,1515 ,124 ,374 ,990 ,329 ,1031 , + 835 ,1342 ,637 ,1739 ,498 ,858 ,1288 ,1189 ,879 ,1085 ,1810 ,1796 ,79 ,1737 ,1166 ,222 , + 2034 ,783 ,1028 ,1941 ,1053 ,689 ,641 ,1422 ,875 ,815 ,881 ,943 ,1979 ,342 ,2016 ,859 , + 1611 ,1178 ,75 ,1415 ,1990 ,1710 ,352 ,284 ,1451 ,1626 ,1880 ,1395 ,1427 ,627 ,798 ,1014 , + 1155 ,1119 ,552 ,1917 ,1837 ,1992 ,1874 ,488 ,887 ,624 ,1246 ,2034 ,1059 ,381 ,921 ,1814 , + 1050 ,177 ,279 ,157 ,143 ,1884 ,1397 ,1915 ,850 ,1537 ,920 ,302 ,1052 ,397 ,1964 ,1417 , + 1648 ,600 ,2040 ,1712 ,840 ,210 ,257 ,507 ,2021 ,954 ,1028 ,1625 ,1352 ,968 ,672 ,923 , + 1917 ,642 ,1017 ,1853 ,553 ,1621 ,1019 ,1216 ,1769 ,485 ,133 ,1845 ,769 ,1687 ,895 ,1454 , + 45 ,600 ,552 ,1136 ,946 ,482 ,825 ,1291 ,6 ,1238 ,1017 ,1625 ,122 ,1213 ,788 ,849 , + 402 ,1798 ,1518 ,1116 ,1487 ,1488 ,526 ,874 ,848 ,1370 ,421 ,1873 ,1351 ,1864 ,271 ,1276 , + 1569 ,1625 ,1873 ,290 ,378 ,1030 ,819 ,1648 ,785 ,91 ,2006 ,88 ,1908 ,1714 ,1653 ,741 , + 1906 ,963 ,359 ,1406 ,123 ,128 ,107 ,753 ,1205 ,1932 ,2034 ,69 ,369 ,1983 ,1196 ,1276 , + 148 ,477 ,1178 ,1334 ,481 ,555 ,1978 ,1648 ,1210 ,278 ,34 ,77 ,194 ,1839 ,469 ,1684 , + 32 ,646 ,1490 ,1853 ,813 ,154 ,377 ,963 ,724 ,1275 ,115 ,1845 ,35 ,77 ,332 ,573 , + 969 ,1625 ,1535 ,1334 ,998 ,1198 ,666 ,151 ,1866 ,1992 ,2006 ,253 ,1736 ,1969 ,1653 ,1684 , + 1531 ,1429 ,1065 ,1116 ,1908 ,1626 ,1497 ,1614 ,711 ,847 ,1538 ,483 ,1252 ,1999 ,1216 ,1454 , + 969 ,1625 ,234 ,1334 ,1736 ,2010 ,160 ,1180 ,469 ,1542 ,2006 ,77 ,1485 ,2034 ,1653 ,1448 , + 32 ,554 ,1758 ,1402 ,407 ,1897 ,1497 ,1774 ,724 ,973 ,1016 ,1431 ,559 ,436 ,489 ,1562 , + 148 ,477 ,1178 ,1334 ,481 ,1030 ,1978 ,1665 ,1686 ,1370 ,1165 ,77 ,1123 ,2034 ,144 ,1684 , + 1978 ,1419 ,290 ,443 ,692 ,1135 ,377 ,1774 ,711 ,1343 ,569 ,1458 ,1034 ,1864 ,1349 ,1956 , + 293 ,1056 ,1697 ,290 ,481 ,1572 ,976 ,1665 ,1437 ,1542 ,118 ,1383 ,1736 ,1714 ,853 ,667 , + 1978 ,1419 ,1490 ,1406 ,692 ,1626 ,1332 ,1774 ,1423 ,847 ,1538 ,483 ,35 ,1688 ,1140 ,1752 , + 1850 ,1056 ,1697 ,1348 ,481 ,347 ,666 ,2008 ,1437 ,374 ,662 ,439 ,1101 ,650 ,144 ,851 , + 32 ,646 ,618 ,1406 ,610 ,1626 ,1332 ,1586 ,724 ,973 ,569 ,483 ,1995 ,14 ,427 ,1562 , + 771 ,1192 ,1877 ,2030 ,1598 ,1469 ,2002 ,714 ,1881 ,1379 ,1847 ,467 ,705 ,1287 ,1946 ,1117 , + 712 ,533 ,698 ,1437 ,1005 ,839 ,1050 ,1457 ,2006 ,1447 ,1053 ,1791 ,989 ,1240 ,1716 ,1009 , + 361 ,86 ,1430 ,330 ,646 ,789 ,710 ,529 ,831 ,215 ,213 ,463 ,1375 ,1370 ,491 ,1733 , + 1485 ,1040 ,197 ,1811 ,1810 ,1346 ,419 ,1819 ,270 ,249 ,1552 ,1885 ,977 ,547 ,778 ,381 , + 1748 ,1985 ,547 ,1400 ,909 ,1326 ,1992 ,1715 ,296 ,1498 ,120 ,356 ,964 ,1097 ,813 ,799 , + 1536 ,35 ,38 ,1889 ,1155 ,1127 ,117 ,1743 ,446 ,913 ,318 ,1979 ,810 ,273 ,1515 ,1123 , + 1507 ,823 ,932 ,549 ,1910 ,1024 ,1097 ,1019 ,1585 ,1756 ,1107 ,1405 ,1892 ,1143 ,1749 ,1335 , + 1394 ,1158 ,590 ,332 ,1826 ,1062 ,441 ,532 ,1094 ,902 ,1177 ,347 ,1274 ,655 ,2002 ,2037 , + 1020 ,1987 ,869 ,1276 ,2000 ,1782 ,216 ,1266 ,2000 ,1034 ,1348 ,1121 ,1859 ,484 ,1420 ,1367 , + 322 ,755 ,1021 ,830 ,1259 ,1093 ,365 ,2012 ,853 ,1013 ,1445 ,1882 ,1349 ,475 ,448 ,345 , + 577 ,1482 ,323 ,417 ,1935 ,716 ,1067 ,977 ,1431 ,1170 ,738 ,889 ,1978 ,1213 ,1776 ,721 , + 673 ,2014 ,633 ,2018 ,789 ,1771 ,101 ,646 ,1105 ,1310 ,280 ,119 ,397 ,1999 ,1611 ,672 , + 863 ,1922 ,742 ,1628 ,1677 ,1520 ,1741 ,800 ,336 ,471 ,111 ,242 ,1132 ,510 ,1537 ,984 , + 720 ,1922 ,2047 ,353 ,815 ,1329 ,608 ,1048 ,931 ,1136 ,169 ,1044 ,1466 ,833 ,928 ,791 , + 939 ,320 ,71 ,1243 ,920 ,702 ,118 ,1355 ,376 ,1176 ,742 ,1663 ,1722 ,665 ,403 ,1306 , + 329 ,214 ,772 ,1844 ,1374 ,620 ,798 ,1185 ,1823 ,1916 ,83 ,279 ,1653 ,1613 ,1621 ,1728 , + 939 ,1635 ,147 ,1947 ,1624 ,683 ,1293 ,1361 ,995 ,1836 ,935 ,201 ,1041 ,1126 ,1490 ,1779 , + 970 ,633 ,1043 ,495 ,839 ,79 ,598 ,965 ,980 ,1876 ,1249 ,20 ,1900 ,1347 ,1864 ,1753 , + 939 ,636 ,1953 ,1377 ,493 ,1292 ,1457 ,940 ,2029 ,942 ,528 ,1066 ,1112 ,1275 ,699 ,654 , + 1756 ,835 ,1012 ,1039 ,665 ,1306 ,44 ,1372 ,347 ,1583 ,857 ,240 ,260 ,521 ,1568 ,260 , + 1196 ,1912 ,953 ,414 ,1628 ,964 ,1759 ,833 ,1483 ,883 ,40 ,1887 ,1339 ,1246 ,1088 ,1532 , + 1583 ,1295 ,1903 ,706 ,2009 ,1025 ,1194 ,1857 ,1571 ,330 ,1636 ,1898 ,0 ,916 ,1870 ,519 , + 801 ,1122 ,249 ,634 ,1842 ,541 ,866 ,1842 ,1956 ,169 ,1009 ,267 ,1724 ,383 ,102 ,1213 , + 1747 ,307 ,1872 ,683 ,1053 ,1460 ,71 ,366 ,1331 ,1482 ,917 ,1346 ,1351 ,240 ,1070 ,1052 , + 976 ,1851 ,156 ,139 ,264 ,726 ,379 ,836 ,868 ,193 ,198 ,1764 ,1268 ,1399 ,1672 ,1915 , + 1195 ,1757 ,407 ,1361 ,1060 ,1163 ,435 ,1104 ,1424 ,1765 ,1040 ,1227 ,394 ,1302 ,466 ,1822 , + 322 ,398 ,503 ,1527 ,1607 ,1012 ,824 ,306 ,30 ,1057 ,973 ,347 ,472 ,1560 ,959 ,963 , + 1229 ,362 ,595 ,1888 ,234 ,1800 ,1197 ,1969 ,1595 ,1734 ,1514 ,1942 ,1102 ,295 ,186 ,393 , + 1475 ,1688 ,716 ,1263 ,2015 ,535 ,1004 ,717 ,540 ,1642 ,951 ,1858 ,555 ,1647 ,1433 ,96 , + 1822 ,1843 ,1444 ,1435 ,352 ,71 ,136 ,399 ,1278 ,1734 ,788 ,810 ,345 ,490 ,536 ,743 , + 801 ,963 ,588 ,1753 ,1044 ,345 ,1431 ,1181 ,1023 ,1102 ,398 ,593 ,1086 ,846 ,1756 ,776 , + 1647 ,898 ,28 ,2047 ,1559 ,118 ,1063 ,13 ,1422 ,133 ,684 ,1143 ,130 ,33 ,261 ,532 , + 801 ,575 ,385 ,297 ,1843 ,437 ,474 ,1854 ,661 ,957 ,1566 ,816 ,834 ,1114 ,677 ,778 , + 262 ,1356 ,185 ,1375 ,2023 ,589 ,1815 ,1156 ,1747 ,41 ,46 ,1294 ,1850 ,693 ,1607 ,1860 , + 1475 ,1186 ,1909 ,549 ,656 ,1139 ,479 ,1026 ,1452 ,1677 ,1410 ,1226 ,955 ,1524 ,1730 ,303 , + 1949 ,1267 ,561 ,1923 ,2007 ,1656 ,125 ,763 ,106 ,1695 ,494 ,1894 ,846 ,13 ,1763 ,676 , + 1212 ,1005 ,514 ,1055 ,186 ,246 ,822 ,397 ,517 ,732 ,5 ,1005 ,1354 ,1730 ,777 ,176 , + 732 ,278 ,730 ,1350 ,437 ,2011 ,680 ,769 ,310 ,1506 ,955 ,242 ,1323 ,224 ,315 ,640 , + 947 ,973 ,1401 ,706 ,1566 ,418 ,95 ,1818 ,106 ,1791 ,488 ,1668 ,1682 ,629 ,1845 ,340 , + 425 ,1605 ,1616 ,189 ,107 ,1375 ,1437 ,1391 ,414 ,915 ,1832 ,364 ,222 ,1051 ,474 ,1500 , + 1174 ,821 ,1368 ,549 ,1894 ,1527 ,908 ,993 ,991 ,1183 ,724 ,773 ,1591 ,170 ,1999 ,1813 , + 404 ,299 ,1731 ,1799 ,55 ,349 ,43 ,729 ,230 ,1318 ,104 ,1050 ,471 ,325 ,1217 ,622 , + 390 ,1567 ,478 ,1394 ,1379 ,1229 ,803 ,1093 ,677 ,245 ,1182 ,1838 ,265 ,876 ,121 ,1205 , + 223 ,1666 ,1706 ,2011 ,391 ,417 ,244 ,1530 ,1355 ,1651 ,767 ,886 ,181 ,1317 ,1417 ,838 , + 599 ,909 ,648 ,1386 ,809 ,1958 ,807 ,1838 ,1215 ,298 ,258 ,1522 ,997 ,1632 ,1257 ,1538 , + 1694 ,257 ,112 ,1963 ,1466 ,656 ,1739 ,1441 ,1197 ,1522 ,898 ,11 ,547 ,508 ,55 ,912 , + 974 ,845 ,1389 ,1821 ,869 ,1371 ,1229 ,1747 ,501 ,1452 ,1879 ,806 ,1674 ,205 ,1372 ,1959 , + 1146 ,1182 ,1256 ,127 ,269 ,111 ,327 ,1792 ,285 ,693 ,1495 ,160 ,128 ,980 ,1376 ,667 , + 734 ,905 ,705 ,1309 ,328 ,928 ,1605 ,851 ,227 ,1677 ,1108 ,1403 ,239 ,281 ,671 ,547 , + 465 ,791 ,2022 ,1919 ,1727 ,826 ,474 ,1698 ,691 ,923 ,599 ,1444 ,975 ,1973 ,216 ,735 , + 1990 ,563 ,1853 ,1714 ,1024 ,1036 ,1299 ,1376 ,1231 ,206 ,1252 ,165 ,1551 ,1613 ,1643 ,1108 , + 132 ,564 ,1593 ,1419 ,289 ,1925 ,1910 ,202 ,1963 ,987 ,1918 ,9 ,1653 ,630 ,1859 ,985 , + 306 ,523 ,1593 ,1358 ,1509 ,48 ,1740 ,875 ,327 ,1933 ,250 ,194 ,500 ,701 ,1242 ,1715 , + 1712 ,809 ,1056 ,398 ,764 ,1116 ,322 ,1644 ,287 ,1048 ,288 ,1313 ,1398 ,1738 ,552 ,2025 , + 866 ,682 ,1125 ,1921 ,61 ,1706 ,366 ,1081 ,172 ,1120 ,615 ,470 ,412 ,982 ,2008 ,1514 , + 430 ,1840 ,1803 ,180 ,802 ,1292 ,1694 ,816 ,1609 ,2011 ,104 ,1336 ,1683 ,1421 ,397 ,1960 , + 1020 ,286 ,1616 ,184 ,1197 ,1697 ,1613 ,727 ,1288 ,505 ,922 ,334 ,1738 ,100 ,1719 ,585 , + 1207 ,214 ,168 ,1636 ,1503 ,1779 ,1977 ,1770 ,644 ,782 ,183 ,931 ,962 ,738 ,859 ,632 , + 1255 ,91 ,537 ,1894 ,1801 ,1697 ,837 ,944 ,1186 ,1384 ,1037 ,1062 ,1300 ,1932 ,1821 ,1591 , + 920 ,1491 ,1736 ,955 ,608 ,585 ,743 ,1093 ,1205 ,531 ,133 ,1672 ,1571 ,1115 ,1561 ,1759 , + 771 ,572 ,2 ,436 ,1427 ,195 ,148 ,1172 ,1158 ,1420 ,1557 ,1284 ,750 ,1069 ,1406 ,707 , + 969 ,1262 ,989 ,527 ,1633 ,43 ,2022 ,2002 ,1175 ,1192 ,1733 ,1186 ,1958 ,1358 ,338 ,512 , + 361 ,1797 ,417 ,1887 ,1678 ,1974 ,1015 ,1578 ,1944 ,1375 ,1286 ,206 ,504 ,599 ,690 ,76 , + 1863 ,335 ,1159 ,201 ,1826 ,654 ,1479 ,1840 ,471 ,142 ,2003 ,1244 ,1476 ,1043 ,2033 ,102 , + 1748 ,1745 ,1836 ,1484 ,1221 ,792 ,1860 ,1256 ,449 ,1550 ,630 ,340 ,436 ,1371 ,188 ,779 , + 422 ,1117 ,1241 ,690 ,1244 ,63 ,1315 ,1746 ,1069 ,1475 ,1975 ,1301 ,1531 ,383 ,1871 ,1179 , + 383 ,1936 ,443 ,211 ,1106 ,39 ,934 ,359 ,1138 ,942 ,1412 ,1240 ,160 ,1483 ,348 ,431 , + 853 ,283 ,754 ,1648 ,1321 ,44 ,989 ,1913 ,262 ,197 ,993 ,1318 ,1973 ,946 ,449 ,1352 , + 854 ,250 ,816 ,1309 ,1670 ,572 ,736 ,1815 ,797 ,1611 ,1441 ,114 ,1985 ,196 ,1416 ,186 , + 46 ,1806 ,162 ,433 ,1302 ,1125 ,1368 ,629 ,115 ,44 ,165 ,1831 ,1865 ,1537 ,762 ,724 , + 974 ,1667 ,349 ,725 ,486 ,1169 ,221 ,1753 ,201 ,1306 ,913 ,1791 ,1910 ,742 ,525 ,787 , + 1919 ,253 ,97 ,1316 ,1394 ,906 ,56 ,1620 ,24 ,1556 ,177 ,389 ,2011 ,12 ,757 ,940 , + 649 ,1936 ,1214 ,789 ,509 ,631 ,922 ,1221 ,744 ,1451 ,1024 ,172 ,1610 ,43 ,1102 ,807 , + 57 ,1137 ,1452 ,403 ,861 ,510 ,391 ,1209 ,827 ,880 ,390 ,579 ,52 ,859 ,1662 ,1178 , + 429 ,675 ,1817 ,211 ,211 ,1580 ,1500 ,1599 ,24 ,511 ,409 ,983 ,1339 ,1880 ,1037 ,288 , + 1585 ,1589 ,1174 ,1252 ,111 ,1850 ,1230 ,1562 ,2032 ,1861 ,823 ,572 ,578 ,90 ,47 ,583 , + 944 ,1971 ,705 ,543 ,72 ,411 ,1701 ,867 ,814 ,34 ,747 ,1250 ,1472 ,1014 ,184 ,988 , + 1890 ,919 ,1695 ,1787 ,1958 ,724 ,424 ,556 ,51 ,1556 ,1312 ,1968 ,11 ,1566 ,342 ,1195 , + 427 ,496 ,1401 ,1858 ,337 ,1474 ,882 ,968 ,1172 ,890 ,1572 ,19 ,112 ,1613 ,1273 ,1197 , + 62 ,379 ,52 ,1232 ,1867 ,897 ,985 ,985 ,941 ,1344 ,372 ,1660 ,2006 ,841 ,157 ,1868 , + 229 ,1246 ,1399 ,1875 ,566 ,1713 ,384 ,203 ,291 ,1872 ,1423 ,762 ,608 ,1748 ,1094 ,1259 , + 336 ,1729 ,440 ,1128 ,922 ,1733 ,1254 ,248 ,750 ,984 ,1240 ,218 ,1547 ,38 ,2027 ,1663 , + 1622 ,1212 ,1419 ,1515 ,717 ,364 ,1867 ,163 ,816 ,1245 ,1982 ,2005 ,875 ,1964 ,1279 ,60 , + 95 ,722 ,1263 ,1481 ,659 ,1171 ,27 ,320 ,513 ,1578 ,1617 ,767 ,1170 ,1085 ,1162 ,228 , + 1332 ,313 ,1439 ,1630 ,245 ,1609 ,1370 ,1686 ,1840 ,633 ,215 ,1718 ,1726 ,306 ,1563 ,90 , + 112 ,1118 ,1394 ,1765 ,1115 ,725 ,1895 ,1430 ,1220 ,1084 ,1831 ,231 ,161 ,867 ,1303 ,403 , + 1999 ,910 ,1388 ,634 ,846 ,612 ,1662 ,334 ,1919 ,1996 ,1095 ,1002 ,1855 ,1759 ,1376 ,22 , + 1471 ,1918 ,1250 ,1292 ,806 ,1774 ,1221 ,1160 ,1044 ,293 ,39 ,168 ,1340 ,668 ,582 ,314 , + 871 ,677 ,2031 ,888 ,495 ,946 ,182 ,1642 ,616 ,982 ,1398 ,1974 ,74 ,20 ,692 ,1338 , + 513 ,1564 ,76 ,1470 ,61 ,617 ,1221 ,370 ,501 ,7 ,1976 ,1258 ,606 ,372 ,344 ,1303 , + 1864 ,324 ,1745 ,952 ,1563 ,1824 ,6 ,1389 ,2036 ,1827 ,806 ,1483 ,1977 ,1882 ,37 ,413 , + 1476 ,843 ,1853 ,1337 ,1762 ,972 ,491 ,1852 ,673 ,724 ,808 ,447 ,121 ,395 ,1792 ,1026 , + 235 ,307 ,1575 ,194 ,1687 ,1795 ,1921 ,1895 ,1178 ,1556 ,429 ,850 ,613 ,660 ,1471 ,576 , + 250 ,1555 ,1268 ,1024 ,1163 ,1680 ,1706 ,319 ,1655 ,1156 ,610 ,766 ,1549 ,1915 ,1667 ,568 , + 138 ,936 ,50 ,444 ,1258 ,1163 ,318 ,195 ,1829 ,123 ,893 ,2025 ,1637 ,902 ,807 ,1745 , + 1789 ,1229 ,1007 ,1970 ,1402 ,1948 ,414 ,906 ,756 ,1560 ,277 ,1269 ,807 ,864 ,142 ,2002 , + 599 ,1797 ,1419 ,1745 ,1944 ,1377 ,270 ,18 ,1880 ,1950 ,1591 ,70 ,1853 ,1022 ,2035 ,979 , + 1846 ,688 ,856 ,160 ,1627 ,1262 ,300 ,151 ,1054 ,1129 ,1448 ,451 ,712 ,1555 ,86 ,801 , + 1173 ,815 ,1456 ,1218 ,1783 ,1420 ,686 ,1775 ,1343 ,396 ,701 ,441 ,1080 ,647 ,694 ,1720 , + 1883 ,758 ,235 ,1493 ,86 ,505 ,1915 ,1206 ,385 ,1619 ,442 ,1038 ,190 ,717 ,984 ,1432 , + 324 ,1046 ,277 ,1858 ,419 ,1299 ,2000 ,311 ,735 ,1975 ,1491 ,305 ,1264 ,739 ,1143 ,414 , + 606 ,305 ,1077 ,1951 ,1258 ,1443 ,935 ,194 ,1628 ,1906 ,382 ,591 ,1682 ,211 ,1048 ,1435 , + 309 ,1349 ,932 ,671 ,893 ,1828 ,839 ,999 ,1644 ,774 ,1273 ,264 ,1550 ,253 ,234 ,426 , + 1032 ,2009 ,1477 ,1972 ,705 ,1047 ,253 ,1756 ,1732 ,333 ,1245 ,513 ,1978 ,1990 ,1531 ,722 , + 1520 ,1406 ,1549 ,1850 ,66 ,1878 ,660 ,1985 ,44 ,656 ,1344 ,1141 ,335 ,419 ,1488 ,548 , + 709 ,1003 ,1195 ,147 ,1766 ,1916 ,431 ,1831 ,1833 ,97 ,634 ,1244 ,133 ,1448 ,191 ,281 , + 760 ,1421 ,66 ,1519 ,1771 ,1122 ,67 ,1625 ,902 ,1093 ,176 ,2041 ,865 ,1434 ,1486 ,302 , + 1818 ,70 ,181 ,790 ,1724 ,1417 ,1316 ,2004 ,919 ,35 ,1098 ,1545 ,1959 ,322 ,761 ,1651 , + 422 ,828 ,1773 ,1105 ,816 ,1513 ,1143 ,1280 ,213 ,763 ,1681 ,106 ,1643 ,322 ,1158 ,1446 , + 888 ,672 ,1239 ,400 ,1019 ,64 ,891 ,59 ,1964 ,1844 ,240 ,1608 ,433 ,141 ,975 ,1916 , + 1925 ,858 ,1923 ,1691 ,216 ,1317 ,45 ,877 ,1428 ,1411 ,1354 ,1774 ,430 ,1769 ,1088 ,374 , + 167 ,655 ,1348 ,301 ,1240 ,1611 ,1587 ,1421 ,554 ,1429 ,718 ,1855 ,1077 ,1948 ,1463 ,1952 , + 680 ,989 ,382 ,1955 ,1695 ,326 ,972 ,1286 ,1419 ,225 ,981 ,898 ,409 ,161 ,192 ,1242 , + 521 ,991 ,1114 ,1335 ,92 ,837 ,2041 ,923 ,1411 ,1467 ,1422 ,973 ,1818 ,739 ,635 ,234 , + 1991 ,1454 ,699 ,1332 ,131 ,1258 ,1431 ,12 ,759 ,87 ,1817 ,1615 ,1325 ,1780 ,704 ,1599 , + 149 ,918 ,1117 ,336 ,480 ,1418 ,609 ,578 ,941 ,1987 ,1692 ,1847 ,787 ,1946 ,114 ,584 , + 140 ,286 ,1856 ,184 ,933 ,198 ,179 ,1407 ,232 ,1044 ,1256 ,1639 ,1901 ,1165 ,1041 ,369 , + 1949 ,668 ,130 ,95 ,883 ,358 ,1117 ,800 ,294 ,1934 ,1718 ,1651 ,750 ,124 ,864 ,139 , + 808 ,11 ,1830 ,325 ,1199 ,1285 ,1224 ,1785 ,2016 ,2007 ,488 ,789 ,1257 ,947 ,437 ,387 , + 227 ,740 ,43 ,969 ,165 ,504 ,1148 ,499 ,209 ,956 ,1278 ,1075 ,1395 ,1056 ,1702 ,1365 , + 1948 ,1587 ,134 ,936 ,753 ,1850 ,1802 ,1210 ,708 ,1361 ,811 ,1799 ,276 ,847 ,1499 ,616 , + 1934 ,1262 ,128 ,1971 ,1335 ,1996 ,607 ,680 ,1315 ,1878 ,1042 ,612 ,1399 ,683 ,1018 ,1535 , + 1441 ,726 ,1405 ,249 ,1382 ,1244 ,2041 ,1337 ,370 ,537 ,1183 ,895 ,636 ,556 ,1148 ,1656 , + 508 ,113 ,926 ,1701 ,1713 ,1294 ,1677 ,904 ,666 ,44 ,259 ,102 ,509 ,670 ,1128 ,1601 , + 386 ,586 ,263 ,343 ,125 ,456 ,2020 ,1673 ,1417 ,1230 ,1608 ,1669 ,1004 ,333 ,1167 ,786 , + 78 ,206 ,972 ,1657 ,1834 ,972 ,1799 ,777 ,63 ,89 ,1909 ,1235 ,566 ,1109 ,1230 ,1094 , + 1687 ,1694 ,889 ,1051 ,721 ,378 ,750 ,1839 ,1753 ,1913 ,67 ,1662 ,1913 ,674 ,1956 ,925 , + 639 ,619 ,21 ,381 ,965 ,603 ,1888 ,1719 ,1098 ,1641 ,1387 ,1182 ,1388 ,958 ,222 ,919 , + 725 ,1013 ,1789 ,870 ,303 ,414 ,1818 ,95 ,76 ,239 ,6 ,8 ,1329 ,1766 ,1136 ,1995 , + 1052 ,220 ,1505 ,45 ,885 ,736 ,897 ,1599 ,767 ,1105 ,183 ,674 ,1008 ,1483 ,101 ,326 , + 599 ,1544 ,173 ,78 ,132 ,1032 ,847 ,1941 ,1787 ,397 ,1660 ,1166 ,1379 ,1343 ,1437 ,364 , + 1733 ,728 ,1970 ,1804 ,1627 ,802 ,324 ,510 ,847 ,1940 ,657 ,2017 ,10 ,1980 ,1467 ,1865 , + 1817 ,605 ,1465 ,1296 ,1082 ,1697 ,1142 ,1301 ,500 ,1663 ,2014 ,1000 ,1349 ,785 ,201 ,1775 , + 1022 ,1218 ,354 ,1881 ,294 ,1977 ,330 ,447 ,1662 ,1667 ,404 ,1944 ,633 ,300 ,190 ,1613 , + 394 ,1229 ,1878 ,249 ,819 ,251 ,1589 ,1601 ,1909 ,1637 ,757 ,1133 ,1175 ,900 ,1168 ,448 , + 797 ,646 ,52 ,1525 ,1133 ,1456 ,1199 ,1004 ,222 ,125 ,1435 ,1343 ,1064 ,1356 ,1394 ,523 , + 192 ,477 ,1697 ,97 ,589 ,714 ,1871 ,1744 ,469 ,345 ,34 ,102 ,1736 ,811 ,166 ,1032 , + 200 ,1419 ,1736 ,1428 ,610 ,251 ,1353 ,74 ,792 ,847 ,1272 ,191 ,1160 ,1864 ,1047 ,1752 , + 448 ,316 ,1559 ,1348 ,481 ,1443 ,825 ,1744 ,1561 ,91 ,1048 ,436 ,1736 ,1430 ,802 ,851 , + 1978 ,251 ,1736 ,1428 ,1255 ,1897 ,655 ,1586 ,204 ,591 ,569 ,951 ,35 ,14 ,427 ,1562 , + 1850 ,316 ,783 ,164 ,267 ,1572 ,976 ,1744 ,1210 ,374 ,118 ,439 ,1908 ,1714 ,853 ,851 , + 210 ,646 ,1736 ,1406 ,1908 ,1897 ,1497 ,1586 ,1423 ,847 ,1538 ,951 ,1930 ,97 ,1047 ,1752 , + 1850 ,1700 ,1178 ,290 ,1736 ,1572 ,976 ,1744 ,1210 ,91 ,662 ,436 ,1343 ,1430 ,644 ,851 , + 1978 ,251 ,1490 ,1428 ,1908 ,1067 ,377 ,1586 ,1423 ,973 ,1538 ,1388 ,1995 ,14 ,427 ,1956 , + 7 ,1056 ,1456 ,390 ,340 ,1572 ,1978 ,1744 ,739 ,459 ,1165 ,1190 ,1912 ,1714 ,853 ,241 , + 929 ,1419 ,1736 ,385 ,423 ,1953 ,1332 ,717 ,604 ,591 ,1911 ,906 ,1930 ,14 ,1140 ,1752 , + 1833 ,811 ,480 ,290 ,481 ,1443 ,1978 ,1648 ,739 ,1370 ,415 ,253 ,202 ,2043 ,644 ,851 , + 1900 ,1829 ,1490 ,1853 ,1218 ,1953 ,1332 ,1004 ,724 ,1704 ,899 ,559 ,1995 ,1688 ,271 ,1956 , + 293 ,1056 ,1029 ,546 ,267 ,1443 ,825 ,2008 ,1437 ,1542 ,662 ,253 ,1736 ,1714 ,802 ,851 , + 1978 ,251 ,618 ,1402 ,1908 ,1626 ,1497 ,1586 ,724 ,973 ,1538 ,1124 ,35 ,1688 ,1047 ,1562 , + 1850 ,1056 ,1697 ,1348 ,481 ,1443 ,666 ,1744 ,1972 ,374 ,415 ,253 ,1736 ,1714 ,802 ,1684 , + 32 ,1642 ,290 ,1047 ,1908 ,1626 ,1353 ,1774 ,1423 ,973 ,1538 ,483 ,1995 ,97 ,1047 ,1562 , + 1890 ,2038 ,1668 ,939 ,1684 ,1799 ,1286 ,82 ,2029 ,1696 ,1587 ,428 ,437 ,1711 ,322 ,1514 , + 615 ,1571 ,1396 ,1859 ,509 ,1163 ,5 ,697 ,85 ,201 ,1109 ,1921 ,162 ,21 ,186 ,852 , + 361 ,133 ,645 ,1929 ,1446 ,230 ,1688 ,494 ,1446 ,890 ,1264 ,1689 ,824 ,1345 ,1942 ,1783 , + 752 ,1549 ,1579 ,1799 ,477 ,384 ,253 ,945 ,429 ,487 ,855 ,610 ,970 ,335 ,1390 ,1365 , + 1748 ,656 ,1060 ,175 ,2036 ,627 ,1827 ,1540 ,461 ,1517 ,913 ,60 ,973 ,1265 ,693 ,301 , + 1795 ,147 ,1826 ,365 ,1505 ,1250 ,184 ,975 ,81 ,1953 ,259 ,784 ,179 ,486 ,1254 ,77 , + 1631 ,1518 ,1448 ,2026 ,1502 ,54 ,617 ,963 ,904 ,790 ,1295 ,676 ,1009 ,201 ,898 ,1869 , + 638 ,876 ,2013 ,32 ,1952 ,1007 ,160 ,1303 ,1365 ,833 ,242 ,1219 ,213 ,1484 ,1514 ,851 , + 1255 ,994 ,1016 ,1673 ,623 ,1737 ,469 ,2016 ,1639 ,1500 ,1176 ,350 ,1783 ,1863 ,394 ,1492 , + 1224 ,1810 ,1884 ,1369 ,358 ,1843 ,1658 ,1314 ,1390 ,668 ,1938 ,235 ,1543 ,876 ,757 ,933 , + 577 ,82 ,658 ,966 ,863 ,1007 ,89 ,612 ,887 ,1182 ,2040 ,375 ,1084 ,2007 ,1311 ,1028 , + 835 ,1612 ,637 ,677 ,608 ,555 ,813 ,179 ,1344 ,910 ,766 ,1682 ,1904 ,351 ,1730 ,1871 , + 1362 ,1520 ,1456 ,183 ,811 ,1652 ,904 ,1300 ,1210 ,1769 ,1516 ,1383 ,1343 ,1492 ,170 ,485 , + 1063 ,1642 ,1490 ,172 ,1218 ,1667 ,634 ,1163 ,804 ,726 ,47 ,1388 ,866 ,1459 ,475 ,573 , + 1195 ,1376 ,127 ,408 ,910 ,1385 ,640 ,1747 ,1381 ,1841 ,454 ,677 ,1572 ,772 ,1543 ,1798 , + 55 ,1908 ,1140 ,1330 ,1500 ,1591 ,538 ,1262 ,492 ,1016 ,181 ,569 ,1018 ,1516 ,1536 ,1739 , + 214 ,1083 ,1309 ,1271 ,104 ,1213 ,1722 ,375 ,928 ,363 ,19 ,1984 ,1538 ,629 ,1621 ,334 , + 234 ,193 ,1616 ,408 ,2029 ,1365 ,688 ,1675 ,432 ,1485 ,326 ,805 ,1170 ,2044 ,1271 ,920 , + 214 ,278 ,705 ,1994 ,1341 ,867 ,885 ,440 ,1390 ,113 ,1866 ,982 ,1800 ,2023 ,1965 ,1628 , + 851 ,1556 ,233 ,1945 ,930 ,526 ,1851 ,1090 ,1160 ,770 ,608 ,1751 ,641 ,1835 ,1486 ,918 , + 805 ,351 ,996 ,1671 ,56 ,1907 ,229 ,980 ,984 ,1283 ,1256 ,1957 ,985 ,1748 ,698 ,1527 , + 734 ,1471 ,1369 ,581 ,215 ,369 ,476 ,1666 ,439 ,635 ,1374 ,1446 ,5 ,1605 ,337 ,53 , + 872 ,1666 ,432 ,1673 ,353 ,769 ,577 ,159 ,568 ,974 ,1777 ,1413 ,870 ,766 ,89 ,670 , + 1689 ,1077 ,1404 ,108 ,55 ,1064 ,649 ,111 ,1975 ,1406 ,1121 ,724 ,253 ,1938 ,14 ,1185 , + 2000 ,809 ,750 ,1767 ,795 ,1020 ,1414 ,165 ,1506 ,659 ,802 ,1646 ,1643 ,1164 ,630 ,1349 , + 1013 ,354 ,49 ,1226 ,1225 ,351 ,951 ,1236 ,1144 ,833 ,450 ,137 ,831 ,217 ,1026 ,1287 , + 1285 ,1783 ,1314 ,309 ,121 ,1856 ,401 ,1529 ,836 ,519 ,1162 ,286 ,886 ,916 ,1844 ,878 , + 707 ,690 ,1758 ,20 ,368 ,818 ,950 ,638 ,345 ,433 ,1090 ,1713 ,1580 ,1017 ,628 ,1086 , + 574 ,1873 ,1574 ,1736 ,690 ,1661 ,203 ,512 ,607 ,1853 ,1631 ,536 ,1182 ,243 ,1892 ,573 , + 1787 ,533 ,1024 ,962 ,71 ,596 ,1442 ,1694 ,856 ,661 ,1236 ,1635 ,1650 ,1474 ,1867 ,1392 , + 1367 ,2019 ,465 ,1306 ,681 ,1791 ,1540 ,1523 ,1984 ,1827 ,285 ,1282 ,912 ,466 ,294 ,357 , + 693 ,1179 ,117 ,1492 ,1566 ,702 ,698 ,966 ,695 ,1365 ,478 ,1148 ,617 ,1375 ,143 ,1907 , + 443 ,948 ,1883 ,550 ,1545 ,1777 ,1956 ,1570 ,1652 ,1925 ,1840 ,173 ,1287 ,664 ,1267 ,1337 , + 628 ,147 ,1218 ,542 ,950 ,1393 ,778 ,1341 ,1613 ,1833 ,783 ,531 ,1702 ,198 ,1615 ,932 , + 1076 ,1642 ,1388 ,1693 ,276 ,1012 ,493 ,1543 ,1505 ,775 ,543 ,1976 ,1529 ,1558 ,41 ,1914 , + 1641 ,161 ,1605 ,230 ,1710 ,1162 ,987 ,1669 ,1951 ,1212 ,975 ,1154 ,867 ,1138 ,470 ,212 , + 38 ,795 ,1238 ,1723 ,1507 ,1077 ,1409 ,1982 ,2043 ,1102 ,2047 ,269 ,629 ,197 ,1524 ,1160 , + 770 ,260 ,1081 ,776 ,825 ,37 ,1805 ,1061 ,1622 ,438 ,352 ,736 ,1203 ,1351 ,175 ,1313 , + 736 ,813 ,1463 ,40 ,1095 ,927 ,977 ,1756 ,1045 ,1872 ,457 ,1937 ,563 ,1929 ,1884 ,1162 , + 186 ,1464 ,46 ,74 ,1372 ,625 ,849 ,1842 ,846 ,1533 ,2046 ,1385 ,1870 ,90 ,1941 ,587 , + 965 ,1253 ,1156 ,1618 ,524 ,1147 ,422 ,376 ,1384 ,581 ,405 ,943 ,1483 ,1648 ,772 ,1556 , + 1354 ,380 ,1904 ,1697 ,691 ,637 ,1730 ,1189 ,1092 ,1379 ,1584 ,104 ,604 ,937 ,1427 ,574 , + 577 ,1520 ,1016 ,309 ,942 ,1522 ,1524 ,1628 ,836 ,1170 ,1310 ,1610 ,139 ,36 ,446 ,241 , + 1041 ,104 ,1665 ,1900 ,469 ,1909 ,433 ,1612 ,1671 ,1591 ,1076 ,784 ,1992 ,1640 ,712 ,1937 , + 1696 ,427 ,161 ,1697 ,200 ,1652 ,50 ,390 ,1852 ,697 ,209 ,769 ,908 ,914 ,787 ,959 , + 530 ,287 ,731 ,518 ,1120 ,77 ,378 ,1170 ,482 ,459 ,115 ,906 ,1730 ,258 ,1587 ,1274 , + 1266 ,1423 ,1668 ,139 ,888 ,2043 ,1972 ,1113 ,963 ,1460 ,319 ,408 ,1560 ,998 ,1955 ,1040 , + 382 ,517 ,492 ,1865 ,650 ,1184 ,1547 ,609 ,1744 ,674 ,1839 ,910 ,511 ,7 ,883 ,861 , + 1154 ,869 ,1856 ,843 ,1460 ,969 ,1401 ,1074 ,29 ,1561 ,1737 ,283 ,1075 ,1750 ,265 ,527 , + 1394 ,1649 ,1185 ,366 ,797 ,284 ,687 ,775 ,1257 ,1527 ,143 ,1956 ,1895 ,627 ,1005 ,720 , + 1618 ,1436 ,1590 ,1439 ,2028 ,484 ,1123 ,1280 ,1470 ,1142 ,481 ,1569 ,1176 ,1997 ,1321 ,1023 , + 1954 ,1622 ,369 ,979 ,429 ,1186 ,1445 ,1085 ,1556 ,1216 ,63 ,748 ,1474 ,117 ,402 ,1519 , + 1181 ,594 ,1812 ,1297 ,365 ,39 ,822 ,1690 ,1385 ,1550 ,528 ,1519 ,1853 ,431 ,1648 ,321 , + 366 ,474 ,62 ,1984 ,1038 ,516 ,1775 ,754 ,1249 ,109 ,474 ,982 ,1790 ,1590 ,1853 ,1287 , + 1866 ,1760 ,1866 ,53 ,741 ,600 ,1841 ,2024 ,889 ,323 ,1257 ,1575 ,153 ,1205 ,1400 ,898 , + 1211 ,480 ,689 ,1995 ,1727 ,29 ,1887 ,1710 ,119 ,1623 ,1833 ,1952 ,305 ,373 ,1421 ,1914 , + 1864 ,1268 ,443 ,1881 ,377 ,1616 ,68 ,1669 ,2022 ,1097 ,24 ,345 ,790 ,1235 ,980 ,1660 , + 1620 ,240 ,1370 ,1894 ,204 ,200 ,309 ,1350 ,283 ,316 ,150 ,1283 ,133 ,1358 ,1103 ,1318 , + 1240 ,1626 ,1349 ,1321 ,401 ,319 ,2020 ,513 ,1264 ,1083 ,217 ,287 ,1375 ,1047 ,1052 ,304 , + 780 ,150 ,40 ,507 ,1773 ,266 ,1689 ,1655 ,31 ,1402 ,711 ,482 ,920 ,211 ,981 ,1524 , + 734 ,376 ,752 ,1397 ,40 ,219 ,1378 ,482 ,1948 ,200 ,918 ,842 ,849 ,1779 ,68 ,2000 , + 437 ,1624 ,1737 ,1289 ,432 ,1847 ,1104 ,174 ,1393 ,1467 ,483 ,996 ,1308 ,1407 ,1544 ,414 , + 429 ,656 ,728 ,187 ,1224 ,1230 ,1223 ,622 ,551 ,1410 ,687 ,193 ,1741 ,620 ,389 ,1397 , + 1804 ,1609 ,965 ,558 ,828 ,1718 ,1776 ,599 ,957 ,1100 ,110 ,1594 ,705 ,686 ,528 ,1577 , + 932 ,1541 ,983 ,607 ,1398 ,1753 ,1634 ,1767 ,1513 ,1278 ,163 ,928 ,319 ,828 ,1241 ,1357 , + 1778 ,1041 ,679 ,1471 ,1789 ,1089 ,285 ,481 ,1697 ,143 ,438 ,1244 ,790 ,1402 ,630 ,96 , + 866 ,1498 ,45 ,2002 ,75 ,1600 ,492 ,647 ,604 ,1825 ,1681 ,2003 ,652 ,1232 ,1687 ,1826 , + 980 ,1199 ,1520 ,403 ,957 ,1249 ,264 ,1827 ,587 ,1318 ,1596 ,542 ,1087 ,564 ,1212 ,26 , + 1035 ,1535 ,1945 ,1021 ,1929 ,1554 ,1792 ,904 ,4 ,471 ,1640 ,434 ,1349 ,1281 ,1038 ,1054 , + 646 ,1730 ,1557 ,211 ,1449 ,569 ,790 ,5 ,934 ,1608 ,1275 ,1141 ,1295 ,930 ,1682 ,1290 , + 1020 ,363 ,537 ,1673 ,1801 ,1356 ,417 ,1538 ,904 ,278 ,1966 ,652 ,1475 ,1705 ,1149 ,715 , + 937 ,1692 ,1896 ,808 ,1256 ,531 ,561 ,508 ,11 ,810 ,1091 ,1808 ,213 ,77 ,901 ,1021 , + 1160 ,1646 ,709 ,1264 ,16 ,1652 ,1524 ,1156 ,505 ,611 ,605 ,860 ,2014 ,1705 ,1246 ,1218 , + 646 ,1644 ,841 ,1989 ,705 ,507 ,254 ,1386 ,1205 ,669 ,1343 ,432 ,1365 ,251 ,1704 ,1256 , + 1160 ,996 ,1466 ,825 ,1792 ,1514 ,168 ,948 ,1961 ,1625 ,1029 ,1455 ,508 ,807 ,702 ,1604 , + 1902 ,1689 ,1479 ,1916 ,728 ,1291 ,1142 ,1248 ,1875 ,152 ,1587 ,1528 ,1809 ,1244 ,1705 ,1475 , + 1046 ,523 ,1817 ,859 ,1502 ,1003 ,2001 ,428 ,184 ,79 ,571 ,305 ,225 ,1461 ,659 ,821 , + 1265 ,1356 ,1495 ,1920 ,1866 ,344 ,593 ,276 ,1342 ,525 ,433 ,526 ,1289 ,766 ,1871 ,941 , + 1674 ,892 ,969 ,14 ,1761 ,1765 ,1718 ,1509 ,236 ,411 ,114 ,419 ,1276 ,1574 ,1626 ,109 , + 1076 ,1653 ,432 ,1681 ,1657 ,519 ,1431 ,1064 ,1882 ,1329 ,1397 ,854 ,1387 ,1355 ,348 ,1132 , + 1693 ,1768 ,1448 ,655 ,1992 ,1080 ,125 ,1262 ,74 ,1425 ,1006 ,496 ,1871 ,920 ,1623 ,181 , + 894 ,475 ,533 ,1808 ,258 ,1960 ,677 ,781 ,1662 ,628 ,1446 ,916 ,166 ,1806 ,1642 ,1573 , + 1824 ,820 ,340 ,309 ,1761 ,515 ,1660 ,1370 ,953 ,1259 ,784 ,1985 ,1080 ,479 ,1427 ,1560 , + 1069 ,1431 ,823 ,1472 ,335 ,239 ,762 ,1077 ,523 ,54 ,535 ,827 ,1913 ,1012 ,1447 ,265 , + 599 ,737 ,1938 ,1089 ,1852 ,451 ,1144 ,1721 ,863 ,552 ,125 ,1398 ,610 ,1304 ,1879 ,177 , + 1582 ,1015 ,686 ,1978 ,1599 ,601 ,1465 ,206 ,102 ,740 ,1474 ,350 ,1451 ,1710 ,962 ,909 , + 130 ,986 ,896 ,1271 ,1167 ,1152 ,2014 ,1494 ,1710 ,354 ,227 ,2004 ,1272 ,178 ,1157 ,249 , + 1214 ,1421 ,1165 ,1073 ,315 ,1884 ,187 ,1589 ,895 ,1728 ,1945 ,1834 ,868 ,904 ,1599 ,1670 , + 812 ,774 ,1549 ,343 ,1273 ,1325 ,1898 ,1097 ,474 ,1241 ,1144 ,374 ,1315 ,1040 ,2 ,1138 , + 740 ,1295 ,127 ,1585 ,85 ,429 ,856 ,1932 ,1694 ,2041 ,479 ,74 ,569 ,897 ,804 ,1559 , + 964 ,908 ,1380 ,771 ,1658 ,379 ,154 ,1118 ,1946 ,1849 ,196 ,1084 ,853 ,1209 ,1307 ,1441 , + 1205 ,678 ,1827 ,1073 ,1364 ,101 ,1756 ,1437 ,483 ,242 ,148 ,675 ,1338 ,669 ,1457 ,1601 , + 450 ,1432 ,580 ,306 ,1783 ,493 ,955 ,458 ,136 ,1903 ,1065 ,176 ,1622 ,425 ,190 ,746 , + 960 ,1534 ,1036 ,1107 ,808 ,456 ,1601 ,497 ,1018 ,1140 ,148 ,1627 ,1176 ,954 ,1819 ,1493 , + 2023 ,1374 ,791 ,733 ,2020 ,934 ,715 ,1139 ,2013 ,1988 ,1616 ,1384 ,548 ,306 ,1681 ,599 , + 1558 ,50 ,193 ,371 ,1460 ,650 ,165 ,1129 ,541 ,1875 ,603 ,1281 ,853 ,282 ,1104 ,1582 , + 592 ,897 ,1542 ,1921 ,1690 ,1817 ,1416 ,453 ,1034 ,665 ,846 ,1755 ,596 ,433 ,1095 ,109 , + 1925 ,1973 ,10 ,998 ,1410 ,2002 ,1874 ,1187 ,1475 ,5 ,1821 ,213 ,1766 ,1232 ,1114 ,33 , + 1452 ,1439 ,999 ,1603 ,1109 ,1424 ,1818 ,1813 ,1342 ,90 ,238 ,962 ,481 ,1251 ,1643 ,666 , + 1573 ,332 ,451 ,262 ,1640 ,1085 ,1100 ,143 ,1523 ,1928 ,1419 ,561 ,1148 ,1659 ,153 ,1578 , + 491 ,250 ,1480 ,124 ,1855 ,1105 ,1969 ,1725 ,1386 ,155 ,39 ,332 ,152 ,1323 ,706 ,1212 , + 1443 ,1085 ,2030 ,1733 ,853 ,497 ,1773 ,1329 ,1568 ,1663 ,516 ,1113 ,200 ,1489 ,1951 ,1540 , + 1024 ,578 ,493 ,2014 ,307 ,1513 ,944 ,1892 ,610 ,121 ,972 ,658 ,1551 ,940 ,1744 ,1059 , + 1733 ,1628 ,1487 ,558 ,327 ,1812 ,1978 ,1740 ,1591 ,959 ,916 ,218 ,1975 ,216 ,578 ,1175 , + 231 ,872 ,139 ,887 ,1552 ,1420 ,1506 ,451 ,1674 ,941 ,261 ,651 ,319 ,1451 ,1479 ,1530 , + 1523 ,1950 ,284 ,183 ,124 ,228 ,414 ,1049 ,1102 ,504 ,1193 ,575 ,1506 ,1749 ,790 ,523 , + 1705 ,54 ,1873 ,547 ,932 ,862 ,2000 ,1142 ,927 ,1182 ,687 ,1534 ,1223 ,469 ,2038 ,1212 , + 945 ,736 ,1152 ,420 ,18 ,1960 ,656 ,1030 ,1364 ,429 ,579 ,108 ,354 ,875 ,1998 ,939 , + 1980 ,138 ,690 ,1469 ,745 ,822 ,1665 ,148 ,1634 ,225 ,1027 ,1141 ,1789 ,894 ,1756 ,728 , + 455 ,1986 ,517 ,1162 ,2030 ,1139 ,1309 ,91 ,1553 ,194 ,1616 ,824 ,163 ,49 ,244 ,1593 , + 1729 ,399 ,1990 ,1921 ,887 ,1272 ,1274 ,1619 ,971 ,1703 ,1974 ,1420 ,1127 ,766 ,103 ,1296 , + 762 ,927 ,455 ,1830 ,678 ,349 ,1606 ,1790 ,1479 ,613 ,2002 ,1208 ,214 ,186 ,426 ,1407 , + 1033 ,857 ,307 ,658 ,1081 ,509 ,811 ,1198 ,741 ,1682 ,816 ,1630 ,598 ,1498 ,1519 ,382 , + 907 ,935 ,257 ,1138 ,432 ,1397 ,1587 ,979 ,988 ,1747 ,1720 ,874 ,985 ,1342 ,1268 ,169 , + 970 ,871 ,1902 ,1116 ,531 ,1961 ,1773 ,1207 ,1075 ,663 ,888 ,1435 ,2038 ,1262 ,235 ,102 , + 1742 ,357 ,1881 ,1404 ,944 ,41 ,946 ,1080 ,1199 ,296 ,1072 ,335 ,1480 ,22 ,527 ,1363 , + 1289 ,1014 ,717 ,1020 ,1926 ,1700 ,957 ,848 ,516 ,1436 ,1272 ,725 ,1923 ,101 ,1044 ,462 , + 1854 ,1958 ,964 ,363 ,1955 ,858 ,1619 ,1659 ,1203 ,919 ,1299 ,431 ,1917 ,1045 ,193 ,330 , + 11 ,1065 ,549 ,1266 ,1526 ,1001 ,773 ,80 ,1337 ,1831 ,1745 ,731 ,958 ,463 ,150 ,308 , + 245 ,230 ,1152 ,1866 ,1181 ,401 ,472 ,1267 ,372 ,1372 ,1169 ,327 ,201 ,1547 ,1030 ,1755 , + 1066 ,748 ,1615 ,48 ,403 ,1127 ,1638 ,947 ,1532 ,1531 ,1286 ,346 ,700 ,1616 ,1632 ,1497 , + 523 ,803 ,584 ,686 ,1094 ,231 ,1620 ,26 ,1323 ,2016 ,1363 ,1046 ,1306 ,1216 ,98 ,1498 , + 366 ,341 ,1158 ,115 ,1270 ,1966 ,465 ,874 ,911 ,874 ,1923 ,1608 ,2017 ,1696 ,936 ,1161 , + 749 ,733 ,493 ,1872 ,432 ,2 ,1126 ,1858 ,596 ,357 ,1138 ,1718 ,869 ,125 ,295 ,608 , + 1056 ,1505 ,1900 ,274 ,1101 ,1636 ,1654 ,146 ,1181 ,1072 ,329 ,12 ,1926 ,1410 ,958 ,796 , + 18 ,222 ,1453 ,1467 ,959 ,587 ,1247 ,952 ,1627 ,1240 ,78 ,1543 ,884 ,1132 ,426 ,1771 , + 929 ,1329 ,1011 ,1314 ,202 ,1034 ,795 ,1522 ,618 ,736 ,566 ,1670 ,1424 ,1565 ,1485 ,1657 , + 1734 ,1191 ,339 ,1190 ,894 ,1536 ,1011 ,633 ,1149 ,856 ,1193 ,1746 ,543 ,1421 ,1641 ,1197 , + 1610 ,656 ,1103 ,1178 ,268 ,718 ,464 ,503 ,1742 ,1758 ,558 ,1761 ,951 ,164 ,823 ,1487 , + 1183 ,1939 ,821 ,194 ,1806 ,243 ,1649 ,1220 ,211 ,1935 ,1848 ,1310 ,1720 ,993 ,303 ,1504 , + 1845 ,286 ,1081 ,1461 ,844 ,1335 ,1285 ,491 ,1381 ,916 ,531 ,173 ,820 ,831 ,1472 ,1206 , + 2002 ,1931 ,1650 ,1780 ,684 ,293 ,335 ,848 ,445 ,699 ,261 ,18 ,170 ,1286 ,105 ,1124 , + 1746 ,1005 ,1867 ,725 ,648 ,1534 ,571 ,226 ,361 ,712 ,1659 ,1457 ,1778 ,846 ,697 ,72 , + 1804 ,1559 ,499 ,680 ,1728 ,1982 ,1879 ,1696 ,1397 ,1219 ,289 ,1574 ,213 ,1152 ,1658 ,61 , + 134 ,979 ,237 ,309 ,653 ,1564 ,1216 ,1509 ,1306 ,1569 ,2038 ,1911 ,774 ,1304 ,1667 ,1034 , + 173 ,128 ,1334 ,955 ,1317 ,649 ,1609 ,307 ,68 ,1379 ,424 ,1865 ,226 ,1539 ,624 ,955 , + 890 ,155 ,468 ,1834 ,1135 ,1220 ,1198 ,606 ,677 ,1517 ,1920 ,1210 ,562 ,1716 ,4 ,372 , + 21 ,1785 ,529 ,1829 ,275 ,980 ,1792 ,459 ,609 ,1044 ,1312 ,1193 ,1859 ,1534 ,865 ,1372 , + 1601 ,246 ,718 ,1785 ,932 ,192 ,1475 ,63 ,1689 ,399 ,115 ,1923 ,1903 ,1665 ,27 ,1299 , + 1363 ,396 ,700 ,388 ,1737 ,1095 ,364 ,1046 ,753 ,136 ,1310 ,765 ,261 ,716 ,1266 ,57 , + 599 ,261 ,1251 ,510 ,2032 ,25 ,450 ,1755 ,863 ,1371 ,935 ,569 ,36 ,1638 ,491 ,936 , + 180 ,952 ,1485 ,823 ,1771 ,1838 ,419 ,2018 ,243 ,1036 ,1406 ,948 ,1180 ,1517 ,240 ,839 , + 1635 ,1294 ,1531 ,1794 ,1829 ,2013 ,1355 ,1775 ,345 ,1365 ,669 ,237 ,217 ,764 ,326 ,1524 , + 1305 ,397 ,1255 ,1273 ,1132 ,1289 ,1881 ,990 ,573 ,1217 ,1579 ,458 ,795 ,1053 ,1923 ,107 , + 134 ,1444 ,1702 ,561 ,618 ,1533 ,270 ,1192 ,376 ,149 ,141 ,1644 ,1261 ,543 ,298 ,2019 , + 672 ,1109 ,1172 ,3 ,212 ,139 ,1879 ,1207 ,257 ,172 ,58 ,1317 ,644 ,67 ,1335 ,1146 , + 409 ,1496 ,1484 ,324 ,1727 ,901 ,222 ,522 ,1842 ,1174 ,257 ,1354 ,724 ,1725 ,522 ,670 , + 1033 ,577 ,1521 ,1170 ,1606 ,554 ,39 ,1153 ,1583 ,435 ,41 ,1573 ,389 ,7 ,1888 ,1214 , + 1428 ,1466 ,543 ,1723 ,1832 ,1799 ,1020 ,215 ,334 ,1284 ,1445 ,250 ,1197 ,515 ,1673 ,352 , + 620 ,1001 ,1997 ,147 ,1396 ,458 ,197 ,444 ,283 ,183 ,1790 ,1228 ,751 ,1824 ,1853 ,1577 , + 689 ,1859 ,645 ,1389 ,1581 ,68 ,713 ,1072 ,984 ,1733 ,1285 ,1407 ,19 ,1693 ,1494 ,1678 , + 323 ,1130 ,665 ,382 ,73 ,1133 ,74 ,1699 ,1706 ,303 ,692 ,119 ,1575 ,208 ,1961 ,1903 , + 609 ,1783 ,1220 ,1653 ,581 ,78 ,589 ,1450 ,962 ,1318 ,100 ,398 ,1992 ,330 ,1047 ,1991 , + 1063 ,926 ,1758 ,344 ,624 ,721 ,1453 ,1756 ,142 ,461 ,259 ,1613 ,1682 ,1305 ,632 ,1050 , + 1276 ,1871 ,1111 ,925 ,1439 ,321 ,1423 ,1142 ,1446 ,1673 ,781 ,302 ,19 ,661 ,1238 ,684 , + 1511 ,1855 ,2009 ,76 ,1983 ,387 ,298 ,1785 ,1071 ,407 ,979 ,1718 ,320 ,16 ,186 ,886 , + 1446 ,1474 ,887 ,343 ,1733 ,14 ,788 ,1075 ,1004 ,169 ,817 ,1822 ,1179 ,620 ,178 ,998 , + 1732 ,1322 ,26 ,1259 ,1861 ,1517 ,548 ,414 ,1486 ,1929 ,315 ,35 ,1306 ,1304 ,1910 ,132 , + 438 ,1688 ,981 ,637 ,1939 ,1190 ,1506 ,142 ,1247 ,1205 ,884 ,1209 ,54 ,1812 ,2004 ,571 , + 2042 ,860 ,310 ,859 ,1116 ,467 ,1990 ,43 ,952 ,215 ,352 ,633 ,251 ,909 ,1554 ,785 , + 523 ,328 ,1116 ,1136 ,819 ,1858 ,1807 ,249 ,557 ,570 ,854 ,667 ,137 ,669 ,745 ,1810 , + 676 ,217 ,982 ,1728 ,1234 ,1196 ,553 ,663 ,999 ,1953 ,1415 ,1237 ,1628 ,1093 ,965 ,734 , + 390 ,1051 ,375 ,1802 ,717 ,1543 ,1950 ,179 ,1457 ,713 ,39 ,1517 ,973 ,1028 ,1349 ,1164 , + 965 ,1709 ,162 ,1909 ,1100 ,1717 ,860 ,684 ,319 ,1731 ,738 ,347 ,2042 ,1610 ,1525 ,1455 , + 208 ,1656 ,1334 ,1178 ,1034 ,507 ,1703 ,1886 ,1423 ,362 ,1068 ,14 ,1855 ,1099 ,231 ,1245 , + 1395 ,371 ,621 ,203 ,872 ,841 ,1673 ,1976 ,1584 ,675 ,1174 ,1986 ,643 ,848 ,1354 ,1212 , + 749 ,549 ,849 ,1271 ,1444 ,969 ,1102 ,1949 ,1412 ,482 ,245 ,133 ,1373 ,1011 ,1717 ,1848 , + 1739 ,1513 ,712 ,1519 ,965 ,1042 ,1298 ,278 ,199 ,2020 ,549 ,1251 ,1918 ,1334 ,1978 ,1784 , + 1020 ,1625 ,552 ,135 ,242 ,936 ,624 ,388 ,904 ,820 ,1704 ,242 ,1300 ,914 ,901 ,1119 , + 1669 ,1950 ,1138 ,725 ,608 ,873 ,254 ,270 ,91 ,2 ,1959 ,1446 ,1608 ,559 ,1477 ,1454 , + 1928 ,477 ,1873 ,1411 ,420 ,317 ,819 ,1648 ,1791 ,2000 ,161 ,439 ,1912 ,565 ,220 ,1343 , + 210 ,646 ,1490 ,725 ,1802 ,39 ,1497 ,1774 ,2044 ,1343 ,762 ,1808 ,422 ,1440 ,481 ,1956 , + 973 ,800 ,552 ,1917 ,946 ,1652 ,160 ,582 ,356 ,891 ,1165 ,1655 ,812 ,811 ,644 ,1523 , + 748 ,1874 ,1518 ,228 ,423 ,1905 ,1683 ,1093 ,619 ,599 ,133 ,513 ,941 ,1260 ,989 ,1086 , + 1968 ,1625 ,552 ,1334 ,946 ,300 ,1512 ,1321 ,1686 ,604 ,504 ,1631 ,1275 ,1112 ,1246 ,1448 , + 1562 ,1593 ,1544 ,443 ,1504 ,873 ,749 ,249 ,619 ,1664 ,342 ,139 ,170 ,195 ,1911 ,643 , + 758 ,724 ,552 ,1289 ,894 ,1735 ,1394 ,13 ,1429 ,1903 ,1633 ,667 ,2037 ,1169 ,1816 ,553 , + 1445 ,1364 ,42 ,1551 ,1370 ,57 ,1351 ,1414 ,942 ,1340 ,488 ,1702 ,141 ,1502 ,1308 ,174 , + 1584 ,710 ,214 ,1996 ,1021 ,560 ,1648 ,1341 ,1951 ,1902 ,1372 ,2047 ,967 ,814 ,1238 ,1322 , + 1317 ,299 ,679 ,659 ,1849 ,1822 ,716 ,1656 ,1089 ,94 ,705 ,985 ,787 ,569 ,744 ,1899 , + 1332 ,1738 ,13 ,845 ,1010 ,158 ,194 ,1965 ,889 ,386 ,1343 ,1886 ,134 ,332 ,1567 ,1960 , + 635 ,1993 ,797 ,357 ,1517 ,114 ,1397 ,1844 ,1687 ,1703 ,418 ,116 ,280 ,66 ,965 ,555 , + 348 ,1603 ,1284 ,121 ,1824 ,104 ,720 ,1415 ,351 ,880 ,1106 ,1845 ,697 ,132 ,670 ,1572 , + 678 ,485 ,988 ,934 ,1451 ,1050 ,1953 ,499 ,129 ,1625 ,1192 ,1924 ,1668 ,108 ,891 ,576 , + 20 ,665 ,1146 ,1509 ,389 ,1077 ,493 ,110 ,48 ,535 ,1187 ,1970 ,418 ,1869 ,1548 ,139 , + 478 ,665 ,391 ,1030 ,182 ,1120 ,1984 ,1095 ,1540 ,1637 ,532 ,527 ,1077 ,482 ,696 ,972 , + 1870 ,707 ,1392 ,1113 ,1469 ,1116 ,1436 ,1010 ,768 ,1606 ,1051 ,1745 ,95 ,101 ,135 ,964 , + 1750 ,722 ,1894 ,497 ,1541 ,862 ,1863 ,1500 ,1977 ,1906 ,1435 ,1785 ,644 ,1112 ,938 ,668 , + 853 ,1871 ,444 ,1105 ,1670 ,61 ,240 ,1162 ,154 ,1764 ,1404 ,585 ,1160 ,1796 ,199 ,1161 , + 1356 ,1942 ,1227 ,416 ,1994 ,419 ,174 ,1512 ,1619 ,751 ,1758 ,1892 ,1607 ,1154 ,1200 ,616 , + 1674 ,1899 ,1082 ,597 ,544 ,619 ,1605 ,1055 ,257 ,1584 ,743 ,935 ,1879 ,316 ,1621 ,248 , + 1772 ,697 ,1451 ,1696 ,1488 ,162 ,485 ,1261 ,799 ,1019 ,1689 ,874 ,971 ,605 ,367 ,1532 , + 1674 ,1942 ,1328 ,47 ,1632 ,1343 ,1177 ,1914 ,1428 ,735 ,980 ,322 ,834 ,872 ,1044 ,595 , + 783 ,835 ,1171 ,432 ,880 ,1274 ,966 ,814 ,1732 ,1740 ,359 ,1157 ,868 ,832 ,1721 ,681 , + 1824 ,773 ,1878 ,1370 ,365 ,1605 ,437 ,1932 ,1832 ,821 ,1700 ,1075 ,47 ,1153 ,1724 ,1585 , + 1111 ,213 ,154 ,1094 ,1892 ,279 ,1534 ,811 ,1182 ,466 ,1728 ,1985 ,222 ,274 ,112 ,943 , + 1056 ,150 ,818 ,1300 ,1749 ,274 ,598 ,408 ,947 ,465 ,1510 ,372 ,762 ,1120 ,533 ,68 , + 36 ,1022 ,746 ,1125 ,251 ,1560 ,1205 ,927 ,1582 ,484 ,985 ,633 ,1876 ,143 ,1317 ,1344 , + 1754 ,79 ,1647 ,1462 ,1997 ,1157 ,1507 ,1013 ,1460 ,444 ,435 ,1801 ,48 ,2025 ,1049 ,1379 , + 331 ,675 ,64 ,953 ,1034 ,771 ,1275 ,728 ,2001 ,1966 ,443 ,1752 ,975 ,787 ,1432 ,596 , + 149 ,1441 ,930 ,680 ,925 ,1790 ,1082 ,1746 ,1316 ,1907 ,473 ,37 ,220 ,1512 ,1824 ,1837 , + 1117 ,629 ,306 ,29 ,2037 ,71 ,901 ,1276 ,1144 ,1984 ,564 ,781 ,1693 ,1615 ,2000 ,1540 , + 43 ,1190 ,2039 ,358 ,1468 ,1371 ,1132 ,412 ,826 ,556 ,1174 ,1089 ,649 ,997 ,1476 ,1924 , + 114 ,1981 ,2004 ,1575 ,1562 ,689 ,1445 ,324 ,1835 ,904 ,1500 ,713 ,1785 ,1397 ,757 ,1528 , + 389 ,284 ,959 ,1218 ,752 ,1370 ,1374 ,1077 ,879 ,491 ,1697 ,491 ,19 ,315 ,275 ,970 , + 544 ,1716 ,454 ,1541 ,1317 ,353 ,1622 ,2041 ,479 ,342 ,79 ,1603 ,133 ,1340 ,1050 ,681 , + 609 ,979 ,1676 ,1400 ,187 ,1564 ,1860 ,1954 ,666 ,1581 ,1804 ,1451 ,1415 ,189 ,298 ,1962 , + 624 ,1114 ,2036 ,1941 ,467 ,468 ,101 ,1462 ,1138 ,177 ,349 ,376 ,425 ,130 ,1838 ,63 , + 309 ,809 ,1676 ,885 ,144 ,1722 ,21 ,338 ,630 ,668 ,1691 ,798 ,1310 ,1893 ,429 ,755 , + 191 ,512 ,798 ,685 ,453 ,955 ,2012 ,1253 ,1560 ,1129 ,1275 ,591 ,977 ,1474 ,1662 ,1392 , + 1920 ,142 ,1809 ,1178 ,343 ,1363 ,885 ,1241 ,794 ,1092 ,277 ,151 ,956 ,1976 ,1188 ,528 , + 1152 ,526 ,1957 ,269 ,648 ,1051 ,894 ,219 ,1292 ,1812 ,28 ,825 ,463 ,315 ,476 ,406 , + 760 ,1121 ,337 ,1886 ,503 ,248 ,1023 ,769 ,1549 ,219 ,571 ,1545 ,453 ,1115 ,1039 ,130 , + 1436 ,750 ,1870 ,1455 ,485 ,1850 ,1010 ,1852 ,1324 ,574 ,1941 ,554 ,1741 ,1455 ,163 ,630 , + 950 ,1933 ,1168 ,2004 ,822 ,130 ,1247 ,1318 ,1451 ,392 ,1901 ,805 ,1207 ,114 ,417 ,1733 , + 1082 ,799 ,1076 ,1911 ,781 ,1633 ,1883 ,676 ,1060 ,1943 ,2015 ,451 ,231 ,622 ,1275 ,654 , + 1870 ,972 ,786 ,407 ,1069 ,1834 ,781 ,1573 ,1506 ,1200 ,318 ,1616 ,792 ,1950 ,1507 ,13 , + 1752 ,563 ,1044 ,846 ,1806 ,755 ,1132 ,253 ,1810 ,1506 ,492 ,560 ,917 ,684 ,200 ,1932 , + 1404 ,412 ,2012 ,487 ,935 ,687 ,571 ,1442 ,1252 ,215 ,578 ,577 ,22 ,1538 ,910 ,1148 , + 1333 ,413 ,18 ,532 ,1823 ,1689 ,1786 ,1984 ,37 ,859 ,1316 ,1008 ,1136 ,2026 ,1290 ,743 , + 277 ,1051 ,547 ,1178 ,296 ,297 ,1645 ,105 ,494 ,329 ,622 ,493 ,459 ,887 ,720 ,341 , + 594 ,814 ,724 ,1951 ,1455 ,1402 ,1674 ,1357 ,1712 ,829 ,1467 ,722 ,1145 ,307 ,309 ,126 , + 140 ,218 ,402 ,1307 ,1474 ,1919 ,1375 ,1568 ,93 ,420 ,93 ,401 ,622 ,1092 ,1637 ,1101 , + 110 ,1402 ,55 ,1495 ,1676 ,1913 ,1751 ,195 ,454 ,1681 ,826 ,539 ,1503 ,261 ,387 ,1654 , + 1556 ,1414 ,1251 ,1461 ,1695 ,531 ,1155 ,1552 ,1843 ,1987 ,1758 ,18 ,169 ,1906 ,1156 ,181 , + 438 ,1394 ,1659 ,1811 ,277 ,2031 ,478 ,1620 ,964 ,424 ,858 ,304 ,283 ,1568 ,517 ,2040 , + 1080 ,501 ,1799 ,856 ,1610 ,1257 ,377 ,723 ,408 ,1599 ,913 ,1688 ,952 ,1972 ,1654 ,1999 , + 1884 ,617 ,432 ,1754 ,1486 ,1873 ,123 ,1722 ,1247 ,515 ,1470 ,923 ,1984 ,1446 ,280 ,687 , + 1765 ,1955 ,46 ,81 ,1600 ,1077 ,1325 ,147 ,1138 ,1173 ,1401 ,1101 ,1116 ,1826 ,868 ,1542 , + 531 ,218 ,459 ,1583 ,8 ,1616 ,1355 ,1458 ,1970 ,715 ,1167 ,1219 ,1726 ,1137 ,1174 ,1166 , + 1004 ,349 ,1264 ,1178 ,1574 ,1635 ,342 ,1163 ,337 ,1149 ,1068 ,1965 ,838 ,1937 ,903 ,1190 , + 1173 ,200 ,910 ,377 ,429 ,631 ,776 ,1106 ,126 ,142 ,143 ,1723 ,566 ,1904 ,531 ,2038 , + 1262 ,1739 ,1264 ,1870 ,1878 ,1904 ,475 ,1333 ,873 ,1362 ,1852 ,719 ,490 ,1838 ,1587 ,1213 , + 1380 ,399 ,2016 ,133 ,1784 ,1612 ,1818 ,117 ,95 ,625 ,1239 ,1894 ,585 ,1567 ,1591 ,1984 , + 1660 ,809 ,1527 ,1887 ,1689 ,657 ,946 ,1211 ,449 ,1613 ,1328 ,781 ,447 ,680 ,1074 ,1078 , + 1239 ,481 ,1620 ,1299 ,1780 ,354 ,1779 ,1386 ,863 ,1188 ,1959 ,107 ,390 ,875 ,421 ,436 , + 30 ,1706 ,127 ,252 ,1434 ,927 ,806 ,1705 ,1749 ,672 ,1072 ,1562 ,966 ,1182 ,324 ,1916 , + 1076 ,1150 ,267 ,396 ,1010 ,79 ,1444 ,1316 ,1387 ,899 ,1087 ,875 ,367 ,575 ,1982 ,798 , + 1879 ,1572 ,319 ,1089 ,848 ,380 ,1235 ,293 ,1418 ,982 ,587 ,821 ,1090 ,1752 ,350 ,1398 , + 173 ,50 ,1066 ,94 ,857 ,1462 ,1664 ,1671 ,1305 ,564 ,1334 ,423 ,193 ,1545 ,1215 ,919 , + 1000 ,1552 ,1805 ,1490 ,1577 ,337 ,636 ,555 ,1072 ,898 ,808 ,1989 ,256 ,1034 ,561 ,1202 , + 803 ,122 ,177 ,1903 ,1397 ,1552 ,1750 ,91 ,1543 ,794 ,1632 ,577 ,996 ,153 ,595 ,654 , + 862 ,1859 ,1144 ,1326 ,1485 ,539 ,1709 ,1583 ,19 ,1242 ,838 ,871 ,1622 ,911 ,347 ,310 , + 1423 ,1173 ,192 ,678 ,1085 ,1395 ,1173 ,1283 ,849 ,1946 ,303 ,999 ,1290 ,1579 ,472 ,1491 , + 1026 ,112 ,791 ,1381 ,1390 ,83 ,365 ,1065 ,1047 ,8 ,1660 ,1717 ,1787 ,1554 ,689 ,565 , + 552 ,262 ,502 ,1270 ,964 ,1276 ,1707 ,1763 ,1610 ,265 ,597 ,423 ,1824 ,1522 ,18 ,1424 , + 609 ,520 ,549 ,1321 ,1568 ,1839 ,688 ,1664 ,1929 ,1109 ,328 ,519 ,1499 ,914 ,1411 ,1815 , + 889 ,515 ,1665 ,826 ,841 ,693 ,763 ,1633 ,1091 ,1636 ,1682 ,1610 ,752 ,1566 ,1774 ,1159 , + 919 ,1132 ,1175 ,19 ,456 ,1828 ,967 ,603 ,795 ,828 ,188 ,654 ,291 ,510 ,1740 ,787 , + 513 ,1576 ,293 ,1399 ,973 ,751 ,1101 ,1344 ,1649 ,1699 ,1478 ,554 ,112 ,1411 ,817 ,312 , + 599 ,756 ,802 ,343 ,438 ,1358 ,1376 ,1970 ,449 ,1180 ,1544 ,1674 ,955 ,1030 ,1627 ,1779 , + 638 ,567 ,1191 ,2022 ,1636 ,840 ,353 ,45 ,592 ,1118 ,1711 ,1884 ,1396 ,1928 ,349 ,545 , + 1705 ,54 ,617 ,959 ,250 ,1728 ,2030 ,565 ,1505 ,158 ,2045 ,1393 ,1242 ,1767 ,378 ,1502 , + 521 ,1695 ,1466 ,149 ,959 ,687 ,914 ,1776 ,960 ,1029 ,661 ,788 ,1557 ,1027 ,1721 ,586 , + 2047 ,130 ,1902 ,1283 ,403 ,1225 ,460 ,105 ,489 ,1293 ,1846 ,1499 ,608 ,244 ,1131 ,401 , + 1985 ,1844 ,660 ,1259 ,1586 ,1195 ,1782 ,572 ,455 ,1427 ,1989 ,1905 ,412 ,1784 ,746 ,1060 , + 90 ,1636 ,19 ,914 ,1176 ,1496 ,154 ,168 ,771 ,1722 ,1158 ,1174 ,2022 ,1806 ,1344 ,759 , + 1279 ,1467 ,945 ,666 ,487 ,1409 ,1999 ,1259 ,930 ,45 ,273 ,540 ,1014 ,272 ,1108 ,1605 , + 1223 ,1961 ,401 ,655 ,1065 ,80 ,1652 ,1075 ,1103 ,150 ,949 ,579 ,465 ,1678 ,657 ,1298 , + 702 ,1800 ,396 ,1583 ,1296 ,1974 ,306 ,1366 ,492 ,911 ,1346 ,1259 ,1343 ,1109 ,329 ,1589 , + 1565 ,1882 ,1314 ,353 ,1773 ,251 ,30 ,510 ,781 ,1187 ,961 ,1473 ,1550 ,381 ,63 ,1826 , + 1275 ,1842 ,1138 ,1747 ,751 ,402 ,602 ,167 ,1189 ,54 ,576 ,1974 ,466 ,537 ,805 ,1117 , +}; diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index 937948425cdc4..d47b0e598dd28 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -1,9 +1,12 @@ +#include "ggml.h" #include "llama.h" #include "common.h" #include "log.h" #include "arg.h" #include "mimi-model.h" +#include "tts-csm-data.h" +#include #include #include #include @@ -32,8 +35,14 @@ static void print_usage(int, char ** argv) { LOG("\n"); } +struct speaker_turn { + std::string text; + std::vector audio_embd; // only used for system prompt (speaker reference) processing + size_t n_embd_tokens = 0; +}; + // split text containing "[N]..." into speaker turns -static std::vector get_speaker_turns(const std::string & input) { +static std::vector get_speaker_turns(const std::string & input) { if (input.empty()) { LOG_ERR("Empty input\n"); return {}; @@ -44,19 +53,60 @@ static std::vector get_speaker_turns(const std::string & input) { } std::regex re(R"((\[\d+\][\s\S]*?)(?=\[\d+\]|$))"); std::smatch match; - std::vector turns; + std::vector turns; std::string::const_iterator searchStart(input.cbegin()); while (std::regex_search(searchStart, input.cend(), match, re)) { - std::string turn = match[1].str(); - if (turn.empty()) { + std::string turn_text = match[1].str(); + if (turn_text.empty()) { continue; } + // clean up newline, the model is quite sensitive to this + string_replace_all(turn_text, "\n", " "); + turn_text = string_strip(turn_text); + // add turn + speaker_turn turn; + turn.text = turn_text; turns.push_back(turn); searchStart = match.suffix().first; } return turns; } +static speaker_turn get_ref_speaker_turn(const char * text, std::initializer_list & codes, std::vector & codebook) { + const size_t n_embd = 2048; + const size_t n_codes_per_codebook = 2051; + const size_t n_codebooks = 32; + GGML_ASSERT(codebook.size() == n_embd * n_codes_per_codebook * n_codebooks); + GGML_ASSERT(codes.size() % 32 == 0); + + // 1 frame = 32 codes + size_t n_frames = codes.size() / n_codebooks; + speaker_turn turn; + turn.text = text; + turn.audio_embd.reserve((n_frames+1) * n_embd); + turn.n_embd_tokens = n_frames+1; // +1 for EOS frame + + for (size_t i_fr = 0; i_fr <= n_frames; i_fr++) { + std::vector frame_embd_sum(n_embd, 0.0f); + + for (size_t i_cb = 0; i_cb < n_codebooks; i_cb++) { + const size_t code = i_fr == n_frames + ? 0 // insert audio EOS for last pseudo-frame + : codes.begin()[i_fr*n_codebooks + i_cb]; + printf("code %zu: %zu, codebook entry %zu\n", i_cb, code, i_cb*n_codes_per_codebook + code); + float * entry = codebook.data() + i_cb*n_codes_per_codebook*n_embd + code*n_embd; + for (size_t i_embd = 0; i_embd < n_embd; i_embd++) { + frame_embd_sum[i_embd] += entry[i_embd]; + } + } + + turn.audio_embd.insert(turn.audio_embd.end(), frame_embd_sum.begin(), frame_embd_sum.end()); + } + + GGML_ASSERT(turn.audio_embd.size() == (n_frames+1) * n_embd); + return turn; +} + // sampling with custom n_vocab // modified version of llama_sampler_sample() static llama_token sample_token(struct llama_sampler * smpl, const float * logits, int n_vocab) { @@ -80,24 +130,75 @@ static llama_token sample_token(struct llama_sampler * smpl, const float * logit return token; } +struct hook_data { + std::vector embd; + std::vector codebook; +}; + // hook to retrieve the embeddings static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) { - std::vector * embd = (std::vector *) user_data; + hook_data * data = (hook_data *) user_data; // output_csm_proj is the embeddings output from backbone // output_audio_embd is the embeddings output from decoder if (t && (strcmp(t->name, "output_csm_proj") == 0 || strcmp(t->name, "output_audio_embd") == 0)) { if (ask) return true; - embd->resize(ggml_nelements(t)); - ggml_backend_tensor_get(t, embd->data(), 0, ggml_nbytes(t)); + GGML_ASSERT(t->type == GGML_TYPE_F32); + data->embd.resize(ggml_nelements(t)); + ggml_backend_tensor_get(t, data->embd.data(), 0, ggml_nbytes(t)); // printf("%s tensor size: %lld, %lld\n", t->name, t->ne[0], t->ne[1]); return true; } + if (t && strncmp(t->name, "audio_embd.weight", 18) == 0) { + if (ask) return true; + + printf("%s tensor size: %lld, %lld\n", t->name, t->ne[0], t->ne[1]); + GGML_ASSERT(t->type == GGML_TYPE_F32); + GGML_ASSERT(t->ne[0] == 2048); // backbone embd size + data->codebook.resize(ggml_nelements(t)); + ggml_backend_tensor_get(t, data->codebook.data(), 0, ggml_nbytes(t)); + return true; + } + return false; } +// convenience wrapper around llama_batch to handle memory allocation +struct decode_embd_batch { + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_0; + std::vector seq_ids; + std::vector logits; + llama_batch batch; + decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + pos .resize(n_tokens); + n_seq_id.resize(n_tokens); + seq_ids .resize(n_tokens + 1); + logits .resize(n_tokens); + seq_id_0.resize(1); + seq_id_0[0] = seq_id; + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ embd, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } +}; + int main(int argc, char ** argv) { common_params params; @@ -127,10 +228,9 @@ int main(int argc, char ** argv) { return 1; } - std::vector embd; + hook_data cb_data; params.cb_eval = ggml_callback; - params.cb_eval_user_data = &embd; - params.warmup = false; + params.cb_eval_user_data = &cb_data; common_params params_decoder(params); // duplicate the params params_decoder.n_ctx = 64; // we never use more than this @@ -177,15 +277,22 @@ int main(int argc, char ** argv) { std::vector generated_codes; - auto turns = get_speaker_turns(params.prompt); + std::vector turns; + // speaker reference + turns.push_back(get_ref_speaker_turn(default_speaker_a_text, default_speaker_a_codes, cb_data.codebook)); + turns.push_back(get_ref_speaker_turn(default_speaker_b_text, default_speaker_b_codes, cb_data.codebook)); + + // user input + auto custom_turns = get_speaker_turns(params.prompt); + turns.insert(turns.end(), custom_turns.begin(), custom_turns.end()); - for (const std::string & turn : turns) { + for (speaker_turn & turn : turns) { // tokenize the turn llama_tokens prompt_tokens; { - printf("\n---\nturn: %s\n\n", turn.c_str()); + printf("\n---\n\nturn: %s\n\n", turn.text.c_str()); const llama_vocab * vocab = llama_model_get_vocab(model_bb); - prompt_tokens = common_tokenize(vocab, turn, false, true); + prompt_tokens = common_tokenize(vocab, turn.text, false, true); prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab)); prompt_tokens.insert(prompt_tokens.end(), llama_vocab_eos(vocab)); @@ -193,21 +300,38 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < prompt_tokens.size(); ++i) { printf("%d, ", prompt_tokens[i]); } - printf("\n"); + printf("\n\n"); common_batch_clear(batch_prompt); for (size_t i = 0; i < prompt_tokens.size(); ++i) { common_batch_add(batch_prompt, prompt_tokens[i], n_past_bb++, { 0 }, false); } batch_prompt.logits[batch_prompt.n_tokens - 1] = true; + + if (llama_decode(ctx_bb, batch_prompt) != 0) { + LOG_ERR("%s: backbone llama_decode(text) failed\n", __func__); + return 1; + } + } + + // optionally process the system prompt (speaker reference) + if (turn.n_embd_tokens) { + decode_embd_batch batch_embd(turn.audio_embd.data(), turn.n_embd_tokens, n_past_bb, 0); + if (llama_decode(ctx_bb, batch_embd.batch) != 0) { + LOG_ERR("%s: backbone llama_decode(embeddings) failed\n", __func__); + return 1; + } + LOG_INF("%s: backbone done decoding %zu audio codes\n\n", __func__, turn.n_embd_tokens); + n_past_bb += turn.n_embd_tokens; + continue; // no need to generate the audio } // backbone generation loop bool is_end_of_turn = false; for (int k = 0; k < params.n_predict; ++k) { - bool is_prompt_processing = k == 0; + bool is_first_tok = k == 0; - if (!is_prompt_processing) { + if (!is_first_tok) { // generate the next RVQ semantic token batch_past_embd.n_tokens = 1; batch_past_embd.pos[0] = n_past_bb++; @@ -215,15 +339,15 @@ int main(int argc, char ** argv) { batch_past_embd.n_seq_id[0] = 1; batch_past_embd.logits[0] = true; std::memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float)); - } - int64_t t_bb_start = ggml_time_ms(); - if (llama_decode(ctx_bb, is_prompt_processing ? batch_prompt : batch_past_embd) != 0) { - LOG_ERR("%s: backbone llama_decode() failed\n", __func__); - return 1; + int64_t t_bb_start = ggml_time_ms(); + if (llama_decode(ctx_bb, batch_past_embd) != 0) { + LOG_ERR("%s: backbone llama_decode() failed\n", __func__); + return 1; + } + n_bb_gen++; + t_bb += ggml_time_ms() - t_bb_start; } - n_bb_gen++; - t_bb += ggml_time_ms() - t_bb_start; if (is_end_of_turn) { // done decoding audio's EOS token @@ -231,7 +355,7 @@ int main(int argc, char ** argv) { } auto vocab_dc = llama_model_get_vocab(model_dc); - auto logits = llama_get_logits_ith(ctx_bb, is_prompt_processing ? (batch_prompt.n_tokens - 1) : 0); + auto logits = llama_get_logits_ith(ctx_bb, is_first_tok ? (batch_prompt.n_tokens - 1) : 0); // for (size_t i = 0; i < 10; ++i) { // printf("%4.2f, ", logits[i]); // } @@ -251,7 +375,7 @@ int main(int argc, char ** argv) { inp_past_embd = std::vector(inp_past_embd.size(), 0.0f); { llama_kv_self_clear(ctx_dc); - llama_batch batch_embd = llama_batch_init(1, embd.size(), 1); + llama_batch batch_embd = llama_batch_init(1, cb_data.embd.size(), 1); llama_batch batch_token = llama_batch_init(1, 0, 1); // first "token" is the latent embeddings from backbone @@ -261,7 +385,7 @@ int main(int argc, char ** argv) { batch_embd.seq_id[0][0] = 0; batch_embd.n_seq_id[0] = 1; batch_embd.logits[0] = false; - std::memcpy(batch_embd.embd, embd.data(), embd.size() * sizeof(float)); + std::memcpy(batch_embd.embd, cb_data.embd.data(), cb_data.embd.size() * sizeof(float)); } if (llama_decode(ctx_dc, batch_embd) != 0) { LOG_ERR("%s: decoder llama_decode(embd) failed\n", __func__); @@ -299,9 +423,9 @@ int main(int argc, char ** argv) { } // do progressive hsum of embeddings - GGML_ASSERT(inp_past_embd.size() == embd.size()); + GGML_ASSERT(inp_past_embd.size() == cb_data.embd.size()); for (size_t i = 0; i < inp_past_embd.size(); ++i) { - inp_past_embd[i] += embd[i]; + inp_past_embd[i] += cb_data.embd[i]; } } printf("\n"); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 81d460d55e75c..54aa80ce3dbd2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4639,6 +4639,14 @@ struct llm_build_llama_csm : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); + // hacky way to get the audio embedding from user code (used in prompt processing) + // this will be triggered during warmup + if (is_decoder && n_tokens == 2) { + ggml_tensor * tmp = ggml_cast(ctx0, model.tok_embd, GGML_TYPE_F32); + cb(tmp, "audio_embd.weight", -1); + ggml_build_forward_expand(gf, tmp); + } + ggml_tensor * input_embd = inpL; // inp_pos - contains the positions From 9533fb752cb6e7a5d57436d015be41134ba8ead7 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 23 Apr 2025 14:30:16 +0200 Subject: [PATCH 29/31] fix build_attn --- src/llama-model.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5c4be5b8c3729..cd549e986c2a9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4783,7 +4783,7 @@ struct llm_build_llama_csm : public llm_graph_context { cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, kq_scale, il); + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1) { From e5bb5606976dd305aea072adc721451d8ab7d00a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 23 Apr 2025 14:44:45 +0200 Subject: [PATCH 30/31] rm print --- examples/tts/tts-csm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp index d47b0e598dd28..d9a5ef1102d89 100644 --- a/examples/tts/tts-csm.cpp +++ b/examples/tts/tts-csm.cpp @@ -154,7 +154,7 @@ static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) { if (t && strncmp(t->name, "audio_embd.weight", 18) == 0) { if (ask) return true; - printf("%s tensor size: %lld, %lld\n", t->name, t->ne[0], t->ne[1]); + // printf("%s tensor size: %lld, %lld\n", t->name, t->ne[0], t->ne[1]); GGML_ASSERT(t->type == GGML_TYPE_F32); GGML_ASSERT(t->ne[0] == 2048); // backbone embd size data->codebook.resize(ggml_nelements(t)); From c1cd710f592e6f825d8d9e199479ece0c53e2259 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 23 Apr 2025 14:56:17 +0200 Subject: [PATCH 31/31] fix pyright --- examples/tts/csm_generate_speaker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/tts/csm_generate_speaker.py b/examples/tts/csm_generate_speaker.py index 0dc6929a23d4c..a06dee6846eac 100644 --- a/examples/tts/csm_generate_speaker.py +++ b/examples/tts/csm_generate_speaker.py @@ -3,6 +3,7 @@ from transformers import MimiModel, AutoFeatureExtractor from transformers.models.mimi.modeling_mimi import MimiEncoderOutput +# pyright: reportMissingImports=false from scipy.io.wavfile import read from scipy.signal import resample import numpy as np