diff --git a/.gitignore b/.gitignore index 2c67ad7f7c609..41fe1f31271d2 100644 --- a/.gitignore +++ b/.gitignore @@ -107,6 +107,7 @@ examples/server/*.gz.hpp !examples/*/*/*.kts !examples/sycl/*.bat !examples/sycl/*.sh +/*.wav # Server Web UI temporary files node_modules diff --git a/common/common.cpp b/common/common.cpp index 94f545f815c27..b5668ddfdb2c9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1565,3 +1565,31 @@ common_control_vector_data common_control_vector_load(const std::vector & data, int sample_rate) { + std::ofstream file(fname, std::ios::binary); + if (!file) { + LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); + return false; + } + + wav_header header; + header.sample_rate = sample_rate; + header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); + header.block_align = header.num_channels * (header.bits_per_sample / 8); + header.data_size = data.size() * (header.bits_per_sample / 8); + header.chunk_size = 36 + header.data_size; + + file.write(reinterpret_cast(&header), sizeof(header)); + + for (const auto & sample : data) { + int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); + file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); + } + + return file.good(); +} diff --git a/common/common.h b/common/common.h index e6eaa8e80cf05..9012e657fbefc 100644 --- a/common/common.h +++ b/common/common.h @@ -662,3 +662,25 @@ const char * const LLM_KV_SPLIT_COUNT = "split.count"; const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; } + +// +// Audio utils +// + +struct wav_header { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t chunk_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_chunk_size = 16; + uint16_t audio_format = 1; // PCM + uint16_t num_channels = 1; // Mono + uint32_t sample_rate; + uint32_t byte_rate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; +}; + +bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate); diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index c72bd814c3b31..e66c298db461a 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -3,3 +3,20 @@ add_executable(${TARGET} tts.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) + +add_library(mimi-model STATIC mimi-model.h mimi-model.cpp) +target_link_libraries(mimi-model PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +# for using C++ designated initializers, TODO: can be changed back to C++17 in the future +target_compile_features(mimi-model PRIVATE cxx_std_20) + +set(TARGET llama-mimi) +add_executable(${TARGET} mimi.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common mimi-model ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-tts-csm) +add_executable(${TARGET} tts-csm.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common mimi-model ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/tts/README-csm.md b/examples/tts/README-csm.md new file mode 100644 index 0000000000000..676b9889e157d --- /dev/null +++ b/examples/tts/README-csm.md @@ -0,0 +1,47 @@ +# Sesame CSM + +This demo shows running inference of [Sesame CSM](https://github.com/SesameAILabs/csm) using llama.cpp / GGML + +It contains 3 components (each has its own GGUF file): +1. Backbone LLM +2. Decoder LLM +3. Mimi decoder + +## Quick start + +By default, all GGUF files are downloaded from [ggml-org Hugging Face's account](https://huggingface.co/ggml-org/sesame-csm-1b-GGUF) + +```sh +# build (make sure to have LLAMA_CURL enabled) +cmake -B build -DLLAMA_CURL=ON +cmake --build build -j --target llama-tts-csm + +# run it +./build/bin/llama-tts-csm -p "[0]Hi, my name is Xuan Son. I am software engineer at Hugging Face." +``` + +## Convert the model yourself + +To get the GGUF: + +```sh +python examples/tts/convert_csm_to_gguf.py + +# default output files: +# sesame-csm-backbone.gguf +# sesame-csm-decoder.gguf + +# optionally, quantize it +# (lowest scheme is q8_0, it does not make sense to quantize further, quality degrades too much) +python examples/tts/convert_csm_to_gguf.py --outtype q8_0 +``` + +Run the example using local file: + +```sh +./build/bin/llama-tts-csm -m sesame-csm-backbone.gguf -mv kyutai-mimi.gguf -p "[0]Hello world." +# sesame-csm-backbone.gguf will automatically be loaded +# make sure the place these 2 GGUF files in the same directory + +# output file: output.wav +``` diff --git a/examples/tts/README-mimi.md b/examples/tts/README-mimi.md new file mode 100644 index 0000000000000..6576a118291ad --- /dev/null +++ b/examples/tts/README-mimi.md @@ -0,0 +1,50 @@ +# llama.cpp/example/mimi + +This demonstrates running [Kyutai's Mimi](https://huggingface.co/kyutai/mimi) model via GGML. + +## Quickstart + +Convert model to GGUF (no need to download, the script will automatically download the `safetensors` file) + +```sh +python examples/tts/convert_mimi_to_gguf.py + +# output file: kyutai-mimi.gguf + +# optionally, use q8_0 quantization for faster speed +python examples/tts/convert_mimi_to_gguf.py --outtype q8_0 +``` + +Then compile, run it: + +```sh +cmake --build build -j --target llama-mimi + +./build/bin/llama-mimi kyutai-mimi.gguf codes.txt + +# output: output.wav + +# alternatively, use "dummy1" to get a "wah hello there" sample output file +./build/bin/llama-mimi kyutai-mimi.gguf dummy1 +``` + +Example of code file (one code per line): + +``` +1263 +1597 +1596 +1477 +1540 +1720 +1433 +118 +1066 +1968 +1096 +232 +418 +566 +1653 +2010 +``` diff --git a/examples/tts/convert_csm_to_gguf.py b/examples/tts/convert_csm_to_gguf.py new file mode 100644 index 0000000000000..53f586f19962d --- /dev/null +++ b/examples/tts/convert_csm_to_gguf.py @@ -0,0 +1,328 @@ +import os +import sys +import argparse +import logging +import torch +from safetensors.torch import load_file +from typing import Union, Any, Dict +from pathlib import Path +from torch import Tensor +from huggingface_hub import hf_hub_download + +cur_path = sys.path +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent.parent.parent / 'gguf-py')) +import gguf + +sys.path = cur_path + +logger = logging.getLogger("csm") + + +# This converts directly one safetensors file to 2 GGUFs +# It is easier to do this way, rather than convert to 2 smaller HF models and then convert to GGUF +# This is because the Sesame model does not have built-in tokenizer + +def get_field_data(reader: gguf.GGUFReader, key: str) -> Any: + field = reader.get_field(key) + return field.contents() if field else None + +# copied from https://github.com/SesameAILabs/csm/blob/main/models.py +class Llama_3_2_1B: + vocab_size=128_256 + num_layers=16 + num_heads=32 + num_kv_heads=8 + embed_dim=2048 + max_seq_len=2048 + intermediate_dim=8192 + attn_dropout=0.0 + norm_eps=1e-5 + rope_base=500_000 + scale_factor=32 + + def write_gguf_metadata(self, fout: gguf.GGUFWriter, fvocab: gguf.GGUFReader): + arch = get_field_data(fvocab, gguf.Keys.General.ARCHITECTURE) + assert arch == "llama" + fout.add_type("model") + fout.add_block_count(self.num_layers) + fout.add_context_length(self.max_seq_len) + fout.add_feed_forward_length(self.intermediate_dim) + fout.add_embedding_length(self.embed_dim) + # attn + fout.add_head_count(self.num_heads) + fout.add_head_count_kv(self.num_kv_heads) + fout.add_rope_freq_base(self.rope_base) + # fout.add_rope_scaling_factor(self.scale_factor) # breaks if this is added + fout.add_rope_dimension_count(self.embed_dim // self.num_heads) + fout.add_layer_norm_rms_eps(self.norm_eps) + fout.add_key_length(self.embed_dim // self.num_heads) + fout.add_value_length(self.embed_dim // self.num_heads) + # vocab + fout.add_vocab_size(self.vocab_size) + fout.add_tokenizer_model(get_field_data(fvocab, gguf.Keys.Tokenizer.MODEL)) + fout.add_tokenizer_pre(get_field_data(fvocab, gguf.Keys.Tokenizer.PRE)) + fout.add_token_list(get_field_data(fvocab, gguf.Keys.Tokenizer.LIST)[:self.vocab_size]) + fout.add_token_types(get_field_data(fvocab, gguf.Keys.Tokenizer.TOKEN_TYPE)[:self.vocab_size]) + fout.add_token_merges(get_field_data(fvocab, gguf.Keys.Tokenizer.MERGES)) + fout.add_bos_token_id(get_field_data(fvocab, gguf.Keys.Tokenizer.BOS_ID)) + fout.add_eos_token_id(get_field_data(fvocab, gguf.Keys.Tokenizer.EOS_ID)) + +class Llama_3_2_100M(Llama_3_2_1B): + vocab_size=65_632 #128_256 + num_layers=4 + num_heads=8 + num_kv_heads=2 + embed_dim=1024 + max_seq_len=2048 + intermediate_dim=8192 + attn_dropout=0.0 + norm_eps=1e-5 + rope_base=500_000 + scale_factor=32 + +class CSMModelConverter: + state_dict: Dict[str, Tensor] + gguf_writer_backbone: gguf.GGUFWriter + gguf_writer_decoder: gguf.GGUFWriter + gguf_reader_vocab: gguf.GGUFReader + fname_out: Path + ftype: gguf.LlamaFileType + + def __init__(self, + safetensors_path: Union[Path, str], + path_to_vocab_gguf: Path, + fname_out: Path, + ftype: gguf.LlamaFileType, + is_big_endian: bool,): + + if "" not in fname_out.name: + raise ValueError("Output file name must contain '' placeholder, for example: 'sesame-csm-.gguf'") + + self.state_dict = load_file(safetensors_path, device="cpu") + self.fname_out = fname_out + self.ftype = ftype + self.gguf_reader_vocab = gguf.GGUFReader(path_to_vocab_gguf) + endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + + # backbone + self.gguf_writer_backbone = gguf.GGUFWriter( + path=None, + arch="llama-csm", + endianess=endianess) + + # decoder + self.gguf_writer_decoder = gguf.GGUFWriter( + path=None, + arch="llama-csm", + endianess=endianess) + + Llama_3_2_1B().write_gguf_metadata(self.gguf_writer_backbone, self.gguf_reader_vocab) + Llama_3_2_100M().write_gguf_metadata(self.gguf_writer_decoder, self.gguf_reader_vocab) + + # load tensors + for component in ("backbone", "decoder"): + print() + print(f"Converting {component}...") + print() + for name, data_torch in self.state_dict.items(): + # convert any unsupported data types to float32 + old_dtype = data_torch.dtype + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + self.add_tensor(name, data_torch, old_dtype, component) + + def add_tensor(self, name: str, data_torch: Tensor, old_dtype: torch.dtype, component: str): + is_1d = len(data_torch.shape) == 1 + #is_embd = "_embeddings" in name + can_quantize = not is_1d #and not is_embd + data_qtype = gguf.GGMLQuantizationType.F32 + + is_backbone = False + is_decoder = False + + def rename_transformer(name: str) -> str: + # transformer + name = name.replace(".scale", ".weight") + name = name.replace("attn.k_proj", "attn_k") + name = name.replace("attn.q_proj", "attn_q") + name = name.replace("attn.v_proj", "attn_v") + name = name.replace("attn.output_proj", "attn_output") + name = name.replace("sa_norm", "attn_norm") + name = name.replace("mlp.w1", "ffn_gate") + name = name.replace("mlp.w2", "ffn_down") + name = name.replace("mlp.w3", "ffn_up") + name = name.replace("mlp_norm", "ffn_norm") + return name + + if "audio_embeddings." in name: + is_decoder = True + name = name.replace("audio_embeddings.", "audio_embd.") + + elif "text_embeddings." in name: + is_backbone = True + name = name.replace("text_embeddings.", "token_embd.") + + elif "backbone." in name or "codebook0_head." in name: + is_backbone = True + name = name.replace("backbone.layers.", "blk.") + name = name.replace("backbone.norm.scale", "output_norm.weight") + name = rename_transformer(name) + + elif "decoder." in name: + is_decoder = True + name = name.replace("decoder.layers.", "blk.") + name = name.replace("decoder.norm.scale", "output_norm.weight") + name = rename_transformer(name) + + elif name == "audio_head": + is_decoder = True + name = "audio_head.weight" + if component == "decoder": + # add padding at the beginning and the end so that build_lora_mm_id can be used + zero_tensor = torch.zeros(1, 1024, 2051) + data_torch = torch.cat([zero_tensor, data_torch, zero_tensor], dim=0) + assert data_torch.shape == (33, 1024, 2051) + # then, transpose it + data_torch = data_torch.transpose(1, 2) + + elif name == "projection.weight": + is_decoder = True + is_backbone = True + name = "csm_proj.weight" + + if can_quantize: + if self.ftype == gguf.LlamaFileType.ALL_F32: + data_qtype = gguf.GGMLQuantizationType.F32 + elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_qtype = gguf.GGMLQuantizationType.F16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data_qtype = gguf.GGMLQuantizationType.BF16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: + # decoder is very sensitive to quantization, do not quantize it lower than F16 + data_qtype = gguf.GGMLQuantizationType.Q8_0 if component != "decoder" \ + else gguf.GGMLQuantizationType.F16 + else: + raise ValueError(f"Unsupported file type: {self.ftype}") + + data = data_torch.numpy() + + try: + data = gguf.quants.quantize(data, data_qtype) + except Exception as e: + logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) + + if (is_backbone and component == "backbone") or (is_decoder and component == "decoder"): + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}" + logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + if component == "backbone": + self.gguf_writer_backbone.add_tensor(name, data, raw_dtype=data_qtype) + elif component == "decoder": + self.gguf_writer_decoder.add_tensor(name, data, raw_dtype=data_qtype) + + def write(self): + self._write_single(self.gguf_writer_backbone, "backbone") + self._write_single(self.gguf_writer_decoder, "decoder") + + def _write_single(self, gguf_writer: gguf.GGUFWriter, component: str): + output_path = str(self.fname_out).replace("", component) + gguf_writer.write_header_to_file(path=Path(output_path)) + gguf_writer.write_kv_data_to_file() + gguf_writer.write_tensors_to_file(progress=True) + gguf_writer.close() + + @staticmethod + def undo_permute(weights: Tensor, n_head: int, n_head_kv: int): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert Sesame model to GGUFs (multiple files)",) + parser.add_argument( + "--outfile", type=Path, default="sesame-csm-.gguf", + help="path to write to, the '' placeholder is required and will be replaced with 'backbone' and 'decoder'", + ) + parser.add_argument( + "--vocab", type=Path, default="models/ggml-vocab-llama-bpe.gguf", + help="path to vocab GGUF", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", + help="output format", + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "model", type=Path, + help="path to safetensors or model ID containing model file (if model ID is specified, download from Hugging Face hub)", + nargs="?", + default="sesame/csm-1b:model.safetensors", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + + args = parser.parse_args() + if args.model is None: + parser.error("the following arguments are required: model") + return args + + +def main() -> None: + args = parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + dir_model = args.model + path_vocab = args.vocab + + dir_parts = str(dir_model).split(":") + if len(dir_parts) == 2: + try: + dir_model = Path(hf_hub_download(dir_parts[0], dir_parts[1])) + except Exception as e: + print("Error downloading model from Hugging Face hub:", e) + print() + print("Please make sure you have access to the model") + print("Hint: you may need to set HF_TOKEN by running: huggingface-cli login") + + if not path_vocab.exists(): + raise FileNotFoundError(f"Vocab file not found: {path_vocab} ; Hint: download it from https://github.com/ggml-org/llama.cpp/blob/master/models/ggml-vocab-llama-bpe.gguf") + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + } + + logger.info(f"Loading model: {dir_model}") + + with torch.inference_mode(): + converter = CSMModelConverter( + safetensors_path=dir_model, + fname_out=args.outfile, + path_to_vocab_gguf=path_vocab, + ftype=ftype_map[args.outtype], + is_big_endian=args.bigendian, + ) + converter.write() + + +if __name__ == '__main__': + main() + diff --git a/examples/tts/convert_mimi_to_gguf.py b/examples/tts/convert_mimi_to_gguf.py new file mode 100644 index 0000000000000..81cb8f48cc25e --- /dev/null +++ b/examples/tts/convert_mimi_to_gguf.py @@ -0,0 +1,191 @@ +import gguf +import argparse +import logging +import torch +from typing import Union +from pathlib import Path +from torch import Tensor +from transformers import MimiModel, PreTrainedModel + +logger = logging.getLogger("mimi") + + +class MimiModelConverter: + mimi_model: PreTrainedModel + gguf_writer: gguf.GGUFWriter + fname_out: Path + ftype: gguf.LlamaFileType + + def __init__(self, + pretrained_model_name_or_path: Union[Path, str], + fname_out: Path, + ftype: gguf.LlamaFileType, + is_big_endian: bool,): + self.mimi_model = MimiModel.from_pretrained(pretrained_model_name_or_path) + self.fname_out = fname_out + self.ftype = ftype + endianess = gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE + self.gguf_writer = gguf.GGUFWriter( + path=None, + arch="this model cannot be used as LLM, use it via --model-vocoder in TTS examples", + endianess=endianess) + + assert self.mimi_model.config.architectures[0] == "MimiModel" + + # load tensors + for name, data_torch in self.mimi_model.state_dict().items(): + # convert any unsupported data types to float32 + old_dtype = data_torch.dtype + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + self.add_tensor(name, data_torch, old_dtype) + + def add_tensor(self, name: str, data_torch: Tensor, old_dtype: torch.dtype): + is_1d = len(data_torch.shape) == 1 + is_bias = ".bias" in name + can_quantize = not is_1d and not is_bias + data_qtype = gguf.GGMLQuantizationType.F32 + + n_head = self.mimi_model.config.num_attention_heads + n_kv_head = self.mimi_model.config.num_key_value_heads + if name.endswith(("q_proj.weight", "q_proj.bias")): + data_torch = self.undo_permute(data_torch, n_head, n_head) + if name.endswith(("k_proj.weight", "k_proj.bias")): + data_torch = self.undo_permute(data_torch, n_head, n_kv_head) + + # process codebook + if ".codebook.initialized" in name: + # "initialized" tensor + state_dict = self.mimi_model.state_dict() + embed_sum = state_dict[name.replace(".initialized", ".embed_sum")] + cluster_usage = state_dict[name.replace(".initialized", ".cluster_usage")] + # see modeling_mimi.py --> MimiEuclideanCodebook + data_torch = embed_sum / cluster_usage.clamp(min=self.mimi_model.config.norm_eps)[:, None] + name = name.replace(".initialized", "") + + # ignore processed tensors + if ".cluster_usage" in name or ".embed_sum" in name: + return + + # transpose some tensors + if ".conv.bias" in name: + data_torch = data_torch.view((1, data_torch.shape[0])) + data_torch = data_torch.transpose(0, 1) + + # change view 3d to 2d + if "quantizer" in name and "_proj." in name: + assert data_torch.shape[2] == 1 + data_torch = data_torch.view((data_torch.shape[0], data_torch.shape[1])) + + # shorten name, otherwise it will be too long for ggml to read + name = name.replace("_residual_vector_quantizer", "_rvq") + + if can_quantize: + if self.ftype == gguf.LlamaFileType.ALL_F32: + data_qtype = gguf.GGMLQuantizationType.F32 + elif self.ftype == gguf.LlamaFileType.MOSTLY_F16: + data_qtype = gguf.GGMLQuantizationType.F16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + data_qtype = gguf.GGMLQuantizationType.BF16 + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: + data_qtype = gguf.GGMLQuantizationType.Q8_0 + else: + raise ValueError(f"Unsupported file type: {self.ftype}") + + # Conv kernels are always F16 + if ".conv.weight" in name: + data_qtype = gguf.GGMLQuantizationType.F16 + + data = data_torch.numpy() + + try: + data = gguf.quants.quantize(data, data_qtype) + except Exception as e: + logger.error(f"Error quantizing tensor '{name}': {e}, fallback to F16") + data_qtype = gguf.GGMLQuantizationType.F16 + data = gguf.quants.quantize(data, data_qtype) + + # reverse shape to make it similar to the internal ggml dimension order + shape_str = f"{{{', '.join(str(n) for n in reversed(data_torch.shape))}}}" + logger.info(f"{f'%-32s' % f'{name},'} {old_dtype} --> {data_qtype.name}, shape = {shape_str}") + + self.gguf_writer.add_tensor(name, data, raw_dtype=data_qtype) + + def write(self): + self.gguf_writer.write_header_to_file(path=self.fname_out) + self.gguf_writer.write_kv_data_to_file() + self.gguf_writer.write_tensors_to_file(progress=True) + self.gguf_writer.close() + + @staticmethod + def undo_permute(weights: Tensor, n_head: int, n_head_kv: int): + if n_head_kv is not None and n_head != n_head_kv: + n_head = n_head_kv + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Convert Mimi safetensors model to GGUF",) + parser.add_argument( + "--outfile", type=Path, default="kyutai-mimi.gguf", + help="path to write to", + ) + parser.add_argument( + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0"], default="f16", + help="output format", + ) + parser.add_argument( + "--bigendian", action="store_true", + help="model is executed on big endian machine", + ) + parser.add_argument( + "model", type=Path, + help="directory or model ID containing model file (if model ID is specified, download from Hugging Face hub)", + nargs="?", + default="kyutai/mimi", + ) + parser.add_argument( + "--verbose", action="store_true", + help="increase output verbosity", + ) + + args = parser.parse_args() + if args.model is None: + parser.error("the following arguments are required: model") + return args + + +def main() -> None: + args = parse_args() + + if args.verbose: + logging.basicConfig(level=logging.DEBUG) + else: + logging.basicConfig(level=logging.INFO) + + dir_model = args.model + + ftype_map: dict[str, gguf.LlamaFileType] = { + "f32": gguf.LlamaFileType.ALL_F32, + "f16": gguf.LlamaFileType.MOSTLY_F16, + "bf16": gguf.LlamaFileType.MOSTLY_BF16, + "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + } + + logger.info(f"Loading model: {dir_model}") + + with torch.inference_mode(): + converter = MimiModelConverter( + pretrained_model_name_or_path=dir_model, + fname_out=args.outfile, + ftype=ftype_map[args.outtype], + is_big_endian=args.bigendian, + ) + converter.write() + + +if __name__ == '__main__': + main() + diff --git a/examples/tts/csm-demo.txt b/examples/tts/csm-demo.txt new file mode 100644 index 0000000000000..1c913388bfb3d --- /dev/null +++ b/examples/tts/csm-demo.txt @@ -0,0 +1,5 @@ +[0]Hey how are you doing. +[1]Pretty good, pretty good. +[0]I'm great, so happy to be speaking to you. +What about you? +[1]Me too, this is some cool stuff huh? diff --git a/examples/tts/csm_generate_speaker.py b/examples/tts/csm_generate_speaker.py new file mode 100644 index 0000000000000..a06dee6846eac --- /dev/null +++ b/examples/tts/csm_generate_speaker.py @@ -0,0 +1,80 @@ +import argparse +from pathlib import Path +from transformers import MimiModel, AutoFeatureExtractor +from transformers.models.mimi.modeling_mimi import MimiEncoderOutput + +# pyright: reportMissingImports=false +from scipy.io.wavfile import read +from scipy.signal import resample +import numpy as np + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Generate speaker reference file, used by llama-tts-csm example",) + parser.add_argument( + "--model-path", type=Path, + help="custom Mimi model path (safetensors model). If not specified, will use the default model from Hugging Face hub", + ) + parser.add_argument( + "infile", type=Path, + help="the wav input file to use for generating the speaker reference file", + nargs="?", + ) + # parser.add_argument( + # "outfile", type=Path, + # help="the output file, defaults to the input file with .codes suffix", + # nargs="?", + # ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + if args.infile is None: + raise ValueError("Input file is required") + + if not args.infile.exists(): + raise FileNotFoundError(f"Input file {args.infile} not found") + + # if args.outfile is None: + # args.outfile = args.infile.with_suffix(".codes") + + model = MimiModel.from_pretrained(args.model_path or "kyutai/mimi") + feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_path or "kyutai/mimi") + + inp_audio = read(args.infile) + original_sample_rate = inp_audio[0] + audio_data = inp_audio[1] + + # If stereo, get only the first channel + if len(audio_data.shape) > 1 and audio_data.shape[1] >= 2: + audio_data = audio_data[:, 0] + + # resample + target_sample_rate = 24000 + number_of_samples = round(len(audio_data) * float(target_sample_rate) / original_sample_rate) + resampled_audio = resample(audio_data, number_of_samples) + resampled_audio = resampled_audio / max(np.max(np.abs(resampled_audio)), 1e-10) + + # pre-process the inputs + audio_sample = np.array(resampled_audio, dtype=float) + inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") + print('inputs', inputs["input_values"], inputs["input_values"].shape) + + # encode + encoder_outputs = model.encode(inputs["input_values"]) + assert isinstance(encoder_outputs, MimiEncoderOutput), "encoder_outputs should be of type MimiEncoderOutput" + + # output + flattened_audio_codes = encoder_outputs.audio_codes.transpose(-1, -2).flatten() + for i in range(0, len(flattened_audio_codes), 16): + for code in flattened_audio_codes[i:i+16].tolist(): + print(f"{code:<5}", end=",") + print() + + +if __name__ == '__main__': + main() diff --git a/examples/tts/mimi-model.cpp b/examples/tts/mimi-model.cpp new file mode 100644 index 0000000000000..fee88c679e1f3 --- /dev/null +++ b/examples/tts/mimi-model.cpp @@ -0,0 +1,734 @@ +#include "ggml.h" +#include "ggml-cpp.h" +#include "ggml-cpu.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "gguf.h" + +#include "common.h" +#include "mimi-model.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * Implementation of Kyutai's Mimi model using GGML. + * Based on this research: https://github.com/ngxson/ggml-easy/blob/master/demo/kyutai-mimi.cpp + * + * NOTE: only decoder is working for now. + * + * Background: + * - The audio codes can be generated using any Mimi-based model, for example: Moshi, Hibiki, Sesame, etc + * - Audio codes must be in the order: N semantic codes followed by (N*31) acoustic codes + * (In other words, input matrix has shape 32 cols x N rows) + * + * How it works? + * 1. Audio code passed to RVQ (mimi_residual_vector_quantizer) to get the latent code + * 2. The latent code is passed to a mimi_conv_transpose_1d (depthwise) to upscale + * 3. The upscaled code is passed to transformer, it converts N frames to N frames + * 4. The output embeddings is then passed to SEANet (mimi_encoder_decoder) to get the final waveform + * 5. Waveform is written to a file + */ + +// copied from https://huggingface.co/kyutai/mimi/blob/main/config.json +struct mimi_config_t { + bool causal = true; + int sample_rate = 24000; + int max_position_embeddings = 8000; + int num_hidden_layers = 8; + int n_embd = 512; + int n_ffn = 2048; + int n_head = 8; + int n_head_kv = 8; + int n_rot = 64; + float norm_eps = 1e-5; + float rope_theta = 10000.0f; + int sliding_window = 250; + std::array upsampling_ratio = {8, 6, 5, 4}; + std::array downsampling_ratio = {4, 5, 6, 8}; // reverse of upsampling_ratio + // vector quantizer + float frame_rate = 12.5; + int audio_channels = 1; + int codebook_size = 2048; + int codebook_dim = 256; + int n_semantic_components = 1; + int n_acoustic_components = 31; + // decode + float trim_right_ratio = 1.0f; + int n_codes_per_frame = (sliding_window / 2) * (n_semantic_components + n_acoustic_components); +} mimi_config; + +// Adapted from https://github.com/ngxson/ggml-easy/blob/master/ggml-easy.h +struct mimi_ggml_ctx { + gguf_context * ctx_gguf = nullptr; + ggml_context * ctx_data = nullptr; + ggml_context * ctx_gf = nullptr; + + // CPU-only for now, as many kernels are missing and we actually get less performance with GPU + ggml_backend_t backend = nullptr; + ggml_backend_buffer_t buf = nullptr; + ggml_backend_sched_ptr sched; + + ggml_cgraph * gf = nullptr; + std::vector buf_compute_meta; + int max_nodes = 16 * 1024; + + std::unordered_map tensors; + + mimi_ggml_ctx() { + backend = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + auto buft = ggml_backend_get_default_buffer_type(backend); + sched.reset( + ggml_backend_sched_new(&backend, &buft, 1, max_nodes, false) + ); + buf_compute_meta.resize(max_nodes * ggml_tensor_overhead() + ggml_graph_overhead()); + } + + void load_gguf(const char * fname) { + ggml_context * meta = nullptr; + + gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &meta, + }; + + ctx_gguf = gguf_init_from_file(fname, params); + + // load tensors + const int n_tensors = gguf_get_n_tensors(ctx_gguf); + + std::vector read_buf; + ggml_init_params ggml_params = { + /*.mem_size =*/ (n_tensors + 1) * ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ctx_data = ggml_init(ggml_params); + auto fin = std::ifstream(fname, std::ios::binary); + if (!fin) { + ggml_free(meta); + throw std::runtime_error("cannot open model file for loading tensors"); + } + + // add tensors to context + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + ggml_tensor * t = ggml_get_tensor(meta, name); + ggml_tensor * cur = ggml_dup_tensor(ctx_data, t); + ggml_set_name(cur, name); + tensors.insert({name, cur}); + } + + // alloc memory and offload data + ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend); + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx_data, buft); + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + ggml_tensor * cur = ggml_get_tensor(ctx_data, name); + const size_t offset = gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, i); + // printf("%s: Loading tensor \"%s\"\n", __func__, name); + fin.seekg(offset, std::ios::beg); + if (!fin) { + ggml_free(meta); + throw std::runtime_error(string_format("failed to seek for tensor: %s", name)); + } + int num_bytes = ggml_nbytes(cur); + if (ggml_backend_buft_is_host(buft)) { + // for the CPU and Metal backend, we can read directly into the tensor + fin.read(reinterpret_cast(cur->data), num_bytes); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(num_bytes); + fin.read(reinterpret_cast(read_buf.data()), num_bytes); + ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + } + } + printf("%s: Loaded %d tensors from %s\n", __func__, n_tensors, fname); + fin.close(); + + ggml_free(meta); + } + + /** + * Build a cgraph using the given builder function. + * + * The built cgraph will be stored in `ctx.gf` + */ + void build_graph(std::function builder_fn) { + ggml_free(ctx_gf); + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ctx_gf = ggml_init(params); + ggml_backend_sched_reset(sched.get()); + gf = ggml_new_graph_custom(ctx_gf, max_nodes, false); + + builder_fn(ctx_gf, gf); + ggml_backend_sched_alloc_graph(sched.get(), gf); + } + + ggml_status compute() { + ggml_status status = ggml_backend_sched_graph_compute(sched.get(), gf); + return status; + } + + void set_tensor_data(const std::string & name, const void * data) { + ggml_tensor * t = ggml_get_tensor(ctx_gf, name.c_str()); + if (!t) { + throw std::runtime_error(string_format("tensor not found: %s", name.c_str())); + } + ggml_backend_tensor_set(t, data, 0, ggml_nbytes(t)); + } + + std::pair> get_tensor_data(const std::string & name) { + ggml_tensor * t = ggml_get_tensor(ctx_gf, name.c_str()); + if (!t) { + throw std::runtime_error(string_format("tensor not found: %s", name.c_str())); + } + std::vector data(ggml_nbytes(t)); + ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t)); + return std::make_pair(t, data); + } + + ggml_tensor * get_weight(const char *fmt, ...) { + std::vector str(128); + va_list va; + va_start(va, fmt); + vsnprintf(str.data(), 128, fmt, va); + va_end(va); + auto it = tensors.find(str.data()); + if (it == tensors.end()) { + throw std::runtime_error(string_format("weight tensor not found: %s", str.data())); + } + return it->second; + } + + ~mimi_ggml_ctx() { + ggml_free(ctx_data); + gguf_free(ctx_gguf); + ggml_backend_buffer_free(buf); + } +}; + +/////////////////////////////////////////////////////////////////////////// +// extension to ggml.h +// TODO: add these ops to the library (ofc with a more optimized kernel) + + +// mode: (0) constant, (1) reflect, (2) replicate, (3) circular +// value is only used in "constant" +// only "constant" with 0.0f and "replicate" are implemented here +static ggml_tensor * ggml_pad_ext(ggml_context * ctx0, ggml_tensor * x, int mode, + int64_t pad_left, int64_t pad_right, float value = 0.0f) { + GGML_ASSERT(value == 0.0f); // we can technically use ggml_arange, but for simplication we only support 0.0f + GGML_ASSERT(mode == 0 || mode == 2); + if (pad_left > 0) { + ggml_tensor * tmp = ggml_new_tensor_2d(ctx0, x->type, pad_left, x->ne[1]); + if (mode == 0) { + tmp = ggml_scale(ctx0, tmp, value); + } else if (mode == 2) { + ggml_tensor * elem = ggml_view_2d(ctx0, x, 1, x->ne[1], x->nb[1], 0); // get first column + tmp = ggml_repeat(ctx0, elem, tmp); + } + x = ggml_concat(ctx0, tmp, x, 0); + } + if (pad_right > 0) { + ggml_tensor * tmp = ggml_new_tensor_2d(ctx0, x->type, pad_right, x->ne[1]); + if (mode == 0) { + tmp = ggml_scale(ctx0, tmp, value); + } else if (mode == 2) { + int64_t last = x->ne[0] - 1; + ggml_tensor * elem = ggml_view_2d(ctx0, x, 1, x->ne[1], x->nb[1], last * ggml_element_size(x)); // get last column + tmp = ggml_repeat(ctx0, elem, tmp); + } + x = ggml_concat(ctx0, x, tmp, 0); + } + return x; +} + + + + +/////////////////////////////////////////////////////////////////////////// +// MimiConv and MimiConvTranspose + +static int64_t div_ceil(int64_t a, int64_t b) { + return a / b + (a % b ? 1 : 0); +} + +static ggml_tensor * mimi_conv_1d(ggml_context * ctx0, ggml_tensor * x, + ggml_tensor * kernel, ggml_tensor * bias, int stride, int dilation, bool pad_zero = true) { + int64_t kernel_size = (kernel->ne[0] - 1) * dilation + 1; + int64_t p_total = kernel_size - stride; // padding total + int64_t p_half = p_total / 2; + + int64_t n_frames = div_ceil(x->ne[0] - kernel_size + p_total, stride); + int64_t ideal_len = n_frames * stride + kernel_size - p_total; + int64_t p_extra = ideal_len - x->ne[0]; + + int64_t p_right = (mimi_config.causal ? 0 : p_half) + p_extra; + int64_t p_left = p_total - (mimi_config.causal ? 0 : p_half); + + x = ggml_pad_ext(ctx0, x, pad_zero ? 0 : 2, p_left, p_right); + + x = ggml_conv_1d(ctx0, kernel, x, stride, 0, dilation); + if (bias) { + x = ggml_add(ctx0, x, bias); + } + ggml_set_name(x, "mimi_conv_1d"); + return x; +} + +static ggml_tensor * mimi_conv_transpose_1d(ggml_context * ctx0, ggml_tensor * x, + ggml_tensor * kernel, ggml_tensor * bias, int stride, int dilation, bool depthwise) { + GGML_ASSERT(x->ne[1] == kernel->ne[2]); + int64_t n_rows = x->ne[1]; + int64_t kernel_size = kernel->ne[0]; + int64_t p_total = kernel_size - stride; // padding total + + int64_t p_right = mimi_config.causal + ? (float)p_total / mimi_config.trim_right_ratio + : p_total / 2; + int64_t p_left = p_total - p_right; + + ggml_tensor * out = nullptr; + + if (depthwise) { + for (int64_t ir = 0; ir < n_rows; ir++) { + ggml_tensor * row = ggml_view_1d(ctx0, x, + x->ne[0], ir*x->ne[0]*ggml_element_size(x)); + ggml_tensor * krn = ggml_view_1d(ctx0, kernel, + kernel->ne[0], ir*kernel->ne[0]*ggml_element_size(kernel)); + row = ggml_conv_transpose_1d(ctx0, krn, row, stride, 0, dilation); + // unpad (remove p_right and p_left columns) + row = ggml_view_1d(ctx0, row, row->ne[0] - p_total, p_left*ggml_element_size(row)); + + // TODO: concat can be slow, we should use ggml_view_1d/ggml_cpy to avoid realloc + out = out ? ggml_concat(ctx0, out, row, 1) : row; + } + + } else { + out = ggml_conv_transpose_1d(ctx0, kernel, x, stride, 0, dilation); + // unpad + out = ggml_view_2d(ctx0, out, + out->ne[0] - p_total, out->ne[1], + out->nb[1], p_left*ggml_element_size(out)); + } + + if (bias) { + out = ggml_add(ctx0, out, bias); + } + + return out; +} + + + +/////////////////////////////////////////////////////////////////////////// + +// based on MimiEncoder +// SEANet encoder as used by Mimi. +struct mimi_encoder_decoder { + mimi_ggml_ctx & ctx; + struct layer { + bool is_elu = false; + bool is_resnet = false; + bool is_transposed_conv = false; + ggml_tensor * conv_0_w = nullptr; + ggml_tensor * conv_0_b = nullptr; + ggml_tensor * conv_1_w = nullptr; + ggml_tensor * conv_1_b = nullptr; + int stride = 1; + }; + std::vector layers; + + std::array repeated_pattern = {1, 4, 7, 10}; + + mimi_encoder_decoder(mimi_ggml_ctx & ctx): ctx(ctx) { + layers.push_back({ + .conv_0_w = ctx.get_weight("decoder.layers.0.conv.weight"), + .conv_0_b = ctx.get_weight("decoder.layers.0.conv.bias"), + }); + for (int i = 0; i < (int)repeated_pattern.size(); ++i) { + int i_start = repeated_pattern[i]; + // upsampling layers + layers.push_back({ + .is_elu = true, // layer (i_start) + }); + layers.push_back({ + .is_transposed_conv = true, + .conv_0_w = ctx.get_weight("decoder.layers.%d.conv.weight", i_start + 1), + .conv_0_b = ctx.get_weight("decoder.layers.%d.conv.bias", i_start + 1), + .stride = mimi_config.upsampling_ratio[i], + }); + // residual layers + layers.push_back({ + .is_resnet = true, + .conv_0_w = ctx.get_weight("decoder.layers.%d.block.1.conv.weight", i_start + 2), + .conv_0_b = ctx.get_weight("decoder.layers.%d.block.1.conv.bias", i_start + 2), + .conv_1_w = ctx.get_weight("decoder.layers.%d.block.3.conv.weight", i_start + 2), + .conv_1_b = ctx.get_weight("decoder.layers.%d.block.3.conv.bias", i_start + 2), + }); + } + layers.push_back({ + .is_elu = true, // layer 13 + }); + layers.push_back({ + .conv_0_w = ctx.get_weight("decoder.layers.14.conv.weight"), + .conv_0_b = ctx.get_weight("decoder.layers.14.conv.bias"), + }); + } + + ggml_tensor * forward(ggml_context * ctx0, ggml_tensor * input) { + ggml_tensor * x = input; + + for (auto & layer : layers) { + if (layer.is_elu) { + x = ggml_elu(ctx0, x); + } else if (layer.is_resnet) { + ggml_tensor * residual = x; + x = ggml_elu(ctx0, x); + x = mimi_conv_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, 1, 1); + x = ggml_elu(ctx0, x); + x = mimi_conv_1d(ctx0, x, layer.conv_1_w, layer.conv_1_b, 1, 1); + x = ggml_add(ctx0, x, residual); + } else { + x = layer.is_transposed_conv + ? mimi_conv_transpose_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, layer.stride, 1, false) + : mimi_conv_1d(ctx0, x, layer.conv_0_w, layer.conv_0_b, layer.stride, 1); + } + } + + return x; + } +}; + +struct mimi_transformer { + struct layer { + ggml_tensor * inp_norm_w = nullptr; + ggml_tensor * inp_norm_b = nullptr; + + ggml_tensor * attn_q = nullptr; + ggml_tensor * attn_k = nullptr; + ggml_tensor * attn_v = nullptr; + ggml_tensor * attn_o = nullptr; + ggml_tensor * attn_post_norm_w = nullptr; + ggml_tensor * attn_post_norm_b = nullptr; + ggml_tensor * attn_layer_scale = nullptr; + + ggml_tensor * ffn_up = nullptr; + ggml_tensor * ffn_down = nullptr; + ggml_tensor * mlp_layer_scale = nullptr; + }; + std::vector layers; + + mimi_transformer(mimi_ggml_ctx & ctx, const char * prefix, int n_layers) { + for (int il = 0; il < n_layers; il++) { + layers.push_back({ + .inp_norm_w = ctx.get_weight("%s_transformer.layers.%d.input_layernorm.weight", prefix, il), + .inp_norm_b = ctx.get_weight("%s_transformer.layers.%d.input_layernorm.bias", prefix, il), + + .attn_q = ctx.get_weight("%s_transformer.layers.%d.self_attn.q_proj.weight", prefix, il), + .attn_k = ctx.get_weight("%s_transformer.layers.%d.self_attn.k_proj.weight", prefix, il), + .attn_v = ctx.get_weight("%s_transformer.layers.%d.self_attn.v_proj.weight", prefix, il), + .attn_o = ctx.get_weight("%s_transformer.layers.%d.self_attn.o_proj.weight", prefix, il), + .attn_post_norm_w = ctx.get_weight("%s_transformer.layers.%d.post_attention_layernorm.weight", prefix, il), + .attn_post_norm_b = ctx.get_weight("%s_transformer.layers.%d.post_attention_layernorm.bias", prefix, il), + .attn_layer_scale = ctx.get_weight("%s_transformer.layers.%d.self_attn_layer_scale.scale", prefix, il), + + .ffn_up = ctx.get_weight("%s_transformer.layers.%d.mlp.fc1.weight", prefix, il), + .ffn_down = ctx.get_weight("%s_transformer.layers.%d.mlp.fc2.weight", prefix, il), + .mlp_layer_scale = ctx.get_weight("%s_transformer.layers.%d.mlp_layer_scale.scale", prefix, il), + }); + } + } + + ggml_tensor * forward(ggml_context * ctx0, ggml_tensor * input, ggml_tensor * inp_pos) { + int n_tokens = input->ne[1]; + ggml_tensor * x = input; + + auto layer_norm = [&](ggml_tensor * x, ggml_tensor * w, ggml_tensor * b) { + x = ggml_norm(ctx0, x, mimi_config.norm_eps); + x = ggml_mul(ctx0, x, w); + x = ggml_add(ctx0, x, b); + return x; + }; + + ggml_tensor * residual = input; + + for (auto & layer : layers) { + residual = x; + + // input layer norm + x = layer_norm(x, layer.inp_norm_w, layer.inp_norm_b); + + // self attention + { + ggml_tensor * q = ggml_mul_mat(ctx0, layer.attn_q, x); + ggml_tensor * k = ggml_mul_mat(ctx0, layer.attn_k, x); + ggml_tensor * v = ggml_mul_mat(ctx0, layer.attn_v, x); + + int n_embd_head = mimi_config.n_embd / mimi_config.n_head; + q = ggml_reshape_3d(ctx0, q, n_embd_head, mimi_config.n_head, n_tokens); + k = ggml_reshape_3d(ctx0, k, n_embd_head, mimi_config.n_head_kv, n_tokens); + v = ggml_reshape_3d(ctx0, v, n_embd_head, mimi_config.n_head_kv, n_tokens); + + int n_rot = n_embd_head; + q = ggml_rope_inplace(ctx0, q, inp_pos, n_rot, 0); + q = ggml_cont(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3)); + + k = ggml_rope_inplace(ctx0, k, inp_pos, n_rot, 0); + k = ggml_cont(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3)); + + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + ggml_mul_mat_set_prec(kq, GGML_PREC_F32); // mimic behavior of llama.cpp + kq = ggml_scale_inplace(ctx0, kq, 1.0f / std::sqrt(n_embd_head)); + ggml_tensor * kq_masked = ggml_diag_mask_inf_inplace(ctx0, kq, n_tokens); + kq = ggml_soft_max_inplace(ctx0, kq_masked); + + v = ggml_cont(ctx0, ggml_permute(ctx0, v, 1, 2, 0, 3)); + + ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); + kqv = ggml_reshape_3d(ctx0, kqv, n_embd_head, n_tokens, mimi_config.n_head); + kqv = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + kqv = ggml_cont_2d(ctx0, kqv, mimi_config.n_embd, n_tokens); + + x = ggml_mul_mat(ctx0, layer.attn_o, kqv); + } + + // residual + x = ggml_mul(ctx0, x, layer.attn_layer_scale); + x = ggml_add(ctx0, x, residual); + + residual = x; + x = layer_norm(x, layer.attn_post_norm_w, layer.attn_post_norm_b); + + // mlp + { + x = ggml_mul_mat(ctx0, layer.ffn_up, x); + x = ggml_gelu(ctx0, x); + x = ggml_mul_mat(ctx0, layer.ffn_down, x); + } + + // residual + x = ggml_mul(ctx0, x, layer.mlp_layer_scale); + x = ggml_add(ctx0, x, residual); + } + + return x; + } +}; + +struct mimi_residual_vector_quantizer { + struct component { + ggml_tensor * codebook; + }; + + ggml_tensor * semantic_inp_proj; + std::vector semantic_components; + ggml_tensor * semantic_out_proj; + + ggml_tensor * acoustic_inp_proj; + std::vector acoustic_components; + ggml_tensor * acoustic_out_proj; + + mimi_residual_vector_quantizer(mimi_ggml_ctx & ctx) { + semantic_inp_proj = ctx.get_weight("quantizer.semantic_rvq.input_proj.weight"); + semantic_out_proj = ctx.get_weight("quantizer.semantic_rvq.output_proj.weight"); + for (int i = 0; i < mimi_config.n_semantic_components; i++) { + semantic_components.push_back({ + .codebook = ctx.get_weight("quantizer.semantic_rvq.layers.%d.codebook", i), + }); + } + acoustic_inp_proj = ctx.get_weight("quantizer.acoustic_rvq.input_proj.weight"); + acoustic_out_proj = ctx.get_weight("quantizer.acoustic_rvq.output_proj.weight"); + for (int i = 0; i < mimi_config.n_acoustic_components; i++) { + acoustic_components.push_back({ + .codebook = ctx.get_weight("quantizer.acoustic_rvq.layers.%d.codebook", i), + }); + } + } + + // the input has shape [n_codes, n_codes_per_embd] + // first row is semantic, the rest are acoustic + // example: [ [semantic], [acoustic1], [acoustic2], ... ] + ggml_tensor * decode(ggml_context * ctx0, ggml_tensor * input) { + GGML_ASSERT(input->type == GGML_TYPE_I32); + + size_t n_semantic = semantic_components.size(); + int64_t n_codes_per_embd = (n_semantic + acoustic_components.size()); + int64_t n_codes = input->ne[0] / n_codes_per_embd; + + GGML_ASSERT(input->ne[0] % n_codes_per_embd == 0); + + ggml_tensor * out_s = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mimi_config.codebook_dim, n_codes); + ggml_tensor * out_a = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mimi_config.codebook_dim, n_codes); + out_s = ggml_scale(ctx0, out_s, 0.0f); // clear + out_a = ggml_scale(ctx0, out_a, 0.0f); // clear + + for (size_t ir = 0; ir < (size_t)n_codes_per_embd; ir++) { + ggml_tensor * row = ggml_view_1d(ctx0, input, n_codes, ir*n_codes*ggml_element_size(input)); + if (ir < n_semantic) { + // semantic + ggml_tensor * codebook = semantic_components[ir].codebook; + ggml_tensor * embd = ggml_get_rows(ctx0, codebook, row); + out_s = ggml_add(ctx0, out_s, embd); + } else { + // acoustic + ggml_tensor * codebook = acoustic_components[ir-n_semantic].codebook; + ggml_tensor * embd = ggml_get_rows(ctx0, codebook, row); + out_a = ggml_add(ctx0, out_a, embd); + } + } + + out_s = ggml_mul_mat(ctx0, semantic_out_proj, out_s); + out_a = ggml_mul_mat(ctx0, acoustic_out_proj, out_a); + + return ggml_add(ctx0, out_s, out_a); + } +}; + + +mimi_model::mimi_model(const char * fname, bool verbose) : verbose(verbose) { + ctx.reset(new mimi_ggml_ctx()); + ctx->load_gguf(fname); + + // initialize components + seanet_dec .reset(new mimi_encoder_decoder(*ctx)); + transformer_dec.reset(new mimi_transformer(*ctx, "decoder", mimi_config.num_hidden_layers)); + quantizer .reset(new mimi_residual_vector_quantizer(*ctx)); +} + +mimi_model::~mimi_model() { +} + +std::vector mimi_model::decode_frame(const std::vector & codes, int & n_past) { + // build cgraph + int n_pos = -1; + int n_codes = codes.size(); + int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components; + GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiply of n_codes_per_embd"); + + ctx->build_graph([&](ggml_context * ctx_gf, ggml_cgraph * gf) { + ggml_tensor * inp_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_codes); + ggml_set_name(inp_dec, "inp_dec"); + ggml_set_input(inp_dec); + + // RVQ + ggml_tensor * embeddings = quantizer->decode(ctx_gf, inp_dec); + + // upsample + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + embeddings = mimi_conv_transpose_1d(ctx_gf, embeddings, ctx->get_weight("upsample.conv.weight"), nullptr, 2, 1, true); + + // transformer + n_pos = embeddings->ne[0]; + ggml_tensor * pos_dec = ggml_new_tensor_1d(ctx_gf, GGML_TYPE_I32, n_pos); + ggml_set_name(pos_dec, "pos_dec"); + ggml_set_input(pos_dec); + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + embeddings = transformer_dec->forward(ctx_gf, embeddings, pos_dec); + + // SEANET decoder + embeddings = ggml_cont(ctx_gf, ggml_transpose(ctx_gf, embeddings)); + ggml_tensor * output = seanet_dec->forward(ctx_gf, embeddings); + + ggml_set_name(output, "output"); + ggml_set_output(output); + ggml_build_forward_expand(gf, output); + }); + + // position data + GGML_ASSERT(n_pos <= mimi_config.sliding_window); + std::vector pos_data(n_pos); + for (int i = 0; i < (int)pos_data.size(); i++) { + pos_data[i] = i + n_past; + } + if (verbose) { + printf("%s: n_pos: %d, n_past: %d\n", __func__, n_pos, n_past); + } + n_past += n_pos; + ctx->set_tensor_data("pos_dec", pos_data.data()); + + // code data + auto codes_T = mimi_model::transpose_input(codes); + ctx->set_tensor_data("inp_dec", codes_T.data()); + + ctx->compute(); + + auto output = ctx->get_tensor_data("output"); + // auto output_tensor = output.first; + auto output_data = output.second; + // printf("Output shape: [%lld, %lld]\n", output_tensor->ne[0], output_tensor->ne[1]); + + std::vector wav_data(output_data.size() / sizeof(float)); + for (size_t i = 0; i < wav_data.size(); i++) { + wav_data[i] = ((float *)output_data.data())[i]; + } + + return wav_data; +} + +std::vector mimi_model::decode(const std::vector & codes) { + std::vector output; + + if (verbose) { + printf("%s: n_codes: %zu\n", __func__, codes.size()); + } + + int64_t t_start = ggml_time_ms(); + int n_frames = 0; + + int n_past = 0; + for (size_t i = 0; i < codes.size(); i += mimi_config.n_codes_per_frame) { + size_t remaining = std::min((size_t)mimi_config.n_codes_per_frame, codes.size() - i); + std::vector frame(codes.begin() + i, codes.begin() + i + remaining); + + auto wav_data = decode_frame(frame, n_past); + output.insert(output.end(), wav_data.begin(), wav_data.end()); + + n_frames++; + } + + int64_t t_end = ggml_time_ms(); + if (verbose) { + printf("%s: n_frames: %d, time: %" PRId64 "ms, per_frame: %" PRId64 "ms\n", __func__, n_frames, t_end - t_start, (t_end - t_start) / n_frames); + } + + return output; +} + +std::vector mimi_model::transpose_input(const std::vector & codes) { + int n_codes = codes.size(); + int n_codes_per_embd = mimi_config.n_semantic_components + mimi_config.n_acoustic_components; + GGML_ASSERT(n_codes % n_codes_per_embd == 0 && "number of codes must be a multiply of n_codes_per_embd"); + + std::vector codes_T(n_codes); + for (int i = 0; i < n_codes / n_codes_per_embd; i++) { + for (int j = 0; j < n_codes_per_embd; j++) { + int src_idx = i * n_codes_per_embd + j; + int dst_idx = j * (n_codes / n_codes_per_embd) + i; + codes_T[dst_idx] = codes[src_idx]; + } + } + + return codes_T; +} + +int mimi_model::get_sample_rate() const { + return mimi_config.sample_rate; +} diff --git a/examples/tts/mimi-model.h b/examples/tts/mimi-model.h new file mode 100644 index 0000000000000..eb5eb46c22807 --- /dev/null +++ b/examples/tts/mimi-model.h @@ -0,0 +1,39 @@ +#pragma once + +#include "ggml.h" +#include +#include + +struct mimi_ggml_ctx; +struct mimi_encoder_decoder; +struct mimi_transformer; +struct mimi_residual_vector_quantizer; + +struct mimi_model { + bool verbose = false; + std::unique_ptr ctx; + + std::unique_ptr seanet_dec; + std::unique_ptr transformer_dec; + std::unique_ptr quantizer; + + mimi_model(const char * fname, bool verbose = false); + ~mimi_model(); + + int get_sample_rate() const; + + // layout of codes: (1 semantic code followed by 31 acoustic codes) repeast N times + std::vector decode(const std::vector & codes); + + // TODO: implement encoding pass + // std::vector encode(const std::vector & wav_data); + +private: + std::vector decode_frame(const std::vector & codes, int & n_past); + + // transpose layout (from streaming layout to non-streaming): + // - from: (1 semantic code followed by 31 acoustic codes) repeast N times + // - to: N semantic codes followed by (N*31) acoustic codes + // streaming layout is 1-31, 1-31, 1-31, ..., used for real-time processing + static std::vector transpose_input(const std::vector & codes); +}; diff --git a/examples/tts/mimi.cpp b/examples/tts/mimi.cpp new file mode 100644 index 0000000000000..a50bd44f599a9 --- /dev/null +++ b/examples/tts/mimi.cpp @@ -0,0 +1,113 @@ +#include "common.h" +#include "mimi-model.h" + +#include +#include +#include // strcmp + + +/** + * This file is used for testing and showcase how to use "mimi_model" class. + * Please keep it simple and easy to understand. + */ + +int main(int argc, const char ** argv) { + if (argc < 3) { + fprintf(stderr, "Usage: %s model.gguf codes.txt [output.wav]\n", argv[0]); + fprintf(stderr, " Format of codes.txt file: one code per line\n"); + fprintf(stderr, " Replace codes.txt with dummy0 and dummy1 for testing\n"); + fprintf(stderr, " dummy0: using code 1, 2, 3,..., 96, used for logits matching\n"); + fprintf(stderr, " dummy1: using code that will outputs 'wah hello there' sound\n"); + return 1; + } + + const char * model_path = argv[1]; + const char * codes_path = argv[2]; + const char * out_path = argc < 4 ? "output.wav" : argv[3]; + + // load codes + std::vector codes; + if (strcmp(codes_path, "dummy0") == 0) { + printf("Using dummy0 codes\n"); + codes.resize(32 * 3); // [n_codes_per_embd = 32, n_codes = 3] + for (int i = 0; i < (int)codes.size(); i++) { + codes[i] = i; + } + } else if (strcmp(codes_path, "dummy1") == 0) { + printf("Using dummy1 codes\n"); + codes = { + 1049 ,1597 ,1325 ,839 ,592 ,1440 ,1341 ,985 ,1239 ,1146 ,1778 ,1636 ,1485 ,1622 ,757 ,480 , + 1899 ,1481 ,840 ,1397 ,82 ,1565 ,116 ,1449 ,1038 ,1015 ,436 ,150 ,159 ,1414 ,1740 ,1971 , + 1415 ,175 ,1539 ,776 ,1046 ,117 ,803 ,1499 ,1457 ,1307 ,2 ,1135 ,1287 ,1039 ,1124 ,716 , + 1798 ,201 ,1517 ,1299 ,886 ,1786 ,521 ,353 ,1912 ,1357 ,1311 ,450 ,297 ,971 ,1154 ,1729 , + 1962 ,1280 ,1943 ,878 ,1588 ,723 ,568 ,1736 ,1021 ,983 ,10 ,833 ,973 ,1209 ,1091 ,681 , + 1606 ,779 ,334 ,765 ,1836 ,1400 ,150 ,877 ,464 ,1487 ,870 ,1114 ,1703 ,476 ,1839 ,666 , + 914 ,1202 ,1601 ,1719 ,1670 ,412 ,568 ,1838 ,341 ,1237 ,1279 ,830 ,1815 ,32 ,1369 ,1686 , + 1307 ,419 ,1143 ,1158 ,325 ,1696 ,1597 ,93 ,795 ,4 ,1032 ,369 ,819 ,1685 ,912 ,282 , + 1372 ,1911 ,141 ,1069 ,1485 ,642 ,1370 ,702 ,284 ,1407 ,999 ,1758 ,314 ,679 ,1061 ,1624 , + 1549 ,430 ,823 ,1809 ,1976 ,232 ,727 ,266 ,747 ,253 ,134 ,267 ,93 ,428 ,731 ,1993 , + 704 ,85 ,257 ,1302 ,1141 ,1717 ,1995 ,1345 ,882 ,1350 ,1549 ,2015 ,2020 ,732 ,415 ,335 , + 1814 ,1451 ,454 ,1299 ,761 ,1736 ,1916 ,1853 ,56 ,1871 ,984 ,1273 ,247 ,1802 ,602 ,1551 , + 1922 ,47 ,564 ,893 ,34 ,131 ,1063 ,1657 ,474 ,1960 ,1049 ,1275 ,424 ,976 ,1217 ,865 , + 114 ,1000 ,725 ,1585 ,359 ,512 ,815 ,1255 ,124 ,933 ,1983 ,1136 ,1366 ,653 ,1064 ,1703 , + 2036 ,692 ,1435 ,2005 ,1465 ,37 ,892 ,511 ,1559 ,1255 ,373 ,1675 ,1085 ,1462 ,1135 ,1356 , + 483 ,156 ,1298 ,1776 ,1136 ,518 ,1826 ,872 ,431 ,215 ,1103 ,1578 ,144 ,1290 ,1508 ,1124 , + 288 ,632 ,876 ,875 ,1156 ,345 ,273 ,1774 ,1923 ,878 ,1355 ,287 ,982 ,805 ,1360 ,1688 , + 958 ,1062 ,1325 ,625 ,1720 ,1895 ,1382 ,1974 ,1868 ,1228 ,1627 ,1063 ,1617 ,614 ,834 ,1628 , + 968 ,251 ,1096 ,908 ,1938 ,112 ,895 ,1787 ,273 ,1979 ,1200 ,744 ,1994 ,402 ,1578 ,307 , + 1919 ,615 ,649 ,1539 ,2036 ,1854 ,653 ,556 ,609 ,633 ,1627 ,1820 ,1428 ,1663 ,1387 ,1725 , + 193 ,1553 ,636 ,586 ,435 ,1979 ,1226 ,945 ,1330 ,1500 ,1466 ,89 ,1563 ,1150 ,1205 ,366 , + 1179 ,1353 ,1737 ,830 ,904 ,1584 ,1596 ,1885 ,855 ,1306 ,414 ,120 ,812 ,1528 ,252 ,107 , + 1139 ,1735 ,61 ,2001 ,753 ,2034 ,354 ,1927 ,1406 ,1939 ,1009 ,430 ,1269 ,170 ,1785 ,541 , + 898 ,414 ,913 ,1563 ,719 ,1393 ,286 ,857 ,1522 ,2024 ,1845 ,779 ,121 ,1344 ,745 ,808 , + 897 ,1577 ,1497 ,186 ,1418 ,1822 ,1726 ,947 ,1782 ,1415 ,75 ,1724 ,1769 ,1529 ,1835 ,1262 , + 834 ,1214 ,685 ,461 ,526 ,1869 ,1373 ,992 ,912 ,1453 ,583 ,652 ,1637 ,798 ,1034 ,1096 , + 897 ,132 ,1010 ,1932 ,277 ,1536 ,1541 ,952 ,19 ,88 ,2042 ,1232 ,1681 ,2013 ,1241 ,1167 , + 1526 ,1487 ,761 ,308 ,1567 ,1702 ,177 ,5 ,1709 ,900 ,1699 ,1266 ,1620 ,1027 ,1102 ,1753 , + 1243 ,471 ,485 ,1765 ,391 ,1281 ,1607 ,1418 ,116 ,1702 ,1725 ,1692 ,1082 ,350 ,14 ,59 , + 386 ,882 ,2010 ,1438 ,145 ,789 ,1397 ,1921 ,1507 ,457 ,1458 ,1929 ,289 ,1305 ,965 ,500 , + 1511 ,433 ,284 ,721 ,1741 ,56 ,615 ,916 ,887 ,1253 ,916 ,535 ,1666 ,1175 ,716 ,269 , + 447 ,32 ,63 ,321 ,1860 ,1986 ,1009 ,1849 ,1062 ,471 ,2018 ,1213 ,1557 ,990 ,696 ,677 , + }; + } else { + std::ifstream fin(codes_path); + if (!fin) { + fprintf(stderr, "Error: cannot open codes file: %s\n", codes_path); + return 1; + } + std::string line; + while (std::getline(fin, line)) { + // Skip empty lines + if (line.empty()) continue; + // TODO: support both comma (with spaces) and new line + try { + int code = std::stoi(line); + codes.push_back(code); + } catch (const std::exception& e) { + fprintf(stderr, "Error parsing code: %s\n", line.c_str()); + return 1; + } + } + if (codes.empty()) { + fprintf(stderr, "Error: no codes found in file: %s\n", codes_path); + return 1; + } + + printf("Loaded %d codes from %s\n", (int)codes.size(), codes_path); + } + + mimi_model model(model_path, true); + std::vector wav_data = model.decode(codes); + + // print first 20 values + printf("Number of output samples: %d\n", (int)wav_data.size()); + printf("First 20 samples:\n"); + for (int i = 0; i < 20; i++) { + printf("%2.4f, ", wav_data[i]); + } + printf("...\n"); + + // write to wav + printf("Writing to %s\n", out_path); + save_wav16(out_path, wav_data, model.get_sample_rate()); +} diff --git a/examples/tts/tts-csm-data.h b/examples/tts/tts-csm-data.h new file mode 100644 index 0000000000000..c3c47ca7ac3a2 --- /dev/null +++ b/examples/tts/tts-csm-data.h @@ -0,0 +1,1513 @@ +#pragma once + +#include + +// https://huggingface.co/spaces/sesame/csm-1b/blob/main/prompts/conversational_a.wav +const char * default_speaker_a_text = "[0]like revising for an exam I'd have to try and like keep up the momentum because I'd start really early I'd be like okay I'm gonna start revising now and then like you're revising for ages and then I just like start losing steam I didn't do that for the exam we had recently to be fair that was a more of a last minute scenario but like yeah I'm trying to like yeah I noticed this yesterday that like Mondays I sort of start the day with this not like a panic but like a"; +std::initializer_list default_speaker_a_codes = { + 1952 ,425 ,59 ,331 ,2022 ,592 ,648 ,917 ,849 ,1427 ,133 ,1238 ,1045 ,897 ,303 ,413 , + 890 ,171 ,1726 ,1991 ,439 ,743 ,1129 ,1343 ,1406 ,493 ,2003 ,1541 ,401 ,662 ,325 ,879 , + 809 ,1597 ,420 ,193 ,780 ,618 ,1643 ,178 ,1151 ,927 ,1995 ,1857 ,1947 ,353 ,577 ,1493 , + 1598 ,407 ,814 ,5 ,1768 ,170 ,294 ,1386 ,972 ,1158 ,336 ,1300 ,1126 ,1687 ,1088 ,330 , + 526 ,1423 ,1324 ,1444 ,2032 ,1765 ,1127 ,736 ,536 ,307 ,572 ,1136 ,672 ,1324 ,750 ,17 , + 1334 ,129 ,1322 ,362 ,76 ,561 ,1773 ,827 ,1861 ,1876 ,1455 ,245 ,2045 ,1872 ,1660 ,1505 , + 1365 ,338 ,1205 ,1503 ,1177 ,1064 ,203 ,1684 ,805 ,1944 ,1661 ,1128 ,1135 ,504 ,133 ,652 , + 120 ,901 ,1821 ,1828 ,1248 ,1131 ,157 ,604 ,938 ,1520 ,884 ,963 ,1306 ,421 ,1214 ,912 , + 1417 ,10 ,1713 ,1128 ,1158 ,360 ,958 ,1912 ,68 ,1677 ,1496 ,1945 ,1596 ,1641 ,1385 ,1097 , + 1961 ,1096 ,421 ,894 ,883 ,1804 ,252 ,1662 ,1180 ,919 ,1706 ,777 ,1562 ,158 ,1638 ,483 , + 371 ,588 ,1890 ,683 ,1573 ,645 ,1331 ,213 ,1822 ,1458 ,27 ,85 ,174 ,250 ,1881 ,255 , + 186 ,1592 ,1951 ,777 ,1466 ,1542 ,183 ,431 ,1173 ,744 ,526 ,1814 ,98 ,997 ,1376 ,1009 , + 728 ,1206 ,762 ,776 ,791 ,487 ,45 ,993 ,2002 ,249 ,544 ,1845 ,662 ,357 ,1760 ,1896 , + 1582 ,1822 ,760 ,1586 ,173 ,163 ,1541 ,1443 ,697 ,975 ,1775 ,1759 ,768 ,61 ,251 ,1620 , + 819 ,852 ,1539 ,691 ,1655 ,420 ,1158 ,1890 ,728 ,569 ,925 ,1092 ,1550 ,1502 ,194 ,1728 , + 1180 ,1393 ,1021 ,1896 ,529 ,408 ,1816 ,1537 ,647 ,1701 ,766 ,1099 ,1442 ,1481 ,1026 ,1770 , + 994 ,520 ,852 ,464 ,44 ,1739 ,1285 ,1143 ,1466 ,1637 ,1980 ,553 ,2037 ,329 ,1464 ,1938 , + 519 ,590 ,1175 ,157 ,398 ,806 ,12 ,1488 ,1565 ,1534 ,1484 ,1712 ,170 ,431 ,1166 ,555 , + 313 ,1423 ,1867 ,76 ,239 ,469 ,159 ,2014 ,323 ,1254 ,601 ,451 ,1014 ,176 ,970 ,1048 , + 229 ,1322 ,536 ,1979 ,376 ,283 ,618 ,2019 ,1702 ,1272 ,1968 ,75 ,1943 ,462 ,251 ,686 , + 1791 ,1005 ,779 ,815 ,1075 ,932 ,1956 ,1206 ,1853 ,1639 ,1568 ,1794 ,274 ,622 ,1633 ,867 , + 21 ,515 ,2041 ,845 ,879 ,198 ,442 ,579 ,1326 ,1734 ,523 ,531 ,197 ,1806 ,821 ,901 , + 2038 ,194 ,424 ,1942 ,625 ,1186 ,139 ,1654 ,1647 ,699 ,1996 ,1992 ,1917 ,1503 ,1818 ,297 , + 1190 ,694 ,638 ,1001 ,1918 ,707 ,291 ,911 ,36 ,501 ,1976 ,761 ,592 ,1994 ,1587 ,672 , + 93 ,322 ,747 ,1016 ,920 ,959 ,529 ,567 ,109 ,69 ,953 ,1381 ,1258 ,2020 ,441 ,38 , + 620 ,194 ,1230 ,1806 ,1737 ,1550 ,2029 ,1518 ,875 ,976 ,952 ,542 ,2040 ,577 ,1946 ,625 , + 82 ,1581 ,167 ,810 ,1380 ,1095 ,1784 ,97 ,1122 ,1335 ,185 ,428 ,83 ,1399 ,1610 ,854 , + 1714 ,1003 ,197 ,2034 ,80 ,1392 ,575 ,1955 ,340 ,604 ,827 ,443 ,1549 ,792 ,1593 ,1750 , + 429 ,1702 ,288 ,1370 ,925 ,1276 ,1954 ,734 ,371 ,1657 ,1707 ,1945 ,1855 ,145 ,1045 ,312 , + 590 ,1189 ,1542 ,1255 ,457 ,1484 ,738 ,731 ,1667 ,1033 ,1058 ,47 ,1061 ,1315 ,866 ,2008 , + 704 ,183 ,201 ,238 ,128 ,1736 ,926 ,1210 ,479 ,1873 ,698 ,1092 ,197 ,1081 ,1837 ,1883 , + 1721 ,806 ,730 ,531 ,1049 ,1428 ,266 ,894 ,499 ,1525 ,1283 ,1520 ,4 ,1291 ,870 ,1674 , + 235 ,301 ,213 ,286 ,1414 ,1570 ,914 ,410 ,55 ,1037 ,1631 ,1689 ,313 ,1012 ,1241 ,1951 , + 1932 ,1531 ,752 ,1727 ,1667 ,694 ,1754 ,2011 ,1645 ,428 ,387 ,291 ,327 ,1961 ,1666 ,418 , + 1339 ,901 ,1147 ,1894 ,811 ,242 ,1302 ,546 ,721 ,62 ,680 ,1439 ,140 ,258 ,1846 ,411 , + 747 ,1981 ,1665 ,58 ,1411 ,1116 ,340 ,874 ,1498 ,1470 ,794 ,741 ,131 ,938 ,783 ,736 , + 2030 ,1947 ,750 ,130 ,744 ,84 ,864 ,1264 ,1114 ,1275 ,244 ,54 ,1003 ,97 ,1002 ,1608 , + 1617 ,1260 ,945 ,894 ,524 ,664 ,59 ,810 ,235 ,1839 ,141 ,1430 ,2018 ,1385 ,220 ,51 , + 646 ,1638 ,505 ,825 ,1177 ,1445 ,1291 ,293 ,779 ,1023 ,337 ,1155 ,1171 ,1379 ,1205 ,214 , + 1557 ,1312 ,684 ,2039 ,1925 ,39 ,1242 ,1928 ,222 ,1987 ,938 ,509 ,1093 ,1172 ,663 ,922 , + 1468 ,266 ,551 ,54 ,212 ,1058 ,389 ,294 ,1396 ,771 ,360 ,1415 ,209 ,11 ,208 ,818 , + 1841 ,1828 ,1293 ,409 ,1058 ,1503 ,1208 ,1593 ,993 ,330 ,1527 ,713 ,1925 ,382 ,780 ,149 , + 75 ,538 ,1999 ,1932 ,800 ,1486 ,1692 ,470 ,2000 ,1661 ,404 ,1638 ,225 ,1780 ,256 ,384 , + 189 ,1987 ,456 ,2034 ,1056 ,1890 ,827 ,406 ,748 ,978 ,1202 ,727 ,227 ,1310 ,1101 ,1045 , + 918 ,1628 ,1599 ,544 ,2000 ,95 ,96 ,1302 ,712 ,257 ,1806 ,1293 ,17 ,1579 ,426 ,432 , + 1832 ,1987 ,1032 ,739 ,613 ,44 ,1881 ,1361 ,1113 ,1700 ,790 ,1582 ,335 ,1837 ,273 ,755 , + 877 ,133 ,984 ,1698 ,361 ,764 ,353 ,1574 ,498 ,791 ,67 ,1572 ,804 ,1875 ,1102 ,91 , + 955 ,773 ,2008 ,693 ,129 ,1523 ,290 ,862 ,1752 ,552 ,1732 ,632 ,1407 ,1230 ,1013 ,2025 , + 854 ,1044 ,1764 ,409 ,190 ,1485 ,125 ,1134 ,538 ,2034 ,1456 ,577 ,990 ,1493 ,1587 ,526 , + 1320 ,480 ,827 ,290 ,1837 ,679 ,99 ,1852 ,866 ,1798 ,163 ,943 ,1806 ,1979 ,31 ,1999 , + 702 ,31 ,1852 ,1072 ,63 ,1550 ,1440 ,999 ,530 ,1493 ,1 ,405 ,1877 ,136 ,1413 ,1525 , + 402 ,8 ,250 ,786 ,304 ,1426 ,1600 ,1852 ,1063 ,215 ,313 ,1269 ,1875 ,490 ,383 ,1117 , + 769 ,1515 ,1535 ,164 ,1019 ,102 ,326 ,1255 ,120 ,1542 ,1996 ,1027 ,1731 ,1430 ,802 ,485 , + 210 ,646 ,1758 ,443 ,1270 ,1953 ,771 ,643 ,699 ,393 ,47 ,1314 ,941 ,1218 ,481 ,764 , + 666 ,243 ,783 ,546 ,267 ,555 ,825 ,2008 ,1210 ,1542 ,1165 ,439 ,1736 ,1204 ,166 ,1942 , + 32 ,646 ,1490 ,1402 ,423 ,1953 ,1353 ,717 ,724 ,847 ,115 ,951 ,1995 ,1688 ,1047 ,1752 , + 448 ,243 ,783 ,290 ,1736 ,1443 ,666 ,1744 ,1210 ,1992 ,1165 ,253 ,1123 ,113 ,166 ,1684 , + 1978 ,1829 ,618 ,1853 ,1255 ,1067 ,1353 ,717 ,724 ,591 ,569 ,1124 ,35 ,97 ,332 ,1752 , + 1850 ,243 ,783 ,290 ,1736 ,1443 ,666 ,1744 ,1210 ,1370 ,1165 ,436 ,1908 ,113 ,644 ,851 , + 1978 ,251 ,1736 ,1853 ,1255 ,1626 ,377 ,1586 ,204 ,591 ,1538 ,951 ,1995 ,1688 ,427 ,1833 , + 1850 ,243 ,783 ,290 ,1736 ,1030 ,666 ,1744 ,739 ,1370 ,1165 ,436 ,1123 ,113 ,144 ,851 , + 1978 ,1829 ,1736 ,1406 ,1255 ,1626 ,1332 ,1586 ,204 ,847 ,1538 ,483 ,35 ,1688 ,1047 ,1833 , + 1850 ,243 ,783 ,290 ,1736 ,1030 ,976 ,2008 ,739 ,1992 ,1165 ,436 ,1908 ,1430 ,644 ,1684 , + 32 ,1419 ,1736 ,1402 ,692 ,1953 ,377 ,717 ,204 ,847 ,1538 ,1388 ,1995 ,1440 ,1047 ,1833 , + 1850 ,243 ,783 ,142 ,481 ,1030 ,666 ,1744 ,739 ,1370 ,1165 ,436 ,1908 ,650 ,644 ,851 , + 1978 ,251 ,1736 ,1853 ,1255 ,1626 ,377 ,1586 ,204 ,591 ,569 ,951 ,1995 ,1688 ,1140 ,1833 , + 1850 ,243 ,783 ,142 ,481 ,1030 ,666 ,1744 ,739 ,1370 ,1165 ,253 ,1908 ,650 ,166 ,1684 , + 1978 ,1419 ,290 ,1402 ,1255 ,1626 ,1353 ,717 ,724 ,591 ,569 ,1388 ,1995 ,97 ,332 ,1833 , + 481 ,243 ,783 ,142 ,481 ,1030 ,666 ,1744 ,739 ,1370 ,1165 ,253 ,1908 ,650 ,802 ,1684 , + 1978 ,1419 ,290 ,1402 ,1255 ,1626 ,1353 ,717 ,204 ,591 ,1538 ,1388 ,1995 ,1440 ,332 ,1752 , + 481 ,243 ,783 ,142 ,481 ,1030 ,976 ,2008 ,739 ,374 ,1165 ,436 ,1736 ,1112 ,644 ,1684 , + 1978 ,251 ,1490 ,1402 ,610 ,1953 ,1353 ,717 ,204 ,591 ,1538 ,951 ,35 ,1440 ,1047 ,1752 , + 481 ,243 ,1178 ,546 ,267 ,555 ,976 ,1648 ,739 ,374 ,1165 ,253 ,1908 ,113 ,166 ,851 , + 1978 ,1829 ,1736 ,1853 ,1255 ,1067 ,1353 ,1774 ,724 ,591 ,569 ,1124 ,35 ,1688 ,332 ,1562 , + 481 ,243 ,1178 ,546 ,267 ,555 ,976 ,1648 ,1210 ,374 ,1165 ,439 ,1912 ,1204 ,144 ,851 , + 32 ,646 ,1490 ,1428 ,692 ,1626 ,1332 ,1774 ,724 ,847 ,1538 ,1124 ,1995 ,1688 ,427 ,1833 , + 481 ,243 ,1178 ,546 ,267 ,1030 ,976 ,2008 ,739 ,374 ,1165 ,436 ,194 ,1430 ,644 ,1684 , + 32 ,1642 ,1736 ,1402 ,1908 ,1626 ,377 ,717 ,204 ,591 ,1538 ,1388 ,422 ,1440 ,427 ,1833 , + 481 ,243 ,1178 ,546 ,267 ,1030 ,976 ,2008 ,739 ,374 ,1165 ,436 ,194 ,1430 ,644 ,1684 , + 1978 ,1642 ,1736 ,1402 ,1908 ,1626 ,377 ,717 ,204 ,591 ,1538 ,483 ,422 ,1440 ,1140 ,1752 , + 481 ,243 ,1178 ,546 ,267 ,1030 ,825 ,2008 ,739 ,1370 ,1165 ,436 ,1101 ,650 ,853 ,610 , + 1978 ,251 ,290 ,1406 ,692 ,1626 ,1497 ,1774 ,724 ,591 ,1538 ,951 ,1995 ,1440 ,332 ,1562 , + 481 ,243 ,1178 ,546 ,267 ,1030 ,825 ,2008 ,739 ,1370 ,1165 ,436 ,1101 ,650 ,853 ,851 , + 1978 ,251 ,290 ,1406 ,692 ,1626 ,1497 ,1586 ,204 ,591 ,1538 ,951 ,422 ,1440 ,332 ,1562 , + 384 ,1211 ,1456 ,417 ,267 ,347 ,666 ,1744 ,1210 ,1370 ,1165 ,760 ,1123 ,1492 ,853 ,851 , + 210 ,1829 ,1490 ,1406 ,1133 ,1067 ,1095 ,1586 ,1423 ,973 ,1841 ,254 ,1995 ,1688 ,1047 ,1562 , + 1826 ,1052 ,658 ,1507 ,73 ,2010 ,1666 ,1273 ,306 ,1500 ,2040 ,730 ,1395 ,1907 ,570 ,1218 , + 1816 ,1681 ,1615 ,909 ,1860 ,1490 ,526 ,1998 ,2029 ,17 ,209 ,912 ,1919 ,2020 ,155 ,1806 , + 481 ,243 ,1697 ,546 ,481 ,555 ,1871 ,1648 ,1972 ,1992 ,1028 ,253 ,1123 ,1430 ,644 ,851 , + 200 ,1829 ,1490 ,1397 ,692 ,1067 ,377 ,717 ,204 ,591 ,1538 ,1388 ,1995 ,1688 ,332 ,1562 , + 481 ,243 ,1178 ,1348 ,1335 ,1572 ,976 ,2008 ,739 ,1542 ,1165 ,436 ,1908 ,1112 ,644 ,851 , + 1978 ,1419 ,1736 ,1402 ,692 ,1897 ,377 ,1586 ,1423 ,847 ,115 ,1388 ,35 ,1688 ,1140 ,1752 , + 481 ,243 ,1178 ,1348 ,1335 ,1443 ,976 ,2008 ,1210 ,1992 ,1165 ,436 ,194 ,1112 ,644 ,851 , + 1978 ,1829 ,1490 ,1402 ,610 ,1897 ,377 ,1586 ,724 ,591 ,1538 ,1124 ,1995 ,1688 ,1140 ,1833 , + 384 ,991 ,1686 ,1709 ,568 ,1356 ,1871 ,1868 ,322 ,1546 ,675 ,1439 ,1700 ,839 ,148 ,465 , + 435 ,271 ,63 ,1314 ,65 ,992 ,1201 ,641 ,1033 ,1325 ,7 ,1792 ,369 ,473 ,271 ,1549 , + 1738 ,1521 ,146 ,1846 ,56 ,457 ,1658 ,1739 ,1379 ,2028 ,937 ,1457 ,712 ,345 ,1877 ,5 , + 386 ,613 ,1007 ,686 ,2030 ,1093 ,107 ,722 ,1476 ,125 ,1068 ,201 ,207 ,1234 ,159 ,128 , + 522 ,1511 ,742 ,405 ,547 ,1176 ,546 ,1078 ,464 ,1834 ,1400 ,487 ,1703 ,921 ,148 ,1587 , + 382 ,166 ,1972 ,1540 ,1375 ,1785 ,789 ,83 ,983 ,1138 ,1484 ,1347 ,437 ,367 ,744 ,1370 , + 785 ,1190 ,1614 ,1453 ,1715 ,1975 ,1246 ,1068 ,990 ,1216 ,1669 ,1892 ,117 ,491 ,938 ,542 , + 1969 ,148 ,0 ,704 ,1035 ,790 ,1274 ,1828 ,445 ,1530 ,703 ,1656 ,530 ,1749 ,1322 ,1485 , + 354 ,1854 ,110 ,1445 ,1526 ,1262 ,64 ,278 ,1474 ,1239 ,1986 ,1345 ,1177 ,286 ,382 ,171 , + 464 ,1428 ,1722 ,347 ,864 ,61 ,602 ,2033 ,1684 ,561 ,348 ,1535 ,1728 ,1179 ,416 ,1411 , + 521 ,344 ,62 ,1606 ,1473 ,1163 ,29 ,885 ,906 ,573 ,1032 ,1870 ,300 ,924 ,852 ,55 , + 587 ,1673 ,904 ,495 ,1585 ,1804 ,1294 ,1133 ,561 ,1089 ,1175 ,1075 ,1117 ,1365 ,137 ,1124 , + 521 ,749 ,1590 ,1947 ,1602 ,302 ,1109 ,1610 ,441 ,613 ,680 ,213 ,1584 ,1909 ,1520 ,276 , + 461 ,493 ,1934 ,346 ,780 ,201 ,564 ,1350 ,1494 ,892 ,616 ,975 ,585 ,802 ,1508 ,1302 , + 1686 ,1976 ,349 ,1393 ,825 ,368 ,99 ,798 ,1384 ,472 ,546 ,442 ,1709 ,1021 ,418 ,932 , + 1264 ,541 ,1769 ,1987 ,1229 ,1007 ,896 ,1120 ,327 ,544 ,579 ,1758 ,1150 ,1103 ,329 ,1955 , + 1548 ,578 ,1879 ,862 ,509 ,1158 ,1278 ,1200 ,937 ,145 ,766 ,1907 ,83 ,1903 ,1683 ,691 , + 65 ,1096 ,769 ,737 ,1146 ,819 ,1617 ,1650 ,636 ,1535 ,707 ,419 ,214 ,661 ,1215 ,808 , + 1548 ,1351 ,769 ,1461 ,1823 ,156 ,890 ,526 ,1694 ,392 ,36 ,845 ,658 ,1336 ,597 ,1807 , + 1597 ,1173 ,1225 ,1225 ,274 ,2035 ,1087 ,2039 ,896 ,846 ,592 ,415 ,688 ,1522 ,1222 ,1728 , + 1109 ,1398 ,1764 ,1826 ,1034 ,2023 ,914 ,1239 ,1534 ,1342 ,197 ,830 ,723 ,854 ,2011 ,1132 , + 272 ,315 ,744 ,145 ,1838 ,791 ,162 ,757 ,1749 ,1110 ,267 ,781 ,532 ,1187 ,869 ,192 , + 29 ,740 ,1051 ,1626 ,432 ,1966 ,725 ,396 ,1048 ,512 ,418 ,1787 ,1838 ,990 ,1205 ,1464 , + 947 ,525 ,1303 ,1325 ,624 ,1697 ,438 ,951 ,757 ,1125 ,390 ,177 ,1343 ,1273 ,746 ,834 , + 268 ,1190 ,1562 ,284 ,473 ,955 ,895 ,1553 ,747 ,1339 ,890 ,1804 ,1300 ,1537 ,201 ,166 , + 1040 ,84 ,1872 ,631 ,331 ,353 ,22 ,1982 ,576 ,162 ,84 ,1097 ,1067 ,752 ,463 ,1609 , + 1558 ,740 ,1916 ,2015 ,1906 ,201 ,1110 ,1708 ,853 ,675 ,357 ,1727 ,938 ,986 ,2016 ,509 , + 1385 ,1985 ,1948 ,1347 ,1297 ,390 ,1344 ,1199 ,1208 ,566 ,258 ,450 ,1599 ,53 ,100 ,806 , + 199 ,1054 ,1544 ,1716 ,696 ,1983 ,835 ,1281 ,538 ,1199 ,203 ,765 ,1961 ,611 ,546 ,396 , + 256 ,382 ,647 ,419 ,1370 ,800 ,1614 ,825 ,1040 ,264 ,514 ,1901 ,1713 ,1273 ,860 ,1656 , + 1912 ,1879 ,1037 ,1604 ,577 ,1507 ,1170 ,1010 ,1375 ,892 ,1242 ,1843 ,1286 ,1041 ,1503 ,1215 , + 1395 ,648 ,2044 ,995 ,1372 ,474 ,310 ,517 ,1278 ,743 ,1903 ,469 ,1985 ,1855 ,9 ,2015 , + 533 ,497 ,1455 ,46 ,1568 ,432 ,1524 ,1735 ,1274 ,349 ,1250 ,73 ,405 ,1600 ,783 ,509 , + 1385 ,228 ,793 ,768 ,827 ,39 ,442 ,310 ,2044 ,1561 ,1861 ,88 ,1598 ,1385 ,1949 ,1337 , + 1756 ,1727 ,1501 ,985 ,647 ,2044 ,1974 ,195 ,853 ,1731 ,681 ,1854 ,556 ,775 ,613 ,1765 , + 649 ,1481 ,1288 ,1858 ,1442 ,1623 ,785 ,270 ,579 ,1325 ,420 ,1564 ,20 ,1643 ,822 ,639 , + 833 ,1202 ,1645 ,519 ,1386 ,1247 ,909 ,644 ,871 ,1193 ,1692 ,542 ,1131 ,507 ,1301 ,1654 , + 612 ,887 ,1246 ,1246 ,1937 ,1365 ,168 ,913 ,1788 ,1473 ,1986 ,1357 ,736 ,220 ,1946 ,1171 , + 929 ,1636 ,448 ,1565 ,1333 ,1593 ,647 ,43 ,1099 ,1679 ,1065 ,632 ,652 ,993 ,1342 ,1186 , + 785 ,1992 ,260 ,1311 ,662 ,1490 ,1879 ,1475 ,1661 ,1946 ,1880 ,372 ,790 ,446 ,1367 ,989 , + 1141 ,185 ,277 ,698 ,476 ,1177 ,1597 ,1519 ,1553 ,1254 ,1975 ,374 ,1943 ,606 ,2046 ,930 , + 1566 ,1510 ,1451 ,512 ,740 ,1829 ,1114 ,1968 ,1644 ,1150 ,1827 ,910 ,1448 ,1339 ,381 ,422 , + 215 ,1603 ,344 ,1162 ,294 ,1511 ,316 ,671 ,531 ,827 ,1211 ,1217 ,1684 ,1161 ,1370 ,111 , + 139 ,101 ,521 ,1984 ,1714 ,452 ,1177 ,634 ,319 ,122 ,618 ,2030 ,2041 ,769 ,862 ,1237 , + 1929 ,1867 ,1878 ,1754 ,1686 ,1239 ,1529 ,663 ,1061 ,1095 ,633 ,1998 ,157 ,1838 ,297 ,1001 , + 1887 ,1890 ,591 ,1110 ,754 ,1273 ,481 ,245 ,1587 ,1087 ,1964 ,1011 ,615 ,148 ,967 ,1451 , + 1384 ,870 ,76 ,874 ,1455 ,1587 ,190 ,782 ,801 ,164 ,696 ,1228 ,990 ,862 ,1825 ,1928 , + 1956 ,182 ,137 ,1592 ,1108 ,442 ,513 ,2027 ,1828 ,1302 ,1066 ,1626 ,557 ,1655 ,604 ,2041 , + 1924 ,504 ,78 ,1597 ,956 ,1019 ,1744 ,1340 ,738 ,1903 ,1582 ,1002 ,534 ,413 ,1397 ,655 , + 294 ,728 ,1240 ,1992 ,1557 ,769 ,178 ,1518 ,680 ,232 ,850 ,1483 ,340 ,875 ,73 ,216 , + 1915 ,332 ,1280 ,1530 ,920 ,146 ,1895 ,18 ,81 ,1895 ,779 ,1564 ,953 ,399 ,1627 ,291 , + 109 ,453 ,101 ,611 ,613 ,1660 ,952 ,1386 ,1926 ,623 ,270 ,242 ,506 ,892 ,391 ,712 , + 1384 ,172 ,912 ,916 ,921 ,1077 ,1528 ,379 ,960 ,293 ,330 ,1805 ,451 ,1362 ,1596 ,1033 , + 427 ,1210 ,637 ,1788 ,1426 ,1896 ,2015 ,693 ,544 ,1538 ,416 ,1137 ,668 ,1310 ,1456 ,1092 , + 964 ,846 ,556 ,828 ,573 ,1096 ,761 ,1075 ,19 ,2025 ,1598 ,791 ,1725 ,234 ,1204 ,680 , + 657 ,1615 ,523 ,1362 ,1299 ,1405 ,217 ,575 ,994 ,1090 ,195 ,1537 ,1234 ,1880 ,172 ,1574 , + 1559 ,1440 ,574 ,607 ,1574 ,1894 ,1998 ,1508 ,39 ,577 ,1388 ,838 ,1074 ,1493 ,627 ,1742 , + 854 ,1142 ,267 ,130 ,45 ,169 ,1036 ,818 ,875 ,1157 ,1701 ,1420 ,455 ,283 ,1937 ,1722 , + 547 ,1312 ,370 ,917 ,1441 ,607 ,125 ,1828 ,1106 ,391 ,356 ,1233 ,1507 ,1084 ,1019 ,659 , + 1324 ,1706 ,1749 ,1767 ,73 ,1006 ,1293 ,627 ,590 ,30 ,1363 ,764 ,630 ,583 ,1484 ,1418 , + 1862 ,2019 ,1481 ,90 ,1822 ,1623 ,1836 ,311 ,506 ,1204 ,1973 ,1280 ,1057 ,557 ,1743 ,1994 , + 44 ,1818 ,1313 ,885 ,862 ,1200 ,887 ,1641 ,1921 ,277 ,1347 ,521 ,1269 ,166 ,388 ,993 , + 1221 ,752 ,963 ,2015 ,1529 ,691 ,783 ,1125 ,55 ,1257 ,190 ,1968 ,1962 ,1225 ,1593 ,335 , + 301 ,362 ,1102 ,112 ,48 ,1359 ,1437 ,924 ,1210 ,1581 ,1147 ,717 ,206 ,655 ,1247 ,1352 , + 496 ,1527 ,1037 ,1258 ,1296 ,1999 ,1840 ,1352 ,578 ,484 ,1736 ,1105 ,914 ,781 ,934 ,7 , + 1894 ,804 ,1197 ,1321 ,1546 ,180 ,1713 ,871 ,1467 ,698 ,1142 ,1179 ,1174 ,1812 ,942 ,1277 , + 1030 ,200 ,856 ,941 ,169 ,1680 ,969 ,227 ,229 ,831 ,1665 ,175 ,992 ,2020 ,754 ,1541 , + 275 ,1187 ,1155 ,237 ,580 ,2008 ,304 ,784 ,890 ,1243 ,1498 ,583 ,1694 ,1205 ,772 ,265 , + 1225 ,380 ,1464 ,1249 ,1779 ,308 ,567 ,1364 ,397 ,252 ,197 ,1787 ,468 ,460 ,1781 ,386 , + 1024 ,926 ,1262 ,1108 ,618 ,839 ,839 ,1234 ,1257 ,1669 ,392 ,965 ,1161 ,810 ,832 ,803 , + 93 ,386 ,1252 ,1260 ,1866 ,1975 ,517 ,171 ,1144 ,1570 ,1158 ,1590 ,1761 ,544 ,839 ,1626 , + 1839 ,1232 ,616 ,2 ,743 ,1646 ,698 ,852 ,953 ,88 ,1712 ,295 ,257 ,1832 ,1863 ,2008 , + 1765 ,1729 ,214 ,112 ,1012 ,589 ,815 ,141 ,1683 ,256 ,1647 ,1952 ,364 ,1243 ,1571 ,1208 , + 1353 ,1485 ,1199 ,1896 ,1676 ,1931 ,1720 ,1340 ,7 ,910 ,1686 ,467 ,90 ,1837 ,1015 ,1858 , + 1127 ,559 ,1604 ,726 ,1465 ,1543 ,1861 ,1644 ,382 ,1641 ,1130 ,1451 ,173 ,474 ,1628 ,1415 , + 1128 ,912 ,1167 ,1433 ,2033 ,511 ,1410 ,571 ,171 ,315 ,1533 ,769 ,262 ,1544 ,630 ,244 , + 632 ,501 ,910 ,1315 ,913 ,1150 ,719 ,237 ,1678 ,282 ,320 ,245 ,1557 ,1053 ,831 ,1366 , + 2008 ,488 ,1343 ,191 ,2029 ,193 ,1358 ,248 ,1699 ,637 ,1034 ,196 ,347 ,688 ,1502 ,380 , + 728 ,872 ,713 ,1871 ,1165 ,1017 ,397 ,1567 ,332 ,616 ,19 ,1792 ,978 ,1123 ,1397 ,537 , + 1172 ,694 ,1705 ,1723 ,1046 ,593 ,780 ,2002 ,725 ,115 ,1419 ,730 ,485 ,678 ,57 ,938 , + 389 ,1287 ,1313 ,1918 ,43 ,668 ,1878 ,1728 ,1786 ,1987 ,1874 ,1863 ,1236 ,1124 ,1726 ,337 , + 1596 ,1870 ,1547 ,1780 ,151 ,185 ,1456 ,1093 ,1603 ,1534 ,1096 ,1317 ,1206 ,1081 ,1300 ,315 , + 103 ,110 ,1042 ,79 ,1822 ,285 ,633 ,1763 ,875 ,172 ,1604 ,1013 ,1829 ,1551 ,314 ,750 , + 1352 ,1139 ,202 ,1432 ,1649 ,938 ,1037 ,906 ,1252 ,1359 ,586 ,1861 ,1295 ,1376 ,1904 ,1164 , + 524 ,1398 ,469 ,194 ,2019 ,811 ,1221 ,1520 ,815 ,1369 ,1099 ,1285 ,492 ,152 ,1289 ,1742 , + 533 ,1029 ,1592 ,560 ,116 ,852 ,268 ,2029 ,1932 ,423 ,1277 ,721 ,544 ,347 ,1534 ,933 , + 1222 ,1983 ,170 ,1511 ,1239 ,1792 ,846 ,1854 ,1876 ,1410 ,1989 ,1884 ,1629 ,894 ,1185 ,1567 , + 252 ,773 ,632 ,1794 ,109 ,1804 ,976 ,758 ,417 ,1529 ,676 ,203 ,1522 ,771 ,1777 ,131 , + 495 ,1373 ,1645 ,2016 ,543 ,1695 ,1171 ,1895 ,994 ,1987 ,296 ,418 ,1194 ,1189 ,1595 ,1801 , + 1334 ,773 ,762 ,434 ,1368 ,1249 ,1738 ,1546 ,1939 ,1019 ,550 ,531 ,1552 ,1362 ,323 ,316 , + 400 ,1961 ,766 ,1201 ,875 ,2028 ,211 ,111 ,508 ,758 ,598 ,906 ,63 ,681 ,42 ,1988 , + 1732 ,1184 ,1270 ,1490 ,202 ,692 ,1961 ,1057 ,852 ,978 ,894 ,1082 ,1048 ,888 ,889 ,1047 , + 860 ,254 ,1833 ,19 ,38 ,896 ,14 ,1245 ,2028 ,416 ,886 ,213 ,1617 ,807 ,442 ,1422 , + 1899 ,667 ,595 ,111 ,79 ,1161 ,938 ,1020 ,603 ,1527 ,1402 ,1747 ,2022 ,1376 ,735 ,418 , + 1140 ,1785 ,338 ,1633 ,1881 ,1556 ,916 ,84 ,1378 ,1147 ,1462 ,1415 ,1829 ,726 ,1436 ,645 , + 1552 ,1459 ,1719 ,1535 ,892 ,1933 ,1163 ,672 ,1203 ,1231 ,1503 ,772 ,1272 ,1918 ,107 ,2036 , + 1367 ,968 ,1989 ,888 ,2019 ,1376 ,1767 ,2025 ,368 ,29 ,1358 ,952 ,1348 ,116 ,1002 ,65 , + 1970 ,1522 ,1784 ,523 ,173 ,1765 ,904 ,1572 ,432 ,71 ,1460 ,1278 ,347 ,300 ,502 ,136 , + 317 ,902 ,1669 ,1738 ,777 ,1076 ,1441 ,553 ,949 ,1906 ,622 ,1409 ,285 ,1081 ,1125 ,256 , + 1467 ,1165 ,390 ,171 ,109 ,1342 ,421 ,856 ,1616 ,597 ,787 ,1375 ,1070 ,903 ,1264 ,230 , + 317 ,856 ,130 ,677 ,216 ,212 ,211 ,49 ,732 ,1883 ,2015 ,1564 ,1278 ,1340 ,621 ,79 , + 624 ,1117 ,1087 ,1876 ,1489 ,711 ,1089 ,1912 ,191 ,1510 ,171 ,526 ,1420 ,136 ,848 ,1586 , + 877 ,376 ,1865 ,1875 ,1401 ,1032 ,973 ,736 ,1559 ,1067 ,2026 ,347 ,1074 ,143 ,656 ,1912 , + 100 ,25 ,959 ,813 ,1115 ,1534 ,986 ,1154 ,426 ,1305 ,1600 ,1228 ,416 ,763 ,534 ,2004 , + 854 ,55 ,1523 ,1290 ,311 ,1032 ,542 ,1398 ,1660 ,1427 ,2043 ,815 ,118 ,1515 ,163 ,907 , + 1511 ,439 ,224 ,1569 ,327 ,370 ,1662 ,454 ,155 ,234 ,1153 ,1461 ,1599 ,1905 ,1922 ,1973 , + 702 ,1540 ,183 ,1071 ,291 ,1431 ,1506 ,1567 ,1214 ,883 ,1991 ,1544 ,234 ,1657 ,885 ,1211 , + 1471 ,763 ,418 ,1021 ,1928 ,745 ,1507 ,507 ,1826 ,858 ,650 ,1589 ,459 ,221 ,1168 ,1879 , + 34 ,1700 ,1178 ,97 ,1019 ,555 ,666 ,1744 ,1210 ,1542 ,415 ,436 ,1101 ,1430 ,853 ,1942 , + 200 ,251 ,1490 ,1402 ,1908 ,1626 ,1353 ,717 ,204 ,591 ,47 ,1388 ,687 ,1440 ,1140 ,1833 , + 666 ,243 ,783 ,142 ,481 ,555 ,666 ,1648 ,1210 ,1542 ,1165 ,253 ,1912 ,650 ,166 ,851 , + 1978 ,1419 ,290 ,1853 ,1255 ,1626 ,1353 ,1586 ,724 ,847 ,1538 ,951 ,1995 ,97 ,332 ,1752 , + 448 ,243 ,783 ,142 ,481 ,1030 ,666 ,2008 ,739 ,1370 ,1165 ,1383 ,1908 ,650 ,853 ,851 , + 32 ,646 ,290 ,1428 ,692 ,1897 ,1497 ,1586 ,204 ,973 ,1538 ,951 ,1995 ,14 ,1047 ,1752 , + 84 ,243 ,783 ,142 ,481 ,1030 ,666 ,2008 ,739 ,1370 ,1165 ,436 ,194 ,650 ,144 ,1684 , + 1978 ,646 ,290 ,1402 ,1908 ,1626 ,1332 ,717 ,724 ,591 ,47 ,483 ,422 ,1440 ,1047 ,1833 , + 84 ,243 ,783 ,142 ,481 ,1030 ,976 ,2008 ,739 ,1992 ,1165 ,436 ,1101 ,650 ,644 ,610 , + 1978 ,251 ,290 ,1853 ,692 ,1897 ,1353 ,1774 ,724 ,591 ,1538 ,1388 ,35 ,1440 ,332 ,1833 , + 7 ,203 ,265 ,290 ,306 ,93 ,104 ,583 ,1938 ,278 ,618 ,1040 ,321 ,1213 ,166 ,1732 , + 959 ,271 ,1531 ,172 ,1133 ,1680 ,359 ,1509 ,1110 ,1591 ,260 ,254 ,1334 ,2023 ,911 ,1752 , + 739 ,1068 ,811 ,1473 ,1141 ,301 ,1784 ,1374 ,791 ,1505 ,402 ,1444 ,1321 ,1625 ,397 ,1711 , + 653 ,514 ,1779 ,1949 ,1648 ,998 ,289 ,1555 ,1342 ,1723 ,54 ,1238 ,1654 ,1538 ,798 ,823 , + 739 ,1274 ,682 ,460 ,1631 ,120 ,411 ,1277 ,761 ,1117 ,2030 ,1587 ,1961 ,1468 ,1538 ,772 , + 1306 ,1725 ,828 ,419 ,362 ,981 ,1583 ,1843 ,2024 ,1650 ,306 ,1062 ,1913 ,650 ,1441 ,1040 , + 1417 ,774 ,1530 ,1086 ,1018 ,1496 ,1015 ,885 ,142 ,870 ,1121 ,1829 ,1907 ,1089 ,403 ,1411 , + 23 ,770 ,1480 ,553 ,1711 ,530 ,1905 ,860 ,1972 ,736 ,582 ,119 ,1965 ,1941 ,1724 ,1425 , + 728 ,580 ,1209 ,454 ,990 ,1507 ,1411 ,824 ,1306 ,407 ,1630 ,1968 ,735 ,848 ,574 ,1851 , + 2002 ,1186 ,1661 ,132 ,1082 ,217 ,1619 ,761 ,1465 ,1416 ,1146 ,88 ,1191 ,1555 ,236 ,1506 , + 1353 ,1748 ,1434 ,563 ,837 ,1612 ,1514 ,481 ,1272 ,770 ,99 ,988 ,1413 ,1560 ,273 ,1656 , + 642 ,1759 ,30 ,1163 ,629 ,1705 ,297 ,1732 ,1467 ,802 ,1138 ,701 ,570 ,1466 ,330 ,1435 , + 1608 ,1945 ,407 ,1259 ,1545 ,1828 ,486 ,1851 ,675 ,1515 ,1664 ,1395 ,700 ,1054 ,938 ,1903 , + 1566 ,668 ,1663 ,70 ,409 ,1363 ,108 ,525 ,1986 ,1474 ,1211 ,1952 ,1175 ,1419 ,1710 ,574 , + 1516 ,1527 ,95 ,664 ,1029 ,439 ,1716 ,1333 ,815 ,26 ,867 ,1269 ,730 ,429 ,509 ,1977 , + 1618 ,1651 ,328 ,1499 ,1037 ,618 ,202 ,979 ,1952 ,536 ,1322 ,1041 ,1649 ,1279 ,2011 ,290 , + 1230 ,392 ,936 ,598 ,597 ,1628 ,1904 ,603 ,761 ,804 ,839 ,461 ,1729 ,781 ,1938 ,2017 , + 141 ,370 ,827 ,1623 ,545 ,266 ,484 ,1926 ,352 ,493 ,70 ,847 ,1864 ,707 ,1430 ,1552 , + 1178 ,1503 ,1090 ,1938 ,862 ,1763 ,224 ,1012 ,1167 ,1395 ,877 ,688 ,837 ,1044 ,601 ,1031 , + 1542 ,665 ,859 ,1707 ,113 ,1694 ,2021 ,575 ,1217 ,112 ,483 ,52 ,1861 ,2036 ,744 ,97 , + 1451 ,867 ,647 ,454 ,1480 ,1956 ,981 ,1288 ,996 ,1393 ,595 ,1575 ,1870 ,891 ,673 ,385 , + 1411 ,756 ,929 ,765 ,1897 ,1085 ,1124 ,1363 ,1561 ,1627 ,474 ,875 ,1925 ,422 ,741 ,1119 , + 819 ,1354 ,1492 ,921 ,1041 ,469 ,641 ,1532 ,180 ,1157 ,1381 ,1620 ,2024 ,895 ,495 ,1820 , + 1903 ,780 ,1415 ,1646 ,71 ,1933 ,967 ,1773 ,253 ,1305 ,1042 ,1342 ,1521 ,1392 ,1045 ,649 , + 1497 ,710 ,1169 ,1064 ,1509 ,1987 ,468 ,1292 ,664 ,773 ,78 ,578 ,2029 ,497 ,53 ,394 , + 1992 ,1709 ,767 ,1202 ,1054 ,388 ,2007 ,1772 ,815 ,1081 ,1141 ,30 ,1641 ,1316 ,1647 ,311 , + 576 ,694 ,1578 ,1418 ,1323 ,706 ,2013 ,663 ,83 ,268 ,1359 ,1912 ,1004 ,235 ,345 ,420 , + 900 ,429 ,1301 ,1615 ,1812 ,1187 ,1625 ,1571 ,105 ,1466 ,765 ,2013 ,1506 ,1295 ,1171 ,730 , + 872 ,1446 ,1076 ,1145 ,528 ,480 ,736 ,1663 ,1649 ,1419 ,1808 ,851 ,1075 ,1931 ,392 ,1646 , + 1570 ,736 ,122 ,1580 ,702 ,2014 ,382 ,1434 ,974 ,1679 ,876 ,167 ,338 ,334 ,594 ,1614 , + 872 ,20 ,302 ,2044 ,1376 ,1213 ,1698 ,278 ,1035 ,128 ,669 ,1123 ,479 ,282 ,512 ,530 , + 1260 ,1469 ,1804 ,228 ,751 ,1773 ,1677 ,498 ,567 ,1510 ,468 ,1820 ,1041 ,707 ,1683 ,784 , + 1678 ,1453 ,2026 ,1451 ,972 ,755 ,1569 ,1559 ,1864 ,973 ,823 ,405 ,901 ,874 ,1689 ,770 , + 1855 ,1120 ,1148 ,321 ,701 ,1488 ,801 ,1365 ,1108 ,241 ,761 ,1985 ,34 ,479 ,252 ,1008 , + 1149 ,148 ,1025 ,529 ,616 ,1007 ,1589 ,1200 ,1676 ,1678 ,146 ,931 ,353 ,346 ,1642 ,185 , + 1985 ,1232 ,1969 ,1091 ,16 ,1097 ,526 ,1054 ,1387 ,1317 ,1385 ,95 ,1467 ,2043 ,421 ,1218 , + 1149 ,2010 ,794 ,67 ,811 ,1644 ,1735 ,1834 ,1151 ,1839 ,487 ,520 ,298 ,329 ,617 ,1728 , + 823 ,150 ,1012 ,1749 ,691 ,422 ,1914 ,240 ,1692 ,1792 ,742 ,634 ,1977 ,1804 ,1973 ,851 , + 390 ,1945 ,228 ,871 ,595 ,964 ,796 ,206 ,829 ,1145 ,973 ,1777 ,1556 ,1082 ,1282 ,1296 , + 1031 ,441 ,751 ,2004 ,1176 ,800 ,1411 ,906 ,2 ,1755 ,1381 ,282 ,97 ,1981 ,458 ,1495 , + 802 ,440 ,642 ,1586 ,573 ,116 ,1324 ,612 ,1029 ,1266 ,460 ,489 ,901 ,79 ,1563 ,758 , + 1639 ,1009 ,1293 ,1894 ,1643 ,1608 ,34 ,438 ,640 ,1629 ,766 ,1189 ,693 ,1647 ,1222 ,1864 , + 93 ,629 ,2021 ,370 ,1423 ,363 ,343 ,1294 ,570 ,258 ,823 ,1404 ,1937 ,232 ,477 ,715 , + 1429 ,287 ,584 ,592 ,274 ,1949 ,1420 ,501 ,1308 ,261 ,1778 ,49 ,94 ,709 ,1965 ,1581 , + 1960 ,1541 ,1068 ,188 ,1387 ,362 ,1892 ,1778 ,38 ,1007 ,31 ,151 ,355 ,1823 ,693 ,1917 , + 364 ,945 ,1886 ,37 ,1377 ,995 ,54 ,237 ,787 ,277 ,840 ,1526 ,1560 ,1744 ,395 ,754 , + 1338 ,788 ,1158 ,629 ,2038 ,865 ,667 ,234 ,687 ,1739 ,1811 ,1406 ,1252 ,688 ,1642 ,1457 , + 214 ,1151 ,1916 ,1581 ,1221 ,311 ,1347 ,152 ,1303 ,1815 ,705 ,16 ,1274 ,241 ,153 ,1048 , + 1794 ,1908 ,256 ,1942 ,893 ,11 ,271 ,1115 ,1106 ,554 ,316 ,990 ,1081 ,411 ,95 ,1407 , + 758 ,1523 ,77 ,1962 ,281 ,1871 ,1945 ,1929 ,81 ,797 ,1076 ,1467 ,37 ,790 ,1412 ,1442 , + 740 ,1153 ,533 ,1029 ,1453 ,1697 ,202 ,1052 ,1447 ,2028 ,1040 ,1372 ,1149 ,565 ,1551 ,1511 , + 1300 ,1292 ,292 ,333 ,893 ,1869 ,1761 ,2022 ,2017 ,1501 ,693 ,1647 ,110 ,1241 ,135 ,425 , + 1453 ,416 ,225 ,563 ,171 ,1386 ,1518 ,1330 ,759 ,1170 ,651 ,1037 ,20 ,288 ,843 ,472 , + 1378 ,1067 ,1466 ,1303 ,357 ,1011 ,222 ,1620 ,1913 ,1962 ,1684 ,10 ,1870 ,1703 ,949 ,1571 , + 1274 ,70 ,1313 ,93 ,534 ,436 ,1214 ,855 ,1375 ,835 ,592 ,1919 ,942 ,953 ,1034 ,837 , + 1612 ,1838 ,445 ,1717 ,1225 ,210 ,1612 ,237 ,700 ,766 ,415 ,237 ,1788 ,1593 ,75 ,869 , + 1790 ,539 ,1677 ,653 ,1735 ,343 ,1686 ,1001 ,1073 ,1587 ,509 ,49 ,1770 ,444 ,1429 ,1183 , + 1935 ,473 ,947 ,1890 ,1364 ,43 ,1344 ,31 ,1255 ,271 ,336 ,2010 ,733 ,764 ,1065 ,1688 , + 389 ,785 ,50 ,1205 ,1269 ,804 ,1728 ,671 ,1390 ,152 ,946 ,51 ,1400 ,622 ,1425 ,1612 , + 1346 ,1842 ,997 ,1636 ,959 ,1989 ,1288 ,877 ,704 ,762 ,1265 ,353 ,884 ,1413 ,1947 ,1118 , + 186 ,287 ,1220 ,236 ,38 ,1069 ,327 ,948 ,767 ,2000 ,1023 ,1281 ,1014 ,591 ,1254 ,986 , + 196 ,1598 ,1121 ,1710 ,910 ,414 ,1627 ,1794 ,1819 ,1543 ,594 ,1588 ,496 ,1311 ,1649 ,1228 , + 1307 ,520 ,157 ,828 ,264 ,1069 ,837 ,568 ,887 ,1318 ,1704 ,141 ,791 ,376 ,1149 ,1032 , + 175 ,1658 ,1288 ,1047 ,1133 ,39 ,687 ,1066 ,18 ,17 ,883 ,1667 ,171 ,1983 ,1327 ,54 , + 2042 ,1700 ,1029 ,164 ,915 ,347 ,976 ,754 ,1972 ,1992 ,1458 ,253 ,1123 ,1430 ,144 ,1942 , + 1531 ,251 ,618 ,1428 ,1255 ,1626 ,1332 ,74 ,1423 ,973 ,115 ,1845 ,422 ,97 ,1047 ,1752 , + 541 ,243 ,1697 ,164 ,1736 ,1030 ,976 ,2008 ,1210 ,91 ,1165 ,436 ,1912 ,113 ,144 ,851 , + 1978 ,1419 ,1736 ,1853 ,692 ,1953 ,1332 ,1586 ,724 ,847 ,47 ,1388 ,35 ,14 ,1047 ,1752 , + 48 ,355 ,962 ,523 ,1514 ,20 ,1505 ,2015 ,435 ,954 ,583 ,1916 ,1883 ,1427 ,716 ,1091 , + 1663 ,797 ,1529 ,1861 ,897 ,219 ,357 ,643 ,948 ,543 ,1582 ,1543 ,687 ,419 ,1556 ,1470 , + 1945 ,1974 ,1323 ,1156 ,420 ,54 ,1607 ,583 ,435 ,954 ,1012 ,436 ,1001 ,1571 ,603 ,1279 , + 821 ,2002 ,723 ,1347 ,1405 ,424 ,1301 ,709 ,684 ,52 ,429 ,1168 ,687 ,1464 ,1342 ,1823 , + 201 ,438 ,246 ,751 ,636 ,960 ,1714 ,1408 ,161 ,1852 ,1111 ,1416 ,969 ,1105 ,1237 ,1591 , + 376 ,139 ,1733 ,705 ,780 ,286 ,1508 ,1104 ,163 ,1981 ,1824 ,507 ,1869 ,1003 ,1452 ,371 , + 1359 ,1285 ,1984 ,1299 ,371 ,148 ,727 ,255 ,1744 ,1424 ,708 ,1988 ,188 ,680 ,533 ,656 , + 1667 ,1459 ,1066 ,1678 ,229 ,1727 ,130 ,2045 ,519 ,413 ,1825 ,586 ,688 ,297 ,134 ,598 , + 1661 ,122 ,1095 ,309 ,40 ,1896 ,457 ,36 ,589 ,1170 ,1701 ,1088 ,1738 ,1601 ,931 ,275 , + 1441 ,792 ,607 ,842 ,41 ,95 ,1470 ,636 ,12 ,1645 ,170 ,1827 ,1553 ,1168 ,1452 ,652 , + 1242 ,1715 ,1865 ,443 ,1318 ,844 ,1045 ,1668 ,1540 ,550 ,1344 ,298 ,623 ,1175 ,1270 ,1535 , + 1156 ,547 ,926 ,1415 ,1775 ,486 ,163 ,524 ,255 ,1717 ,711 ,527 ,1984 ,961 ,992 ,1413 , + 627 ,1450 ,971 ,448 ,621 ,926 ,839 ,1628 ,1059 ,158 ,147 ,1073 ,1884 ,935 ,481 ,1622 , + 1478 ,498 ,900 ,1294 ,507 ,560 ,1505 ,1862 ,1461 ,1604 ,601 ,1472 ,890 ,758 ,1339 ,60 , + 867 ,252 ,1560 ,154 ,400 ,1688 ,887 ,1090 ,1256 ,2020 ,1466 ,1293 ,1349 ,1166 ,791 ,679 , + 723 ,607 ,2 ,47 ,893 ,580 ,337 ,1981 ,364 ,1704 ,113 ,451 ,1100 ,172 ,1076 ,1277 , + 709 ,1515 ,1626 ,164 ,306 ,1030 ,428 ,1648 ,1972 ,357 ,1458 ,760 ,1912 ,1714 ,1821 ,1942 , + 1029 ,646 ,1490 ,1402 ,1255 ,1953 ,1945 ,1774 ,699 ,847 ,1538 ,1124 ,35 ,1924 ,850 ,573 , + 768 ,243 ,783 ,142 ,481 ,555 ,666 ,1744 ,1210 ,1542 ,1165 ,253 ,1101 ,113 ,166 ,851 , + 32 ,1419 ,618 ,1853 ,1255 ,1953 ,1332 ,1774 ,724 ,847 ,569 ,1388 ,1930 ,1688 ,427 ,1752 , + 768 ,243 ,783 ,142 ,481 ,555 ,976 ,1648 ,1210 ,374 ,1165 ,436 ,194 ,1112 ,853 ,851 , + 1978 ,1829 ,1736 ,1406 ,610 ,1953 ,1497 ,1586 ,1423 ,847 ,1538 ,1388 ,1930 ,14 ,332 ,1833 , + 768 ,243 ,783 ,142 ,481 ,555 ,976 ,1648 ,1210 ,374 ,1165 ,436 ,194 ,1112 ,644 ,851 , + 1978 ,646 ,1736 ,1402 ,1908 ,1067 ,377 ,1586 ,1423 ,591 ,1538 ,483 ,1930 ,14 ,427 ,1833 , + 1850 ,243 ,783 ,142 ,481 ,1030 ,976 ,1744 ,739 ,1992 ,662 ,436 ,1908 ,650 ,166 ,667 , + 1978 ,251 ,1736 ,1402 ,1255 ,1067 ,1353 ,717 ,604 ,847 ,569 ,1124 ,35 ,14 ,1047 ,1833 , + 1850 ,243 ,783 ,1348 ,1335 ,1030 ,976 ,1744 ,1210 ,1992 ,662 ,253 ,1908 ,650 ,802 ,667 , + 1978 ,1642 ,290 ,1402 ,1255 ,1067 ,1353 ,1774 ,604 ,591 ,1538 ,483 ,1930 ,1688 ,1047 ,1833 , + 758 ,427 ,199 ,1697 ,839 ,1167 ,121 ,1630 ,1833 ,1546 ,963 ,291 ,814 ,1094 ,496 ,478 , + 200 ,901 ,1100 ,808 ,1802 ,352 ,796 ,619 ,1350 ,777 ,1847 ,1314 ,936 ,943 ,1448 ,573 , + 1268 ,1920 ,1504 ,1179 ,1626 ,1724 ,1856 ,2004 ,1349 ,959 ,542 ,210 ,1973 ,1517 ,210 ,1395 , + 1522 ,1216 ,76 ,994 ,491 ,489 ,1139 ,1287 ,1375 ,1151 ,403 ,1740 ,1072 ,979 ,1389 ,777 , + 1221 ,1192 ,1021 ,705 ,731 ,573 ,818 ,328 ,853 ,1037 ,1976 ,563 ,1934 ,175 ,1303 ,320 , + 1975 ,1835 ,960 ,1246 ,827 ,93 ,385 ,782 ,1482 ,217 ,387 ,672 ,1003 ,1001 ,428 ,44 , + 1906 ,1209 ,474 ,1084 ,917 ,1621 ,1590 ,1750 ,1854 ,705 ,1129 ,648 ,1770 ,761 ,41 ,164 , + 1569 ,1044 ,912 ,1346 ,580 ,1636 ,290 ,683 ,1004 ,986 ,1762 ,1535 ,1275 ,820 ,853 ,785 , + 1906 ,50 ,513 ,1311 ,1858 ,1413 ,895 ,120 ,1970 ,770 ,1606 ,910 ,1294 ,614 ,593 ,1796 , + 1541 ,1039 ,970 ,1797 ,1311 ,1343 ,1250 ,793 ,117 ,637 ,408 ,1860 ,1274 ,650 ,1707 ,1062 , + 1346 ,1038 ,17 ,454 ,1513 ,1700 ,886 ,483 ,1415 ,1138 ,1690 ,826 ,1132 ,1481 ,1599 ,15 , + 1882 ,1607 ,1412 ,944 ,784 ,659 ,1330 ,278 ,1464 ,1895 ,1287 ,657 ,273 ,602 ,1837 ,405 , + 726 ,1406 ,1077 ,941 ,359 ,1272 ,916 ,255 ,1129 ,1277 ,1762 ,182 ,69 ,622 ,2036 ,761 , + 1014 ,1255 ,1406 ,733 ,1162 ,660 ,1526 ,436 ,1579 ,1392 ,103 ,441 ,1198 ,1079 ,232 ,355 , + 290 ,661 ,873 ,166 ,1619 ,700 ,753 ,1513 ,2027 ,2035 ,1750 ,1956 ,787 ,44 ,471 ,963 , + 1761 ,940 ,1581 ,498 ,1264 ,843 ,1955 ,1258 ,1689 ,1032 ,562 ,1500 ,1903 ,767 ,1229 ,111 , + 1918 ,1540 ,401 ,383 ,1930 ,1453 ,77 ,842 ,152 ,1706 ,1561 ,1133 ,1857 ,561 ,993 ,194 , + 192 ,1480 ,1304 ,1028 ,1197 ,219 ,843 ,1440 ,1762 ,1390 ,313 ,1561 ,687 ,133 ,772 ,1424 , + 1430 ,310 ,146 ,713 ,1338 ,747 ,580 ,978 ,1301 ,208 ,277 ,1385 ,1367 ,230 ,907 ,1790 , + 601 ,905 ,786 ,553 ,1413 ,357 ,88 ,203 ,352 ,1886 ,1225 ,1980 ,664 ,2047 ,939 ,100 , + 1378 ,715 ,2006 ,1606 ,691 ,531 ,403 ,133 ,1301 ,717 ,1054 ,21 ,1525 ,1715 ,634 ,368 , + 1938 ,4 ,878 ,1440 ,796 ,1399 ,1980 ,537 ,1777 ,715 ,747 ,395 ,827 ,562 ,1463 ,1301 , + 1204 ,580 ,232 ,1682 ,393 ,453 ,1170 ,599 ,1518 ,983 ,1680 ,763 ,1988 ,1896 ,274 ,382 , + 1320 ,547 ,155 ,1422 ,251 ,114 ,1357 ,1078 ,602 ,689 ,907 ,1078 ,1848 ,290 ,887 ,575 , + 409 ,96 ,159 ,1437 ,960 ,58 ,518 ,1090 ,1100 ,916 ,802 ,1217 ,188 ,1421 ,560 ,2039 , + 1848 ,1941 ,1928 ,1292 ,1359 ,883 ,701 ,189 ,1248 ,785 ,280 ,763 ,360 ,73 ,1987 ,372 , + 1174 ,136 ,436 ,162 ,1333 ,1706 ,1255 ,323 ,664 ,557 ,226 ,1642 ,1382 ,691 ,89 ,1582 , + 1886 ,2025 ,1051 ,1809 ,1679 ,1404 ,518 ,1491 ,671 ,377 ,997 ,908 ,402 ,234 ,921 ,1631 , + 1682 ,1318 ,820 ,726 ,875 ,492 ,1644 ,36 ,117 ,1208 ,1088 ,666 ,1955 ,1524 ,1716 ,1487 , + 1193 ,1953 ,1982 ,1783 ,1949 ,1405 ,780 ,800 ,334 ,203 ,1495 ,1615 ,1121 ,1918 ,514 ,1871 , + 1033 ,365 ,502 ,1066 ,1657 ,1581 ,232 ,1442 ,1227 ,297 ,1946 ,1792 ,996 ,669 ,725 ,2023 , + 975 ,1976 ,1591 ,1168 ,925 ,238 ,802 ,1116 ,1941 ,1629 ,1231 ,1395 ,1847 ,244 ,174 ,1554 , + 1010 ,199 ,61 ,1196 ,977 ,891 ,1963 ,845 ,1250 ,562 ,2047 ,1945 ,601 ,691 ,1140 ,1559 , + 765 ,920 ,1149 ,671 ,1409 ,1299 ,1249 ,461 ,1227 ,1678 ,1269 ,1697 ,636 ,784 ,443 ,239 , + 1010 ,618 ,126 ,920 ,287 ,1669 ,39 ,727 ,60 ,795 ,1310 ,311 ,1944 ,382 ,1216 ,955 , + 288 ,1295 ,31 ,576 ,339 ,2029 ,1111 ,1567 ,5 ,1449 ,1506 ,449 ,992 ,705 ,1107 ,274 , + 450 ,1885 ,1355 ,723 ,1572 ,379 ,190 ,1922 ,700 ,1917 ,1071 ,510 ,963 ,555 ,1686 ,216 , + 1552 ,1613 ,593 ,407 ,360 ,1217 ,598 ,1176 ,184 ,485 ,765 ,1989 ,327 ,198 ,4 ,139 , + 1095 ,98 ,12 ,1034 ,1164 ,1367 ,518 ,727 ,871 ,1689 ,320 ,133 ,218 ,1841 ,672 ,1175 , + 158 ,1015 ,837 ,1714 ,1045 ,1820 ,1744 ,999 ,2028 ,1239 ,1503 ,728 ,1472 ,243 ,713 ,1832 , + 1110 ,446 ,412 ,1293 ,40 ,583 ,557 ,1017 ,1106 ,1805 ,1176 ,1190 ,582 ,1943 ,983 ,923 , + 1599 ,20 ,1531 ,1377 ,1870 ,1621 ,1658 ,1480 ,1260 ,1986 ,688 ,775 ,1930 ,963 ,1448 ,1269 , + 2005 ,1363 ,996 ,386 ,1135 ,89 ,1531 ,1808 ,767 ,1314 ,486 ,1055 ,1760 ,222 ,1224 ,1189 , + 1568 ,1173 ,1652 ,734 ,131 ,895 ,560 ,133 ,1618 ,1569 ,543 ,368 ,1201 ,1657 ,552 ,1258 , + 2005 ,1078 ,674 ,2021 ,2047 ,1920 ,37 ,1757 ,19 ,1955 ,1376 ,575 ,1160 ,1345 ,180 ,1019 , + 125 ,902 ,438 ,1471 ,291 ,1903 ,1113 ,561 ,645 ,1174 ,286 ,1934 ,194 ,1998 ,1300 ,160 , + 40 ,81 ,77 ,1342 ,555 ,955 ,377 ,1804 ,1976 ,1505 ,253 ,37 ,8 ,216 ,197 ,445 , + 66 ,425 ,458 ,1747 ,1396 ,210 ,437 ,1585 ,1228 ,1105 ,215 ,309 ,1746 ,1547 ,1062 ,52 , + 1489 ,1744 ,374 ,1797 ,819 ,903 ,513 ,1454 ,1338 ,10 ,9 ,1407 ,1820 ,561 ,383 ,1057 , + 1857 ,1950 ,568 ,1927 ,469 ,1373 ,1199 ,753 ,1586 ,1291 ,1887 ,906 ,1904 ,195 ,1079 ,1341 , + 1621 ,1597 ,480 ,1225 ,1677 ,716 ,1603 ,1628 ,245 ,158 ,34 ,619 ,202 ,1702 ,1594 ,1555 , + 306 ,339 ,352 ,725 ,407 ,1491 ,2008 ,74 ,765 ,49 ,573 ,191 ,78 ,1260 ,2043 ,1282 , + 302 ,81 ,1223 ,521 ,1749 ,1571 ,1461 ,1302 ,1548 ,1867 ,147 ,1091 ,1231 ,1900 ,1165 ,828 , + 663 ,136 ,1160 ,1765 ,616 ,1070 ,922 ,1522 ,775 ,1292 ,1248 ,1462 ,1057 ,1818 ,1627 ,563 , + 302 ,991 ,627 ,1527 ,947 ,1567 ,1440 ,5 ,1960 ,864 ,468 ,1810 ,528 ,206 ,110 ,58 , + 835 ,513 ,805 ,57 ,1218 ,1680 ,1458 ,885 ,1603 ,1292 ,260 ,1431 ,58 ,388 ,1625 ,1585 , + 1364 ,387 ,25 ,1667 ,427 ,1199 ,1819 ,1266 ,1655 ,1633 ,823 ,1920 ,273 ,804 ,966 ,311 , + 1475 ,1357 ,397 ,105 ,1982 ,1036 ,1596 ,1260 ,1408 ,958 ,731 ,1700 ,2033 ,858 ,1425 ,1361 , + 476 ,1961 ,1045 ,334 ,1848 ,554 ,537 ,1707 ,1857 ,145 ,1447 ,730 ,1323 ,1195 ,336 ,529 , + 131 ,1769 ,452 ,813 ,509 ,629 ,886 ,832 ,289 ,898 ,889 ,691 ,384 ,701 ,948 ,392 , + 4 ,1540 ,667 ,1034 ,1142 ,1040 ,607 ,1739 ,809 ,1769 ,1516 ,1552 ,676 ,381 ,1996 ,880 , + 835 ,590 ,1138 ,1861 ,131 ,1845 ,1940 ,249 ,1565 ,771 ,1076 ,690 ,1427 ,1553 ,826 ,195 , + 769 ,1700 ,1178 ,164 ,267 ,1443 ,1238 ,2008 ,1866 ,374 ,415 ,436 ,1123 ,1430 ,853 ,667 , + 210 ,1642 ,290 ,1428 ,1255 ,1067 ,1353 ,963 ,604 ,973 ,1538 ,483 ,422 ,97 ,1140 ,1833 , + 666 ,243 ,783 ,142 ,481 ,1030 ,976 ,2008 ,739 ,1992 ,1165 ,436 ,194 ,1112 ,144 ,1448 , + 1978 ,251 ,1490 ,1853 ,610 ,1897 ,1332 ,1774 ,724 ,591 ,1538 ,951 ,1930 ,1440 ,1047 ,1752 , + 768 ,243 ,783 ,142 ,481 ,1030 ,976 ,1744 ,1972 ,91 ,1165 ,253 ,1908 ,1430 ,644 ,851 , + 1978 ,1642 ,1490 ,1402 ,692 ,1953 ,377 ,74 ,204 ,591 ,1538 ,1388 ,1160 ,1688 ,332 ,1956 , + 768 ,1864 ,265 ,164 ,89 ,555 ,976 ,1744 ,739 ,59 ,1778 ,253 ,1912 ,591 ,921 ,480 , + 260 ,1642 ,2015 ,82 ,423 ,1444 ,1821 ,963 ,765 ,1792 ,1538 ,1543 ,987 ,1171 ,427 ,904 , + 787 ,212 ,1884 ,194 ,1601 ,765 ,742 ,1850 ,967 ,1317 ,1310 ,605 ,1208 ,882 ,943 ,701 , + 29 ,58 ,479 ,490 ,943 ,1601 ,1685 ,961 ,36 ,147 ,75 ,874 ,1612 ,632 ,754 ,32 , + 1783 ,400 ,409 ,872 ,920 ,1212 ,613 ,1669 ,1704 ,480 ,1527 ,1430 ,241 ,1809 ,404 ,666 , + 1413 ,1308 ,1018 ,1381 ,1906 ,828 ,305 ,212 ,779 ,535 ,1225 ,1748 ,109 ,1319 ,478 ,776 , + 208 ,51 ,409 ,1811 ,2009 ,1060 ,216 ,1084 ,1225 ,1366 ,723 ,1902 ,1304 ,216 ,433 ,1866 , + 983 ,1158 ,157 ,1766 ,449 ,400 ,1405 ,1676 ,796 ,305 ,319 ,1890 ,1003 ,1335 ,1457 ,89 , + 208 ,232 ,1419 ,1850 ,1419 ,1060 ,137 ,1512 ,631 ,1830 ,279 ,1808 ,1994 ,1872 ,402 ,986 , + 1808 ,85 ,21 ,1279 ,567 ,4 ,544 ,1151 ,1379 ,295 ,682 ,1113 ,1953 ,757 ,1180 ,1068 , + 208 ,449 ,759 ,1768 ,1300 ,567 ,1102 ,183 ,648 ,1885 ,645 ,1225 ,1440 ,214 ,938 ,1818 , + 113 ,1370 ,746 ,681 ,588 ,1972 ,386 ,926 ,1581 ,1971 ,286 ,776 ,1673 ,1017 ,1125 ,1855 , + 2005 ,655 ,126 ,1533 ,1799 ,1851 ,934 ,1628 ,693 ,1487 ,90 ,753 ,1956 ,1427 ,734 ,1205 , + 259 ,1693 ,729 ,737 ,650 ,23 ,772 ,327 ,1930 ,61 ,488 ,763 ,599 ,797 ,40 ,1254 , + 1285 ,1540 ,328 ,1562 ,279 ,726 ,160 ,1529 ,292 ,624 ,1165 ,963 ,1979 ,543 ,1552 ,262 , + 1993 ,1364 ,1665 ,654 ,123 ,1092 ,1199 ,1022 ,711 ,607 ,1405 ,1589 ,687 ,436 ,1349 ,805 , + 192 ,1700 ,1212 ,1620 ,267 ,1443 ,326 ,1648 ,1866 ,357 ,662 ,253 ,1908 ,1714 ,853 ,667 , + 32 ,1829 ,1490 ,1402 ,692 ,1626 ,377 ,1774 ,604 ,591 ,133 ,1388 ,1995 ,1688 ,332 ,1562 , + 1419 ,145 ,1539 ,1384 ,491 ,474 ,183 ,807 ,1214 ,939 ,1017 ,1054 ,1698 ,1660 ,35 ,513 , + 835 ,165 ,407 ,461 ,398 ,870 ,950 ,304 ,1881 ,1099 ,669 ,65 ,346 ,1134 ,901 ,111 , + 642 ,1851 ,910 ,1278 ,417 ,1737 ,1130 ,609 ,779 ,379 ,1617 ,488 ,1449 ,1969 ,973 ,508 , + 295 ,1762 ,207 ,1038 ,595 ,1662 ,107 ,2008 ,1673 ,1158 ,436 ,1559 ,1252 ,122 ,1216 ,761 , + 1621 ,1783 ,1502 ,350 ,200 ,553 ,712 ,88 ,767 ,899 ,143 ,1548 ,814 ,900 ,851 ,1031 , + 734 ,1218 ,1779 ,440 ,1558 ,1656 ,1455 ,1029 ,1181 ,2042 ,1591 ,1916 ,1052 ,1659 ,1008 ,278 , + 1174 ,373 ,1605 ,1634 ,323 ,1286 ,1645 ,490 ,994 ,1598 ,784 ,5 ,1973 ,1064 ,1132 ,104 , + 612 ,967 ,1071 ,1898 ,253 ,1032 ,2021 ,1180 ,485 ,1176 ,1114 ,1907 ,1290 ,486 ,143 ,1567 , + 1174 ,1944 ,1461 ,364 ,349 ,586 ,774 ,1064 ,1983 ,631 ,1914 ,906 ,928 ,546 ,1736 ,467 , + 933 ,1269 ,882 ,404 ,675 ,296 ,1096 ,472 ,1321 ,314 ,1326 ,1490 ,997 ,1744 ,1191 ,1928 , + 1308 ,1799 ,124 ,1944 ,1511 ,1244 ,1508 ,1173 ,1346 ,1127 ,1607 ,1836 ,951 ,821 ,842 ,625 , + 63 ,621 ,1225 ,388 ,308 ,650 ,1031 ,1447 ,1382 ,1986 ,1381 ,942 ,913 ,1352 ,1692 ,722 , + 523 ,1284 ,774 ,1815 ,1895 ,1187 ,246 ,1062 ,1250 ,1405 ,55 ,340 ,1741 ,859 ,1292 ,1690 , + 63 ,1096 ,817 ,318 ,1120 ,97 ,1040 ,383 ,1572 ,1214 ,351 ,1168 ,1382 ,195 ,1647 ,83 , + 1688 ,1198 ,739 ,1319 ,507 ,351 ,886 ,1803 ,157 ,1100 ,864 ,310 ,1299 ,623 ,426 ,390 , + 1620 ,113 ,1252 ,1212 ,1294 ,116 ,1557 ,694 ,829 ,552 ,644 ,1870 ,950 ,10 ,910 ,290 , + 1907 ,1665 ,307 ,2032 ,1944 ,588 ,1505 ,1888 ,347 ,1225 ,1528 ,337 ,797 ,983 ,274 ,965 , + 1937 ,1812 ,1956 ,1822 ,513 ,839 ,640 ,115 ,621 ,1649 ,2041 ,1079 ,1109 ,28 ,1561 ,1879 , + 898 ,498 ,1012 ,1133 ,1044 ,846 ,202 ,533 ,1748 ,2023 ,1954 ,1522 ,412 ,1200 ,1768 ,1360 , + 750 ,1409 ,404 ,1881 ,138 ,904 ,1265 ,870 ,121 ,638 ,1756 ,1793 ,1009 ,1104 ,313 ,62 , + 497 ,994 ,61 ,1785 ,580 ,87 ,1324 ,1190 ,369 ,846 ,1607 ,1704 ,676 ,1422 ,1339 ,537 , + 1757 ,1982 ,438 ,1849 ,150 ,1884 ,882 ,568 ,781 ,1446 ,1137 ,1260 ,1678 ,1834 ,765 ,1489 , + 1797 ,1515 ,1626 ,164 ,306 ,1443 ,183 ,662 ,1561 ,91 ,675 ,1625 ,1101 ,1430 ,144 ,1942 , + 1458 ,963 ,1736 ,1047 ,1908 ,1626 ,655 ,643 ,1252 ,591 ,47 ,1388 ,987 ,1820 ,632 ,1752 , + 448 ,243 ,783 ,142 ,481 ,1030 ,976 ,2008 ,739 ,1992 ,1165 ,436 ,194 ,1714 ,644 ,1684 , + 1978 ,1642 ,1490 ,1428 ,692 ,1897 ,377 ,717 ,1423 ,591 ,1538 ,1388 ,1995 ,97 ,332 ,1752 , + 7 ,1597 ,265 ,1712 ,306 ,906 ,825 ,1744 ,1338 ,374 ,1165 ,253 ,1449 ,1204 ,144 ,1942 , + 210 ,251 ,75 ,1402 ,1768 ,1491 ,1353 ,717 ,11 ,973 ,47 ,597 ,1065 ,1688 ,1140 ,1956 , + 1883 ,590 ,958 ,1287 ,1457 ,1040 ,225 ,1197 ,2001 ,1963 ,1306 ,355 ,1291 ,555 ,810 ,854 , + 489 ,1072 ,1344 ,1283 ,906 ,113 ,466 ,370 ,1226 ,1706 ,901 ,1611 ,1970 ,182 ,1330 ,1887 , + 1956 ,132 ,1509 ,710 ,973 ,464 ,1248 ,1120 ,1599 ,1248 ,1944 ,1452 ,743 ,1711 ,1776 ,408 , + 1151 ,513 ,1005 ,1548 ,1929 ,1123 ,272 ,654 ,534 ,353 ,892 ,203 ,372 ,920 ,91 ,1349 , + 1688 ,815 ,173 ,4 ,771 ,1993 ,395 ,1447 ,1273 ,1553 ,1387 ,2032 ,498 ,968 ,1667 ,627 , + 520 ,1941 ,1323 ,956 ,1696 ,674 ,1402 ,322 ,1188 ,766 ,861 ,585 ,1346 ,1247 ,1407 ,2021 , + 964 ,585 ,1092 ,1847 ,1720 ,430 ,79 ,1940 ,489 ,1519 ,866 ,848 ,2011 ,886 ,2042 ,338 , + 1668 ,909 ,684 ,1771 ,581 ,677 ,39 ,1984 ,428 ,1127 ,1525 ,551 ,1925 ,308 ,1054 ,19 , + 1914 ,1251 ,1338 ,2016 ,1484 ,1040 ,1168 ,1615 ,1687 ,832 ,1895 ,1563 ,1962 ,2029 ,1012 ,471 , + 118 ,571 ,1292 ,579 ,480 ,929 ,1111 ,1051 ,542 ,1602 ,1871 ,1803 ,1943 ,870 ,589 ,1668 , + 1004 ,1541 ,992 ,466 ,1040 ,786 ,1065 ,881 ,622 ,481 ,122 ,1093 ,641 ,267 ,961 ,386 , + 298 ,1604 ,1789 ,758 ,65 ,65 ,989 ,1691 ,955 ,1876 ,440 ,1987 ,2047 ,735 ,1975 ,345 , + 1227 ,440 ,881 ,533 ,770 ,1870 ,137 ,108 ,357 ,1149 ,1536 ,698 ,1585 ,906 ,741 ,2015 , + 1966 ,1253 ,512 ,768 ,1579 ,1070 ,1971 ,1414 ,1717 ,1892 ,1944 ,1109 ,360 ,755 ,654 ,1673 , + 1424 ,1727 ,320 ,1354 ,711 ,1191 ,611 ,1329 ,809 ,1416 ,1262 ,153 ,1192 ,1863 ,650 ,1511 , + 508 ,403 ,716 ,984 ,1399 ,818 ,213 ,601 ,172 ,89 ,1323 ,1543 ,168 ,1149 ,970 ,1780 , + 2007 ,175 ,170 ,785 ,1322 ,1167 ,1706 ,463 ,1801 ,982 ,428 ,1484 ,1327 ,94 ,1075 ,1800 , + 1305 ,1502 ,686 ,681 ,935 ,698 ,1791 ,1535 ,170 ,701 ,1856 ,1598 ,71 ,1637 ,1824 ,36 , + 1999 ,1028 ,1326 ,403 ,1956 ,184 ,576 ,531 ,767 ,238 ,654 ,502 ,1005 ,1072 ,1047 ,1128 , + 976 ,1409 ,1972 ,1596 ,148 ,1978 ,1105 ,1740 ,1775 ,159 ,150 ,285 ,11 ,1101 ,88 ,1895 , + 1432 ,971 ,1767 ,1990 ,1849 ,1081 ,328 ,735 ,366 ,2017 ,914 ,1975 ,121 ,255 ,1232 ,66 , + 1665 ,1232 ,1424 ,1525 ,533 ,641 ,1341 ,1074 ,1855 ,435 ,404 ,197 ,797 ,1589 ,1279 ,110 , + 13 ,156 ,351 ,46 ,2015 ,1041 ,765 ,1416 ,1633 ,1723 ,1164 ,7 ,1697 ,912 ,1976 ,175 , + 1894 ,1259 ,25 ,1213 ,481 ,1342 ,918 ,1297 ,527 ,956 ,789 ,760 ,573 ,1374 ,1315 ,1395 , + 1147 ,1354 ,642 ,1196 ,50 ,1127 ,1313 ,914 ,1582 ,81 ,554 ,917 ,632 ,1268 ,520 ,714 , + 542 ,1006 ,253 ,1527 ,1182 ,261 ,101 ,1014 ,1324 ,1658 ,1728 ,1036 ,349 ,378 ,1644 ,116 , + 1678 ,1981 ,955 ,1219 ,1247 ,1709 ,1387 ,336 ,456 ,989 ,819 ,2023 ,1129 ,1946 ,1740 ,597 , + 528 ,1280 ,1723 ,1385 ,1961 ,430 ,1477 ,730 ,1382 ,891 ,267 ,1046 ,736 ,194 ,2010 ,1513 , + 1381 ,387 ,448 ,1561 ,820 ,1065 ,801 ,267 ,1393 ,418 ,981 ,1042 ,985 ,2041 ,1787 ,1591 , + 1423 ,1019 ,39 ,2021 ,1628 ,829 ,1787 ,1404 ,657 ,978 ,1859 ,296 ,689 ,1377 ,696 ,1660 , + 1287 ,1024 ,44 ,1537 ,848 ,1014 ,1495 ,1779 ,1135 ,78 ,969 ,1899 ,1151 ,41 ,1257 ,1679 , + 213 ,6 ,1319 ,1058 ,1818 ,1637 ,956 ,1696 ,338 ,1163 ,1183 ,1719 ,500 ,1997 ,170 ,830 , + 35 ,1333 ,1840 ,1224 ,959 ,502 ,1910 ,1738 ,1127 ,1373 ,1706 ,1611 ,1577 ,1822 ,409 ,1913 , + 582 ,598 ,118 ,545 ,1355 ,69 ,984 ,1385 ,824 ,647 ,1497 ,1603 ,1159 ,1504 ,935 ,884 , + 1082 ,1623 ,61 ,566 ,515 ,1220 ,998 ,394 ,137 ,1102 ,1415 ,412 ,1274 ,988 ,168 ,1418 , + 2042 ,314 ,1892 ,382 ,163 ,1161 ,1775 ,1957 ,1884 ,1914 ,284 ,836 ,1253 ,894 ,994 ,1875 , + 1909 ,290 ,990 ,1932 ,1994 ,1102 ,1970 ,21 ,1585 ,1071 ,2017 ,115 ,1671 ,653 ,999 ,1041 , + 116 ,1188 ,1890 ,1070 ,1750 ,694 ,450 ,944 ,1332 ,1135 ,270 ,622 ,1959 ,49 ,311 ,186 , + 497 ,1104 ,924 ,1964 ,420 ,716 ,574 ,763 ,1041 ,820 ,1012 ,36 ,1704 ,836 ,745 ,1585 , + 1352 ,84 ,1587 ,1279 ,1768 ,787 ,340 ,1929 ,524 ,2036 ,1175 ,1336 ,612 ,534 ,798 ,1516 , + 1263 ,1515 ,783 ,142 ,1736 ,2010 ,976 ,1744 ,1561 ,1542 ,1048 ,1383 ,1343 ,1421 ,644 ,610 , + 1993 ,1829 ,1923 ,1853 ,610 ,1067 ,377 ,717 ,18 ,973 ,1538 ,1124 ,346 ,97 ,427 ,1833 , + 541 ,243 ,783 ,1348 ,267 ,1443 ,825 ,2008 ,1210 ,1992 ,1165 ,1383 ,1101 ,1714 ,802 ,1448 , + 32 ,1642 ,1490 ,1406 ,692 ,1897 ,1353 ,1774 ,1423 ,591 ,1538 ,951 ,35 ,97 ,1047 ,1752 , + 1771 ,243 ,1559 ,1348 ,1736 ,1443 ,666 ,1648 ,1972 ,1370 ,415 ,143 ,1343 ,1112 ,144 ,1448 , + 200 ,646 ,1490 ,1406 ,1255 ,1626 ,1497 ,1774 ,711 ,591 ,47 ,1388 ,1930 ,97 ,1216 ,1752 , + 448 ,243 ,1697 ,546 ,1736 ,1443 ,976 ,1744 ,1210 ,1370 ,1165 ,253 ,1908 ,650 ,644 ,851 , + 1978 ,1419 ,290 ,1853 ,692 ,1067 ,377 ,1774 ,724 ,847 ,47 ,951 ,1995 ,97 ,332 ,1833 , + 448 ,243 ,1697 ,1348 ,1736 ,1443 ,1978 ,1744 ,1210 ,1992 ,1165 ,253 ,1912 ,1430 ,802 ,851 , + 32 ,251 ,1736 ,1428 ,692 ,1067 ,1332 ,1586 ,724 ,973 ,1538 ,951 ,422 ,1688 ,1047 ,1956 , + 448 ,243 ,1697 ,164 ,1736 ,1572 ,976 ,1744 ,739 ,1542 ,1165 ,253 ,1908 ,1714 ,644 ,1684 , + 32 ,1419 ,1736 ,1406 ,1908 ,1067 ,1497 ,717 ,1423 ,847 ,47 ,1388 ,422 ,1688 ,332 ,1562 , + 448 ,243 ,1697 ,164 ,1736 ,1572 ,976 ,1744 ,1210 ,1542 ,1165 ,253 ,1908 ,1714 ,644 ,851 , + 1978 ,1419 ,1736 ,1406 ,1908 ,1953 ,377 ,1586 ,1423 ,591 ,1538 ,1388 ,1995 ,1688 ,332 ,1752 , + 1850 ,243 ,1697 ,1348 ,1736 ,1443 ,1978 ,1744 ,1210 ,1370 ,1165 ,1190 ,1912 ,1430 ,144 ,851 , + 1978 ,251 ,1736 ,1406 ,692 ,1067 ,1497 ,1586 ,724 ,591 ,47 ,951 ,1930 ,1688 ,1047 ,1833 , + 384 ,1211 ,1622 ,1562 ,1836 ,555 ,477 ,1648 ,1938 ,278 ,675 ,253 ,194 ,650 ,1336 ,1511 , + 32 ,1829 ,618 ,82 ,123 ,154 ,1332 ,1277 ,172 ,1040 ,1538 ,1124 ,2008 ,1688 ,332 ,1209 , + 420 ,1260 ,573 ,183 ,877 ,27 ,1797 ,1879 ,1200 ,546 ,100 ,1093 ,1300 ,80 ,1060 ,1640 , + 1789 ,168 ,1725 ,1579 ,2046 ,1469 ,1888 ,1990 ,76 ,1692 ,839 ,1116 ,692 ,551 ,686 ,1650 , + 1738 ,693 ,311 ,1527 ,623 ,1339 ,755 ,335 ,753 ,357 ,1456 ,1304 ,761 ,1770 ,1377 ,695 , + 447 ,1584 ,1529 ,1330 ,942 ,776 ,1362 ,986 ,1611 ,1332 ,429 ,44 ,1704 ,1718 ,1190 ,1004 , + 293 ,1700 ,1178 ,97 ,1736 ,1572 ,666 ,2008 ,1972 ,1992 ,415 ,436 ,1123 ,113 ,644 ,851 , + 1978 ,646 ,290 ,1428 ,423 ,1067 ,377 ,963 ,1423 ,1343 ,47 ,483 ,1995 ,97 ,1047 ,1562 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,2008 ,739 ,374 ,1165 ,436 ,1101 ,113 ,644 ,1448 , + 1978 ,646 ,1736 ,1406 ,692 ,1897 ,377 ,1774 ,204 ,591 ,1538 ,1124 ,422 ,1688 ,1047 ,1956 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,2008 ,1210 ,374 ,1165 ,436 ,1912 ,1714 ,644 ,1448 , + 1978 ,251 ,1736 ,1428 ,692 ,1897 ,377 ,1774 ,204 ,591 ,1538 ,1124 ,422 ,1688 ,1047 ,1833 , + 384 ,243 ,1559 ,1348 ,1736 ,1443 ,1978 ,1744 ,1210 ,1992 ,1165 ,1190 ,1912 ,650 ,144 ,851 , + 32 ,251 ,290 ,1428 ,423 ,1953 ,1497 ,1586 ,724 ,591 ,47 ,1388 ,422 ,1688 ,1047 ,1752 , + 1208 ,1172 ,595 ,321 ,1594 ,614 ,387 ,1684 ,401 ,1656 ,1055 ,1638 ,2 ,1161 ,988 ,734 , + 1243 ,569 ,515 ,1627 ,1985 ,1226 ,742 ,1862 ,1994 ,1461 ,913 ,1032 ,419 ,1784 ,1220 ,1771 , + 1382 ,1949 ,2008 ,89 ,1099 ,1158 ,185 ,1016 ,298 ,874 ,1630 ,1599 ,968 ,1069 ,380 ,957 , + 1583 ,856 ,1522 ,1681 ,1855 ,945 ,1388 ,1974 ,1848 ,1825 ,658 ,866 ,1248 ,741 ,1999 ,628 , + 982 ,1879 ,1663 ,1021 ,1638 ,261 ,1718 ,834 ,1809 ,449 ,1462 ,1438 ,968 ,1144 ,405 ,1910 , + 1138 ,1395 ,498 ,224 ,1334 ,1143 ,169 ,1911 ,1876 ,396 ,1367 ,1522 ,1794 ,273 ,553 ,557 , + 1635 ,520 ,328 ,1034 ,610 ,657 ,2031 ,663 ,1594 ,47 ,358 ,1109 ,1735 ,1722 ,1806 ,1868 , + 844 ,974 ,438 ,493 ,144 ,1997 ,1612 ,1668 ,1141 ,822 ,487 ,420 ,769 ,1529 ,1963 ,1117 , + 1263 ,1515 ,1626 ,164 ,1736 ,347 ,666 ,2008 ,739 ,1370 ,1165 ,436 ,1908 ,1430 ,144 ,667 , + 748 ,1642 ,1736 ,1402 ,1908 ,1897 ,377 ,1396 ,604 ,973 ,693 ,951 ,422 ,1688 ,427 ,1417 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,1744 ,739 ,1542 ,1165 ,253 ,1123 ,650 ,644 ,851 , + 32 ,1419 ,290 ,1406 ,692 ,1897 ,377 ,1586 ,204 ,591 ,1538 ,951 ,422 ,1688 ,332 ,1833 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,2008 ,1210 ,1542 ,1165 ,436 ,1736 ,1112 ,144 ,1448 , + 1978 ,1642 ,1736 ,1853 ,610 ,1897 ,1332 ,1774 ,724 ,591 ,1538 ,951 ,422 ,1688 ,1047 ,1752 , + 1850 ,243 ,1697 ,164 ,1736 ,1030 ,976 ,1744 ,739 ,1992 ,1165 ,253 ,1101 ,1714 ,644 ,1684 , + 32 ,1419 ,1490 ,1406 ,692 ,1067 ,377 ,717 ,204 ,847 ,1538 ,951 ,1930 ,97 ,332 ,1752 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,1744 ,1210 ,374 ,1165 ,253 ,1912 ,1714 ,644 ,1684 , + 1978 ,251 ,1490 ,1428 ,1255 ,1953 ,377 ,717 ,724 ,591 ,1538 ,1124 ,1995 ,1688 ,332 ,1752 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,1744 ,1210 ,374 ,1165 ,253 ,1912 ,1714 ,644 ,1684 , + 1978 ,251 ,1490 ,1428 ,1255 ,1953 ,377 ,717 ,724 ,591 ,1538 ,951 ,1995 ,1688 ,332 ,1752 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,2008 ,1210 ,374 ,1165 ,436 ,1912 ,1714 ,644 ,1448 , + 1978 ,251 ,1736 ,1428 ,692 ,1897 ,377 ,1774 ,204 ,591 ,1538 ,1124 ,422 ,1688 ,1047 ,1833 , + 1850 ,243 ,783 ,142 ,481 ,1030 ,666 ,1744 ,1210 ,1370 ,1165 ,253 ,1912 ,1714 ,644 ,667 , + 1978 ,251 ,618 ,1428 ,1255 ,1626 ,377 ,1774 ,604 ,973 ,1538 ,1388 ,35 ,1688 ,332 ,1833 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,1744 ,1210 ,374 ,1165 ,253 ,1912 ,1714 ,644 ,1684 , + 1978 ,251 ,1490 ,1428 ,1255 ,1626 ,377 ,717 ,724 ,591 ,1538 ,1124 ,422 ,1688 ,332 ,1752 , + 1850 ,243 ,783 ,1348 ,267 ,1030 ,825 ,1744 ,1210 ,374 ,1165 ,253 ,1101 ,113 ,644 ,851 , + 1978 ,1642 ,618 ,1406 ,1255 ,1067 ,377 ,1586 ,724 ,591 ,47 ,1388 ,422 ,1688 ,1047 ,1956 , + 1850 ,243 ,1178 ,1348 ,267 ,1443 ,666 ,1744 ,1210 ,1370 ,1165 ,439 ,1912 ,1430 ,144 ,851 , + 1978 ,646 ,290 ,1428 ,1255 ,1067 ,1332 ,717 ,724 ,847 ,899 ,1124 ,422 ,97 ,332 ,1956 , + 481 ,243 ,783 ,142 ,481 ,1030 ,976 ,1744 ,1210 ,1992 ,1165 ,253 ,1912 ,1714 ,644 ,1684 , + 1978 ,251 ,1490 ,1428 ,1255 ,1953 ,377 ,717 ,724 ,973 ,1538 ,1388 ,1995 ,97 ,1047 ,1752 , + 1850 ,243 ,1178 ,546 ,267 ,1030 ,976 ,2008 ,739 ,1370 ,1165 ,1383 ,1908 ,1714 ,144 ,851 , + 32 ,1419 ,1736 ,1428 ,692 ,1897 ,1332 ,1586 ,724 ,847 ,569 ,951 ,422 ,1688 ,1047 ,1752 , + 384 ,243 ,1559 ,1348 ,1736 ,555 ,825 ,1648 ,739 ,91 ,1165 ,253 ,1908 ,1112 ,144 ,1684 , + 1978 ,1642 ,1490 ,1402 ,610 ,1626 ,1332 ,1774 ,724 ,591 ,47 ,1388 ,35 ,97 ,427 ,904 , + 1885 ,1216 ,1203 ,1196 ,1317 ,64 ,1893 ,524 ,384 ,240 ,1361 ,1850 ,690 ,1967 ,934 ,1357 , + 172 ,1201 ,583 ,539 ,76 ,964 ,615 ,1090 ,413 ,1384 ,193 ,782 ,1428 ,1873 ,455 ,1827 , + 314 ,143 ,1286 ,1368 ,64 ,1219 ,1012 ,2025 ,1708 ,782 ,1899 ,1802 ,1109 ,107 ,159 ,1807 , + 1425 ,1415 ,1281 ,1949 ,2029 ,158 ,110 ,265 ,1961 ,422 ,1269 ,1214 ,409 ,484 ,1831 ,1972 , + 569 ,1493 ,1743 ,1290 ,791 ,99 ,2009 ,577 ,1766 ,969 ,1919 ,1524 ,364 ,132 ,1865 ,208 , + 1671 ,807 ,1600 ,1765 ,758 ,401 ,565 ,1975 ,621 ,1608 ,119 ,1310 ,823 ,5 ,1143 ,242 , + 963 ,400 ,1743 ,1907 ,1610 ,1972 ,1713 ,749 ,1806 ,1272 ,985 ,1699 ,1321 ,1992 ,507 ,590 , + 1018 ,1615 ,179 ,867 ,1067 ,637 ,1646 ,1077 ,800 ,30 ,1084 ,1429 ,823 ,918 ,461 ,321 , + 476 ,971 ,984 ,1290 ,1902 ,1504 ,194 ,1107 ,1322 ,1648 ,2041 ,1683 ,1048 ,105 ,1901 ,1627 , + 151 ,1287 ,731 ,1440 ,687 ,17 ,545 ,762 ,1170 ,1867 ,389 ,544 ,489 ,1455 ,800 ,350 , + 1035 ,669 ,1149 ,417 ,1067 ,716 ,1752 ,479 ,1571 ,1316 ,1162 ,271 ,666 ,1783 ,170 ,966 , + 518 ,1764 ,52 ,1139 ,35 ,670 ,1064 ,874 ,1668 ,169 ,1272 ,957 ,644 ,237 ,107 ,963 , + 1696 ,348 ,420 ,1057 ,411 ,718 ,164 ,318 ,755 ,551 ,1756 ,497 ,739 ,807 ,1524 ,176 , + 82 ,197 ,1179 ,1389 ,151 ,293 ,176 ,1089 ,1757 ,976 ,1175 ,268 ,442 ,1234 ,1518 ,542 , + 1268 ,1172 ,1362 ,1857 ,691 ,817 ,1138 ,515 ,1566 ,1071 ,1989 ,1732 ,1419 ,2033 ,1210 ,1457 , + 1937 ,1322 ,315 ,473 ,1527 ,1141 ,1766 ,957 ,236 ,652 ,1382 ,2021 ,1231 ,1221 ,1224 ,471 , + 443 ,160 ,762 ,981 ,735 ,957 ,12 ,1342 ,1053 ,1380 ,602 ,784 ,1434 ,1472 ,1665 ,1469 , + 1682 ,1008 ,311 ,1184 ,1292 ,1189 ,1219 ,1425 ,1214 ,978 ,337 ,89 ,1542 ,1360 ,1443 ,1736 , + 707 ,774 ,1870 ,33 ,1110 ,490 ,1826 ,1419 ,394 ,1172 ,96 ,743 ,1169 ,227 ,822 ,247 , + 1779 ,280 ,1580 ,993 ,1933 ,1639 ,33 ,923 ,1036 ,851 ,1254 ,1576 ,1882 ,243 ,1369 ,354 , + 1038 ,1398 ,984 ,1648 ,1765 ,1932 ,1995 ,744 ,600 ,462 ,588 ,478 ,637 ,1041 ,1633 ,474 , + 740 ,1839 ,44 ,665 ,1661 ,1774 ,1306 ,912 ,552 ,689 ,828 ,926 ,1729 ,1943 ,665 ,1743 , + 1401 ,453 ,1647 ,1796 ,357 ,1109 ,1566 ,1755 ,1296 ,1266 ,313 ,1419 ,702 ,1571 ,27 ,1184 , + 1842 ,620 ,757 ,1734 ,1870 ,1124 ,1117 ,1133 ,862 ,219 ,1690 ,1316 ,1231 ,848 ,1177 ,653 , + 409 ,1256 ,1915 ,1716 ,1244 ,75 ,1514 ,1677 ,1333 ,290 ,111 ,378 ,1112 ,437 ,1848 ,754 , + 1028 ,208 ,727 ,1368 ,1453 ,759 ,1230 ,1326 ,1344 ,557 ,404 ,40 ,817 ,1531 ,1681 ,80 , + 390 ,134 ,906 ,1749 ,554 ,695 ,719 ,514 ,478 ,1593 ,1955 ,1189 ,1348 ,1494 ,1503 ,1513 , + 1579 ,1104 ,1660 ,1362 ,1985 ,1814 ,1579 ,793 ,62 ,1979 ,843 ,1868 ,514 ,1919 ,1396 ,785 , + 1635 ,618 ,124 ,21 ,791 ,225 ,895 ,774 ,1167 ,1658 ,1421 ,1494 ,793 ,1582 ,368 ,1755 , + 16 ,1571 ,824 ,1604 ,1201 ,585 ,1867 ,514 ,314 ,1097 ,1667 ,1703 ,2047 ,851 ,539 ,139 , + 609 ,1663 ,1534 ,1846 ,701 ,54 ,1067 ,243 ,1860 ,1500 ,1179 ,1618 ,1700 ,1722 ,788 ,700 , + 684 ,1829 ,1441 ,241 ,1405 ,1410 ,1353 ,1086 ,848 ,2 ,1778 ,49 ,1431 ,1076 ,1477 ,1420 , + 1345 ,1216 ,1235 ,1453 ,978 ,737 ,1613 ,1134 ,1657 ,965 ,33 ,1694 ,1560 ,1483 ,228 ,1341 , + 275 ,660 ,266 ,956 ,1279 ,1086 ,486 ,1858 ,1690 ,500 ,696 ,237 ,1920 ,322 ,625 ,1798 , + 1384 ,1558 ,78 ,406 ,1050 ,208 ,548 ,1462 ,1016 ,19 ,179 ,1099 ,152 ,267 ,1883 ,1041 , + 1457 ,735 ,1541 ,311 ,2010 ,1566 ,1500 ,1743 ,1507 ,1045 ,37 ,1470 ,679 ,589 ,410 ,268 , + 1382 ,538 ,1300 ,218 ,1446 ,109 ,524 ,1327 ,591 ,496 ,215 ,678 ,862 ,1521 ,1394 ,1350 , + 43 ,772 ,1721 ,129 ,1756 ,1485 ,372 ,932 ,1556 ,1537 ,1598 ,243 ,409 ,1619 ,397 ,1814 , + 1401 ,1976 ,876 ,1796 ,1735 ,706 ,515 ,1530 ,138 ,988 ,760 ,1217 ,552 ,123 ,975 ,45 , + 798 ,1644 ,484 ,1166 ,576 ,1772 ,1438 ,250 ,822 ,1761 ,2006 ,351 ,1013 ,584 ,414 ,4 , + 1332 ,359 ,1311 ,48 ,192 ,1853 ,567 ,544 ,1071 ,1110 ,24 ,967 ,1366 ,1022 ,420 ,1086 , + 566 ,1598 ,1723 ,240 ,327 ,554 ,1111 ,1288 ,1827 ,784 ,1661 ,1102 ,962 ,1183 ,1218 ,1497 , + 1086 ,388 ,1906 ,523 ,400 ,641 ,1106 ,1930 ,761 ,1702 ,763 ,1054 ,510 ,570 ,1781 ,1344 , + 1154 ,1867 ,378 ,663 ,1801 ,1213 ,971 ,795 ,786 ,631 ,1488 ,248 ,271 ,198 ,852 ,1373 , + 1498 ,1935 ,1280 ,1924 ,1886 ,475 ,423 ,1300 ,1103 ,940 ,1410 ,484 ,1464 ,1732 ,1438 ,188 , + 1913 ,292 ,67 ,708 ,670 ,484 ,832 ,166 ,2036 ,1945 ,2012 ,151 ,171 ,117 ,1614 ,729 , + 1498 ,964 ,393 ,1464 ,557 ,693 ,552 ,1804 ,1161 ,2032 ,713 ,243 ,1671 ,342 ,172 ,2036 , + 1652 ,1796 ,1746 ,5 ,1270 ,542 ,112 ,1559 ,810 ,1325 ,2020 ,1177 ,624 ,983 ,1318 ,225 , + 670 ,127 ,774 ,410 ,114 ,531 ,1068 ,927 ,600 ,1738 ,1604 ,619 ,180 ,83 ,1443 ,1662 , + 542 ,980 ,641 ,385 ,919 ,712 ,357 ,672 ,886 ,1900 ,476 ,1268 ,369 ,1687 ,1603 ,807 , + 1346 ,411 ,180 ,1893 ,668 ,161 ,1692 ,394 ,1454 ,661 ,787 ,1310 ,71 ,509 ,822 ,532 , + 1886 ,158 ,1301 ,1260 ,964 ,1602 ,1195 ,1711 ,694 ,1728 ,1193 ,1989 ,1036 ,254 ,1952 ,920 , + 409 ,1366 ,1993 ,617 ,1279 ,1244 ,722 ,950 ,185 ,304 ,395 ,661 ,1976 ,1176 ,379 ,1622 , + 648 ,1873 ,1152 ,463 ,1565 ,1265 ,1626 ,984 ,1892 ,501 ,154 ,1139 ,1765 ,1775 ,1441 ,1704 , + 1750 ,1257 ,1595 ,1181 ,234 ,942 ,1209 ,512 ,1046 ,7 ,360 ,1809 ,1592 ,294 ,417 ,1428 , + 1617 ,781 ,1717 ,1929 ,1317 ,261 ,1161 ,609 ,370 ,895 ,1289 ,1908 ,970 ,577 ,800 ,1235 , + 1434 ,728 ,1681 ,166 ,131 ,1354 ,723 ,1516 ,68 ,729 ,634 ,512 ,1020 ,583 ,707 ,1306 , + 394 ,1539 ,913 ,619 ,1905 ,391 ,2034 ,1664 ,1702 ,1785 ,1858 ,291 ,1945 ,237 ,1068 ,1117 , + 13 ,808 ,1544 ,183 ,45 ,1522 ,577 ,479 ,1347 ,1477 ,1938 ,393 ,882 ,993 ,172 ,980 , + 268 ,415 ,1138 ,472 ,38 ,1213 ,328 ,509 ,1461 ,503 ,1995 ,744 ,2008 ,1707 ,204 ,633 , + 1378 ,756 ,661 ,26 ,2036 ,333 ,337 ,992 ,1942 ,527 ,1294 ,1854 ,1765 ,1291 ,1711 ,800 , + 1503 ,1752 ,1538 ,101 ,1838 ,715 ,1298 ,487 ,1641 ,1464 ,1671 ,1874 ,1499 ,1878 ,372 ,1974 , + 1139 ,1221 ,1343 ,1498 ,639 ,382 ,901 ,1531 ,258 ,1991 ,703 ,1020 ,1994 ,2018 ,1603 ,81 , + 66 ,99 ,1852 ,1012 ,327 ,566 ,1414 ,456 ,2021 ,1063 ,239 ,2045 ,42 ,445 ,1225 ,368 , + 1346 ,924 ,684 ,519 ,521 ,679 ,1700 ,2012 ,978 ,82 ,970 ,691 ,900 ,1539 ,258 ,1428 , + 1612 ,1882 ,1432 ,520 ,429 ,27 ,230 ,806 ,1062 ,696 ,1877 ,433 ,1314 ,341 ,252 ,474 , + 871 ,2043 ,689 ,465 ,361 ,1801 ,938 ,312 ,24 ,64 ,460 ,206 ,1442 ,1128 ,904 ,1107 , + 141 ,382 ,1492 ,35 ,1428 ,751 ,525 ,1694 ,1332 ,1790 ,1665 ,722 ,313 ,685 ,147 ,72 , + 1315 ,889 ,893 ,1431 ,1401 ,1260 ,18 ,1172 ,1111 ,1110 ,423 ,1855 ,575 ,1117 ,141 ,871 , + 1877 ,1754 ,1657 ,938 ,1991 ,1633 ,308 ,1444 ,1360 ,359 ,1493 ,299 ,1223 ,1888 ,1414 ,1933 , + 599 ,593 ,167 ,1352 ,1953 ,1914 ,24 ,1703 ,838 ,139 ,650 ,1239 ,63 ,203 ,1654 ,1959 , + 1486 ,980 ,837 ,1697 ,570 ,626 ,1630 ,1590 ,1247 ,1950 ,1147 ,1841 ,327 ,1203 ,1703 ,983 , + 1797 ,1367 ,1535 ,1348 ,1532 ,716 ,1002 ,1180 ,1548 ,345 ,524 ,1716 ,1256 ,1977 ,802 ,667 , + 168 ,1874 ,1288 ,1047 ,1780 ,531 ,1056 ,1849 ,711 ,2002 ,133 ,906 ,254 ,1447 ,989 ,1718 , + 541 ,1700 ,1178 ,142 ,481 ,1443 ,976 ,1744 ,1437 ,374 ,415 ,1190 ,1123 ,113 ,144 ,851 , + 1978 ,1642 ,1736 ,1402 ,1255 ,1897 ,1332 ,717 ,1423 ,973 ,47 ,1124 ,422 ,1688 ,1047 ,1562 , + 192 ,203 ,265 ,556 ,1736 ,395 ,666 ,13 ,739 ,1338 ,1704 ,515 ,86 ,1492 ,2003 ,1555 , + 1316 ,86 ,1923 ,440 ,143 ,219 ,116 ,977 ,1673 ,1913 ,1847 ,929 ,987 ,957 ,1216 ,1718 , + 1883 ,710 ,186 ,1055 ,544 ,1275 ,1707 ,236 ,1534 ,543 ,247 ,1950 ,1540 ,886 ,464 ,1666 , + 577 ,2029 ,419 ,152 ,1791 ,1154 ,859 ,677 ,1814 ,1380 ,951 ,471 ,136 ,1361 ,341 ,1464 , + 772 ,1879 ,1922 ,1826 ,832 ,1027 ,229 ,1295 ,515 ,1392 ,60 ,310 ,1205 ,55 ,1306 ,1503 , + 580 ,1322 ,934 ,555 ,1807 ,317 ,32 ,358 ,575 ,1494 ,731 ,1526 ,691 ,1637 ,1086 ,1781 , + 68 ,1661 ,1360 ,698 ,1508 ,852 ,1970 ,1429 ,1500 ,979 ,1843 ,444 ,1526 ,678 ,1390 ,1503 , + 1583 ,1693 ,605 ,1117 ,1760 ,1846 ,777 ,731 ,1808 ,1860 ,557 ,378 ,1261 ,971 ,341 ,1590 , + 1024 ,62 ,1611 ,454 ,865 ,1642 ,685 ,1175 ,1882 ,1035 ,1859 ,1589 ,250 ,1101 ,671 ,1608 , + 205 ,1812 ,1548 ,1784 ,898 ,168 ,603 ,54 ,1288 ,1957 ,1645 ,36 ,1017 ,840 ,1683 ,448 , + 1734 ,291 ,1914 ,1804 ,976 ,449 ,319 ,1940 ,2019 ,1632 ,674 ,567 ,788 ,1646 ,347 ,963 , + 1963 ,227 ,1618 ,130 ,1104 ,1888 ,1884 ,973 ,1576 ,1465 ,1066 ,741 ,884 ,837 ,1338 ,1343 , + 409 ,1662 ,1421 ,780 ,578 ,24 ,290 ,1691 ,616 ,240 ,1929 ,497 ,1391 ,1517 ,1455 ,596 , + 601 ,969 ,1713 ,1291 ,543 ,1673 ,291 ,785 ,1386 ,1707 ,520 ,1320 ,1179 ,984 ,742 ,441 , + 1507 ,1671 ,1605 ,312 ,635 ,685 ,1776 ,220 ,1528 ,378 ,744 ,969 ,1544 ,949 ,1614 ,1413 , + 1905 ,227 ,328 ,522 ,160 ,33 ,418 ,736 ,533 ,269 ,1797 ,109 ,218 ,1075 ,884 ,1468 , + 718 ,1578 ,213 ,477 ,1187 ,1871 ,399 ,927 ,639 ,1921 ,348 ,1890 ,1246 ,1229 ,501 ,709 , + 1963 ,658 ,1305 ,1398 ,602 ,163 ,1762 ,539 ,34 ,1540 ,2013 ,134 ,293 ,223 ,1317 ,1442 , + 706 ,670 ,327 ,18 ,372 ,1426 ,274 ,439 ,1371 ,308 ,1331 ,1606 ,1647 ,1656 ,1549 ,1950 , + 288 ,1033 ,1483 ,1959 ,200 ,935 ,725 ,465 ,1213 ,321 ,1786 ,1762 ,2025 ,1151 ,970 ,853 , + 231 ,1433 ,19 ,1458 ,35 ,1938 ,1677 ,738 ,859 ,1157 ,1602 ,1501 ,1172 ,1834 ,643 ,1085 , + 1376 ,1570 ,1317 ,1162 ,612 ,1275 ,637 ,302 ,156 ,1300 ,896 ,1663 ,284 ,753 ,1739 ,638 , + 1817 ,1515 ,1325 ,291 ,1642 ,1981 ,477 ,1551 ,1639 ,376 ,2040 ,1259 ,650 ,355 ,1691 ,1938 , + 530 ,1692 ,858 ,1139 ,1870 ,402 ,1928 ,804 ,1192 ,1179 ,133 ,1139 ,2047 ,357 ,127 ,310 , + 1697 ,138 ,291 ,1176 ,1595 ,1524 ,1495 ,433 ,1757 ,84 ,10 ,972 ,1556 ,962 ,279 ,1325 , + 1505 ,1308 ,1993 ,290 ,930 ,1975 ,242 ,782 ,987 ,601 ,312 ,457 ,471 ,1528 ,40 ,107 , + 802 ,936 ,597 ,1398 ,144 ,30 ,189 ,487 ,1003 ,1256 ,252 ,1286 ,934 ,1020 ,1242 ,1741 , + 506 ,1976 ,1550 ,422 ,508 ,319 ,2041 ,1126 ,2021 ,1284 ,1762 ,898 ,1948 ,1380 ,1776 ,1800 , + 1312 ,9 ,1825 ,921 ,459 ,553 ,422 ,630 ,435 ,1023 ,1024 ,520 ,1704 ,1631 ,198 ,213 , + 1852 ,177 ,1647 ,1084 ,1433 ,989 ,116 ,1704 ,1088 ,1608 ,1041 ,1820 ,228 ,1244 ,383 ,1199 , + 1046 ,494 ,1175 ,1536 ,799 ,5 ,170 ,364 ,1357 ,97 ,1394 ,2038 ,461 ,1581 ,1086 ,805 , + 1252 ,191 ,1826 ,594 ,1636 ,1189 ,674 ,295 ,1544 ,520 ,1449 ,1065 ,30 ,1402 ,509 ,619 , + 1650 ,656 ,1369 ,812 ,1380 ,39 ,1452 ,1457 ,637 ,1600 ,455 ,1931 ,1464 ,231 ,965 ,1547 , + 1627 ,1654 ,245 ,1383 ,129 ,1596 ,1918 ,1069 ,71 ,496 ,1054 ,798 ,490 ,1592 ,472 ,3 , + 1751 ,90 ,1323 ,1057 ,604 ,1644 ,271 ,507 ,926 ,723 ,314 ,1915 ,970 ,627 ,330 ,1319 , + 1389 ,934 ,1304 ,1375 ,407 ,1771 ,882 ,1555 ,1356 ,2033 ,785 ,909 ,1364 ,1939 ,1474 ,2025 , + 504 ,10 ,678 ,1891 ,1292 ,1001 ,1173 ,1117 ,1661 ,134 ,593 ,536 ,2026 ,34 ,1316 ,489 , + 277 ,1768 ,590 ,1319 ,180 ,1940 ,675 ,218 ,1832 ,457 ,203 ,444 ,1958 ,1932 ,1139 ,479 , + 1199 ,364 ,1344 ,479 ,1390 ,413 ,1074 ,41 ,32 ,1335 ,1646 ,775 ,395 ,1106 ,160 ,980 , + 398 ,1802 ,1127 ,217 ,1406 ,338 ,185 ,1683 ,1465 ,260 ,806 ,1443 ,2023 ,1278 ,1677 ,1239 , + 415 ,1425 ,382 ,1632 ,749 ,201 ,1592 ,2038 ,1296 ,1080 ,1060 ,1306 ,1208 ,307 ,192 ,1801 , + 540 ,1414 ,1010 ,984 ,1897 ,1362 ,9 ,2023 ,1814 ,376 ,477 ,903 ,571 ,1821 ,248 ,139 , + 1378 ,1603 ,1427 ,1335 ,132 ,1086 ,1838 ,1986 ,1172 ,748 ,1000 ,481 ,276 ,1827 ,1309 ,1064 , + 1507 ,904 ,1213 ,196 ,1019 ,1189 ,1619 ,574 ,1222 ,1750 ,493 ,1786 ,985 ,1866 ,276 ,1598 , + 454 ,464 ,1235 ,1452 ,196 ,1454 ,1237 ,1152 ,1463 ,1973 ,569 ,1041 ,740 ,1829 ,804 ,295 , + 1739 ,1123 ,1248 ,556 ,1777 ,1453 ,1350 ,2047 ,149 ,1211 ,575 ,410 ,152 ,1836 ,1010 ,913 , + 706 ,670 ,732 ,1385 ,1344 ,489 ,73 ,1590 ,1438 ,1663 ,1020 ,887 ,193 ,117 ,1268 ,1730 , + 186 ,1611 ,721 ,1897 ,1594 ,338 ,448 ,463 ,1083 ,1187 ,618 ,1651 ,1218 ,825 ,814 ,242 , + 1290 ,1157 ,1836 ,656 ,714 ,1525 ,829 ,946 ,1346 ,95 ,688 ,181 ,993 ,778 ,780 ,84 , + 139 ,1688 ,452 ,1383 ,1326 ,18 ,1086 ,1443 ,1761 ,1860 ,851 ,1835 ,1850 ,1094 ,317 ,595 , + 1280 ,1196 ,1490 ,237 ,1231 ,1026 ,1846 ,1817 ,347 ,1011 ,1609 ,1382 ,203 ,1724 ,965 ,1683 , + 1653 ,223 ,1726 ,1520 ,223 ,443 ,1868 ,791 ,1703 ,759 ,1755 ,1529 ,1078 ,1175 ,1150 ,856 , + 1280 ,1885 ,1288 ,55 ,1778 ,1520 ,824 ,1945 ,671 ,590 ,1720 ,93 ,1888 ,6 ,1311 ,1795 , + 1218 ,825 ,465 ,1163 ,247 ,301 ,1192 ,968 ,414 ,482 ,712 ,1799 ,744 ,793 ,1291 ,1170 , + 1228 ,1185 ,1024 ,250 ,1097 ,837 ,115 ,1178 ,1113 ,976 ,420 ,311 ,1391 ,1793 ,900 ,1848 , + 947 ,1739 ,1038 ,1005 ,1364 ,171 ,1612 ,127 ,1938 ,1891 ,682 ,993 ,196 ,290 ,330 ,294 , + 13 ,1974 ,1726 ,989 ,1745 ,1652 ,1607 ,865 ,858 ,336 ,534 ,1665 ,438 ,224 ,608 ,1591 , + 200 ,1644 ,1373 ,5 ,291 ,407 ,687 ,1849 ,1939 ,1332 ,342 ,57 ,1520 ,1820 ,2043 ,1855 , + 1331 ,1159 ,1501 ,1334 ,711 ,1926 ,850 ,1439 ,386 ,374 ,1325 ,1757 ,1912 ,1369 ,1707 ,425 , + 2040 ,473 ,1544 ,725 ,728 ,670 ,357 ,1004 ,249 ,1340 ,421 ,1033 ,1609 ,14 ,1706 ,805 , + 541 ,243 ,783 ,546 ,267 ,1030 ,976 ,2008 ,1437 ,1542 ,662 ,436 ,194 ,1430 ,644 ,851 , + 1978 ,1642 ,1923 ,1853 ,423 ,1626 ,377 ,1586 ,1423 ,591 ,47 ,1388 ,35 ,1688 ,332 ,1562 , + 666 ,243 ,783 ,142 ,481 ,1030 ,666 ,1744 ,1437 ,1370 ,662 ,1190 ,1908 ,1714 ,802 ,1448 , + 1978 ,1642 ,1490 ,1402 ,692 ,1067 ,1353 ,1774 ,1423 ,847 ,1538 ,1388 ,35 ,1440 ,1047 ,1752 , + 448 ,243 ,1697 ,164 ,1736 ,1030 ,666 ,1744 ,1437 ,1370 ,1165 ,1190 ,1912 ,1714 ,644 ,1448 , + 1978 ,646 ,1736 ,1428 ,1255 ,1897 ,377 ,1774 ,1423 ,847 ,1538 ,1124 ,35 ,97 ,332 ,1562 , + 192 ,243 ,1697 ,1348 ,1736 ,1030 ,1978 ,1744 ,739 ,1992 ,1165 ,439 ,1101 ,1430 ,802 ,1448 , + 32 ,646 ,1736 ,1406 ,1255 ,1067 ,1353 ,717 ,604 ,591 ,47 ,1388 ,422 ,97 ,427 ,1956 , + 675 ,153 ,1546 ,818 ,1052 ,948 ,1790 ,462 ,477 ,64 ,807 ,1863 ,1936 ,872 ,384 ,615 , + 74 ,1996 ,1935 ,1445 ,166 ,1798 ,1344 ,569 ,286 ,58 ,1716 ,506 ,357 ,13 ,381 ,974 , + 780 ,1949 ,1620 ,810 ,153 ,697 ,650 ,1851 ,199 ,69 ,1434 ,1458 ,1402 ,1265 ,89 ,1720 , + 58 ,1167 ,1433 ,883 ,1086 ,1253 ,629 ,1613 ,1573 ,1653 ,178 ,19 ,713 ,1079 ,1321 ,363 , + 1315 ,1697 ,1547 ,1696 ,139 ,814 ,878 ,855 ,256 ,1826 ,948 ,1838 ,1928 ,727 ,1600 ,1022 , + 333 ,918 ,1712 ,1508 ,498 ,1577 ,877 ,1159 ,492 ,1208 ,529 ,279 ,1300 ,1796 ,287 ,1329 , + 976 ,419 ,756 ,67 ,1742 ,2029 ,449 ,1617 ,520 ,1256 ,922 ,1234 ,1490 ,476 ,1983 ,697 , + 497 ,1570 ,794 ,1888 ,1307 ,10 ,23 ,1313 ,1799 ,684 ,157 ,1036 ,1419 ,1377 ,129 ,958 , + 297 ,106 ,1944 ,500 ,1734 ,247 ,934 ,472 ,1357 ,940 ,1344 ,1016 ,1161 ,133 ,86 ,627 , + 1940 ,1460 ,1500 ,1827 ,1936 ,468 ,1340 ,538 ,909 ,1958 ,1765 ,1518 ,1405 ,250 ,1200 ,992 , + 846 ,596 ,1819 ,1450 ,2005 ,1569 ,733 ,1190 ,469 ,1992 ,1048 ,605 ,1912 ,837 ,853 ,1938 , + 1050 ,1331 ,77 ,1858 ,1169 ,511 ,1093 ,1774 ,699 ,1438 ,569 ,559 ,207 ,369 ,1783 ,1709 , + 420 ,1828 ,1206 ,1543 ,18 ,1006 ,93 ,101 ,28 ,103 ,7 ,1029 ,978 ,472 ,1353 ,2024 , + 282 ,1410 ,67 ,1973 ,1751 ,676 ,1271 ,1922 ,897 ,1130 ,704 ,941 ,1438 ,788 ,1897 ,871 , + 235 ,199 ,1592 ,1796 ,1802 ,511 ,1317 ,1832 ,754 ,1543 ,1517 ,970 ,1869 ,1570 ,1319 ,541 , + 862 ,1639 ,1973 ,442 ,333 ,1903 ,889 ,221 ,1351 ,25 ,1367 ,1020 ,1936 ,1567 ,902 ,734 , + 1382 ,364 ,1257 ,676 ,1967 ,99 ,829 ,1440 ,600 ,584 ,936 ,592 ,1304 ,1011 ,1864 ,412 , + 471 ,1517 ,958 ,650 ,300 ,1269 ,1246 ,1198 ,1451 ,1497 ,273 ,1828 ,1553 ,615 ,688 ,649 , + 1150 ,1499 ,602 ,538 ,173 ,1370 ,1054 ,322 ,1332 ,327 ,1446 ,1622 ,876 ,1780 ,1471 ,706 , + 672 ,1170 ,1150 ,1301 ,1162 ,2023 ,1810 ,1504 ,865 ,1088 ,1185 ,1900 ,719 ,821 ,418 ,630 , + 1220 ,1478 ,1902 ,940 ,139 ,546 ,642 ,400 ,998 ,272 ,614 ,1283 ,342 ,470 ,432 ,1000 , + 826 ,1772 ,1857 ,2 ,1177 ,294 ,358 ,1815 ,63 ,1130 ,830 ,538 ,2046 ,1477 ,1260 ,1725 , + 1760 ,1532 ,1379 ,1072 ,1010 ,1794 ,1324 ,1488 ,1663 ,1856 ,49 ,1084 ,83 ,1969 ,259 ,1292 , + 431 ,952 ,740 ,1700 ,234 ,827 ,722 ,1112 ,444 ,481 ,1446 ,415 ,1074 ,379 ,1992 ,388 , + 1327 ,1234 ,676 ,780 ,1538 ,1033 ,1941 ,1630 ,303 ,879 ,430 ,25 ,2037 ,1839 ,173 ,206 , + 1161 ,1346 ,793 ,1260 ,628 ,1884 ,1470 ,803 ,69 ,471 ,1431 ,1848 ,519 ,1906 ,1852 ,699 , + 1928 ,1700 ,1559 ,1562 ,340 ,1443 ,976 ,1744 ,1210 ,1542 ,662 ,1757 ,1912 ,1714 ,853 ,241 , + 1978 ,251 ,290 ,1402 ,610 ,1897 ,1497 ,1774 ,204 ,591 ,899 ,1124 ,35 ,1440 ,481 ,970 , + 666 ,243 ,783 ,142 ,267 ,1030 ,976 ,2008 ,739 ,374 ,1165 ,143 ,1912 ,1714 ,644 ,1448 , + 32 ,1642 ,1490 ,1428 ,692 ,1067 ,1497 ,1774 ,204 ,591 ,1538 ,1124 ,422 ,97 ,1047 ,1752 , + 384 ,243 ,1697 ,1348 ,1736 ,1030 ,1978 ,1744 ,739 ,1992 ,1165 ,1190 ,1101 ,1430 ,144 ,1448 , + 1978 ,1642 ,1490 ,1406 ,1255 ,1897 ,1332 ,1774 ,724 ,591 ,47 ,1388 ,422 ,97 ,1140 ,1956 , + 1913 ,1123 ,1568 ,48 ,1380 ,1374 ,1027 ,99 ,947 ,74 ,1780 ,874 ,1170 ,828 ,1792 ,882 , + 1431 ,557 ,1477 ,1049 ,132 ,846 ,895 ,573 ,552 ,759 ,1064 ,1987 ,1541 ,671 ,1844 ,993 , + 854 ,1855 ,54 ,1418 ,830 ,1315 ,988 ,1612 ,1923 ,713 ,1902 ,1564 ,1535 ,1076 ,1781 ,1474 , + 287 ,856 ,1841 ,1246 ,1014 ,812 ,325 ,1397 ,431 ,568 ,1932 ,1376 ,331 ,1743 ,1098 ,50 , + 974 ,1652 ,108 ,719 ,1196 ,306 ,1677 ,728 ,1498 ,832 ,773 ,1973 ,1140 ,1870 ,1939 ,1522 , + 203 ,1590 ,448 ,1014 ,1915 ,470 ,1124 ,1472 ,1369 ,1870 ,1595 ,717 ,994 ,1498 ,619 ,2024 , + 1173 ,1154 ,1076 ,331 ,950 ,1869 ,1603 ,302 ,1381 ,130 ,239 ,1174 ,1852 ,758 ,48 ,648 , + 1496 ,676 ,733 ,1431 ,1240 ,81 ,1873 ,1051 ,278 ,442 ,1282 ,1175 ,1055 ,1474 ,1548 ,62 , + 1797 ,1515 ,1535 ,1450 ,1019 ,1030 ,666 ,1744 ,1106 ,605 ,1165 ,436 ,1101 ,1430 ,144 ,851 , + 1978 ,646 ,618 ,1853 ,610 ,1953 ,377 ,717 ,1423 ,17 ,1538 ,513 ,1930 ,1640 ,1140 ,1752 , + 666 ,243 ,783 ,142 ,267 ,1030 ,976 ,2008 ,739 ,374 ,1165 ,439 ,1101 ,650 ,644 ,1942 , + 1978 ,646 ,618 ,1406 ,692 ,1953 ,377 ,717 ,204 ,591 ,115 ,1124 ,422 ,1688 ,1140 ,1562 , + 448 ,243 ,1697 ,164 ,1736 ,1443 ,666 ,1868 ,739 ,1370 ,1165 ,143 ,1101 ,650 ,644 ,1448 , + 32 ,1829 ,290 ,1406 ,1255 ,1067 ,377 ,1774 ,204 ,591 ,1538 ,951 ,35 ,97 ,427 ,1562 , + 503 ,113 ,963 ,880 ,1070 ,187 ,93 ,344 ,500 ,514 ,1271 ,852 ,1858 ,670 ,1202 ,1010 , + 297 ,1772 ,793 ,1197 ,296 ,1477 ,569 ,999 ,1171 ,1855 ,283 ,1090 ,1479 ,1792 ,1279 ,842 , + 1554 ,643 ,66 ,53 ,1760 ,887 ,1905 ,534 ,1398 ,654 ,1564 ,246 ,1609 ,420 ,1622 ,199 , + 484 ,556 ,924 ,889 ,1432 ,1132 ,129 ,437 ,709 ,469 ,332 ,855 ,1676 ,1738 ,279 ,150 , + 207 ,1845 ,1327 ,584 ,1347 ,249 ,436 ,1111 ,219 ,309 ,267 ,50 ,1647 ,1274 ,1686 ,183 , + 484 ,368 ,103 ,448 ,857 ,955 ,499 ,41 ,1121 ,1181 ,1134 ,878 ,491 ,1619 ,1190 ,1705 , + 1078 ,1290 ,1234 ,1462 ,2 ,1088 ,2012 ,956 ,1749 ,761 ,1664 ,806 ,829 ,646 ,1745 ,1362 , + 1117 ,1545 ,1712 ,298 ,505 ,1921 ,772 ,1431 ,2016 ,1903 ,1300 ,447 ,2000 ,1869 ,358 ,1019 , + 955 ,1516 ,555 ,629 ,796 ,1931 ,1855 ,181 ,1245 ,2020 ,998 ,1157 ,727 ,390 ,263 ,1369 , + 1490 ,746 ,1830 ,1951 ,820 ,1401 ,400 ,1505 ,1715 ,1349 ,627 ,303 ,284 ,894 ,442 ,2043 , + 1789 ,545 ,496 ,1025 ,832 ,1973 ,670 ,158 ,1603 ,672 ,15 ,1183 ,1848 ,204 ,2044 ,1194 , + 604 ,498 ,1454 ,1786 ,1952 ,560 ,400 ,1355 ,641 ,778 ,631 ,1596 ,888 ,392 ,135 ,599 , + 853 ,393 ,159 ,1085 ,696 ,865 ,1492 ,1915 ,944 ,2035 ,1951 ,1775 ,1526 ,1376 ,526 ,1483 , + 579 ,754 ,590 ,1520 ,680 ,1881 ,1501 ,1111 ,1611 ,1395 ,744 ,1826 ,1229 ,1587 ,1770 ,1071 , + 1676 ,136 ,307 ,519 ,1355 ,69 ,452 ,1546 ,871 ,1396 ,433 ,292 ,1895 ,1382 ,1193 ,1004 , + 397 ,900 ,1363 ,1791 ,1384 ,508 ,1597 ,1708 ,719 ,756 ,51 ,1127 ,497 ,1124 ,1465 ,1251 , + 138 ,65 ,424 ,636 ,635 ,874 ,1818 ,1399 ,983 ,736 ,1358 ,975 ,1818 ,250 ,751 ,1094 , + 1151 ,1751 ,1157 ,1091 ,1007 ,500 ,1387 ,1331 ,322 ,1998 ,631 ,1743 ,324 ,1696 ,1591 ,985 , + 1401 ,1355 ,1669 ,677 ,1364 ,1346 ,911 ,1340 ,967 ,1979 ,1139 ,162 ,1221 ,1601 ,316 ,273 , + 1857 ,708 ,777 ,766 ,1279 ,1408 ,1787 ,1137 ,1472 ,866 ,1790 ,1608 ,618 ,461 ,1353 ,1704 , + 1308 ,1031 ,1302 ,570 ,278 ,1981 ,229 ,1520 ,1226 ,708 ,754 ,1874 ,550 ,360 ,3 ,355 , + 1304 ,1609 ,122 ,1161 ,1236 ,28 ,1732 ,1053 ,1786 ,429 ,1454 ,1439 ,1358 ,7 ,645 ,1050 , + 409 ,653 ,244 ,632 ,1699 ,1644 ,1242 ,944 ,386 ,337 ,2028 ,1731 ,1252 ,636 ,788 ,1765 , + 1844 ,1616 ,641 ,1373 ,2036 ,185 ,1832 ,1183 ,73 ,267 ,1886 ,844 ,781 ,1586 ,606 ,1871 , + 2007 ,983 ,1899 ,547 ,1073 ,1592 ,2014 ,1529 ,1031 ,1251 ,1805 ,1040 ,719 ,349 ,1079 ,1943 , + 940 ,1903 ,1028 ,615 ,446 ,1409 ,1778 ,638 ,431 ,341 ,1186 ,792 ,1585 ,1670 ,1557 ,1879 , + 998 ,839 ,1284 ,696 ,213 ,339 ,1564 ,689 ,2003 ,1299 ,685 ,1573 ,888 ,293 ,1715 ,948 , + 1378 ,903 ,1837 ,1025 ,1264 ,877 ,230 ,899 ,4 ,612 ,38 ,1579 ,1977 ,593 ,241 ,260 , + 450 ,929 ,321 ,1387 ,1427 ,360 ,1711 ,667 ,451 ,109 ,1162 ,1704 ,1874 ,1358 ,837 ,1862 , + 372 ,714 ,207 ,440 ,1535 ,591 ,1056 ,1582 ,1667 ,1354 ,1405 ,279 ,1904 ,1712 ,1382 ,540 , + 1608 ,383 ,973 ,320 ,597 ,1584 ,1419 ,1703 ,269 ,1842 ,544 ,1059 ,1090 ,1348 ,297 ,1088 , + 1676 ,305 ,1960 ,1542 ,888 ,1808 ,1830 ,760 ,1906 ,491 ,870 ,21 ,1061 ,267 ,278 ,845 , +}; + +// https://huggingface.co/spaces/sesame/csm-1b/blob/main/prompts/conversational_b.wav +const char * default_speaker_b_text = "[1]like a super Mario level. Like it's very like high detail. And like, once you get into the park, it just like, everything looks like a computer game and they have all these, like, you know, if, if there's like a, you know, like in a Mario game, they will have like a question block. And if you like, you know, punch it, a coin will come out. So like everyone, when they come into the park, they get like this little bracelet and then you can go punching question blocks around."; +std::initializer_list default_speaker_b_codes = { + 1049 ,1864 ,658 ,896 ,819 ,515 ,641 ,1248 ,53 ,278 ,1037 ,141 ,1423 ,565 ,828 ,986 , + 1993 ,1692 ,170 ,1357 ,1780 ,1845 ,967 ,1253 ,1587 ,1854 ,1778 ,1165 ,58 ,575 ,1499 ,491 , + 919 ,934 ,1446 ,392 ,328 ,2020 ,1418 ,1652 ,1117 ,291 ,488 ,1168 ,1989 ,931 ,894 ,140 , + 1820 ,1666 ,1655 ,2038 ,1092 ,1370 ,826 ,1499 ,176 ,554 ,188 ,708 ,1548 ,224 ,437 ,1884 , + 599 ,1960 ,976 ,1150 ,826 ,860 ,287 ,723 ,1818 ,533 ,1790 ,1859 ,1919 ,393 ,652 ,1792 , + 783 ,539 ,252 ,1414 ,1035 ,340 ,1448 ,1506 ,194 ,132 ,109 ,1425 ,741 ,366 ,1157 ,1659 , + 510 ,32 ,759 ,458 ,1226 ,359 ,484 ,540 ,68 ,592 ,975 ,789 ,905 ,1556 ,1323 ,1601 , + 1079 ,1516 ,858 ,170 ,1971 ,1674 ,376 ,1190 ,1346 ,1617 ,368 ,1488 ,308 ,1712 ,423 ,1834 , + 1273 ,1719 ,1334 ,1293 ,1420 ,430 ,1324 ,427 ,1081 ,927 ,1214 ,191 ,1400 ,1292 ,777 ,622 , + 895 ,956 ,2012 ,1040 ,509 ,198 ,821 ,1365 ,1597 ,978 ,548 ,1608 ,1341 ,1148 ,380 ,1511 , + 887 ,879 ,1368 ,305 ,161 ,1121 ,1191 ,1839 ,1986 ,507 ,1540 ,1206 ,1511 ,1948 ,1549 ,1340 , + 583 ,1124 ,515 ,1691 ,1224 ,1357 ,446 ,1070 ,2011 ,1301 ,971 ,789 ,2002 ,1502 ,851 ,193 , + 1295 ,1132 ,2041 ,1522 ,1753 ,869 ,588 ,555 ,1012 ,492 ,48 ,1274 ,1701 ,1733 ,1185 ,635 , + 1881 ,1916 ,1964 ,1907 ,1296 ,467 ,94 ,1245 ,350 ,293 ,476 ,1537 ,689 ,2028 ,1684 ,819 , + 1764 ,1684 ,2002 ,1017 ,1485 ,633 ,1064 ,626 ,1287 ,499 ,131 ,470 ,581 ,1930 ,1585 ,1957 , + 1078 ,830 ,1664 ,1405 ,1471 ,1697 ,942 ,599 ,510 ,75 ,1118 ,992 ,1435 ,756 ,1021 ,1048 , + 1407 ,1158 ,534 ,1168 ,1501 ,1105 ,697 ,602 ,1626 ,1479 ,1187 ,361 ,1651 ,1426 ,557 ,334 , + 1157 ,76 ,877 ,1501 ,321 ,1122 ,597 ,1359 ,1507 ,1344 ,1894 ,984 ,209 ,2043 ,821 ,1230 , + 1610 ,135 ,877 ,662 ,956 ,1905 ,746 ,1324 ,1610 ,1476 ,1091 ,781 ,1734 ,216 ,1595 ,1619 , + 1242 ,1857 ,1331 ,122 ,1415 ,679 ,1437 ,502 ,899 ,546 ,1377 ,353 ,1835 ,1312 ,1333 ,1798 , + 1141 ,795 ,288 ,2020 ,884 ,1734 ,407 ,1357 ,2035 ,1645 ,1609 ,1224 ,1651 ,2025 ,1874 ,1776 , + 494 ,553 ,1859 ,297 ,1451 ,199 ,944 ,391 ,1481 ,621 ,108 ,1837 ,1079 ,845 ,1964 ,1153 , + 719 ,611 ,941 ,1020 ,476 ,1582 ,1413 ,979 ,1224 ,170 ,1747 ,1550 ,530 ,80 ,1982 ,230 , + 1715 ,732 ,1806 ,755 ,1844 ,114 ,476 ,247 ,1772 ,838 ,445 ,1916 ,564 ,263 ,1367 ,938 , + 1914 ,1090 ,1334 ,920 ,1072 ,810 ,176 ,1539 ,1385 ,877 ,1750 ,1422 ,1431 ,1806 ,1950 ,445 , + 430 ,495 ,1691 ,1634 ,1505 ,1201 ,1014 ,72 ,203 ,478 ,593 ,1895 ,657 ,1343 ,1432 ,967 , + 1005 ,448 ,318 ,1583 ,376 ,1303 ,1009 ,1238 ,1130 ,447 ,1604 ,553 ,107 ,142 ,795 ,277 , + 1109 ,1718 ,389 ,1012 ,1475 ,1054 ,1741 ,1366 ,1140 ,1851 ,527 ,1929 ,1186 ,1544 ,792 ,1870 , + 1473 ,1745 ,1309 ,859 ,1138 ,1582 ,177 ,1518 ,260 ,1483 ,1866 ,1873 ,491 ,780 ,1015 ,1967 , + 624 ,2004 ,530 ,2015 ,75 ,313 ,223 ,1627 ,1635 ,693 ,322 ,1843 ,474 ,1114 ,1613 ,1561 , + 1358 ,975 ,68 ,20 ,1056 ,975 ,13 ,1095 ,1754 ,949 ,58 ,1791 ,1560 ,1116 ,668 ,1398 , + 886 ,403 ,441 ,1945 ,2002 ,564 ,1671 ,591 ,1913 ,1076 ,687 ,1789 ,1235 ,684 ,1914 ,170 , + 126 ,960 ,323 ,390 ,1200 ,1069 ,1710 ,169 ,421 ,1008 ,615 ,1322 ,115 ,1973 ,474 ,1099 , + 712 ,1658 ,1344 ,1333 ,1850 ,745 ,1112 ,231 ,1905 ,59 ,1227 ,1834 ,612 ,558 ,492 ,555 , + 1895 ,397 ,156 ,316 ,592 ,1652 ,1334 ,1538 ,1936 ,1521 ,1709 ,705 ,645 ,226 ,851 ,715 , + 63 ,272 ,749 ,282 ,908 ,1950 ,1154 ,696 ,699 ,270 ,1351 ,41 ,1934 ,1431 ,994 ,272 , + 1557 ,1168 ,373 ,386 ,340 ,1707 ,845 ,1665 ,1353 ,1416 ,1867 ,439 ,442 ,1705 ,272 ,458 , + 210 ,1419 ,258 ,786 ,469 ,507 ,78 ,753 ,604 ,531 ,902 ,1388 ,170 ,1030 ,489 ,492 , + 1743 ,477 ,1178 ,1348 ,481 ,347 ,825 ,1665 ,1353 ,91 ,1165 ,439 ,1123 ,1204 ,853 ,1942 , + 200 ,1829 ,1736 ,1668 ,692 ,1897 ,1497 ,1396 ,204 ,847 ,1538 ,483 ,1995 ,14 ,271 ,1752 , + 716 ,1056 ,1029 ,546 ,340 ,347 ,976 ,1665 ,1210 ,605 ,415 ,439 ,1908 ,1204 ,220 ,1684 , + 32 ,251 ,1736 ,1853 ,692 ,1067 ,1353 ,1774 ,724 ,847 ,569 ,1388 ,422 ,97 ,427 ,1833 , + 1850 ,1056 ,1029 ,142 ,481 ,1443 ,666 ,1665 ,1210 ,1370 ,1165 ,439 ,1123 ,113 ,144 ,851 , + 1978 ,1829 ,618 ,1428 ,1908 ,1897 ,1332 ,1586 ,1423 ,847 ,1538 ,1388 ,1995 ,14 ,427 ,1833 , + 752 ,1056 ,1029 ,142 ,1736 ,1443 ,976 ,1665 ,739 ,91 ,415 ,143 ,1908 ,113 ,220 ,1684 , + 32 ,1642 ,290 ,1853 ,1908 ,1897 ,1353 ,1774 ,724 ,847 ,569 ,483 ,1995 ,1440 ,1047 ,1833 , + 919 ,184 ,503 ,2040 ,1509 ,1253 ,1209 ,7 ,484 ,354 ,872 ,792 ,1345 ,351 ,1874 ,139 , + 894 ,1177 ,203 ,2045 ,1663 ,587 ,1735 ,1451 ,1285 ,1283 ,633 ,1487 ,395 ,1255 ,1978 ,1546 , + 854 ,737 ,2002 ,343 ,235 ,985 ,1636 ,1391 ,515 ,1192 ,1290 ,16 ,1114 ,331 ,1475 ,1679 , + 1255 ,816 ,1872 ,512 ,1931 ,1124 ,479 ,863 ,414 ,1401 ,42 ,1938 ,95 ,238 ,455 ,875 , + 979 ,1538 ,319 ,1950 ,1107 ,488 ,750 ,1691 ,1611 ,1273 ,724 ,930 ,1816 ,331 ,1081 ,796 , + 510 ,937 ,943 ,1607 ,323 ,214 ,568 ,458 ,826 ,799 ,1833 ,1843 ,1008 ,1525 ,1183 ,11 , + 946 ,836 ,1539 ,847 ,820 ,1902 ,1728 ,634 ,1150 ,644 ,1376 ,400 ,120 ,1304 ,1891 ,1963 , + 1509 ,1081 ,1361 ,1246 ,1178 ,887 ,401 ,1190 ,1471 ,358 ,206 ,960 ,1569 ,520 ,1761 ,1353 , + 561 ,817 ,274 ,1883 ,1420 ,430 ,82 ,212 ,1379 ,2009 ,1472 ,1441 ,1481 ,1222 ,1501 ,1215 , + 1298 ,11 ,1023 ,605 ,1674 ,1660 ,1519 ,584 ,1587 ,175 ,436 ,388 ,95 ,99 ,1795 ,1677 , + 655 ,543 ,1761 ,790 ,1983 ,961 ,662 ,129 ,1458 ,523 ,1838 ,599 ,1902 ,1010 ,1598 ,128 , + 877 ,527 ,1077 ,1228 ,545 ,1338 ,1980 ,792 ,530 ,987 ,1444 ,595 ,1369 ,1601 ,1425 ,496 , + 1917 ,695 ,1192 ,954 ,1419 ,118 ,567 ,1334 ,142 ,372 ,1200 ,1715 ,1607 ,606 ,1277 ,749 , + 1570 ,399 ,422 ,962 ,2009 ,772 ,46 ,1583 ,685 ,80 ,1578 ,1123 ,342 ,476 ,1491 ,993 , + 1460 ,657 ,557 ,262 ,583 ,1090 ,1418 ,355 ,1275 ,395 ,1074 ,498 ,468 ,1173 ,314 ,1411 , + 947 ,300 ,1935 ,1587 ,1608 ,207 ,725 ,333 ,587 ,1927 ,1256 ,180 ,1534 ,24 ,1904 ,385 , + 1460 ,1249 ,443 ,178 ,433 ,1132 ,382 ,990 ,866 ,1703 ,1092 ,2013 ,2021 ,480 ,727 ,694 , + 510 ,1733 ,1022 ,1706 ,1437 ,1024 ,997 ,283 ,326 ,1694 ,347 ,708 ,1428 ,83 ,1461 ,46 , + 862 ,1739 ,508 ,1976 ,1506 ,168 ,83 ,1854 ,270 ,1110 ,612 ,873 ,339 ,1211 ,1709 ,799 , + 1860 ,1158 ,1307 ,813 ,989 ,1278 ,434 ,766 ,1334 ,1506 ,1726 ,405 ,143 ,806 ,713 ,49 , + 854 ,291 ,1922 ,1982 ,328 ,151 ,724 ,774 ,1345 ,1986 ,13 ,1445 ,310 ,280 ,1123 ,1913 , + 894 ,1131 ,1241 ,386 ,361 ,1332 ,709 ,143 ,1085 ,1645 ,98 ,266 ,1406 ,305 ,1158 ,1978 , + 398 ,986 ,881 ,1523 ,1338 ,1060 ,1138 ,1748 ,1239 ,842 ,198 ,1155 ,1791 ,555 ,1746 ,1706 , + 450 ,493 ,568 ,781 ,2044 ,159 ,1487 ,1135 ,743 ,282 ,1857 ,965 ,1203 ,1452 ,838 ,604 , + 977 ,1224 ,343 ,1393 ,1770 ,430 ,1713 ,567 ,1998 ,627 ,11 ,348 ,1167 ,291 ,1577 ,337 , + 728 ,589 ,971 ,173 ,159 ,245 ,471 ,1982 ,393 ,1219 ,2039 ,421 ,215 ,1350 ,29 ,805 , + 1792 ,307 ,1530 ,303 ,898 ,560 ,463 ,488 ,375 ,217 ,1705 ,569 ,114 ,1895 ,654 ,32 , + 232 ,1259 ,492 ,1980 ,449 ,1940 ,1330 ,1462 ,1627 ,993 ,1782 ,1013 ,791 ,1734 ,1446 ,250 , + 1322 ,1058 ,1334 ,1615 ,1183 ,1850 ,1858 ,862 ,1687 ,760 ,1241 ,1520 ,779 ,1096 ,276 ,175 , + 398 ,1069 ,333 ,1857 ,646 ,1521 ,984 ,115 ,1655 ,122 ,810 ,170 ,52 ,1720 ,127 ,1062 , + 1412 ,26 ,464 ,2002 ,420 ,1097 ,1860 ,1031 ,496 ,153 ,1770 ,1178 ,542 ,65 ,1001 ,1923 , + 1434 ,365 ,443 ,1280 ,893 ,1591 ,1941 ,119 ,1775 ,1278 ,620 ,1829 ,469 ,46 ,952 ,1716 , + 1158 ,700 ,1881 ,1098 ,503 ,922 ,1472 ,1382 ,1088 ,1965 ,1127 ,1093 ,632 ,1128 ,1787 ,1267 , + 1875 ,675 ,2043 ,236 ,1433 ,543 ,1609 ,1061 ,1598 ,887 ,1212 ,425 ,393 ,1775 ,1552 ,1384 , + 1623 ,1941 ,1264 ,1223 ,2045 ,851 ,1495 ,109 ,496 ,582 ,1959 ,1460 ,355 ,1343 ,1442 ,620 , + 1125 ,1984 ,1385 ,352 ,1443 ,1030 ,11 ,454 ,135 ,309 ,1085 ,1259 ,1118 ,1159 ,28 ,646 , + 336 ,1465 ,659 ,1321 ,726 ,851 ,422 ,1226 ,633 ,127 ,1748 ,1704 ,169 ,1980 ,1631 ,1889 , + 990 ,1867 ,793 ,1803 ,1251 ,449 ,1470 ,336 ,513 ,1066 ,438 ,1429 ,211 ,941 ,1154 ,1075 , + 847 ,662 ,465 ,1643 ,1181 ,1478 ,204 ,773 ,113 ,179 ,1914 ,274 ,486 ,843 ,636 ,260 , + 995 ,1464 ,646 ,1953 ,162 ,1003 ,244 ,1952 ,610 ,1813 ,228 ,135 ,889 ,1136 ,1198 ,868 , + 1938 ,1838 ,622 ,458 ,64 ,440 ,1309 ,276 ,1331 ,580 ,1839 ,217 ,70 ,994 ,340 ,2020 , + 1772 ,1761 ,759 ,1712 ,1220 ,1303 ,1876 ,1916 ,589 ,65 ,1054 ,1661 ,266 ,1997 ,1949 ,2017 , + 1938 ,135 ,784 ,1694 ,296 ,358 ,1310 ,1148 ,253 ,1958 ,1779 ,1663 ,849 ,349 ,871 ,317 , + 484 ,214 ,1226 ,203 ,2006 ,955 ,824 ,867 ,1130 ,1243 ,1537 ,1619 ,1905 ,107 ,1867 ,1163 , + 1590 ,1944 ,105 ,1774 ,1696 ,440 ,1952 ,15 ,734 ,700 ,1695 ,982 ,341 ,1712 ,1258 ,1596 , + 1538 ,1276 ,1794 ,1785 ,1192 ,1235 ,1433 ,1241 ,1526 ,496 ,821 ,878 ,551 ,1699 ,1912 ,648 , + 1718 ,1735 ,1446 ,1419 ,700 ,33 ,3 ,404 ,516 ,687 ,1118 ,756 ,741 ,1347 ,1600 ,892 , + 1661 ,877 ,1683 ,673 ,1632 ,489 ,1284 ,1281 ,656 ,263 ,1123 ,1413 ,1219 ,372 ,706 ,1413 , + 860 ,1768 ,307 ,1709 ,1239 ,1234 ,90 ,944 ,432 ,860 ,1220 ,1637 ,1939 ,131 ,1606 ,648 , + 1805 ,755 ,1291 ,1512 ,509 ,1472 ,699 ,616 ,977 ,51 ,1311 ,284 ,1782 ,1018 ,519 ,205 , + 1537 ,2031 ,0 ,966 ,251 ,1340 ,1254 ,1458 ,781 ,927 ,541 ,1896 ,526 ,736 ,513 ,446 , + 1141 ,921 ,819 ,1654 ,546 ,924 ,1307 ,237 ,4 ,1446 ,847 ,1974 ,1947 ,619 ,1245 ,1623 , + 2042 ,477 ,1697 ,142 ,340 ,478 ,2015 ,1665 ,1866 ,1370 ,2040 ,1107 ,1783 ,113 ,1821 ,1189 , + 937 ,1796 ,77 ,1047 ,1218 ,154 ,1453 ,1480 ,765 ,777 ,569 ,1388 ,285 ,1218 ,1249 ,764 , + 666 ,316 ,1559 ,164 ,267 ,1443 ,1978 ,1665 ,1210 ,91 ,1165 ,1190 ,1123 ,1430 ,144 ,1684 , + 32 ,1829 ,1736 ,1402 ,692 ,1953 ,1332 ,1774 ,1423 ,847 ,1538 ,951 ,35 ,1440 ,427 ,1833 , + 835 ,366 ,214 ,2002 ,1383 ,1571 ,280 ,1534 ,1539 ,1058 ,1871 ,1836 ,1242 ,790 ,1923 ,339 , + 312 ,394 ,1304 ,1289 ,1817 ,900 ,1585 ,1400 ,941 ,625 ,393 ,1645 ,346 ,264 ,1353 ,830 , + 685 ,1178 ,1529 ,1623 ,1045 ,926 ,1688 ,184 ,1558 ,1366 ,259 ,631 ,489 ,994 ,263 ,1857 , + 1494 ,488 ,1453 ,844 ,1511 ,636 ,1308 ,271 ,1436 ,751 ,1131 ,1813 ,1281 ,2020 ,66 ,48 , + 873 ,472 ,353 ,253 ,1915 ,1554 ,401 ,1076 ,1954 ,932 ,293 ,1251 ,1266 ,264 ,1875 ,265 , + 1674 ,506 ,1291 ,1275 ,1690 ,1038 ,826 ,1390 ,1235 ,702 ,839 ,1080 ,816 ,868 ,340 ,1975 , + 993 ,1607 ,1046 ,1902 ,1238 ,1756 ,971 ,2010 ,920 ,1137 ,39 ,694 ,1903 ,1056 ,125 ,984 , + 144 ,883 ,1899 ,1025 ,1540 ,1356 ,866 ,1427 ,799 ,1072 ,488 ,1546 ,1194 ,1156 ,1587 ,1008 , + 2000 ,617 ,1212 ,1471 ,1906 ,1237 ,1196 ,516 ,1733 ,849 ,1467 ,1451 ,556 ,1379 ,1229 ,1150 , + 602 ,1431 ,23 ,979 ,1702 ,291 ,1095 ,1549 ,1402 ,1153 ,786 ,1093 ,469 ,967 ,1758 ,720 , + 1039 ,156 ,659 ,1817 ,381 ,1197 ,1046 ,981 ,1770 ,1769 ,1017 ,286 ,304 ,644 ,418 ,44 , + 704 ,1696 ,178 ,1832 ,1786 ,989 ,359 ,638 ,1067 ,501 ,229 ,909 ,1869 ,1200 ,1614 ,176 , + 1320 ,1168 ,1315 ,1620 ,1175 ,1273 ,1254 ,594 ,1599 ,774 ,1246 ,350 ,582 ,1977 ,369 ,447 , + 1184 ,97 ,1012 ,5 ,1764 ,1456 ,526 ,179 ,1816 ,1218 ,2020 ,761 ,171 ,834 ,17 ,324 , + 1331 ,477 ,1178 ,691 ,481 ,1273 ,183 ,1744 ,1106 ,1601 ,415 ,436 ,1912 ,1024 ,144 ,851 , + 210 ,1642 ,618 ,1428 ,692 ,1953 ,1332 ,1774 ,1423 ,1591 ,569 ,483 ,1609 ,1658 ,1047 ,1086 , + 1928 ,1056 ,1029 ,142 ,1736 ,1443 ,976 ,2008 ,739 ,91 ,1028 ,143 ,194 ,1112 ,853 ,851 , + 571 ,473 ,290 ,1406 ,1255 ,1623 ,377 ,1774 ,724 ,847 ,1887 ,1388 ,254 ,436 ,332 ,1780 , + 1771 ,1056 ,480 ,546 ,581 ,11 ,326 ,101 ,1743 ,1521 ,1791 ,229 ,1485 ,650 ,853 ,1860 , + 1531 ,1829 ,618 ,1428 ,1003 ,251 ,377 ,717 ,604 ,847 ,883 ,760 ,1930 ,1924 ,905 ,1833 , + 657 ,1549 ,627 ,1264 ,503 ,30 ,397 ,1671 ,228 ,1160 ,1115 ,1471 ,618 ,361 ,2044 ,631 , + 1652 ,1735 ,679 ,1953 ,251 ,253 ,867 ,83 ,273 ,1876 ,335 ,1335 ,680 ,614 ,1585 ,1875 , + 919 ,296 ,1277 ,753 ,617 ,164 ,1213 ,905 ,590 ,628 ,548 ,45 ,1132 ,1585 ,1350 ,1160 , + 84 ,1724 ,1223 ,851 ,704 ,1060 ,87 ,35 ,924 ,1554 ,1258 ,278 ,992 ,1240 ,158 ,1821 , + 1187 ,1626 ,1103 ,1473 ,1489 ,1600 ,1504 ,1011 ,72 ,296 ,908 ,1271 ,707 ,1774 ,1755 ,932 , + 1694 ,1513 ,1829 ,48 ,989 ,835 ,640 ,1925 ,121 ,1905 ,1320 ,1158 ,577 ,1060 ,441 ,258 , + 392 ,1955 ,497 ,727 ,1216 ,629 ,1139 ,1171 ,400 ,2013 ,1689 ,983 ,1027 ,273 ,1189 ,852 , + 1763 ,550 ,1405 ,1787 ,594 ,1323 ,1121 ,1190 ,1825 ,1190 ,1806 ,1082 ,2004 ,813 ,1768 ,1591 , + 1941 ,496 ,1274 ,1821 ,494 ,127 ,1304 ,1244 ,1113 ,1283 ,1135 ,1932 ,683 ,906 ,1671 ,767 , + 502 ,2041 ,450 ,977 ,1772 ,929 ,1747 ,440 ,669 ,1581 ,1677 ,1877 ,341 ,1730 ,842 ,975 , + 713 ,1145 ,1487 ,1875 ,689 ,549 ,1182 ,17 ,1744 ,499 ,453 ,789 ,573 ,1867 ,1728 ,575 , + 1818 ,1538 ,40 ,1288 ,1011 ,1061 ,685 ,241 ,1589 ,1928 ,701 ,1211 ,835 ,409 ,1109 ,755 , + 783 ,1135 ,1485 ,776 ,1840 ,908 ,202 ,1364 ,653 ,1040 ,451 ,1834 ,867 ,331 ,613 ,172 , + 350 ,1076 ,1647 ,876 ,644 ,1968 ,1863 ,1262 ,1080 ,1101 ,1435 ,166 ,467 ,1620 ,2010 ,972 , + 368 ,777 ,911 ,1293 ,678 ,498 ,1412 ,132 ,1032 ,216 ,672 ,1446 ,79 ,1930 ,1062 ,1110 , + 1179 ,1822 ,108 ,224 ,682 ,433 ,1729 ,1062 ,1469 ,1263 ,824 ,1673 ,971 ,1183 ,1224 ,727 , + 70 ,657 ,1030 ,507 ,511 ,1410 ,1775 ,1100 ,518 ,544 ,957 ,565 ,464 ,973 ,1127 ,1806 , + 1005 ,213 ,437 ,475 ,1248 ,1158 ,600 ,121 ,1290 ,915 ,685 ,265 ,1976 ,1376 ,721 ,551 , + 1164 ,676 ,1046 ,1832 ,33 ,327 ,567 ,1754 ,1073 ,1815 ,326 ,1717 ,1883 ,1034 ,1759 ,1191 , + 209 ,1788 ,1255 ,647 ,1577 ,548 ,569 ,189 ,1997 ,159 ,672 ,865 ,1014 ,1855 ,813 ,1390 , + 1689 ,1383 ,1248 ,1207 ,869 ,985 ,1749 ,1762 ,906 ,32 ,261 ,869 ,1347 ,313 ,1438 ,754 , + 552 ,753 ,497 ,1655 ,1269 ,693 ,1714 ,1250 ,689 ,1924 ,263 ,641 ,1346 ,1632 ,1721 ,1724 , + 577 ,1783 ,498 ,1709 ,985 ,182 ,574 ,1481 ,952 ,497 ,1769 ,252 ,1188 ,1848 ,1447 ,868 , + 607 ,8 ,1975 ,1471 ,487 ,1466 ,369 ,1365 ,1273 ,462 ,776 ,318 ,1208 ,868 ,1416 ,14 , + 577 ,397 ,1819 ,959 ,823 ,1867 ,837 ,1893 ,2044 ,345 ,427 ,1040 ,1428 ,1745 ,513 ,1937 , + 200 ,1881 ,890 ,768 ,334 ,507 ,655 ,1918 ,972 ,667 ,1041 ,1954 ,1312 ,1881 ,1440 ,523 , + 754 ,363 ,328 ,1969 ,1797 ,1033 ,904 ,582 ,904 ,2023 ,118 ,295 ,1495 ,1977 ,1014 ,411 , + 1993 ,1106 ,77 ,493 ,1513 ,843 ,526 ,509 ,844 ,164 ,1841 ,69 ,58 ,1585 ,850 ,94 , + 206 ,1821 ,981 ,487 ,19 ,1744 ,352 ,187 ,123 ,511 ,250 ,462 ,1629 ,282 ,1421 ,1084 , + 121 ,873 ,739 ,68 ,691 ,378 ,221 ,1754 ,1557 ,1041 ,1120 ,228 ,295 ,1725 ,591 ,2008 , + 1235 ,737 ,315 ,153 ,1729 ,1381 ,42 ,1968 ,298 ,864 ,194 ,78 ,1603 ,1464 ,1140 ,588 , + 166 ,31 ,264 ,347 ,395 ,1619 ,1417 ,1064 ,8 ,489 ,1255 ,307 ,567 ,1222 ,752 ,1739 , + 1300 ,1600 ,981 ,571 ,2026 ,1420 ,1439 ,408 ,412 ,1279 ,435 ,942 ,512 ,1304 ,1312 ,406 , + 2019 ,1914 ,1455 ,1460 ,1930 ,1271 ,1926 ,215 ,0 ,608 ,1880 ,1226 ,1556 ,142 ,808 ,500 , + 1479 ,482 ,2 ,1399 ,842 ,233 ,1564 ,698 ,206 ,1930 ,1768 ,1349 ,1740 ,320 ,1651 ,182 , + 1967 ,871 ,249 ,676 ,1026 ,770 ,1500 ,1046 ,1695 ,614 ,1829 ,341 ,1564 ,1399 ,1138 ,1142 , + 926 ,198 ,149 ,435 ,402 ,682 ,1622 ,1015 ,2018 ,1681 ,616 ,642 ,1330 ,1198 ,1745 ,810 , + 33 ,600 ,1430 ,264 ,611 ,797 ,120 ,1311 ,573 ,1344 ,1196 ,1083 ,2046 ,1655 ,229 ,1635 , + 1143 ,1014 ,1037 ,465 ,1336 ,951 ,20 ,441 ,1892 ,587 ,744 ,290 ,1978 ,499 ,1987 ,526 , + 169 ,461 ,1471 ,22 ,632 ,819 ,346 ,2040 ,422 ,29 ,933 ,1191 ,972 ,223 ,1568 ,875 , + 1143 ,1223 ,1883 ,421 ,1235 ,1470 ,798 ,118 ,48 ,557 ,256 ,737 ,1205 ,377 ,690 ,494 , + 1928 ,1848 ,1854 ,1447 ,759 ,718 ,1781 ,1243 ,1759 ,1415 ,1899 ,73 ,836 ,648 ,729 ,264 , + 1066 ,1706 ,1672 ,936 ,403 ,335 ,135 ,1077 ,440 ,1681 ,596 ,1565 ,2002 ,191 ,878 ,1212 , + 69 ,886 ,137 ,1495 ,1684 ,761 ,478 ,1271 ,1 ,326 ,119 ,1746 ,1095 ,484 ,1178 ,1786 , + 272 ,184 ,941 ,493 ,820 ,351 ,654 ,435 ,1445 ,1175 ,195 ,330 ,1272 ,1401 ,1330 ,672 , + 562 ,1482 ,7 ,1968 ,1903 ,1927 ,1606 ,759 ,1425 ,537 ,1735 ,116 ,1674 ,1446 ,1769 ,1492 , + 366 ,820 ,593 ,715 ,393 ,1499 ,927 ,963 ,1424 ,1416 ,1768 ,788 ,1900 ,1721 ,1760 ,1036 , + 775 ,317 ,1062 ,35 ,124 ,1938 ,971 ,859 ,656 ,2028 ,970 ,644 ,2039 ,1529 ,1290 ,1584 , + 1056 ,1500 ,1234 ,467 ,1116 ,671 ,1481 ,372 ,1672 ,2046 ,2039 ,1188 ,169 ,1947 ,9 ,1840 , + 503 ,642 ,919 ,687 ,1952 ,530 ,445 ,1364 ,253 ,1127 ,790 ,1756 ,1422 ,1557 ,716 ,686 , + 114 ,429 ,1902 ,712 ,1921 ,1534 ,273 ,925 ,2002 ,709 ,191 ,255 ,410 ,613 ,1554 ,1798 , + 1665 ,645 ,887 ,218 ,1892 ,16 ,1665 ,447 ,1725 ,475 ,803 ,1274 ,581 ,32 ,398 ,1068 , + 929 ,1461 ,408 ,458 ,406 ,172 ,1420 ,305 ,1 ,884 ,1399 ,1544 ,221 ,557 ,540 ,1055 , + 1351 ,1706 ,1174 ,313 ,367 ,341 ,1955 ,164 ,87 ,1429 ,1434 ,1156 ,1117 ,393 ,629 ,679 , + 1894 ,2026 ,1269 ,1420 ,137 ,24 ,494 ,1970 ,404 ,525 ,862 ,598 ,1615 ,1172 ,318 ,159 , + 955 ,127 ,149 ,1383 ,1806 ,371 ,1490 ,1986 ,644 ,158 ,882 ,1308 ,1547 ,1183 ,957 ,595 , + 1020 ,869 ,485 ,2021 ,1517 ,265 ,168 ,1967 ,892 ,1741 ,1881 ,998 ,1362 ,1271 ,149 ,986 , + 1966 ,1204 ,1304 ,1188 ,473 ,1703 ,1761 ,798 ,1332 ,574 ,631 ,1593 ,19 ,1245 ,1097 ,1977 , + 2042 ,1625 ,1029 ,97 ,1542 ,1335 ,976 ,972 ,1972 ,1187 ,1176 ,760 ,108 ,1722 ,1653 ,1218 , + 32 ,1410 ,581 ,82 ,1908 ,866 ,526 ,601 ,73 ,503 ,1984 ,895 ,987 ,907 ,1706 ,1517 , + 1327 ,1370 ,953 ,870 ,670 ,1858 ,1608 ,890 ,1794 ,1955 ,1082 ,89 ,1362 ,1600 ,708 ,369 , + 1685 ,201 ,438 ,1263 ,1764 ,524 ,1504 ,1022 ,805 ,1919 ,747 ,76 ,1031 ,1820 ,482 ,1877 , + 857 ,1744 ,666 ,1450 ,475 ,701 ,666 ,1190 ,1938 ,1187 ,112 ,1912 ,1123 ,1430 ,1821 ,750 , + 748 ,153 ,1219 ,973 ,407 ,348 ,2024 ,963 ,482 ,2002 ,1538 ,1954 ,389 ,14 ,1496 ,1833 , + 716 ,316 ,1559 ,290 ,1736 ,555 ,819 ,1665 ,739 ,91 ,1165 ,36 ,812 ,1430 ,220 ,1448 , + 200 ,646 ,1490 ,1402 ,1908 ,1897 ,1353 ,643 ,1423 ,1291 ,1263 ,951 ,35 ,1440 ,1140 ,1256 , + 84 ,1056 ,783 ,290 ,267 ,1443 ,976 ,1744 ,1210 ,1992 ,1165 ,439 ,1101 ,1430 ,644 ,1684 , + 32 ,646 ,1736 ,1406 ,1908 ,1626 ,377 ,717 ,204 ,973 ,1538 ,951 ,35 ,97 ,1047 ,1752 , + 84 ,243 ,1697 ,290 ,1736 ,1572 ,976 ,1648 ,1210 ,91 ,415 ,143 ,812 ,650 ,144 ,851 , + 1978 ,646 ,290 ,1406 ,1255 ,1067 ,1497 ,1586 ,724 ,973 ,569 ,951 ,35 ,1440 ,332 ,1562 , + 752 ,243 ,1697 ,1348 ,1335 ,1572 ,976 ,2008 ,739 ,91 ,415 ,436 ,1912 ,1430 ,144 ,241 , + 32 ,251 ,1736 ,1402 ,1908 ,1953 ,1332 ,1774 ,724 ,973 ,569 ,951 ,35 ,1688 ,427 ,1752 , + 322 ,522 ,844 ,1274 ,464 ,1280 ,806 ,460 ,1961 ,1242 ,898 ,916 ,1536 ,134 ,910 ,753 , + 1422 ,1795 ,1809 ,1175 ,1095 ,104 ,1746 ,1602 ,402 ,1328 ,938 ,280 ,1414 ,1062 ,33 ,97 , + 801 ,368 ,33 ,634 ,1692 ,1922 ,289 ,1915 ,1547 ,1629 ,440 ,176 ,1011 ,705 ,189 ,45 , + 151 ,191 ,630 ,34 ,1732 ,577 ,115 ,2021 ,606 ,1286 ,522 ,5 ,1005 ,1013 ,204 ,615 , + 856 ,383 ,1749 ,805 ,1778 ,1094 ,603 ,53 ,1218 ,1411 ,159 ,162 ,667 ,1345 ,586 ,837 , + 1933 ,889 ,122 ,1210 ,1296 ,1486 ,1648 ,496 ,1842 ,856 ,1531 ,1706 ,74 ,1588 ,1036 ,191 , + 256 ,1939 ,1544 ,1746 ,194 ,2008 ,1505 ,1427 ,1281 ,414 ,1276 ,1487 ,1303 ,1633 ,471 ,428 , + 254 ,1187 ,1530 ,516 ,1063 ,771 ,1435 ,1156 ,303 ,1935 ,1864 ,237 ,1336 ,1600 ,238 ,1026 , + 1599 ,1526 ,726 ,121 ,1020 ,507 ,87 ,1235 ,1494 ,1182 ,901 ,479 ,2024 ,33 ,678 ,1457 , + 807 ,714 ,325 ,1238 ,704 ,876 ,1656 ,1440 ,904 ,1791 ,46 ,329 ,1390 ,1036 ,1995 ,342 , + 390 ,739 ,1428 ,1364 ,463 ,776 ,83 ,1766 ,357 ,834 ,2019 ,1970 ,944 ,318 ,895 ,457 , + 1538 ,470 ,1656 ,1356 ,345 ,953 ,494 ,1380 ,299 ,1682 ,1733 ,533 ,1239 ,1265 ,363 ,1974 , + 599 ,1552 ,713 ,343 ,1835 ,1425 ,35 ,574 ,1203 ,904 ,703 ,1776 ,1683 ,1907 ,1567 ,418 , + 1103 ,1541 ,456 ,1938 ,1675 ,1651 ,1119 ,271 ,61 ,1905 ,519 ,278 ,1462 ,975 ,350 ,1074 , + 1508 ,99 ,784 ,1958 ,2007 ,1704 ,1375 ,1831 ,1215 ,1608 ,1960 ,1496 ,1274 ,1594 ,1333 ,1291 , + 617 ,257 ,215 ,2 ,466 ,1166 ,544 ,1430 ,1859 ,1903 ,1361 ,1428 ,1872 ,1379 ,1949 ,503 , + 552 ,1657 ,1384 ,1592 ,1643 ,678 ,1017 ,1480 ,1148 ,294 ,1149 ,1971 ,475 ,1317 ,110 ,501 , + 1905 ,1093 ,1599 ,1848 ,1489 ,1260 ,1475 ,65 ,1849 ,1687 ,1068 ,704 ,1385 ,351 ,1616 ,269 , + 1600 ,1081 ,1602 ,1222 ,215 ,477 ,35 ,135 ,612 ,1462 ,1754 ,1493 ,853 ,386 ,347 ,61 , + 1112 ,1289 ,926 ,1011 ,1468 ,1408 ,1105 ,1390 ,779 ,1289 ,455 ,424 ,11 ,1298 ,1818 ,1097 , + 2001 ,1802 ,939 ,1850 ,1629 ,911 ,76 ,801 ,691 ,655 ,255 ,264 ,1020 ,86 ,1329 ,1223 , + 1733 ,1292 ,1193 ,743 ,820 ,373 ,830 ,1956 ,1894 ,113 ,840 ,728 ,705 ,1981 ,974 ,457 , + 1029 ,472 ,202 ,1009 ,852 ,1391 ,1377 ,1058 ,655 ,1089 ,1333 ,459 ,1098 ,1996 ,37 ,594 , + 849 ,466 ,26 ,1746 ,187 ,1112 ,325 ,1888 ,1560 ,742 ,139 ,1373 ,893 ,547 ,1440 ,333 , + 877 ,1624 ,1606 ,1965 ,1572 ,477 ,456 ,226 ,1042 ,956 ,711 ,1341 ,497 ,167 ,215 ,1995 , + 1350 ,1975 ,1810 ,558 ,1897 ,3 ,602 ,2035 ,1294 ,755 ,1196 ,869 ,472 ,614 ,403 ,1756 , + 1015 ,1244 ,1583 ,1336 ,1708 ,1399 ,914 ,782 ,1152 ,18 ,1895 ,1869 ,1389 ,1605 ,1618 ,1973 , + 156 ,1068 ,2016 ,828 ,1285 ,1970 ,1503 ,561 ,1506 ,501 ,1684 ,581 ,759 ,394 ,2002 ,989 , + 984 ,1442 ,1529 ,1944 ,589 ,601 ,2015 ,1840 ,1051 ,568 ,1965 ,1633 ,2006 ,338 ,530 ,1694 , + 144 ,127 ,1893 ,1003 ,276 ,990 ,2033 ,2045 ,1635 ,1099 ,694 ,246 ,1434 ,606 ,736 ,2047 , + 1428 ,1599 ,615 ,1294 ,281 ,1894 ,1639 ,409 ,443 ,218 ,2046 ,679 ,1673 ,1274 ,139 ,1986 , + 1968 ,1649 ,1542 ,354 ,66 ,584 ,645 ,1558 ,1116 ,775 ,680 ,1557 ,1254 ,256 ,1037 ,961 , + 854 ,1841 ,643 ,1874 ,1897 ,1363 ,687 ,1747 ,1460 ,329 ,595 ,1371 ,880 ,55 ,1889 ,1184 , + 1120 ,1808 ,1700 ,1396 ,1750 ,1208 ,1416 ,204 ,1900 ,426 ,1785 ,770 ,1052 ,173 ,1256 ,2030 , + 527 ,963 ,1273 ,499 ,1983 ,844 ,667 ,1127 ,1079 ,168 ,726 ,1487 ,1772 ,91 ,1571 ,1453 , + 1691 ,1926 ,1561 ,895 ,1869 ,809 ,1782 ,770 ,1265 ,820 ,889 ,755 ,833 ,901 ,1494 ,985 , + 1108 ,286 ,1212 ,1034 ,1837 ,1335 ,410 ,1602 ,1770 ,122 ,1422 ,240 ,1875 ,1600 ,1121 ,583 , + 959 ,510 ,715 ,150 ,1951 ,918 ,1357 ,1574 ,273 ,130 ,1886 ,732 ,1521 ,883 ,1275 ,643 , + 1327 ,725 ,710 ,744 ,1994 ,1773 ,540 ,172 ,156 ,1534 ,1406 ,183 ,1352 ,665 ,131 ,641 , + 769 ,392 ,388 ,1626 ,425 ,652 ,1105 ,582 ,1339 ,32 ,979 ,2004 ,1744 ,1054 ,1761 ,710 , + 1640 ,669 ,1121 ,1287 ,1837 ,317 ,2041 ,1744 ,1186 ,605 ,1223 ,1107 ,812 ,1722 ,1844 ,1785 , + 1385 ,1670 ,75 ,973 ,555 ,1067 ,526 ,659 ,1415 ,49 ,794 ,483 ,285 ,2015 ,989 ,1199 , + 666 ,1056 ,1029 ,546 ,267 ,1030 ,825 ,2008 ,1437 ,91 ,662 ,436 ,1912 ,1430 ,644 ,241 , + 1978 ,251 ,1736 ,1406 ,610 ,1626 ,1353 ,1586 ,1423 ,847 ,1538 ,951 ,35 ,1440 ,332 ,1562 , + 1140 ,255 ,376 ,596 ,278 ,529 ,945 ,1142 ,637 ,950 ,1461 ,741 ,495 ,1965 ,128 ,1190 , + 531 ,529 ,712 ,152 ,877 ,1056 ,500 ,501 ,473 ,1963 ,1910 ,601 ,1616 ,1229 ,130 ,438 , + 510 ,54 ,402 ,1289 ,1696 ,823 ,894 ,275 ,1195 ,1943 ,1220 ,904 ,1933 ,1290 ,876 ,1556 , + 1985 ,2016 ,1176 ,877 ,23 ,1774 ,1257 ,536 ,118 ,1615 ,297 ,1890 ,83 ,888 ,513 ,1894 , + 1697 ,1582 ,375 ,1654 ,1363 ,821 ,1647 ,309 ,1362 ,1927 ,1741 ,1777 ,902 ,2036 ,846 ,1867 , + 1862 ,498 ,1431 ,1028 ,1612 ,1629 ,1405 ,1125 ,874 ,265 ,139 ,314 ,1693 ,1313 ,359 ,682 , + 985 ,1425 ,1278 ,1822 ,388 ,430 ,943 ,1064 ,1670 ,199 ,1247 ,461 ,865 ,366 ,439 ,1024 , + 1117 ,510 ,950 ,1333 ,1241 ,1902 ,1290 ,1412 ,1145 ,1166 ,2012 ,1839 ,1149 ,1100 ,51 ,1877 , + 1540 ,665 ,1220 ,853 ,8 ,1143 ,1868 ,1553 ,1724 ,1781 ,1982 ,1104 ,495 ,832 ,1510 ,1145 , + 1229 ,1549 ,730 ,1438 ,1453 ,1734 ,841 ,1378 ,203 ,1720 ,286 ,1204 ,1680 ,1266 ,584 ,994 , + 1495 ,580 ,228 ,1984 ,1947 ,813 ,1795 ,1936 ,1201 ,1106 ,1762 ,524 ,883 ,1179 ,1223 ,1657 , + 1018 ,404 ,1619 ,937 ,394 ,1267 ,1759 ,679 ,1997 ,279 ,456 ,609 ,1895 ,157 ,267 ,1354 , + 1373 ,903 ,1399 ,1966 ,298 ,256 ,1773 ,1467 ,1485 ,1352 ,1882 ,1140 ,523 ,1827 ,504 ,878 , + 1467 ,567 ,2028 ,96 ,830 ,1679 ,144 ,281 ,1720 ,195 ,1886 ,391 ,136 ,1216 ,1542 ,1786 , + 1929 ,124 ,1766 ,1623 ,765 ,671 ,372 ,1304 ,549 ,626 ,1534 ,1384 ,42 ,1656 ,1714 ,1171 , + 684 ,1422 ,209 ,1767 ,1862 ,684 ,1989 ,961 ,993 ,1869 ,728 ,873 ,1413 ,1502 ,1545 ,581 , + 1575 ,1905 ,1329 ,1381 ,290 ,1305 ,1236 ,735 ,312 ,1128 ,1058 ,1435 ,789 ,137 ,444 ,1444 , + 1729 ,64 ,1185 ,1745 ,355 ,1057 ,44 ,2025 ,814 ,19 ,1118 ,141 ,892 ,874 ,1391 ,422 , + 535 ,1632 ,497 ,1070 ,1403 ,1548 ,42 ,250 ,635 ,956 ,43 ,113 ,334 ,332 ,1949 ,59 , + 1678 ,1916 ,1535 ,108 ,373 ,1659 ,608 ,1908 ,934 ,1400 ,561 ,21 ,958 ,1138 ,1720 ,763 , + 1120 ,1383 ,728 ,1110 ,1044 ,1330 ,1646 ,1172 ,1765 ,1223 ,626 ,1094 ,1195 ,1731 ,1512 ,1093 , + 1196 ,173 ,447 ,271 ,1433 ,92 ,1976 ,1907 ,1157 ,151 ,479 ,1936 ,1960 ,1643 ,1698 ,963 , + 1307 ,397 ,886 ,1287 ,1168 ,906 ,837 ,1084 ,1858 ,671 ,1867 ,889 ,681 ,1094 ,1821 ,1196 , + 1999 ,562 ,1012 ,45 ,281 ,111 ,1093 ,582 ,306 ,253 ,1272 ,107 ,1745 ,1511 ,1194 ,1211 , + 1557 ,1520 ,919 ,309 ,1168 ,680 ,2041 ,156 ,739 ,671 ,1704 ,253 ,1101 ,1024 ,1630 ,851 , + 445 ,1642 ,766 ,1402 ,1908 ,1491 ,1453 ,717 ,2044 ,1772 ,47 ,597 ,285 ,97 ,489 ,1562 , + 222 ,1168 ,1456 ,1034 ,1873 ,347 ,712 ,1648 ,1866 ,345 ,1458 ,240 ,1908 ,650 ,144 ,241 , + 306 ,1821 ,655 ,1406 ,811 ,39 ,1154 ,1774 ,204 ,973 ,1016 ,929 ,987 ,97 ,1477 ,492 , + 1737 ,901 ,1531 ,691 ,1019 ,964 ,1359 ,549 ,1407 ,605 ,516 ,240 ,1824 ,1527 ,496 ,44 , + 1663 ,1829 ,1928 ,1148 ,11 ,348 ,101 ,804 ,1088 ,459 ,2034 ,1680 ,239 ,1864 ,171 ,1082 , + 1737 ,222 ,573 ,421 ,490 ,91 ,183 ,382 ,1423 ,1868 ,680 ,497 ,1867 ,1872 ,792 ,1698 , + 968 ,1043 ,1096 ,223 ,461 ,1503 ,297 ,1567 ,1739 ,517 ,542 ,1752 ,57 ,1240 ,261 ,1163 , + 1863 ,1245 ,552 ,217 ,1763 ,2044 ,523 ,1245 ,1975 ,269 ,819 ,25 ,1921 ,1102 ,1224 ,1424 , + 908 ,1436 ,943 ,526 ,1327 ,1781 ,596 ,1427 ,725 ,1616 ,1335 ,1982 ,1109 ,1468 ,1060 ,1477 , + 45 ,750 ,920 ,1964 ,81 ,757 ,866 ,754 ,1476 ,1779 ,1995 ,1964 ,1362 ,136 ,167 ,721 , + 669 ,1730 ,568 ,1678 ,551 ,2018 ,653 ,1450 ,570 ,1471 ,19 ,354 ,1043 ,1234 ,929 ,19 , + 411 ,1851 ,1626 ,921 ,932 ,1540 ,1607 ,101 ,1629 ,1439 ,9 ,497 ,1717 ,1076 ,381 ,1848 , + 960 ,2029 ,902 ,493 ,533 ,2030 ,624 ,516 ,880 ,215 ,29 ,845 ,500 ,1377 ,1335 ,1126 , + 639 ,1295 ,1586 ,596 ,382 ,744 ,840 ,1204 ,747 ,1239 ,1846 ,1118 ,1143 ,996 ,510 ,991 , + 464 ,1072 ,1514 ,893 ,656 ,1512 ,1473 ,1691 ,312 ,830 ,703 ,482 ,815 ,801 ,1074 ,741 , + 337 ,856 ,509 ,284 ,1609 ,853 ,377 ,1986 ,1350 ,530 ,1138 ,1663 ,788 ,1792 ,706 ,812 , + 364 ,387 ,50 ,1009 ,969 ,50 ,1292 ,770 ,583 ,432 ,846 ,1383 ,1699 ,752 ,1624 ,2010 , + 1539 ,531 ,1649 ,1294 ,665 ,241 ,2007 ,1268 ,1843 ,363 ,1581 ,1764 ,1891 ,1774 ,1913 ,1536 , + 788 ,817 ,469 ,138 ,343 ,1752 ,201 ,1855 ,791 ,975 ,1380 ,550 ,1727 ,336 ,999 ,298 , + 1144 ,468 ,435 ,1385 ,311 ,664 ,856 ,440 ,682 ,890 ,1463 ,935 ,1919 ,999 ,1382 ,1408 , + 1053 ,1705 ,1789 ,1360 ,1563 ,601 ,523 ,973 ,987 ,1244 ,775 ,1519 ,1697 ,1764 ,1896 ,624 , + 376 ,1829 ,348 ,734 ,1062 ,486 ,1358 ,1370 ,1729 ,724 ,1877 ,1087 ,946 ,325 ,1359 ,1847 , + 868 ,1682 ,1259 ,1502 ,1470 ,1718 ,1927 ,1123 ,1131 ,279 ,388 ,1576 ,1575 ,1910 ,119 ,784 , + 1335 ,661 ,277 ,1504 ,2018 ,1846 ,811 ,470 ,231 ,300 ,1868 ,924 ,1189 ,1613 ,102 ,138 , + 1925 ,611 ,1611 ,1588 ,868 ,239 ,422 ,1123 ,871 ,270 ,560 ,1016 ,215 ,541 ,305 ,126 , + 1849 ,1985 ,130 ,1528 ,1625 ,1968 ,987 ,1999 ,1714 ,1703 ,1822 ,1922 ,1522 ,519 ,374 ,1648 , + 1732 ,1102 ,138 ,1097 ,658 ,70 ,1518 ,256 ,914 ,1886 ,147 ,1926 ,99 ,2034 ,2023 ,1857 , + 1543 ,707 ,1889 ,727 ,66 ,567 ,1446 ,1318 ,1885 ,389 ,442 ,702 ,1882 ,298 ,822 ,803 , + 1132 ,1076 ,6 ,2002 ,643 ,1394 ,757 ,1252 ,1454 ,491 ,203 ,109 ,923 ,1818 ,1295 ,1507 , + 871 ,306 ,640 ,1975 ,1918 ,1842 ,959 ,1035 ,1495 ,411 ,1640 ,1083 ,904 ,1583 ,675 ,1096 , + 59 ,1541 ,1316 ,1262 ,292 ,1142 ,2037 ,482 ,1625 ,931 ,296 ,175 ,1721 ,461 ,1156 ,810 , + 1240 ,1462 ,402 ,482 ,1623 ,1468 ,960 ,356 ,148 ,233 ,1333 ,1528 ,1986 ,602 ,50 ,1663 , + 242 ,31 ,1162 ,1706 ,1591 ,706 ,1619 ,1921 ,297 ,2037 ,143 ,852 ,1695 ,516 ,716 ,802 , + 1513 ,46 ,1360 ,1873 ,1330 ,815 ,1066 ,1988 ,74 ,443 ,129 ,1764 ,1923 ,621 ,334 ,504 , + 1252 ,66 ,1495 ,146 ,250 ,854 ,532 ,869 ,1082 ,1019 ,927 ,1544 ,1284 ,104 ,374 ,746 , + 475 ,431 ,1237 ,1318 ,1625 ,207 ,856 ,773 ,1374 ,807 ,549 ,163 ,355 ,605 ,1855 ,33 , + 1716 ,33 ,1748 ,1032 ,1992 ,1614 ,1905 ,931 ,594 ,1745 ,141 ,2019 ,1472 ,583 ,708 ,1892 , + 110 ,1878 ,1062 ,1796 ,1824 ,1998 ,64 ,887 ,234 ,1770 ,321 ,1004 ,503 ,1066 ,1926 ,1220 , + 31 ,206 ,1350 ,2039 ,1748 ,1258 ,808 ,1185 ,1346 ,1294 ,280 ,1828 ,1216 ,680 ,834 ,598 , + 1428 ,1780 ,1159 ,1295 ,503 ,1691 ,1867 ,680 ,114 ,543 ,803 ,1699 ,678 ,1786 ,1267 ,1740 , + 1933 ,1982 ,1775 ,1150 ,266 ,1826 ,1549 ,1886 ,690 ,617 ,1503 ,1175 ,10 ,12 ,273 ,950 , + 1136 ,905 ,825 ,1640 ,271 ,959 ,1080 ,1973 ,1956 ,1038 ,1735 ,1479 ,1170 ,475 ,1716 ,960 , + 1009 ,1168 ,1541 ,223 ,178 ,1928 ,321 ,1094 ,1796 ,1824 ,1136 ,1932 ,680 ,1868 ,878 ,1333 , + 166 ,1381 ,798 ,205 ,1081 ,1253 ,2017 ,837 ,1092 ,397 ,1205 ,1813 ,1642 ,935 ,698 ,1086 , + 1753 ,1341 ,1064 ,1735 ,147 ,419 ,999 ,525 ,1091 ,423 ,897 ,1442 ,1867 ,662 ,962 ,535 , + 527 ,1433 ,302 ,299 ,1161 ,202 ,13 ,1926 ,1300 ,712 ,2018 ,762 ,583 ,1016 ,1221 ,719 , + 1001 ,824 ,826 ,789 ,1396 ,1063 ,1646 ,427 ,987 ,1738 ,527 ,912 ,1857 ,920 ,647 ,954 , + 702 ,54 ,1276 ,746 ,1114 ,882 ,41 ,1739 ,959 ,861 ,1588 ,1555 ,1544 ,1999 ,1399 ,45 , + 683 ,634 ,1072 ,835 ,1634 ,396 ,1100 ,1363 ,1709 ,1600 ,1117 ,280 ,1837 ,521 ,394 ,1221 , + 1327 ,1966 ,530 ,1918 ,1142 ,436 ,1000 ,1771 ,453 ,1825 ,1515 ,124 ,374 ,990 ,329 ,1031 , + 835 ,1342 ,637 ,1739 ,498 ,858 ,1288 ,1189 ,879 ,1085 ,1810 ,1796 ,79 ,1737 ,1166 ,222 , + 2034 ,783 ,1028 ,1941 ,1053 ,689 ,641 ,1422 ,875 ,815 ,881 ,943 ,1979 ,342 ,2016 ,859 , + 1611 ,1178 ,75 ,1415 ,1990 ,1710 ,352 ,284 ,1451 ,1626 ,1880 ,1395 ,1427 ,627 ,798 ,1014 , + 1155 ,1119 ,552 ,1917 ,1837 ,1992 ,1874 ,488 ,887 ,624 ,1246 ,2034 ,1059 ,381 ,921 ,1814 , + 1050 ,177 ,279 ,157 ,143 ,1884 ,1397 ,1915 ,850 ,1537 ,920 ,302 ,1052 ,397 ,1964 ,1417 , + 1648 ,600 ,2040 ,1712 ,840 ,210 ,257 ,507 ,2021 ,954 ,1028 ,1625 ,1352 ,968 ,672 ,923 , + 1917 ,642 ,1017 ,1853 ,553 ,1621 ,1019 ,1216 ,1769 ,485 ,133 ,1845 ,769 ,1687 ,895 ,1454 , + 45 ,600 ,552 ,1136 ,946 ,482 ,825 ,1291 ,6 ,1238 ,1017 ,1625 ,122 ,1213 ,788 ,849 , + 402 ,1798 ,1518 ,1116 ,1487 ,1488 ,526 ,874 ,848 ,1370 ,421 ,1873 ,1351 ,1864 ,271 ,1276 , + 1569 ,1625 ,1873 ,290 ,378 ,1030 ,819 ,1648 ,785 ,91 ,2006 ,88 ,1908 ,1714 ,1653 ,741 , + 1906 ,963 ,359 ,1406 ,123 ,128 ,107 ,753 ,1205 ,1932 ,2034 ,69 ,369 ,1983 ,1196 ,1276 , + 148 ,477 ,1178 ,1334 ,481 ,555 ,1978 ,1648 ,1210 ,278 ,34 ,77 ,194 ,1839 ,469 ,1684 , + 32 ,646 ,1490 ,1853 ,813 ,154 ,377 ,963 ,724 ,1275 ,115 ,1845 ,35 ,77 ,332 ,573 , + 969 ,1625 ,1535 ,1334 ,998 ,1198 ,666 ,151 ,1866 ,1992 ,2006 ,253 ,1736 ,1969 ,1653 ,1684 , + 1531 ,1429 ,1065 ,1116 ,1908 ,1626 ,1497 ,1614 ,711 ,847 ,1538 ,483 ,1252 ,1999 ,1216 ,1454 , + 969 ,1625 ,234 ,1334 ,1736 ,2010 ,160 ,1180 ,469 ,1542 ,2006 ,77 ,1485 ,2034 ,1653 ,1448 , + 32 ,554 ,1758 ,1402 ,407 ,1897 ,1497 ,1774 ,724 ,973 ,1016 ,1431 ,559 ,436 ,489 ,1562 , + 148 ,477 ,1178 ,1334 ,481 ,1030 ,1978 ,1665 ,1686 ,1370 ,1165 ,77 ,1123 ,2034 ,144 ,1684 , + 1978 ,1419 ,290 ,443 ,692 ,1135 ,377 ,1774 ,711 ,1343 ,569 ,1458 ,1034 ,1864 ,1349 ,1956 , + 293 ,1056 ,1697 ,290 ,481 ,1572 ,976 ,1665 ,1437 ,1542 ,118 ,1383 ,1736 ,1714 ,853 ,667 , + 1978 ,1419 ,1490 ,1406 ,692 ,1626 ,1332 ,1774 ,1423 ,847 ,1538 ,483 ,35 ,1688 ,1140 ,1752 , + 1850 ,1056 ,1697 ,1348 ,481 ,347 ,666 ,2008 ,1437 ,374 ,662 ,439 ,1101 ,650 ,144 ,851 , + 32 ,646 ,618 ,1406 ,610 ,1626 ,1332 ,1586 ,724 ,973 ,569 ,483 ,1995 ,14 ,427 ,1562 , + 771 ,1192 ,1877 ,2030 ,1598 ,1469 ,2002 ,714 ,1881 ,1379 ,1847 ,467 ,705 ,1287 ,1946 ,1117 , + 712 ,533 ,698 ,1437 ,1005 ,839 ,1050 ,1457 ,2006 ,1447 ,1053 ,1791 ,989 ,1240 ,1716 ,1009 , + 361 ,86 ,1430 ,330 ,646 ,789 ,710 ,529 ,831 ,215 ,213 ,463 ,1375 ,1370 ,491 ,1733 , + 1485 ,1040 ,197 ,1811 ,1810 ,1346 ,419 ,1819 ,270 ,249 ,1552 ,1885 ,977 ,547 ,778 ,381 , + 1748 ,1985 ,547 ,1400 ,909 ,1326 ,1992 ,1715 ,296 ,1498 ,120 ,356 ,964 ,1097 ,813 ,799 , + 1536 ,35 ,38 ,1889 ,1155 ,1127 ,117 ,1743 ,446 ,913 ,318 ,1979 ,810 ,273 ,1515 ,1123 , + 1507 ,823 ,932 ,549 ,1910 ,1024 ,1097 ,1019 ,1585 ,1756 ,1107 ,1405 ,1892 ,1143 ,1749 ,1335 , + 1394 ,1158 ,590 ,332 ,1826 ,1062 ,441 ,532 ,1094 ,902 ,1177 ,347 ,1274 ,655 ,2002 ,2037 , + 1020 ,1987 ,869 ,1276 ,2000 ,1782 ,216 ,1266 ,2000 ,1034 ,1348 ,1121 ,1859 ,484 ,1420 ,1367 , + 322 ,755 ,1021 ,830 ,1259 ,1093 ,365 ,2012 ,853 ,1013 ,1445 ,1882 ,1349 ,475 ,448 ,345 , + 577 ,1482 ,323 ,417 ,1935 ,716 ,1067 ,977 ,1431 ,1170 ,738 ,889 ,1978 ,1213 ,1776 ,721 , + 673 ,2014 ,633 ,2018 ,789 ,1771 ,101 ,646 ,1105 ,1310 ,280 ,119 ,397 ,1999 ,1611 ,672 , + 863 ,1922 ,742 ,1628 ,1677 ,1520 ,1741 ,800 ,336 ,471 ,111 ,242 ,1132 ,510 ,1537 ,984 , + 720 ,1922 ,2047 ,353 ,815 ,1329 ,608 ,1048 ,931 ,1136 ,169 ,1044 ,1466 ,833 ,928 ,791 , + 939 ,320 ,71 ,1243 ,920 ,702 ,118 ,1355 ,376 ,1176 ,742 ,1663 ,1722 ,665 ,403 ,1306 , + 329 ,214 ,772 ,1844 ,1374 ,620 ,798 ,1185 ,1823 ,1916 ,83 ,279 ,1653 ,1613 ,1621 ,1728 , + 939 ,1635 ,147 ,1947 ,1624 ,683 ,1293 ,1361 ,995 ,1836 ,935 ,201 ,1041 ,1126 ,1490 ,1779 , + 970 ,633 ,1043 ,495 ,839 ,79 ,598 ,965 ,980 ,1876 ,1249 ,20 ,1900 ,1347 ,1864 ,1753 , + 939 ,636 ,1953 ,1377 ,493 ,1292 ,1457 ,940 ,2029 ,942 ,528 ,1066 ,1112 ,1275 ,699 ,654 , + 1756 ,835 ,1012 ,1039 ,665 ,1306 ,44 ,1372 ,347 ,1583 ,857 ,240 ,260 ,521 ,1568 ,260 , + 1196 ,1912 ,953 ,414 ,1628 ,964 ,1759 ,833 ,1483 ,883 ,40 ,1887 ,1339 ,1246 ,1088 ,1532 , + 1583 ,1295 ,1903 ,706 ,2009 ,1025 ,1194 ,1857 ,1571 ,330 ,1636 ,1898 ,0 ,916 ,1870 ,519 , + 801 ,1122 ,249 ,634 ,1842 ,541 ,866 ,1842 ,1956 ,169 ,1009 ,267 ,1724 ,383 ,102 ,1213 , + 1747 ,307 ,1872 ,683 ,1053 ,1460 ,71 ,366 ,1331 ,1482 ,917 ,1346 ,1351 ,240 ,1070 ,1052 , + 976 ,1851 ,156 ,139 ,264 ,726 ,379 ,836 ,868 ,193 ,198 ,1764 ,1268 ,1399 ,1672 ,1915 , + 1195 ,1757 ,407 ,1361 ,1060 ,1163 ,435 ,1104 ,1424 ,1765 ,1040 ,1227 ,394 ,1302 ,466 ,1822 , + 322 ,398 ,503 ,1527 ,1607 ,1012 ,824 ,306 ,30 ,1057 ,973 ,347 ,472 ,1560 ,959 ,963 , + 1229 ,362 ,595 ,1888 ,234 ,1800 ,1197 ,1969 ,1595 ,1734 ,1514 ,1942 ,1102 ,295 ,186 ,393 , + 1475 ,1688 ,716 ,1263 ,2015 ,535 ,1004 ,717 ,540 ,1642 ,951 ,1858 ,555 ,1647 ,1433 ,96 , + 1822 ,1843 ,1444 ,1435 ,352 ,71 ,136 ,399 ,1278 ,1734 ,788 ,810 ,345 ,490 ,536 ,743 , + 801 ,963 ,588 ,1753 ,1044 ,345 ,1431 ,1181 ,1023 ,1102 ,398 ,593 ,1086 ,846 ,1756 ,776 , + 1647 ,898 ,28 ,2047 ,1559 ,118 ,1063 ,13 ,1422 ,133 ,684 ,1143 ,130 ,33 ,261 ,532 , + 801 ,575 ,385 ,297 ,1843 ,437 ,474 ,1854 ,661 ,957 ,1566 ,816 ,834 ,1114 ,677 ,778 , + 262 ,1356 ,185 ,1375 ,2023 ,589 ,1815 ,1156 ,1747 ,41 ,46 ,1294 ,1850 ,693 ,1607 ,1860 , + 1475 ,1186 ,1909 ,549 ,656 ,1139 ,479 ,1026 ,1452 ,1677 ,1410 ,1226 ,955 ,1524 ,1730 ,303 , + 1949 ,1267 ,561 ,1923 ,2007 ,1656 ,125 ,763 ,106 ,1695 ,494 ,1894 ,846 ,13 ,1763 ,676 , + 1212 ,1005 ,514 ,1055 ,186 ,246 ,822 ,397 ,517 ,732 ,5 ,1005 ,1354 ,1730 ,777 ,176 , + 732 ,278 ,730 ,1350 ,437 ,2011 ,680 ,769 ,310 ,1506 ,955 ,242 ,1323 ,224 ,315 ,640 , + 947 ,973 ,1401 ,706 ,1566 ,418 ,95 ,1818 ,106 ,1791 ,488 ,1668 ,1682 ,629 ,1845 ,340 , + 425 ,1605 ,1616 ,189 ,107 ,1375 ,1437 ,1391 ,414 ,915 ,1832 ,364 ,222 ,1051 ,474 ,1500 , + 1174 ,821 ,1368 ,549 ,1894 ,1527 ,908 ,993 ,991 ,1183 ,724 ,773 ,1591 ,170 ,1999 ,1813 , + 404 ,299 ,1731 ,1799 ,55 ,349 ,43 ,729 ,230 ,1318 ,104 ,1050 ,471 ,325 ,1217 ,622 , + 390 ,1567 ,478 ,1394 ,1379 ,1229 ,803 ,1093 ,677 ,245 ,1182 ,1838 ,265 ,876 ,121 ,1205 , + 223 ,1666 ,1706 ,2011 ,391 ,417 ,244 ,1530 ,1355 ,1651 ,767 ,886 ,181 ,1317 ,1417 ,838 , + 599 ,909 ,648 ,1386 ,809 ,1958 ,807 ,1838 ,1215 ,298 ,258 ,1522 ,997 ,1632 ,1257 ,1538 , + 1694 ,257 ,112 ,1963 ,1466 ,656 ,1739 ,1441 ,1197 ,1522 ,898 ,11 ,547 ,508 ,55 ,912 , + 974 ,845 ,1389 ,1821 ,869 ,1371 ,1229 ,1747 ,501 ,1452 ,1879 ,806 ,1674 ,205 ,1372 ,1959 , + 1146 ,1182 ,1256 ,127 ,269 ,111 ,327 ,1792 ,285 ,693 ,1495 ,160 ,128 ,980 ,1376 ,667 , + 734 ,905 ,705 ,1309 ,328 ,928 ,1605 ,851 ,227 ,1677 ,1108 ,1403 ,239 ,281 ,671 ,547 , + 465 ,791 ,2022 ,1919 ,1727 ,826 ,474 ,1698 ,691 ,923 ,599 ,1444 ,975 ,1973 ,216 ,735 , + 1990 ,563 ,1853 ,1714 ,1024 ,1036 ,1299 ,1376 ,1231 ,206 ,1252 ,165 ,1551 ,1613 ,1643 ,1108 , + 132 ,564 ,1593 ,1419 ,289 ,1925 ,1910 ,202 ,1963 ,987 ,1918 ,9 ,1653 ,630 ,1859 ,985 , + 306 ,523 ,1593 ,1358 ,1509 ,48 ,1740 ,875 ,327 ,1933 ,250 ,194 ,500 ,701 ,1242 ,1715 , + 1712 ,809 ,1056 ,398 ,764 ,1116 ,322 ,1644 ,287 ,1048 ,288 ,1313 ,1398 ,1738 ,552 ,2025 , + 866 ,682 ,1125 ,1921 ,61 ,1706 ,366 ,1081 ,172 ,1120 ,615 ,470 ,412 ,982 ,2008 ,1514 , + 430 ,1840 ,1803 ,180 ,802 ,1292 ,1694 ,816 ,1609 ,2011 ,104 ,1336 ,1683 ,1421 ,397 ,1960 , + 1020 ,286 ,1616 ,184 ,1197 ,1697 ,1613 ,727 ,1288 ,505 ,922 ,334 ,1738 ,100 ,1719 ,585 , + 1207 ,214 ,168 ,1636 ,1503 ,1779 ,1977 ,1770 ,644 ,782 ,183 ,931 ,962 ,738 ,859 ,632 , + 1255 ,91 ,537 ,1894 ,1801 ,1697 ,837 ,944 ,1186 ,1384 ,1037 ,1062 ,1300 ,1932 ,1821 ,1591 , + 920 ,1491 ,1736 ,955 ,608 ,585 ,743 ,1093 ,1205 ,531 ,133 ,1672 ,1571 ,1115 ,1561 ,1759 , + 771 ,572 ,2 ,436 ,1427 ,195 ,148 ,1172 ,1158 ,1420 ,1557 ,1284 ,750 ,1069 ,1406 ,707 , + 969 ,1262 ,989 ,527 ,1633 ,43 ,2022 ,2002 ,1175 ,1192 ,1733 ,1186 ,1958 ,1358 ,338 ,512 , + 361 ,1797 ,417 ,1887 ,1678 ,1974 ,1015 ,1578 ,1944 ,1375 ,1286 ,206 ,504 ,599 ,690 ,76 , + 1863 ,335 ,1159 ,201 ,1826 ,654 ,1479 ,1840 ,471 ,142 ,2003 ,1244 ,1476 ,1043 ,2033 ,102 , + 1748 ,1745 ,1836 ,1484 ,1221 ,792 ,1860 ,1256 ,449 ,1550 ,630 ,340 ,436 ,1371 ,188 ,779 , + 422 ,1117 ,1241 ,690 ,1244 ,63 ,1315 ,1746 ,1069 ,1475 ,1975 ,1301 ,1531 ,383 ,1871 ,1179 , + 383 ,1936 ,443 ,211 ,1106 ,39 ,934 ,359 ,1138 ,942 ,1412 ,1240 ,160 ,1483 ,348 ,431 , + 853 ,283 ,754 ,1648 ,1321 ,44 ,989 ,1913 ,262 ,197 ,993 ,1318 ,1973 ,946 ,449 ,1352 , + 854 ,250 ,816 ,1309 ,1670 ,572 ,736 ,1815 ,797 ,1611 ,1441 ,114 ,1985 ,196 ,1416 ,186 , + 46 ,1806 ,162 ,433 ,1302 ,1125 ,1368 ,629 ,115 ,44 ,165 ,1831 ,1865 ,1537 ,762 ,724 , + 974 ,1667 ,349 ,725 ,486 ,1169 ,221 ,1753 ,201 ,1306 ,913 ,1791 ,1910 ,742 ,525 ,787 , + 1919 ,253 ,97 ,1316 ,1394 ,906 ,56 ,1620 ,24 ,1556 ,177 ,389 ,2011 ,12 ,757 ,940 , + 649 ,1936 ,1214 ,789 ,509 ,631 ,922 ,1221 ,744 ,1451 ,1024 ,172 ,1610 ,43 ,1102 ,807 , + 57 ,1137 ,1452 ,403 ,861 ,510 ,391 ,1209 ,827 ,880 ,390 ,579 ,52 ,859 ,1662 ,1178 , + 429 ,675 ,1817 ,211 ,211 ,1580 ,1500 ,1599 ,24 ,511 ,409 ,983 ,1339 ,1880 ,1037 ,288 , + 1585 ,1589 ,1174 ,1252 ,111 ,1850 ,1230 ,1562 ,2032 ,1861 ,823 ,572 ,578 ,90 ,47 ,583 , + 944 ,1971 ,705 ,543 ,72 ,411 ,1701 ,867 ,814 ,34 ,747 ,1250 ,1472 ,1014 ,184 ,988 , + 1890 ,919 ,1695 ,1787 ,1958 ,724 ,424 ,556 ,51 ,1556 ,1312 ,1968 ,11 ,1566 ,342 ,1195 , + 427 ,496 ,1401 ,1858 ,337 ,1474 ,882 ,968 ,1172 ,890 ,1572 ,19 ,112 ,1613 ,1273 ,1197 , + 62 ,379 ,52 ,1232 ,1867 ,897 ,985 ,985 ,941 ,1344 ,372 ,1660 ,2006 ,841 ,157 ,1868 , + 229 ,1246 ,1399 ,1875 ,566 ,1713 ,384 ,203 ,291 ,1872 ,1423 ,762 ,608 ,1748 ,1094 ,1259 , + 336 ,1729 ,440 ,1128 ,922 ,1733 ,1254 ,248 ,750 ,984 ,1240 ,218 ,1547 ,38 ,2027 ,1663 , + 1622 ,1212 ,1419 ,1515 ,717 ,364 ,1867 ,163 ,816 ,1245 ,1982 ,2005 ,875 ,1964 ,1279 ,60 , + 95 ,722 ,1263 ,1481 ,659 ,1171 ,27 ,320 ,513 ,1578 ,1617 ,767 ,1170 ,1085 ,1162 ,228 , + 1332 ,313 ,1439 ,1630 ,245 ,1609 ,1370 ,1686 ,1840 ,633 ,215 ,1718 ,1726 ,306 ,1563 ,90 , + 112 ,1118 ,1394 ,1765 ,1115 ,725 ,1895 ,1430 ,1220 ,1084 ,1831 ,231 ,161 ,867 ,1303 ,403 , + 1999 ,910 ,1388 ,634 ,846 ,612 ,1662 ,334 ,1919 ,1996 ,1095 ,1002 ,1855 ,1759 ,1376 ,22 , + 1471 ,1918 ,1250 ,1292 ,806 ,1774 ,1221 ,1160 ,1044 ,293 ,39 ,168 ,1340 ,668 ,582 ,314 , + 871 ,677 ,2031 ,888 ,495 ,946 ,182 ,1642 ,616 ,982 ,1398 ,1974 ,74 ,20 ,692 ,1338 , + 513 ,1564 ,76 ,1470 ,61 ,617 ,1221 ,370 ,501 ,7 ,1976 ,1258 ,606 ,372 ,344 ,1303 , + 1864 ,324 ,1745 ,952 ,1563 ,1824 ,6 ,1389 ,2036 ,1827 ,806 ,1483 ,1977 ,1882 ,37 ,413 , + 1476 ,843 ,1853 ,1337 ,1762 ,972 ,491 ,1852 ,673 ,724 ,808 ,447 ,121 ,395 ,1792 ,1026 , + 235 ,307 ,1575 ,194 ,1687 ,1795 ,1921 ,1895 ,1178 ,1556 ,429 ,850 ,613 ,660 ,1471 ,576 , + 250 ,1555 ,1268 ,1024 ,1163 ,1680 ,1706 ,319 ,1655 ,1156 ,610 ,766 ,1549 ,1915 ,1667 ,568 , + 138 ,936 ,50 ,444 ,1258 ,1163 ,318 ,195 ,1829 ,123 ,893 ,2025 ,1637 ,902 ,807 ,1745 , + 1789 ,1229 ,1007 ,1970 ,1402 ,1948 ,414 ,906 ,756 ,1560 ,277 ,1269 ,807 ,864 ,142 ,2002 , + 599 ,1797 ,1419 ,1745 ,1944 ,1377 ,270 ,18 ,1880 ,1950 ,1591 ,70 ,1853 ,1022 ,2035 ,979 , + 1846 ,688 ,856 ,160 ,1627 ,1262 ,300 ,151 ,1054 ,1129 ,1448 ,451 ,712 ,1555 ,86 ,801 , + 1173 ,815 ,1456 ,1218 ,1783 ,1420 ,686 ,1775 ,1343 ,396 ,701 ,441 ,1080 ,647 ,694 ,1720 , + 1883 ,758 ,235 ,1493 ,86 ,505 ,1915 ,1206 ,385 ,1619 ,442 ,1038 ,190 ,717 ,984 ,1432 , + 324 ,1046 ,277 ,1858 ,419 ,1299 ,2000 ,311 ,735 ,1975 ,1491 ,305 ,1264 ,739 ,1143 ,414 , + 606 ,305 ,1077 ,1951 ,1258 ,1443 ,935 ,194 ,1628 ,1906 ,382 ,591 ,1682 ,211 ,1048 ,1435 , + 309 ,1349 ,932 ,671 ,893 ,1828 ,839 ,999 ,1644 ,774 ,1273 ,264 ,1550 ,253 ,234 ,426 , + 1032 ,2009 ,1477 ,1972 ,705 ,1047 ,253 ,1756 ,1732 ,333 ,1245 ,513 ,1978 ,1990 ,1531 ,722 , + 1520 ,1406 ,1549 ,1850 ,66 ,1878 ,660 ,1985 ,44 ,656 ,1344 ,1141 ,335 ,419 ,1488 ,548 , + 709 ,1003 ,1195 ,147 ,1766 ,1916 ,431 ,1831 ,1833 ,97 ,634 ,1244 ,133 ,1448 ,191 ,281 , + 760 ,1421 ,66 ,1519 ,1771 ,1122 ,67 ,1625 ,902 ,1093 ,176 ,2041 ,865 ,1434 ,1486 ,302 , + 1818 ,70 ,181 ,790 ,1724 ,1417 ,1316 ,2004 ,919 ,35 ,1098 ,1545 ,1959 ,322 ,761 ,1651 , + 422 ,828 ,1773 ,1105 ,816 ,1513 ,1143 ,1280 ,213 ,763 ,1681 ,106 ,1643 ,322 ,1158 ,1446 , + 888 ,672 ,1239 ,400 ,1019 ,64 ,891 ,59 ,1964 ,1844 ,240 ,1608 ,433 ,141 ,975 ,1916 , + 1925 ,858 ,1923 ,1691 ,216 ,1317 ,45 ,877 ,1428 ,1411 ,1354 ,1774 ,430 ,1769 ,1088 ,374 , + 167 ,655 ,1348 ,301 ,1240 ,1611 ,1587 ,1421 ,554 ,1429 ,718 ,1855 ,1077 ,1948 ,1463 ,1952 , + 680 ,989 ,382 ,1955 ,1695 ,326 ,972 ,1286 ,1419 ,225 ,981 ,898 ,409 ,161 ,192 ,1242 , + 521 ,991 ,1114 ,1335 ,92 ,837 ,2041 ,923 ,1411 ,1467 ,1422 ,973 ,1818 ,739 ,635 ,234 , + 1991 ,1454 ,699 ,1332 ,131 ,1258 ,1431 ,12 ,759 ,87 ,1817 ,1615 ,1325 ,1780 ,704 ,1599 , + 149 ,918 ,1117 ,336 ,480 ,1418 ,609 ,578 ,941 ,1987 ,1692 ,1847 ,787 ,1946 ,114 ,584 , + 140 ,286 ,1856 ,184 ,933 ,198 ,179 ,1407 ,232 ,1044 ,1256 ,1639 ,1901 ,1165 ,1041 ,369 , + 1949 ,668 ,130 ,95 ,883 ,358 ,1117 ,800 ,294 ,1934 ,1718 ,1651 ,750 ,124 ,864 ,139 , + 808 ,11 ,1830 ,325 ,1199 ,1285 ,1224 ,1785 ,2016 ,2007 ,488 ,789 ,1257 ,947 ,437 ,387 , + 227 ,740 ,43 ,969 ,165 ,504 ,1148 ,499 ,209 ,956 ,1278 ,1075 ,1395 ,1056 ,1702 ,1365 , + 1948 ,1587 ,134 ,936 ,753 ,1850 ,1802 ,1210 ,708 ,1361 ,811 ,1799 ,276 ,847 ,1499 ,616 , + 1934 ,1262 ,128 ,1971 ,1335 ,1996 ,607 ,680 ,1315 ,1878 ,1042 ,612 ,1399 ,683 ,1018 ,1535 , + 1441 ,726 ,1405 ,249 ,1382 ,1244 ,2041 ,1337 ,370 ,537 ,1183 ,895 ,636 ,556 ,1148 ,1656 , + 508 ,113 ,926 ,1701 ,1713 ,1294 ,1677 ,904 ,666 ,44 ,259 ,102 ,509 ,670 ,1128 ,1601 , + 386 ,586 ,263 ,343 ,125 ,456 ,2020 ,1673 ,1417 ,1230 ,1608 ,1669 ,1004 ,333 ,1167 ,786 , + 78 ,206 ,972 ,1657 ,1834 ,972 ,1799 ,777 ,63 ,89 ,1909 ,1235 ,566 ,1109 ,1230 ,1094 , + 1687 ,1694 ,889 ,1051 ,721 ,378 ,750 ,1839 ,1753 ,1913 ,67 ,1662 ,1913 ,674 ,1956 ,925 , + 639 ,619 ,21 ,381 ,965 ,603 ,1888 ,1719 ,1098 ,1641 ,1387 ,1182 ,1388 ,958 ,222 ,919 , + 725 ,1013 ,1789 ,870 ,303 ,414 ,1818 ,95 ,76 ,239 ,6 ,8 ,1329 ,1766 ,1136 ,1995 , + 1052 ,220 ,1505 ,45 ,885 ,736 ,897 ,1599 ,767 ,1105 ,183 ,674 ,1008 ,1483 ,101 ,326 , + 599 ,1544 ,173 ,78 ,132 ,1032 ,847 ,1941 ,1787 ,397 ,1660 ,1166 ,1379 ,1343 ,1437 ,364 , + 1733 ,728 ,1970 ,1804 ,1627 ,802 ,324 ,510 ,847 ,1940 ,657 ,2017 ,10 ,1980 ,1467 ,1865 , + 1817 ,605 ,1465 ,1296 ,1082 ,1697 ,1142 ,1301 ,500 ,1663 ,2014 ,1000 ,1349 ,785 ,201 ,1775 , + 1022 ,1218 ,354 ,1881 ,294 ,1977 ,330 ,447 ,1662 ,1667 ,404 ,1944 ,633 ,300 ,190 ,1613 , + 394 ,1229 ,1878 ,249 ,819 ,251 ,1589 ,1601 ,1909 ,1637 ,757 ,1133 ,1175 ,900 ,1168 ,448 , + 797 ,646 ,52 ,1525 ,1133 ,1456 ,1199 ,1004 ,222 ,125 ,1435 ,1343 ,1064 ,1356 ,1394 ,523 , + 192 ,477 ,1697 ,97 ,589 ,714 ,1871 ,1744 ,469 ,345 ,34 ,102 ,1736 ,811 ,166 ,1032 , + 200 ,1419 ,1736 ,1428 ,610 ,251 ,1353 ,74 ,792 ,847 ,1272 ,191 ,1160 ,1864 ,1047 ,1752 , + 448 ,316 ,1559 ,1348 ,481 ,1443 ,825 ,1744 ,1561 ,91 ,1048 ,436 ,1736 ,1430 ,802 ,851 , + 1978 ,251 ,1736 ,1428 ,1255 ,1897 ,655 ,1586 ,204 ,591 ,569 ,951 ,35 ,14 ,427 ,1562 , + 1850 ,316 ,783 ,164 ,267 ,1572 ,976 ,1744 ,1210 ,374 ,118 ,439 ,1908 ,1714 ,853 ,851 , + 210 ,646 ,1736 ,1406 ,1908 ,1897 ,1497 ,1586 ,1423 ,847 ,1538 ,951 ,1930 ,97 ,1047 ,1752 , + 1850 ,1700 ,1178 ,290 ,1736 ,1572 ,976 ,1744 ,1210 ,91 ,662 ,436 ,1343 ,1430 ,644 ,851 , + 1978 ,251 ,1490 ,1428 ,1908 ,1067 ,377 ,1586 ,1423 ,973 ,1538 ,1388 ,1995 ,14 ,427 ,1956 , + 7 ,1056 ,1456 ,390 ,340 ,1572 ,1978 ,1744 ,739 ,459 ,1165 ,1190 ,1912 ,1714 ,853 ,241 , + 929 ,1419 ,1736 ,385 ,423 ,1953 ,1332 ,717 ,604 ,591 ,1911 ,906 ,1930 ,14 ,1140 ,1752 , + 1833 ,811 ,480 ,290 ,481 ,1443 ,1978 ,1648 ,739 ,1370 ,415 ,253 ,202 ,2043 ,644 ,851 , + 1900 ,1829 ,1490 ,1853 ,1218 ,1953 ,1332 ,1004 ,724 ,1704 ,899 ,559 ,1995 ,1688 ,271 ,1956 , + 293 ,1056 ,1029 ,546 ,267 ,1443 ,825 ,2008 ,1437 ,1542 ,662 ,253 ,1736 ,1714 ,802 ,851 , + 1978 ,251 ,618 ,1402 ,1908 ,1626 ,1497 ,1586 ,724 ,973 ,1538 ,1124 ,35 ,1688 ,1047 ,1562 , + 1850 ,1056 ,1697 ,1348 ,481 ,1443 ,666 ,1744 ,1972 ,374 ,415 ,253 ,1736 ,1714 ,802 ,1684 , + 32 ,1642 ,290 ,1047 ,1908 ,1626 ,1353 ,1774 ,1423 ,973 ,1538 ,483 ,1995 ,97 ,1047 ,1562 , + 1890 ,2038 ,1668 ,939 ,1684 ,1799 ,1286 ,82 ,2029 ,1696 ,1587 ,428 ,437 ,1711 ,322 ,1514 , + 615 ,1571 ,1396 ,1859 ,509 ,1163 ,5 ,697 ,85 ,201 ,1109 ,1921 ,162 ,21 ,186 ,852 , + 361 ,133 ,645 ,1929 ,1446 ,230 ,1688 ,494 ,1446 ,890 ,1264 ,1689 ,824 ,1345 ,1942 ,1783 , + 752 ,1549 ,1579 ,1799 ,477 ,384 ,253 ,945 ,429 ,487 ,855 ,610 ,970 ,335 ,1390 ,1365 , + 1748 ,656 ,1060 ,175 ,2036 ,627 ,1827 ,1540 ,461 ,1517 ,913 ,60 ,973 ,1265 ,693 ,301 , + 1795 ,147 ,1826 ,365 ,1505 ,1250 ,184 ,975 ,81 ,1953 ,259 ,784 ,179 ,486 ,1254 ,77 , + 1631 ,1518 ,1448 ,2026 ,1502 ,54 ,617 ,963 ,904 ,790 ,1295 ,676 ,1009 ,201 ,898 ,1869 , + 638 ,876 ,2013 ,32 ,1952 ,1007 ,160 ,1303 ,1365 ,833 ,242 ,1219 ,213 ,1484 ,1514 ,851 , + 1255 ,994 ,1016 ,1673 ,623 ,1737 ,469 ,2016 ,1639 ,1500 ,1176 ,350 ,1783 ,1863 ,394 ,1492 , + 1224 ,1810 ,1884 ,1369 ,358 ,1843 ,1658 ,1314 ,1390 ,668 ,1938 ,235 ,1543 ,876 ,757 ,933 , + 577 ,82 ,658 ,966 ,863 ,1007 ,89 ,612 ,887 ,1182 ,2040 ,375 ,1084 ,2007 ,1311 ,1028 , + 835 ,1612 ,637 ,677 ,608 ,555 ,813 ,179 ,1344 ,910 ,766 ,1682 ,1904 ,351 ,1730 ,1871 , + 1362 ,1520 ,1456 ,183 ,811 ,1652 ,904 ,1300 ,1210 ,1769 ,1516 ,1383 ,1343 ,1492 ,170 ,485 , + 1063 ,1642 ,1490 ,172 ,1218 ,1667 ,634 ,1163 ,804 ,726 ,47 ,1388 ,866 ,1459 ,475 ,573 , + 1195 ,1376 ,127 ,408 ,910 ,1385 ,640 ,1747 ,1381 ,1841 ,454 ,677 ,1572 ,772 ,1543 ,1798 , + 55 ,1908 ,1140 ,1330 ,1500 ,1591 ,538 ,1262 ,492 ,1016 ,181 ,569 ,1018 ,1516 ,1536 ,1739 , + 214 ,1083 ,1309 ,1271 ,104 ,1213 ,1722 ,375 ,928 ,363 ,19 ,1984 ,1538 ,629 ,1621 ,334 , + 234 ,193 ,1616 ,408 ,2029 ,1365 ,688 ,1675 ,432 ,1485 ,326 ,805 ,1170 ,2044 ,1271 ,920 , + 214 ,278 ,705 ,1994 ,1341 ,867 ,885 ,440 ,1390 ,113 ,1866 ,982 ,1800 ,2023 ,1965 ,1628 , + 851 ,1556 ,233 ,1945 ,930 ,526 ,1851 ,1090 ,1160 ,770 ,608 ,1751 ,641 ,1835 ,1486 ,918 , + 805 ,351 ,996 ,1671 ,56 ,1907 ,229 ,980 ,984 ,1283 ,1256 ,1957 ,985 ,1748 ,698 ,1527 , + 734 ,1471 ,1369 ,581 ,215 ,369 ,476 ,1666 ,439 ,635 ,1374 ,1446 ,5 ,1605 ,337 ,53 , + 872 ,1666 ,432 ,1673 ,353 ,769 ,577 ,159 ,568 ,974 ,1777 ,1413 ,870 ,766 ,89 ,670 , + 1689 ,1077 ,1404 ,108 ,55 ,1064 ,649 ,111 ,1975 ,1406 ,1121 ,724 ,253 ,1938 ,14 ,1185 , + 2000 ,809 ,750 ,1767 ,795 ,1020 ,1414 ,165 ,1506 ,659 ,802 ,1646 ,1643 ,1164 ,630 ,1349 , + 1013 ,354 ,49 ,1226 ,1225 ,351 ,951 ,1236 ,1144 ,833 ,450 ,137 ,831 ,217 ,1026 ,1287 , + 1285 ,1783 ,1314 ,309 ,121 ,1856 ,401 ,1529 ,836 ,519 ,1162 ,286 ,886 ,916 ,1844 ,878 , + 707 ,690 ,1758 ,20 ,368 ,818 ,950 ,638 ,345 ,433 ,1090 ,1713 ,1580 ,1017 ,628 ,1086 , + 574 ,1873 ,1574 ,1736 ,690 ,1661 ,203 ,512 ,607 ,1853 ,1631 ,536 ,1182 ,243 ,1892 ,573 , + 1787 ,533 ,1024 ,962 ,71 ,596 ,1442 ,1694 ,856 ,661 ,1236 ,1635 ,1650 ,1474 ,1867 ,1392 , + 1367 ,2019 ,465 ,1306 ,681 ,1791 ,1540 ,1523 ,1984 ,1827 ,285 ,1282 ,912 ,466 ,294 ,357 , + 693 ,1179 ,117 ,1492 ,1566 ,702 ,698 ,966 ,695 ,1365 ,478 ,1148 ,617 ,1375 ,143 ,1907 , + 443 ,948 ,1883 ,550 ,1545 ,1777 ,1956 ,1570 ,1652 ,1925 ,1840 ,173 ,1287 ,664 ,1267 ,1337 , + 628 ,147 ,1218 ,542 ,950 ,1393 ,778 ,1341 ,1613 ,1833 ,783 ,531 ,1702 ,198 ,1615 ,932 , + 1076 ,1642 ,1388 ,1693 ,276 ,1012 ,493 ,1543 ,1505 ,775 ,543 ,1976 ,1529 ,1558 ,41 ,1914 , + 1641 ,161 ,1605 ,230 ,1710 ,1162 ,987 ,1669 ,1951 ,1212 ,975 ,1154 ,867 ,1138 ,470 ,212 , + 38 ,795 ,1238 ,1723 ,1507 ,1077 ,1409 ,1982 ,2043 ,1102 ,2047 ,269 ,629 ,197 ,1524 ,1160 , + 770 ,260 ,1081 ,776 ,825 ,37 ,1805 ,1061 ,1622 ,438 ,352 ,736 ,1203 ,1351 ,175 ,1313 , + 736 ,813 ,1463 ,40 ,1095 ,927 ,977 ,1756 ,1045 ,1872 ,457 ,1937 ,563 ,1929 ,1884 ,1162 , + 186 ,1464 ,46 ,74 ,1372 ,625 ,849 ,1842 ,846 ,1533 ,2046 ,1385 ,1870 ,90 ,1941 ,587 , + 965 ,1253 ,1156 ,1618 ,524 ,1147 ,422 ,376 ,1384 ,581 ,405 ,943 ,1483 ,1648 ,772 ,1556 , + 1354 ,380 ,1904 ,1697 ,691 ,637 ,1730 ,1189 ,1092 ,1379 ,1584 ,104 ,604 ,937 ,1427 ,574 , + 577 ,1520 ,1016 ,309 ,942 ,1522 ,1524 ,1628 ,836 ,1170 ,1310 ,1610 ,139 ,36 ,446 ,241 , + 1041 ,104 ,1665 ,1900 ,469 ,1909 ,433 ,1612 ,1671 ,1591 ,1076 ,784 ,1992 ,1640 ,712 ,1937 , + 1696 ,427 ,161 ,1697 ,200 ,1652 ,50 ,390 ,1852 ,697 ,209 ,769 ,908 ,914 ,787 ,959 , + 530 ,287 ,731 ,518 ,1120 ,77 ,378 ,1170 ,482 ,459 ,115 ,906 ,1730 ,258 ,1587 ,1274 , + 1266 ,1423 ,1668 ,139 ,888 ,2043 ,1972 ,1113 ,963 ,1460 ,319 ,408 ,1560 ,998 ,1955 ,1040 , + 382 ,517 ,492 ,1865 ,650 ,1184 ,1547 ,609 ,1744 ,674 ,1839 ,910 ,511 ,7 ,883 ,861 , + 1154 ,869 ,1856 ,843 ,1460 ,969 ,1401 ,1074 ,29 ,1561 ,1737 ,283 ,1075 ,1750 ,265 ,527 , + 1394 ,1649 ,1185 ,366 ,797 ,284 ,687 ,775 ,1257 ,1527 ,143 ,1956 ,1895 ,627 ,1005 ,720 , + 1618 ,1436 ,1590 ,1439 ,2028 ,484 ,1123 ,1280 ,1470 ,1142 ,481 ,1569 ,1176 ,1997 ,1321 ,1023 , + 1954 ,1622 ,369 ,979 ,429 ,1186 ,1445 ,1085 ,1556 ,1216 ,63 ,748 ,1474 ,117 ,402 ,1519 , + 1181 ,594 ,1812 ,1297 ,365 ,39 ,822 ,1690 ,1385 ,1550 ,528 ,1519 ,1853 ,431 ,1648 ,321 , + 366 ,474 ,62 ,1984 ,1038 ,516 ,1775 ,754 ,1249 ,109 ,474 ,982 ,1790 ,1590 ,1853 ,1287 , + 1866 ,1760 ,1866 ,53 ,741 ,600 ,1841 ,2024 ,889 ,323 ,1257 ,1575 ,153 ,1205 ,1400 ,898 , + 1211 ,480 ,689 ,1995 ,1727 ,29 ,1887 ,1710 ,119 ,1623 ,1833 ,1952 ,305 ,373 ,1421 ,1914 , + 1864 ,1268 ,443 ,1881 ,377 ,1616 ,68 ,1669 ,2022 ,1097 ,24 ,345 ,790 ,1235 ,980 ,1660 , + 1620 ,240 ,1370 ,1894 ,204 ,200 ,309 ,1350 ,283 ,316 ,150 ,1283 ,133 ,1358 ,1103 ,1318 , + 1240 ,1626 ,1349 ,1321 ,401 ,319 ,2020 ,513 ,1264 ,1083 ,217 ,287 ,1375 ,1047 ,1052 ,304 , + 780 ,150 ,40 ,507 ,1773 ,266 ,1689 ,1655 ,31 ,1402 ,711 ,482 ,920 ,211 ,981 ,1524 , + 734 ,376 ,752 ,1397 ,40 ,219 ,1378 ,482 ,1948 ,200 ,918 ,842 ,849 ,1779 ,68 ,2000 , + 437 ,1624 ,1737 ,1289 ,432 ,1847 ,1104 ,174 ,1393 ,1467 ,483 ,996 ,1308 ,1407 ,1544 ,414 , + 429 ,656 ,728 ,187 ,1224 ,1230 ,1223 ,622 ,551 ,1410 ,687 ,193 ,1741 ,620 ,389 ,1397 , + 1804 ,1609 ,965 ,558 ,828 ,1718 ,1776 ,599 ,957 ,1100 ,110 ,1594 ,705 ,686 ,528 ,1577 , + 932 ,1541 ,983 ,607 ,1398 ,1753 ,1634 ,1767 ,1513 ,1278 ,163 ,928 ,319 ,828 ,1241 ,1357 , + 1778 ,1041 ,679 ,1471 ,1789 ,1089 ,285 ,481 ,1697 ,143 ,438 ,1244 ,790 ,1402 ,630 ,96 , + 866 ,1498 ,45 ,2002 ,75 ,1600 ,492 ,647 ,604 ,1825 ,1681 ,2003 ,652 ,1232 ,1687 ,1826 , + 980 ,1199 ,1520 ,403 ,957 ,1249 ,264 ,1827 ,587 ,1318 ,1596 ,542 ,1087 ,564 ,1212 ,26 , + 1035 ,1535 ,1945 ,1021 ,1929 ,1554 ,1792 ,904 ,4 ,471 ,1640 ,434 ,1349 ,1281 ,1038 ,1054 , + 646 ,1730 ,1557 ,211 ,1449 ,569 ,790 ,5 ,934 ,1608 ,1275 ,1141 ,1295 ,930 ,1682 ,1290 , + 1020 ,363 ,537 ,1673 ,1801 ,1356 ,417 ,1538 ,904 ,278 ,1966 ,652 ,1475 ,1705 ,1149 ,715 , + 937 ,1692 ,1896 ,808 ,1256 ,531 ,561 ,508 ,11 ,810 ,1091 ,1808 ,213 ,77 ,901 ,1021 , + 1160 ,1646 ,709 ,1264 ,16 ,1652 ,1524 ,1156 ,505 ,611 ,605 ,860 ,2014 ,1705 ,1246 ,1218 , + 646 ,1644 ,841 ,1989 ,705 ,507 ,254 ,1386 ,1205 ,669 ,1343 ,432 ,1365 ,251 ,1704 ,1256 , + 1160 ,996 ,1466 ,825 ,1792 ,1514 ,168 ,948 ,1961 ,1625 ,1029 ,1455 ,508 ,807 ,702 ,1604 , + 1902 ,1689 ,1479 ,1916 ,728 ,1291 ,1142 ,1248 ,1875 ,152 ,1587 ,1528 ,1809 ,1244 ,1705 ,1475 , + 1046 ,523 ,1817 ,859 ,1502 ,1003 ,2001 ,428 ,184 ,79 ,571 ,305 ,225 ,1461 ,659 ,821 , + 1265 ,1356 ,1495 ,1920 ,1866 ,344 ,593 ,276 ,1342 ,525 ,433 ,526 ,1289 ,766 ,1871 ,941 , + 1674 ,892 ,969 ,14 ,1761 ,1765 ,1718 ,1509 ,236 ,411 ,114 ,419 ,1276 ,1574 ,1626 ,109 , + 1076 ,1653 ,432 ,1681 ,1657 ,519 ,1431 ,1064 ,1882 ,1329 ,1397 ,854 ,1387 ,1355 ,348 ,1132 , + 1693 ,1768 ,1448 ,655 ,1992 ,1080 ,125 ,1262 ,74 ,1425 ,1006 ,496 ,1871 ,920 ,1623 ,181 , + 894 ,475 ,533 ,1808 ,258 ,1960 ,677 ,781 ,1662 ,628 ,1446 ,916 ,166 ,1806 ,1642 ,1573 , + 1824 ,820 ,340 ,309 ,1761 ,515 ,1660 ,1370 ,953 ,1259 ,784 ,1985 ,1080 ,479 ,1427 ,1560 , + 1069 ,1431 ,823 ,1472 ,335 ,239 ,762 ,1077 ,523 ,54 ,535 ,827 ,1913 ,1012 ,1447 ,265 , + 599 ,737 ,1938 ,1089 ,1852 ,451 ,1144 ,1721 ,863 ,552 ,125 ,1398 ,610 ,1304 ,1879 ,177 , + 1582 ,1015 ,686 ,1978 ,1599 ,601 ,1465 ,206 ,102 ,740 ,1474 ,350 ,1451 ,1710 ,962 ,909 , + 130 ,986 ,896 ,1271 ,1167 ,1152 ,2014 ,1494 ,1710 ,354 ,227 ,2004 ,1272 ,178 ,1157 ,249 , + 1214 ,1421 ,1165 ,1073 ,315 ,1884 ,187 ,1589 ,895 ,1728 ,1945 ,1834 ,868 ,904 ,1599 ,1670 , + 812 ,774 ,1549 ,343 ,1273 ,1325 ,1898 ,1097 ,474 ,1241 ,1144 ,374 ,1315 ,1040 ,2 ,1138 , + 740 ,1295 ,127 ,1585 ,85 ,429 ,856 ,1932 ,1694 ,2041 ,479 ,74 ,569 ,897 ,804 ,1559 , + 964 ,908 ,1380 ,771 ,1658 ,379 ,154 ,1118 ,1946 ,1849 ,196 ,1084 ,853 ,1209 ,1307 ,1441 , + 1205 ,678 ,1827 ,1073 ,1364 ,101 ,1756 ,1437 ,483 ,242 ,148 ,675 ,1338 ,669 ,1457 ,1601 , + 450 ,1432 ,580 ,306 ,1783 ,493 ,955 ,458 ,136 ,1903 ,1065 ,176 ,1622 ,425 ,190 ,746 , + 960 ,1534 ,1036 ,1107 ,808 ,456 ,1601 ,497 ,1018 ,1140 ,148 ,1627 ,1176 ,954 ,1819 ,1493 , + 2023 ,1374 ,791 ,733 ,2020 ,934 ,715 ,1139 ,2013 ,1988 ,1616 ,1384 ,548 ,306 ,1681 ,599 , + 1558 ,50 ,193 ,371 ,1460 ,650 ,165 ,1129 ,541 ,1875 ,603 ,1281 ,853 ,282 ,1104 ,1582 , + 592 ,897 ,1542 ,1921 ,1690 ,1817 ,1416 ,453 ,1034 ,665 ,846 ,1755 ,596 ,433 ,1095 ,109 , + 1925 ,1973 ,10 ,998 ,1410 ,2002 ,1874 ,1187 ,1475 ,5 ,1821 ,213 ,1766 ,1232 ,1114 ,33 , + 1452 ,1439 ,999 ,1603 ,1109 ,1424 ,1818 ,1813 ,1342 ,90 ,238 ,962 ,481 ,1251 ,1643 ,666 , + 1573 ,332 ,451 ,262 ,1640 ,1085 ,1100 ,143 ,1523 ,1928 ,1419 ,561 ,1148 ,1659 ,153 ,1578 , + 491 ,250 ,1480 ,124 ,1855 ,1105 ,1969 ,1725 ,1386 ,155 ,39 ,332 ,152 ,1323 ,706 ,1212 , + 1443 ,1085 ,2030 ,1733 ,853 ,497 ,1773 ,1329 ,1568 ,1663 ,516 ,1113 ,200 ,1489 ,1951 ,1540 , + 1024 ,578 ,493 ,2014 ,307 ,1513 ,944 ,1892 ,610 ,121 ,972 ,658 ,1551 ,940 ,1744 ,1059 , + 1733 ,1628 ,1487 ,558 ,327 ,1812 ,1978 ,1740 ,1591 ,959 ,916 ,218 ,1975 ,216 ,578 ,1175 , + 231 ,872 ,139 ,887 ,1552 ,1420 ,1506 ,451 ,1674 ,941 ,261 ,651 ,319 ,1451 ,1479 ,1530 , + 1523 ,1950 ,284 ,183 ,124 ,228 ,414 ,1049 ,1102 ,504 ,1193 ,575 ,1506 ,1749 ,790 ,523 , + 1705 ,54 ,1873 ,547 ,932 ,862 ,2000 ,1142 ,927 ,1182 ,687 ,1534 ,1223 ,469 ,2038 ,1212 , + 945 ,736 ,1152 ,420 ,18 ,1960 ,656 ,1030 ,1364 ,429 ,579 ,108 ,354 ,875 ,1998 ,939 , + 1980 ,138 ,690 ,1469 ,745 ,822 ,1665 ,148 ,1634 ,225 ,1027 ,1141 ,1789 ,894 ,1756 ,728 , + 455 ,1986 ,517 ,1162 ,2030 ,1139 ,1309 ,91 ,1553 ,194 ,1616 ,824 ,163 ,49 ,244 ,1593 , + 1729 ,399 ,1990 ,1921 ,887 ,1272 ,1274 ,1619 ,971 ,1703 ,1974 ,1420 ,1127 ,766 ,103 ,1296 , + 762 ,927 ,455 ,1830 ,678 ,349 ,1606 ,1790 ,1479 ,613 ,2002 ,1208 ,214 ,186 ,426 ,1407 , + 1033 ,857 ,307 ,658 ,1081 ,509 ,811 ,1198 ,741 ,1682 ,816 ,1630 ,598 ,1498 ,1519 ,382 , + 907 ,935 ,257 ,1138 ,432 ,1397 ,1587 ,979 ,988 ,1747 ,1720 ,874 ,985 ,1342 ,1268 ,169 , + 970 ,871 ,1902 ,1116 ,531 ,1961 ,1773 ,1207 ,1075 ,663 ,888 ,1435 ,2038 ,1262 ,235 ,102 , + 1742 ,357 ,1881 ,1404 ,944 ,41 ,946 ,1080 ,1199 ,296 ,1072 ,335 ,1480 ,22 ,527 ,1363 , + 1289 ,1014 ,717 ,1020 ,1926 ,1700 ,957 ,848 ,516 ,1436 ,1272 ,725 ,1923 ,101 ,1044 ,462 , + 1854 ,1958 ,964 ,363 ,1955 ,858 ,1619 ,1659 ,1203 ,919 ,1299 ,431 ,1917 ,1045 ,193 ,330 , + 11 ,1065 ,549 ,1266 ,1526 ,1001 ,773 ,80 ,1337 ,1831 ,1745 ,731 ,958 ,463 ,150 ,308 , + 245 ,230 ,1152 ,1866 ,1181 ,401 ,472 ,1267 ,372 ,1372 ,1169 ,327 ,201 ,1547 ,1030 ,1755 , + 1066 ,748 ,1615 ,48 ,403 ,1127 ,1638 ,947 ,1532 ,1531 ,1286 ,346 ,700 ,1616 ,1632 ,1497 , + 523 ,803 ,584 ,686 ,1094 ,231 ,1620 ,26 ,1323 ,2016 ,1363 ,1046 ,1306 ,1216 ,98 ,1498 , + 366 ,341 ,1158 ,115 ,1270 ,1966 ,465 ,874 ,911 ,874 ,1923 ,1608 ,2017 ,1696 ,936 ,1161 , + 749 ,733 ,493 ,1872 ,432 ,2 ,1126 ,1858 ,596 ,357 ,1138 ,1718 ,869 ,125 ,295 ,608 , + 1056 ,1505 ,1900 ,274 ,1101 ,1636 ,1654 ,146 ,1181 ,1072 ,329 ,12 ,1926 ,1410 ,958 ,796 , + 18 ,222 ,1453 ,1467 ,959 ,587 ,1247 ,952 ,1627 ,1240 ,78 ,1543 ,884 ,1132 ,426 ,1771 , + 929 ,1329 ,1011 ,1314 ,202 ,1034 ,795 ,1522 ,618 ,736 ,566 ,1670 ,1424 ,1565 ,1485 ,1657 , + 1734 ,1191 ,339 ,1190 ,894 ,1536 ,1011 ,633 ,1149 ,856 ,1193 ,1746 ,543 ,1421 ,1641 ,1197 , + 1610 ,656 ,1103 ,1178 ,268 ,718 ,464 ,503 ,1742 ,1758 ,558 ,1761 ,951 ,164 ,823 ,1487 , + 1183 ,1939 ,821 ,194 ,1806 ,243 ,1649 ,1220 ,211 ,1935 ,1848 ,1310 ,1720 ,993 ,303 ,1504 , + 1845 ,286 ,1081 ,1461 ,844 ,1335 ,1285 ,491 ,1381 ,916 ,531 ,173 ,820 ,831 ,1472 ,1206 , + 2002 ,1931 ,1650 ,1780 ,684 ,293 ,335 ,848 ,445 ,699 ,261 ,18 ,170 ,1286 ,105 ,1124 , + 1746 ,1005 ,1867 ,725 ,648 ,1534 ,571 ,226 ,361 ,712 ,1659 ,1457 ,1778 ,846 ,697 ,72 , + 1804 ,1559 ,499 ,680 ,1728 ,1982 ,1879 ,1696 ,1397 ,1219 ,289 ,1574 ,213 ,1152 ,1658 ,61 , + 134 ,979 ,237 ,309 ,653 ,1564 ,1216 ,1509 ,1306 ,1569 ,2038 ,1911 ,774 ,1304 ,1667 ,1034 , + 173 ,128 ,1334 ,955 ,1317 ,649 ,1609 ,307 ,68 ,1379 ,424 ,1865 ,226 ,1539 ,624 ,955 , + 890 ,155 ,468 ,1834 ,1135 ,1220 ,1198 ,606 ,677 ,1517 ,1920 ,1210 ,562 ,1716 ,4 ,372 , + 21 ,1785 ,529 ,1829 ,275 ,980 ,1792 ,459 ,609 ,1044 ,1312 ,1193 ,1859 ,1534 ,865 ,1372 , + 1601 ,246 ,718 ,1785 ,932 ,192 ,1475 ,63 ,1689 ,399 ,115 ,1923 ,1903 ,1665 ,27 ,1299 , + 1363 ,396 ,700 ,388 ,1737 ,1095 ,364 ,1046 ,753 ,136 ,1310 ,765 ,261 ,716 ,1266 ,57 , + 599 ,261 ,1251 ,510 ,2032 ,25 ,450 ,1755 ,863 ,1371 ,935 ,569 ,36 ,1638 ,491 ,936 , + 180 ,952 ,1485 ,823 ,1771 ,1838 ,419 ,2018 ,243 ,1036 ,1406 ,948 ,1180 ,1517 ,240 ,839 , + 1635 ,1294 ,1531 ,1794 ,1829 ,2013 ,1355 ,1775 ,345 ,1365 ,669 ,237 ,217 ,764 ,326 ,1524 , + 1305 ,397 ,1255 ,1273 ,1132 ,1289 ,1881 ,990 ,573 ,1217 ,1579 ,458 ,795 ,1053 ,1923 ,107 , + 134 ,1444 ,1702 ,561 ,618 ,1533 ,270 ,1192 ,376 ,149 ,141 ,1644 ,1261 ,543 ,298 ,2019 , + 672 ,1109 ,1172 ,3 ,212 ,139 ,1879 ,1207 ,257 ,172 ,58 ,1317 ,644 ,67 ,1335 ,1146 , + 409 ,1496 ,1484 ,324 ,1727 ,901 ,222 ,522 ,1842 ,1174 ,257 ,1354 ,724 ,1725 ,522 ,670 , + 1033 ,577 ,1521 ,1170 ,1606 ,554 ,39 ,1153 ,1583 ,435 ,41 ,1573 ,389 ,7 ,1888 ,1214 , + 1428 ,1466 ,543 ,1723 ,1832 ,1799 ,1020 ,215 ,334 ,1284 ,1445 ,250 ,1197 ,515 ,1673 ,352 , + 620 ,1001 ,1997 ,147 ,1396 ,458 ,197 ,444 ,283 ,183 ,1790 ,1228 ,751 ,1824 ,1853 ,1577 , + 689 ,1859 ,645 ,1389 ,1581 ,68 ,713 ,1072 ,984 ,1733 ,1285 ,1407 ,19 ,1693 ,1494 ,1678 , + 323 ,1130 ,665 ,382 ,73 ,1133 ,74 ,1699 ,1706 ,303 ,692 ,119 ,1575 ,208 ,1961 ,1903 , + 609 ,1783 ,1220 ,1653 ,581 ,78 ,589 ,1450 ,962 ,1318 ,100 ,398 ,1992 ,330 ,1047 ,1991 , + 1063 ,926 ,1758 ,344 ,624 ,721 ,1453 ,1756 ,142 ,461 ,259 ,1613 ,1682 ,1305 ,632 ,1050 , + 1276 ,1871 ,1111 ,925 ,1439 ,321 ,1423 ,1142 ,1446 ,1673 ,781 ,302 ,19 ,661 ,1238 ,684 , + 1511 ,1855 ,2009 ,76 ,1983 ,387 ,298 ,1785 ,1071 ,407 ,979 ,1718 ,320 ,16 ,186 ,886 , + 1446 ,1474 ,887 ,343 ,1733 ,14 ,788 ,1075 ,1004 ,169 ,817 ,1822 ,1179 ,620 ,178 ,998 , + 1732 ,1322 ,26 ,1259 ,1861 ,1517 ,548 ,414 ,1486 ,1929 ,315 ,35 ,1306 ,1304 ,1910 ,132 , + 438 ,1688 ,981 ,637 ,1939 ,1190 ,1506 ,142 ,1247 ,1205 ,884 ,1209 ,54 ,1812 ,2004 ,571 , + 2042 ,860 ,310 ,859 ,1116 ,467 ,1990 ,43 ,952 ,215 ,352 ,633 ,251 ,909 ,1554 ,785 , + 523 ,328 ,1116 ,1136 ,819 ,1858 ,1807 ,249 ,557 ,570 ,854 ,667 ,137 ,669 ,745 ,1810 , + 676 ,217 ,982 ,1728 ,1234 ,1196 ,553 ,663 ,999 ,1953 ,1415 ,1237 ,1628 ,1093 ,965 ,734 , + 390 ,1051 ,375 ,1802 ,717 ,1543 ,1950 ,179 ,1457 ,713 ,39 ,1517 ,973 ,1028 ,1349 ,1164 , + 965 ,1709 ,162 ,1909 ,1100 ,1717 ,860 ,684 ,319 ,1731 ,738 ,347 ,2042 ,1610 ,1525 ,1455 , + 208 ,1656 ,1334 ,1178 ,1034 ,507 ,1703 ,1886 ,1423 ,362 ,1068 ,14 ,1855 ,1099 ,231 ,1245 , + 1395 ,371 ,621 ,203 ,872 ,841 ,1673 ,1976 ,1584 ,675 ,1174 ,1986 ,643 ,848 ,1354 ,1212 , + 749 ,549 ,849 ,1271 ,1444 ,969 ,1102 ,1949 ,1412 ,482 ,245 ,133 ,1373 ,1011 ,1717 ,1848 , + 1739 ,1513 ,712 ,1519 ,965 ,1042 ,1298 ,278 ,199 ,2020 ,549 ,1251 ,1918 ,1334 ,1978 ,1784 , + 1020 ,1625 ,552 ,135 ,242 ,936 ,624 ,388 ,904 ,820 ,1704 ,242 ,1300 ,914 ,901 ,1119 , + 1669 ,1950 ,1138 ,725 ,608 ,873 ,254 ,270 ,91 ,2 ,1959 ,1446 ,1608 ,559 ,1477 ,1454 , + 1928 ,477 ,1873 ,1411 ,420 ,317 ,819 ,1648 ,1791 ,2000 ,161 ,439 ,1912 ,565 ,220 ,1343 , + 210 ,646 ,1490 ,725 ,1802 ,39 ,1497 ,1774 ,2044 ,1343 ,762 ,1808 ,422 ,1440 ,481 ,1956 , + 973 ,800 ,552 ,1917 ,946 ,1652 ,160 ,582 ,356 ,891 ,1165 ,1655 ,812 ,811 ,644 ,1523 , + 748 ,1874 ,1518 ,228 ,423 ,1905 ,1683 ,1093 ,619 ,599 ,133 ,513 ,941 ,1260 ,989 ,1086 , + 1968 ,1625 ,552 ,1334 ,946 ,300 ,1512 ,1321 ,1686 ,604 ,504 ,1631 ,1275 ,1112 ,1246 ,1448 , + 1562 ,1593 ,1544 ,443 ,1504 ,873 ,749 ,249 ,619 ,1664 ,342 ,139 ,170 ,195 ,1911 ,643 , + 758 ,724 ,552 ,1289 ,894 ,1735 ,1394 ,13 ,1429 ,1903 ,1633 ,667 ,2037 ,1169 ,1816 ,553 , + 1445 ,1364 ,42 ,1551 ,1370 ,57 ,1351 ,1414 ,942 ,1340 ,488 ,1702 ,141 ,1502 ,1308 ,174 , + 1584 ,710 ,214 ,1996 ,1021 ,560 ,1648 ,1341 ,1951 ,1902 ,1372 ,2047 ,967 ,814 ,1238 ,1322 , + 1317 ,299 ,679 ,659 ,1849 ,1822 ,716 ,1656 ,1089 ,94 ,705 ,985 ,787 ,569 ,744 ,1899 , + 1332 ,1738 ,13 ,845 ,1010 ,158 ,194 ,1965 ,889 ,386 ,1343 ,1886 ,134 ,332 ,1567 ,1960 , + 635 ,1993 ,797 ,357 ,1517 ,114 ,1397 ,1844 ,1687 ,1703 ,418 ,116 ,280 ,66 ,965 ,555 , + 348 ,1603 ,1284 ,121 ,1824 ,104 ,720 ,1415 ,351 ,880 ,1106 ,1845 ,697 ,132 ,670 ,1572 , + 678 ,485 ,988 ,934 ,1451 ,1050 ,1953 ,499 ,129 ,1625 ,1192 ,1924 ,1668 ,108 ,891 ,576 , + 20 ,665 ,1146 ,1509 ,389 ,1077 ,493 ,110 ,48 ,535 ,1187 ,1970 ,418 ,1869 ,1548 ,139 , + 478 ,665 ,391 ,1030 ,182 ,1120 ,1984 ,1095 ,1540 ,1637 ,532 ,527 ,1077 ,482 ,696 ,972 , + 1870 ,707 ,1392 ,1113 ,1469 ,1116 ,1436 ,1010 ,768 ,1606 ,1051 ,1745 ,95 ,101 ,135 ,964 , + 1750 ,722 ,1894 ,497 ,1541 ,862 ,1863 ,1500 ,1977 ,1906 ,1435 ,1785 ,644 ,1112 ,938 ,668 , + 853 ,1871 ,444 ,1105 ,1670 ,61 ,240 ,1162 ,154 ,1764 ,1404 ,585 ,1160 ,1796 ,199 ,1161 , + 1356 ,1942 ,1227 ,416 ,1994 ,419 ,174 ,1512 ,1619 ,751 ,1758 ,1892 ,1607 ,1154 ,1200 ,616 , + 1674 ,1899 ,1082 ,597 ,544 ,619 ,1605 ,1055 ,257 ,1584 ,743 ,935 ,1879 ,316 ,1621 ,248 , + 1772 ,697 ,1451 ,1696 ,1488 ,162 ,485 ,1261 ,799 ,1019 ,1689 ,874 ,971 ,605 ,367 ,1532 , + 1674 ,1942 ,1328 ,47 ,1632 ,1343 ,1177 ,1914 ,1428 ,735 ,980 ,322 ,834 ,872 ,1044 ,595 , + 783 ,835 ,1171 ,432 ,880 ,1274 ,966 ,814 ,1732 ,1740 ,359 ,1157 ,868 ,832 ,1721 ,681 , + 1824 ,773 ,1878 ,1370 ,365 ,1605 ,437 ,1932 ,1832 ,821 ,1700 ,1075 ,47 ,1153 ,1724 ,1585 , + 1111 ,213 ,154 ,1094 ,1892 ,279 ,1534 ,811 ,1182 ,466 ,1728 ,1985 ,222 ,274 ,112 ,943 , + 1056 ,150 ,818 ,1300 ,1749 ,274 ,598 ,408 ,947 ,465 ,1510 ,372 ,762 ,1120 ,533 ,68 , + 36 ,1022 ,746 ,1125 ,251 ,1560 ,1205 ,927 ,1582 ,484 ,985 ,633 ,1876 ,143 ,1317 ,1344 , + 1754 ,79 ,1647 ,1462 ,1997 ,1157 ,1507 ,1013 ,1460 ,444 ,435 ,1801 ,48 ,2025 ,1049 ,1379 , + 331 ,675 ,64 ,953 ,1034 ,771 ,1275 ,728 ,2001 ,1966 ,443 ,1752 ,975 ,787 ,1432 ,596 , + 149 ,1441 ,930 ,680 ,925 ,1790 ,1082 ,1746 ,1316 ,1907 ,473 ,37 ,220 ,1512 ,1824 ,1837 , + 1117 ,629 ,306 ,29 ,2037 ,71 ,901 ,1276 ,1144 ,1984 ,564 ,781 ,1693 ,1615 ,2000 ,1540 , + 43 ,1190 ,2039 ,358 ,1468 ,1371 ,1132 ,412 ,826 ,556 ,1174 ,1089 ,649 ,997 ,1476 ,1924 , + 114 ,1981 ,2004 ,1575 ,1562 ,689 ,1445 ,324 ,1835 ,904 ,1500 ,713 ,1785 ,1397 ,757 ,1528 , + 389 ,284 ,959 ,1218 ,752 ,1370 ,1374 ,1077 ,879 ,491 ,1697 ,491 ,19 ,315 ,275 ,970 , + 544 ,1716 ,454 ,1541 ,1317 ,353 ,1622 ,2041 ,479 ,342 ,79 ,1603 ,133 ,1340 ,1050 ,681 , + 609 ,979 ,1676 ,1400 ,187 ,1564 ,1860 ,1954 ,666 ,1581 ,1804 ,1451 ,1415 ,189 ,298 ,1962 , + 624 ,1114 ,2036 ,1941 ,467 ,468 ,101 ,1462 ,1138 ,177 ,349 ,376 ,425 ,130 ,1838 ,63 , + 309 ,809 ,1676 ,885 ,144 ,1722 ,21 ,338 ,630 ,668 ,1691 ,798 ,1310 ,1893 ,429 ,755 , + 191 ,512 ,798 ,685 ,453 ,955 ,2012 ,1253 ,1560 ,1129 ,1275 ,591 ,977 ,1474 ,1662 ,1392 , + 1920 ,142 ,1809 ,1178 ,343 ,1363 ,885 ,1241 ,794 ,1092 ,277 ,151 ,956 ,1976 ,1188 ,528 , + 1152 ,526 ,1957 ,269 ,648 ,1051 ,894 ,219 ,1292 ,1812 ,28 ,825 ,463 ,315 ,476 ,406 , + 760 ,1121 ,337 ,1886 ,503 ,248 ,1023 ,769 ,1549 ,219 ,571 ,1545 ,453 ,1115 ,1039 ,130 , + 1436 ,750 ,1870 ,1455 ,485 ,1850 ,1010 ,1852 ,1324 ,574 ,1941 ,554 ,1741 ,1455 ,163 ,630 , + 950 ,1933 ,1168 ,2004 ,822 ,130 ,1247 ,1318 ,1451 ,392 ,1901 ,805 ,1207 ,114 ,417 ,1733 , + 1082 ,799 ,1076 ,1911 ,781 ,1633 ,1883 ,676 ,1060 ,1943 ,2015 ,451 ,231 ,622 ,1275 ,654 , + 1870 ,972 ,786 ,407 ,1069 ,1834 ,781 ,1573 ,1506 ,1200 ,318 ,1616 ,792 ,1950 ,1507 ,13 , + 1752 ,563 ,1044 ,846 ,1806 ,755 ,1132 ,253 ,1810 ,1506 ,492 ,560 ,917 ,684 ,200 ,1932 , + 1404 ,412 ,2012 ,487 ,935 ,687 ,571 ,1442 ,1252 ,215 ,578 ,577 ,22 ,1538 ,910 ,1148 , + 1333 ,413 ,18 ,532 ,1823 ,1689 ,1786 ,1984 ,37 ,859 ,1316 ,1008 ,1136 ,2026 ,1290 ,743 , + 277 ,1051 ,547 ,1178 ,296 ,297 ,1645 ,105 ,494 ,329 ,622 ,493 ,459 ,887 ,720 ,341 , + 594 ,814 ,724 ,1951 ,1455 ,1402 ,1674 ,1357 ,1712 ,829 ,1467 ,722 ,1145 ,307 ,309 ,126 , + 140 ,218 ,402 ,1307 ,1474 ,1919 ,1375 ,1568 ,93 ,420 ,93 ,401 ,622 ,1092 ,1637 ,1101 , + 110 ,1402 ,55 ,1495 ,1676 ,1913 ,1751 ,195 ,454 ,1681 ,826 ,539 ,1503 ,261 ,387 ,1654 , + 1556 ,1414 ,1251 ,1461 ,1695 ,531 ,1155 ,1552 ,1843 ,1987 ,1758 ,18 ,169 ,1906 ,1156 ,181 , + 438 ,1394 ,1659 ,1811 ,277 ,2031 ,478 ,1620 ,964 ,424 ,858 ,304 ,283 ,1568 ,517 ,2040 , + 1080 ,501 ,1799 ,856 ,1610 ,1257 ,377 ,723 ,408 ,1599 ,913 ,1688 ,952 ,1972 ,1654 ,1999 , + 1884 ,617 ,432 ,1754 ,1486 ,1873 ,123 ,1722 ,1247 ,515 ,1470 ,923 ,1984 ,1446 ,280 ,687 , + 1765 ,1955 ,46 ,81 ,1600 ,1077 ,1325 ,147 ,1138 ,1173 ,1401 ,1101 ,1116 ,1826 ,868 ,1542 , + 531 ,218 ,459 ,1583 ,8 ,1616 ,1355 ,1458 ,1970 ,715 ,1167 ,1219 ,1726 ,1137 ,1174 ,1166 , + 1004 ,349 ,1264 ,1178 ,1574 ,1635 ,342 ,1163 ,337 ,1149 ,1068 ,1965 ,838 ,1937 ,903 ,1190 , + 1173 ,200 ,910 ,377 ,429 ,631 ,776 ,1106 ,126 ,142 ,143 ,1723 ,566 ,1904 ,531 ,2038 , + 1262 ,1739 ,1264 ,1870 ,1878 ,1904 ,475 ,1333 ,873 ,1362 ,1852 ,719 ,490 ,1838 ,1587 ,1213 , + 1380 ,399 ,2016 ,133 ,1784 ,1612 ,1818 ,117 ,95 ,625 ,1239 ,1894 ,585 ,1567 ,1591 ,1984 , + 1660 ,809 ,1527 ,1887 ,1689 ,657 ,946 ,1211 ,449 ,1613 ,1328 ,781 ,447 ,680 ,1074 ,1078 , + 1239 ,481 ,1620 ,1299 ,1780 ,354 ,1779 ,1386 ,863 ,1188 ,1959 ,107 ,390 ,875 ,421 ,436 , + 30 ,1706 ,127 ,252 ,1434 ,927 ,806 ,1705 ,1749 ,672 ,1072 ,1562 ,966 ,1182 ,324 ,1916 , + 1076 ,1150 ,267 ,396 ,1010 ,79 ,1444 ,1316 ,1387 ,899 ,1087 ,875 ,367 ,575 ,1982 ,798 , + 1879 ,1572 ,319 ,1089 ,848 ,380 ,1235 ,293 ,1418 ,982 ,587 ,821 ,1090 ,1752 ,350 ,1398 , + 173 ,50 ,1066 ,94 ,857 ,1462 ,1664 ,1671 ,1305 ,564 ,1334 ,423 ,193 ,1545 ,1215 ,919 , + 1000 ,1552 ,1805 ,1490 ,1577 ,337 ,636 ,555 ,1072 ,898 ,808 ,1989 ,256 ,1034 ,561 ,1202 , + 803 ,122 ,177 ,1903 ,1397 ,1552 ,1750 ,91 ,1543 ,794 ,1632 ,577 ,996 ,153 ,595 ,654 , + 862 ,1859 ,1144 ,1326 ,1485 ,539 ,1709 ,1583 ,19 ,1242 ,838 ,871 ,1622 ,911 ,347 ,310 , + 1423 ,1173 ,192 ,678 ,1085 ,1395 ,1173 ,1283 ,849 ,1946 ,303 ,999 ,1290 ,1579 ,472 ,1491 , + 1026 ,112 ,791 ,1381 ,1390 ,83 ,365 ,1065 ,1047 ,8 ,1660 ,1717 ,1787 ,1554 ,689 ,565 , + 552 ,262 ,502 ,1270 ,964 ,1276 ,1707 ,1763 ,1610 ,265 ,597 ,423 ,1824 ,1522 ,18 ,1424 , + 609 ,520 ,549 ,1321 ,1568 ,1839 ,688 ,1664 ,1929 ,1109 ,328 ,519 ,1499 ,914 ,1411 ,1815 , + 889 ,515 ,1665 ,826 ,841 ,693 ,763 ,1633 ,1091 ,1636 ,1682 ,1610 ,752 ,1566 ,1774 ,1159 , + 919 ,1132 ,1175 ,19 ,456 ,1828 ,967 ,603 ,795 ,828 ,188 ,654 ,291 ,510 ,1740 ,787 , + 513 ,1576 ,293 ,1399 ,973 ,751 ,1101 ,1344 ,1649 ,1699 ,1478 ,554 ,112 ,1411 ,817 ,312 , + 599 ,756 ,802 ,343 ,438 ,1358 ,1376 ,1970 ,449 ,1180 ,1544 ,1674 ,955 ,1030 ,1627 ,1779 , + 638 ,567 ,1191 ,2022 ,1636 ,840 ,353 ,45 ,592 ,1118 ,1711 ,1884 ,1396 ,1928 ,349 ,545 , + 1705 ,54 ,617 ,959 ,250 ,1728 ,2030 ,565 ,1505 ,158 ,2045 ,1393 ,1242 ,1767 ,378 ,1502 , + 521 ,1695 ,1466 ,149 ,959 ,687 ,914 ,1776 ,960 ,1029 ,661 ,788 ,1557 ,1027 ,1721 ,586 , + 2047 ,130 ,1902 ,1283 ,403 ,1225 ,460 ,105 ,489 ,1293 ,1846 ,1499 ,608 ,244 ,1131 ,401 , + 1985 ,1844 ,660 ,1259 ,1586 ,1195 ,1782 ,572 ,455 ,1427 ,1989 ,1905 ,412 ,1784 ,746 ,1060 , + 90 ,1636 ,19 ,914 ,1176 ,1496 ,154 ,168 ,771 ,1722 ,1158 ,1174 ,2022 ,1806 ,1344 ,759 , + 1279 ,1467 ,945 ,666 ,487 ,1409 ,1999 ,1259 ,930 ,45 ,273 ,540 ,1014 ,272 ,1108 ,1605 , + 1223 ,1961 ,401 ,655 ,1065 ,80 ,1652 ,1075 ,1103 ,150 ,949 ,579 ,465 ,1678 ,657 ,1298 , + 702 ,1800 ,396 ,1583 ,1296 ,1974 ,306 ,1366 ,492 ,911 ,1346 ,1259 ,1343 ,1109 ,329 ,1589 , + 1565 ,1882 ,1314 ,353 ,1773 ,251 ,30 ,510 ,781 ,1187 ,961 ,1473 ,1550 ,381 ,63 ,1826 , + 1275 ,1842 ,1138 ,1747 ,751 ,402 ,602 ,167 ,1189 ,54 ,576 ,1974 ,466 ,537 ,805 ,1117 , +}; diff --git a/examples/tts/tts-csm.cpp b/examples/tts/tts-csm.cpp new file mode 100644 index 0000000000000..d9a5ef1102d89 --- /dev/null +++ b/examples/tts/tts-csm.cpp @@ -0,0 +1,479 @@ +#include "ggml.h" +#include "llama.h" +#include "common.h" +#include "log.h" +#include "arg.h" +#include "mimi-model.h" +#include "tts-csm-data.h" + +#include +#include +#include +#include +#include +#include // memcpy and strcmp +#include + +// For more details on how this works, see: https://github.com/ggml-org/llama.cpp/pull/12648 + +static void print_usage(int, char ** argv) { + LOG("\nExample usage:\n"); + LOG("\n By default, model will be downloaded from https://huggingface.co/ggml-org/sesame-csm-1b-GGUF"); + LOG("\n %s -p \"[0]I have a dream that one day every valley shall be exalted\" -o output.wav", argv[0]); + LOG("\n"); + LOG("\n To use a local model, specify the path to the model file:"); + LOG("\n %s -p ... -m sesame-csm-backbone.gguf -mv kyutai-mimi.gguf -o output.wav", argv[0]); + LOG("\n"); + LOG("\n Note: the model need 2 files to run, one ends with '-backbone-.gguf' and the other ends with '-decoder.gguf'"); + LOG("\n"); + LOG("\nPrompt format:"); + LOG("\n Each line must start with speaker ID in square brackets, followed by the text. One turn per line. A full stop is recommended at the end of each turn"); + LOG("\n Example:"); + LOG("\n [0]Hey how are you doing."); + LOG("\n [1]Pretty good, pretty good."); + LOG("\n If you want to enter long text, use -f file.txt to read from file"); + LOG("\n"); +} + +struct speaker_turn { + std::string text; + std::vector audio_embd; // only used for system prompt (speaker reference) processing + size_t n_embd_tokens = 0; +}; + +// split text containing "[N]..." into speaker turns +static std::vector get_speaker_turns(const std::string & input) { + if (input.empty()) { + LOG_ERR("Empty input\n"); + return {}; + } + if (input[0] != '[') { + LOG_ERR("Invalid input format: missing speaker ID\n"); + return {}; + } + std::regex re(R"((\[\d+\][\s\S]*?)(?=\[\d+\]|$))"); + std::smatch match; + std::vector turns; + std::string::const_iterator searchStart(input.cbegin()); + while (std::regex_search(searchStart, input.cend(), match, re)) { + std::string turn_text = match[1].str(); + if (turn_text.empty()) { + continue; + } + // clean up newline, the model is quite sensitive to this + string_replace_all(turn_text, "\n", " "); + turn_text = string_strip(turn_text); + // add turn + speaker_turn turn; + turn.text = turn_text; + turns.push_back(turn); + searchStart = match.suffix().first; + } + return turns; +} + +static speaker_turn get_ref_speaker_turn(const char * text, std::initializer_list & codes, std::vector & codebook) { + const size_t n_embd = 2048; + const size_t n_codes_per_codebook = 2051; + const size_t n_codebooks = 32; + GGML_ASSERT(codebook.size() == n_embd * n_codes_per_codebook * n_codebooks); + GGML_ASSERT(codes.size() % 32 == 0); + + // 1 frame = 32 codes + size_t n_frames = codes.size() / n_codebooks; + speaker_turn turn; + turn.text = text; + turn.audio_embd.reserve((n_frames+1) * n_embd); + turn.n_embd_tokens = n_frames+1; // +1 for EOS frame + + for (size_t i_fr = 0; i_fr <= n_frames; i_fr++) { + std::vector frame_embd_sum(n_embd, 0.0f); + + for (size_t i_cb = 0; i_cb < n_codebooks; i_cb++) { + const size_t code = i_fr == n_frames + ? 0 // insert audio EOS for last pseudo-frame + : codes.begin()[i_fr*n_codebooks + i_cb]; + printf("code %zu: %zu, codebook entry %zu\n", i_cb, code, i_cb*n_codes_per_codebook + code); + float * entry = codebook.data() + i_cb*n_codes_per_codebook*n_embd + code*n_embd; + for (size_t i_embd = 0; i_embd < n_embd; i_embd++) { + frame_embd_sum[i_embd] += entry[i_embd]; + } + } + + turn.audio_embd.insert(turn.audio_embd.end(), frame_embd_sum.begin(), frame_embd_sum.end()); + } + + GGML_ASSERT(turn.audio_embd.size() == (n_frames+1) * n_embd); + return turn; +} + +// sampling with custom n_vocab +// modified version of llama_sampler_sample() +static llama_token sample_token(struct llama_sampler * smpl, const float * logits, int n_vocab) { + std::vector cur; + cur.reserve(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + } + + llama_token_data_array cur_p = { + /* .data = */ cur.data(), + /* .size = */ cur.size(), + /* .selected = */ -1, + /* .sorted = */ false, + }; + + llama_sampler_apply(smpl, &cur_p); + GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size); + auto token = cur_p.data[cur_p.selected].id; + llama_sampler_accept(smpl, token); + return token; +} + +struct hook_data { + std::vector embd; + std::vector codebook; +}; + +// hook to retrieve the embeddings +static bool ggml_callback(struct ggml_tensor * t, bool ask, void * user_data) { + hook_data * data = (hook_data *) user_data; + + // output_csm_proj is the embeddings output from backbone + // output_audio_embd is the embeddings output from decoder + if (t && (strcmp(t->name, "output_csm_proj") == 0 || strcmp(t->name, "output_audio_embd") == 0)) { + if (ask) return true; + + GGML_ASSERT(t->type == GGML_TYPE_F32); + data->embd.resize(ggml_nelements(t)); + ggml_backend_tensor_get(t, data->embd.data(), 0, ggml_nbytes(t)); + // printf("%s tensor size: %lld, %lld\n", t->name, t->ne[0], t->ne[1]); + return true; + } + + if (t && strncmp(t->name, "audio_embd.weight", 18) == 0) { + if (ask) return true; + + // printf("%s tensor size: %lld, %lld\n", t->name, t->ne[0], t->ne[1]); + GGML_ASSERT(t->type == GGML_TYPE_F32); + GGML_ASSERT(t->ne[0] == 2048); // backbone embd size + data->codebook.resize(ggml_nelements(t)); + ggml_backend_tensor_get(t, data->codebook.data(), 0, ggml_nbytes(t)); + return true; + } + + return false; +} + +// convenience wrapper around llama_batch to handle memory allocation +struct decode_embd_batch { + std::vector pos; + std::vector n_seq_id; + std::vector seq_id_0; + std::vector seq_ids; + std::vector logits; + llama_batch batch; + decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { + pos .resize(n_tokens); + n_seq_id.resize(n_tokens); + seq_ids .resize(n_tokens + 1); + logits .resize(n_tokens); + seq_id_0.resize(1); + seq_id_0[0] = seq_id; + seq_ids [n_tokens] = nullptr; + batch = { + /*n_tokens =*/ n_tokens, + /*tokens =*/ nullptr, + /*embd =*/ embd, + /*pos =*/ pos.data(), + /*n_seq_id =*/ n_seq_id.data(), + /*seq_id =*/ seq_ids.data(), + /*logits =*/ logits.data(), + }; + for (int i = 0; i < n_tokens; i++) { + batch.pos [i] = pos_0 + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i] = seq_id_0.data(); + batch.logits [i] = false; + } + } +}; + +int main(int argc, char ** argv) { + common_params params; + + params.model.path = "sesame-csm-backbone.gguf"; + params.vocoder.model.path = "kyutai-mimi.gguf"; + params.out_file = "output.wav"; + params.prompt = ""; + params.n_predict = 2048; // CSM's max trained seq length + params.sampling.top_k = 50; // default param from CSM python code + params.sampling.temp = 0.9; // default param from CSM python code + + // HF model (hack: we temporary reuse speculative.model as the decoder model, only to get it downloaded) + params.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-backbone.gguf"; + params.speculative.model.path = "sesame-csm-decoder.gguf"; + params.speculative.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/sesame-csm-decoder.gguf"; + params.vocoder.model.url = "https://huggingface.co/ggml-org/sesame-csm-1b-GGUF/resolve/main/kyutai-mimi.gguf"; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_TTS, print_usage)) { + return 1; + } + + llama_backend_init(); + llama_numa_init(params.numa); + + if (params.prompt.empty()) { + LOG_ERR("prompt is empty\n"); + return 1; + } + + hook_data cb_data; + params.cb_eval = ggml_callback; + params.cb_eval_user_data = &cb_data; + + common_params params_decoder(params); // duplicate the params + params_decoder.n_ctx = 64; // we never use more than this + string_replace_all(params_decoder.model.path, "-backbone", "-decoder"); + string_replace_all(params_decoder.model.url, "-backbone", "-decoder"); + + common_init_result llama_backbone = common_init_from_params(params); + llama_model * model_bb = llama_backbone.model.get(); + llama_context * ctx_bb = llama_backbone.context.get(); + + common_init_result llama_decoder = common_init_from_params(params_decoder); + llama_model * model_dc = llama_decoder.model.get(); + llama_context * ctx_dc = llama_decoder.context.get(); + + if (model_bb == nullptr || ctx_bb == nullptr) { + return ENOENT; + } + + if (model_dc == nullptr || ctx_dc == nullptr) { + return ENOENT; + } + + mimi_model mimi(params.vocoder.model.path.c_str(), true); + + // init sampler + // the python implementation only has top-k and temperature sampling, so we'll use just that + llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params())); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_top_k(params.sampling.top_k)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(params.sampling.temp)); + llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(params.sampling.seed)); + + llama_batch batch_prompt = llama_batch_init(params.n_batch, 0, 1); + llama_pos n_past_bb = 0; + + // inp_past_embd is the "squashed" embeddings from the decoder + std::vector inp_past_embd(2048, 0.0f); + llama_batch batch_past_embd = llama_batch_init(1, inp_past_embd.size(), 1); + + int64_t t_gb_start = ggml_time_ms(); // global start time + int64_t t_bb = 0; // backbone time + int64_t n_bb_gen = 0; // backbone generation count + int64_t t_dc = 0; // decoder time + int64_t n_dc_gen = 0; // decoder generation count + + std::vector generated_codes; + + std::vector turns; + // speaker reference + turns.push_back(get_ref_speaker_turn(default_speaker_a_text, default_speaker_a_codes, cb_data.codebook)); + turns.push_back(get_ref_speaker_turn(default_speaker_b_text, default_speaker_b_codes, cb_data.codebook)); + + // user input + auto custom_turns = get_speaker_turns(params.prompt); + turns.insert(turns.end(), custom_turns.begin(), custom_turns.end()); + + for (speaker_turn & turn : turns) { + // tokenize the turn + llama_tokens prompt_tokens; + { + printf("\n---\n\nturn: %s\n\n", turn.text.c_str()); + const llama_vocab * vocab = llama_model_get_vocab(model_bb); + prompt_tokens = common_tokenize(vocab, turn.text, false, true); + prompt_tokens.insert(prompt_tokens.begin(), llama_vocab_bos(vocab)); + prompt_tokens.insert(prompt_tokens.end(), llama_vocab_eos(vocab)); + + printf("prompt (%zu tokens): \n", prompt_tokens.size()); + for (size_t i = 0; i < prompt_tokens.size(); ++i) { + printf("%d, ", prompt_tokens[i]); + } + printf("\n\n"); + + common_batch_clear(batch_prompt); + for (size_t i = 0; i < prompt_tokens.size(); ++i) { + common_batch_add(batch_prompt, prompt_tokens[i], n_past_bb++, { 0 }, false); + } + batch_prompt.logits[batch_prompt.n_tokens - 1] = true; + + if (llama_decode(ctx_bb, batch_prompt) != 0) { + LOG_ERR("%s: backbone llama_decode(text) failed\n", __func__); + return 1; + } + } + + // optionally process the system prompt (speaker reference) + if (turn.n_embd_tokens) { + decode_embd_batch batch_embd(turn.audio_embd.data(), turn.n_embd_tokens, n_past_bb, 0); + if (llama_decode(ctx_bb, batch_embd.batch) != 0) { + LOG_ERR("%s: backbone llama_decode(embeddings) failed\n", __func__); + return 1; + } + LOG_INF("%s: backbone done decoding %zu audio codes\n\n", __func__, turn.n_embd_tokens); + n_past_bb += turn.n_embd_tokens; + continue; // no need to generate the audio + } + + // backbone generation loop + bool is_end_of_turn = false; + for (int k = 0; k < params.n_predict; ++k) { + bool is_first_tok = k == 0; + + if (!is_first_tok) { + // generate the next RVQ semantic token + batch_past_embd.n_tokens = 1; + batch_past_embd.pos[0] = n_past_bb++; + batch_past_embd.seq_id[0][0] = 0; + batch_past_embd.n_seq_id[0] = 1; + batch_past_embd.logits[0] = true; + std::memcpy(batch_past_embd.embd, inp_past_embd.data(), inp_past_embd.size() * sizeof(float)); + + int64_t t_bb_start = ggml_time_ms(); + if (llama_decode(ctx_bb, batch_past_embd) != 0) { + LOG_ERR("%s: backbone llama_decode() failed\n", __func__); + return 1; + } + n_bb_gen++; + t_bb += ggml_time_ms() - t_bb_start; + } + + if (is_end_of_turn) { + // done decoding audio's EOS token + break; + } + + auto vocab_dc = llama_model_get_vocab(model_dc); + auto logits = llama_get_logits_ith(ctx_bb, is_first_tok ? (batch_prompt.n_tokens - 1) : 0); + // for (size_t i = 0; i < 10; ++i) { + // printf("%4.2f, ", logits[i]); + // } + // printf("\n"); + + llama_token semantic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc)); + printf("Sem token %5d : %d,", 1+(int)generated_codes.size()/32, semantic_tok); + generated_codes.push_back(semantic_tok); + + // for (size_t i = 0; i < 10; ++i) { + // printf("%4.2f, ", embd[i]); + // } + // printf("\n"); + + + // decoder generation loop + inp_past_embd = std::vector(inp_past_embd.size(), 0.0f); + { + llama_kv_self_clear(ctx_dc); + llama_batch batch_embd = llama_batch_init(1, cb_data.embd.size(), 1); + llama_batch batch_token = llama_batch_init(1, 0, 1); + + // first "token" is the latent embeddings from backbone + { + batch_embd.n_tokens = 1; + batch_embd.pos[0] = 0; + batch_embd.seq_id[0][0] = 0; + batch_embd.n_seq_id[0] = 1; + batch_embd.logits[0] = false; + std::memcpy(batch_embd.embd, cb_data.embd.data(), cb_data.embd.size() * sizeof(float)); + } + if (llama_decode(ctx_dc, batch_embd) != 0) { + LOG_ERR("%s: decoder llama_decode(embd) failed\n", __func__); + return 1; + } + + // then, decode the semantic_tok to generate acoustic tokens + llama_token tok = semantic_tok; + int n_codes = 32; + int sum_codes = semantic_tok; // to check if all codes are 0 + for (int i = 0; i < n_codes; ++i) { + common_batch_clear(batch_token); + // encoder vocab is further divided into 32 codebooks, each with 2051 entries + llama_token inp_tok = tok + 2051*i; + common_batch_add(batch_token, inp_tok, i+1, { 0 }, true); + + int64_t t_bb_start = ggml_time_ms(); + if (llama_decode(ctx_dc, batch_token) != 0) { + LOG_ERR("%s: decoder llama_decode(token) failed\n", __func__); + return 1; + } + n_dc_gen++; + t_dc += ggml_time_ms() - t_bb_start; + + // sample the acoustic token + auto logits = llama_get_logits_ith(ctx_dc, 0); + llama_token acoustic_tok = sample_token(sampler.get(), logits, llama_vocab_n_tokens(vocab_dc)); + + // discard last code (only for embeddings) + if (i < n_codes - 1) { + printf("%d,", acoustic_tok); + tok = acoustic_tok; // next input token + sum_codes += acoustic_tok; + generated_codes.push_back(acoustic_tok); + } + + // do progressive hsum of embeddings + GGML_ASSERT(inp_past_embd.size() == cb_data.embd.size()); + for (size_t i = 0; i < inp_past_embd.size(); ++i) { + inp_past_embd[i] += cb_data.embd[i]; + } + } + printf("\n"); + + llama_batch_free(batch_embd); + llama_batch_free(batch_token); + + // if all codes are 0, then we are done (got audio EOS token) + // note: we still need to run backbone decode one more time to decode the audio's EOS token + is_end_of_turn = sum_codes == 0; + if (is_end_of_turn) { + // remove last 32 codes since they will be all zeros + generated_codes.resize(generated_codes.size() - 32); + } + } + + // printf("inp_past_embd, n_past_bb = %d\n", n_past_bb); + // for (size_t i = 0; i < inp_past_embd.size(); ++i) { + // printf("%4.4f, ", inp_past_embd[i]); + // if (i == 2) { + // printf("... "); + // i = inp_past_embd.size() - 4; + // } + // } + // printf("\n"); + } + } + + // print timing info + printf("\ntimings:\n"); + printf(" backbone: %" PRId64 " ms, %" PRId64 " generated token (%.2f tok/s)\n", t_bb, n_bb_gen, (float)n_bb_gen*1000/(float)t_bb); + printf(" decoder: %" PRId64 " ms, %" PRId64 " generated token (%.2f tok/s)\n", t_dc, n_dc_gen, (float)n_dc_gen*1000/(float)t_dc); + printf(" total: %" PRId64 " ms\n\n", ggml_time_ms() - t_gb_start); + + llama_batch_free(batch_prompt); + llama_batch_free(batch_past_embd); + + printf("decode %zu RVQ tokens into wav...\n", generated_codes.size()); + std::vector wav_data = mimi.decode(generated_codes); + + printf("output wav file: %s\n", params.out_file.c_str()); + + if (!save_wav16(params.out_file.c_str(), wav_data, mimi.get_sample_rate())) { + LOG_ERR("Failed to save wav file\n"); + return 1; + } + + printf("\n"); + + return 0; +} diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 0f047986965f8..e5e0dd4573fda 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -71,46 +71,6 @@ static void print_usage(int, char ** argv) { LOG("\n"); } -struct wav_header { - char riff[4] = {'R', 'I', 'F', 'F'}; - uint32_t chunk_size; - char wave[4] = {'W', 'A', 'V', 'E'}; - char fmt[4] = {'f', 'm', 't', ' '}; - uint32_t fmt_chunk_size = 16; - uint16_t audio_format = 1; // PCM - uint16_t num_channels = 1; // Mono - uint32_t sample_rate; - uint32_t byte_rate; - uint16_t block_align; - uint16_t bits_per_sample = 16; - char data[4] = {'d', 'a', 't', 'a'}; - uint32_t data_size; -}; - -static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate) { - std::ofstream file(fname, std::ios::binary); - if (!file) { - LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); - return false; - } - - wav_header header; - header.sample_rate = sample_rate; - header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); - header.block_align = header.num_channels * (header.bits_per_sample / 8); - header.data_size = data.size() * (header.bits_per_sample / 8); - header.chunk_size = 36 + header.data_size; - - file.write(reinterpret_cast(&header), sizeof(header)); - - for (const auto & sample : data) { - int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); - file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); - } - - return file.good(); -} - static void fill_hann_window(int length, bool periodic, float * output) { int offset = -1; if (periodic) { diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 62e1480bb5881..c3885c41c1fa1 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -6,6 +6,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_LLAMA, "llama" }, + { LLM_ARCH_LLAMA_CSM, "llama-csm" }, { LLM_ARCH_LLAMA4, "llama4" }, { LLM_ARCH_DECI, "deci" }, { LLM_ARCH_FALCON, "falcon" }, @@ -217,27 +218,57 @@ static const std::map> LLM_TENSOR_N { LLM_ARCH_LLAMA, { - { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, - { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, - { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, - { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, - { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, - { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, - { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, - { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, - { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, - { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, - { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, - { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, - { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, - { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, - { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, - { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, - { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, - { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, - { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, - { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, + { + LLM_ARCH_LLAMA_CSM, // like LLM_ARCH_LLAMA, but with extra tensors for Sesame CSM + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" }, + { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, + { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_CSM_AUDIO_EMBD, "audio_embd" }, + { LLM_TENSOR_CSM_CBOOK_OUTPUT, "codebook0_head" }, + { LLM_TENSOR_CSM_AUDIO_OUTPUT, "audio_head" }, + { LLM_TENSOR_CSM_PROJ, "csm_proj" }, }, }, { @@ -1676,6 +1707,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CSM_AUDIO_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_CSM_CBOOK_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CSM_AUDIO_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CSM_PROJ, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL_MAT}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} diff --git a/src/llama-arch.h b/src/llama-arch.h index 98ca00a1bd0b0..cb6ebd50ff377 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -10,6 +10,7 @@ enum llm_arch { LLM_ARCH_LLAMA, + LLM_ARCH_LLAMA_CSM, LLM_ARCH_LLAMA4, LLM_ARCH_DECI, LLM_ARCH_FALCON, @@ -360,6 +361,10 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + LLM_TENSOR_CSM_AUDIO_EMBD, + LLM_TENSOR_CSM_CBOOK_OUTPUT, + LLM_TENSOR_CSM_AUDIO_OUTPUT, + LLM_TENSOR_CSM_PROJ, }; enum llm_tensor_layer { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6b7bfecf3a1cf..cd549e986c2a9 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -508,7 +508,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); - if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON) { + if (arch == LLM_ARCH_LLAMA + || arch == LLM_ARCH_LLAMA_CSM + || arch == LLM_ARCH_DECI + || arch == LLM_ARCH_FALCON + ) { if (hparams.n_rot != hparams.n_embd_head_k) { throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); } @@ -526,6 +530,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // arch-specific KVs switch (arch) { case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA_CSM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1738,6 +1743,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } } break; + case LLM_ARCH_LLAMA_CSM: + { + // TODO: maybe store these in gguf metadata + int64_t csm_audio_cbook_size = 2051; // audio codebook size + int64_t csm_audio_tokens = 32; // equal to number of audio tokens for Mimi + //int64_t csm_n_audio_vocab = csm_audio_cbook_size*csm_acoustic_tokens; + + csm_output_cbook = create_tensor(tn(LLM_TENSOR_CSM_CBOOK_OUTPUT, "weight"), {n_embd, csm_audio_cbook_size}, TENSOR_NOT_REQUIRED); + + bool is_backbone = csm_output_cbook != nullptr; + + csm_output_audio = is_backbone ? nullptr + : create_tensor(tn(LLM_TENSOR_CSM_AUDIO_OUTPUT, "weight"), {n_embd, csm_audio_cbook_size, csm_audio_tokens+1}, 0); + + tok_embd = is_backbone + ? create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0) + : create_tensor(tn(LLM_TENSOR_CSM_AUDIO_EMBD, "weight"), {n_embd*2, n_vocab}, 0); + + csm_proj = is_backbone + ? create_tensor(tn(LLM_TENSOR_CSM_PROJ, "weight"), {n_embd, n_embd/2}, 0) + : create_tensor(tn(LLM_TENSOR_CSM_PROJ, "weight"), {n_embd*2, n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // output tensor is either audio or code depends on backbone / decoder + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_LLAMA4: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -1765,6 +1812,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); @@ -4625,7 +4675,7 @@ struct llm_build_llama : public llm_graph_context { cb(cur, "result_norm", -1); res->t_embd = cur; - // lm_head + // lm_head (normal case) cur = build_lora_mm(model.output, cur); // For Granite architecture @@ -4640,6 +4690,192 @@ struct llm_build_llama : public llm_graph_context { } }; +// llama used by Sesame CSM +struct llm_build_llama_csm : public llm_graph_context { + llm_build_llama_csm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + bool is_backbone = model.csm_output_cbook; + bool is_decoder = !is_backbone; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // hacky way to get the audio embedding from user code (used in prompt processing) + // this will be triggered during warmup + if (is_decoder && n_tokens == 2) { + ggml_tensor * tmp = ggml_cast(ctx0, model.tok_embd, GGML_TYPE_F32); + cb(tmp, "audio_embd.weight", -1); + ggml_build_forward_expand(gf, tmp); + } + + ggml_tensor * input_embd = inpL; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + if (is_decoder && inpL->ne[0] != hparams.n_embd) { + inpL = build_lora_mm(model.csm_proj, inpL); + } + + auto * inp_attn = build_attn_inp_kv_unified(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, gf, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // For Granite architecture + if (hparams.f_residual_scale) { + cur = ggml_scale(ctx0, cur, hparams.f_residual_scale); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + if (model.csm_output_cbook) { + // Sesame csm backbone + // hack: because n_cbook < n_vocab, we use the first logits for the output + int64_t n_vocab = model.tok_embd->ne[1]; + int64_t n_codes = model.csm_output_cbook->ne[1]; + ggml_tensor * last_h = cur; + cur = build_lora_mm(model.csm_output_cbook, cur); + cur = ggml_pad(ctx0, cur, n_vocab - n_codes, 0, 0, 0); + + // project to csm decoder dim + last_h = build_lora_mm(model.csm_proj, last_h); + cb(last_h, "output_csm_proj", -1); // use callback to retrieve the result + ggml_build_forward_expand(gf, last_h); + + } else if (model.csm_output_audio && ggml_nelements(cur)) { + // Sesame csm decoder + // hack: because n_audio < n_vocab, we use the first logits for the output + cur = build_lora_mm_id(model.csm_output_audio, cur, inp_pos); + int64_t n_vocab = model.tok_embd->ne[1]; + int64_t n_codes = cur->ne[0]; + cur = ggml_pad(ctx0, cur, n_vocab - n_codes, cur->ne[1], 0, 0); + + // also get audio embeddings, which will be passed back to backbone to keep track of generation progress + if (ubatch.token) { + cb(input_embd, "output_audio_embd", -1); + ggml_build_forward_expand(gf, input_embd); + } + + } else { + // otherwise, dummy output + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_deci : public llm_graph_context { llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -12815,6 +13051,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_LLAMA_CSM: + { + llm = std::make_unique(*this, params, gf); + } break; case LLM_ARCH_DECI: { llm = std::make_unique(*this, params, gf); @@ -13170,6 +13410,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { // use what we call a normal RoPE, operating on pairs of consecutive head values case LLM_ARCH_LLAMA: + case LLM_ARCH_LLAMA_CSM: case LLM_ARCH_LLAMA4: case LLM_ARCH_DECI: case LLM_ARCH_BAICHUAN: diff --git a/src/llama-model.h b/src/llama-model.h index fd82d106ccda8..1527c1ea7705c 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -342,6 +342,11 @@ struct llama_model { struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; + // sesame csm + struct ggml_tensor * csm_output_cbook = nullptr; // backbone output codebook + struct ggml_tensor * csm_output_audio = nullptr; // audio decoder output + struct ggml_tensor * csm_proj = nullptr; // to convert backbone dim to decoder dim + std::vector layers; llama_model_params params;