An implementation of the CSM(Conversation Speech Model) for Apple Silicon using MLX.
And fine-tuning too!
Note that this is not a CLI installation. Please refer to the CLI section for more information.
Recommendation: Give uv a try. It's truly magical.
uv add git+https://github.com/senstella/csm-mlx --upgradeOr, you can install it via pip:
pip install git+https://github.com/senstella/csm-mlx --upgradeMake sure to use Python < 3.13. It tends to have sentencepiece compiler error.
from mlx_lm.sample_utils import make_sampler from huggingface_hub import hf_hub_download from csm_mlx import CSM, csm_1b, generate import audiofile import numpy as np # Initialize the model csm = CSM(csm_1b()) # csm_1b() is a configuration for the CSM model. weight = hf_hub_download(repo_id="senstella/csm-1b-mlx", filename="ckpt.safetensors") csm.load_weights(weight) # Generate audio from text audio = generate( csm, text="Hello from Sesame.", speaker=0, context=[], max_audio_length_ms=10_000, sampler=make_sampler(temp=0.8, top_k=50), # Put mlx_lm's sampler here! Supports: temp, top_p, min_p, min_tokens_to_keep, top_k. # Additionally, you can provide `stream` argument to specify what device to use for generation. # https://ml-explore.github.io/mlx/build/html/usage/using_streams.html ) audiofile.write("./audio.wav", np.asarray(audio), 24000)from csm_mlx import CSM, csm_1b, generate, Segment import mlx.core as mx # Initialize the model csm = CSM(csm_1b()) weight = hf_hub_download(repo_id="senstella/csm-1b-mlx", filename="ckpt.safetensors") csm.load_weights(weight) # Create previous conversation segments context = [ Segment( speaker=0, text="How are you doing today?", audio=mx.array(...) # Previous audio for this segment ), Segment( speaker=1, text="I'm doing great, thank you!", audio_path="~/somewhere_in_the_universe/stuff.wav" # Or you can specify the audio path too! ) ] # Generate a response in the conversation audio = generate( csm, text="That's wonderful to hear!", speaker=0, context=context, max_audio_length_ms=5_000 # If you don't provide any sampler, greedy sampling will be used. )from mlx import nn from mlx_lm.sample_utils import make_sampler, make_logits_processors from huggingface_hub import hf_hub_download from csm_mlx import CSM, csm_1b, generate import audiofile import numpy as np # Initialize the model csm = CSM(csm_1b()) # csm_1b() is a configuration for the CSM model. weight = hf_hub_download(repo_id="senstella/csm-1b-mlx", filename="ckpt.safetensors") csm.load_weights(weight) # Quantize # If you want, specify the `group_size`, `bits`. # https://ml-explore.github.io/mlx/build/html/python/_autosummary/mlx.core.quantize.html nn.quantize(csm) # Generate audio from text audio = generate( csm, text="Hello from Sesame.", speaker=0, context=[], max_audio_length_ms=10_000, sampler=make_sampler(temp=0.8, top_k=50), # Put mlx_lm's sampler here! Supports: temp, top_p, min_p, min_tokens_to_keep, top_k. # If the model doesn't end the sequence properly, you can try to adjust the sampler parameters # or set eos_token logit bias via logits_processors=make_logits_processors(logit_bias={0: 4}) # Additionally, you can provide `stream` argument to specify what device to use for generation. # https://ml-explore.github.io/mlx/build/html/usage/using_streams.html ) audiofile.write("./audio.wav", np.asarray(audio), 24000)from mlx import nn from mlx_lm.sample_utils import make_sampler from csm_mlx import CSM, csm_1b from csm_mlx.generation import stream_generate # Initialize model csm = CSM(csm_1b()) csm.load_weights("ckpt.safetensors", strict=True) # nn.quantize(csm) — Speed up nearly real-time on M2 Air, but loses quality. # Stream generate audio chunks for chunk in stream_generate( csm, text="This is an example of streaming audio generation.", speaker=0, context=[], max_audio_length_ms=5000, accumulation_size=1, # Accumulate codebooks and decode at once (increasing = better speed/stability, more initial delay) sampler=make_sampler(temp=0.8, top_k=50), ): # Process each chunk as it's generatedIf you want to load an audio for a segment, you need to resample it to 24000.
import mlx.core as mx import audiofile import audresample def read_audio(audio_path, sampling_rate=24000) -> mx.array: signal, original_sampling_rate = audiofile.read(audio_path, always_2d=True) signal = audresample.resample(signal, original_sampling_rate, sampling_rate) signal = mx.array(signal) if signal.shape[0] >= 1: signal = signal.mean(axis=0) else: signal = signal.squeeze(0) return signal # (audio_length, )If you're having
sentencepieceerror on uv, please try adding--python 3.12flag on theuv tool installcommand.
# Recommendation: uv tools - works best! uv tool install "git+https://github.com/senstella/csm-mlx[cli]" --upgrade # Or with pipx pipx install "git+https://github.com/senstella/csm-mlx[cli]" --upgradeFinetuning CLI usage is here.
csm-mlx generate "Hello from Sesame." -o output.wavcsm-mlx generate "Hello from Sesame." \ --output output.wav \ --model 1b \ --speaker 0 \ --temperature 0.8 \ --min-p 0.05 \ --max-audio-length 10000You can provide conversation context to make the generated speech more natural — or clone a voice with it.
You must provide audio & text & speaker in the pair.
csm-mlx generate "Nice to meet you too!" \ --output response.wav \ --input-audios previous.wav \ --input-texts "Hello, nice to meet you." \ --input-speakers 1uv run --with 'git+https://github.com/senstella/csm-mlx[cli]' --python 3.12 python -m csm_mlx "Hello from Sesame." -o output.wavcsm-mlx generate [TEXT] [OPTIONS] TEXT: The text to convert to speech
-o, --output PATH: Output audio file path [required]-m, --model [1b]: Model size (default: 1b)-s, --speaker INT: Speaker ID (default: 0)-l, --max-audio-length INT: Maximum audio length in milliseconds (default: 10000 — 10 seconds)-t, --temperature, --temp FLOAT: Sampling temperature (default: 0.8)-p, --top-p FLOAT: Top-p sampling parameter-m, --min-p FLOAT: Minimum probability for sampling (default: 0.05)-k, --top-k INT: Top-k sampling parameter-kt, --min-tokens-to-keep INT: Minimum tokens to keep during sampling (default: 1)-is, --input-speakers LIST: List of speaker IDs for context-ia, --input-audios LIST: List of audio files for context-it, --input-texts LIST: List of text transcripts for context
- Fix up RoPE
- Implement watermarking
- Add streaming generation
- Optimize performance further for real-time inference
- Thanks to Sesame for original PyTorch implementation and weights!
- Thanks to Moshi project for creating the mimi codec and
mimi_mlximplementation. - Thanks to torchtune project for providing LLaMA attention implementation.
- Thanks to MLX project for providing the framework that made this implementation possible.
- Thanks to typer for powering the CLI interface.
- Thanks to audiofile and audresample for audio processing.
Apache 2.0