Skip to content

A privacy-first distributed training framework built on MLX for Apple Silicon, enabling secure and efficient AI model training across multiple devices while preserving data privacy.

License

Notifications You must be signed in to change notification settings

jbarnes850/mlx-disitrubted-training

Repository files navigation

MLX Distributed Training (Beta)

A privacy-first distributed training framework built on MLX for Apple Silicon, enabling secure and efficient AI model training across multiple devices while preserving data privacy.

MLX Version Python macOS License Status Tests

What We're Building

We're training a decoder-only transformer model from scratch, optimized for Apple Silicon:

  • Architecture: Decoder-only transformer (similar to GPT-4/Llama 3)

    • 22 transformer layers
    • 2048 embedding dimensions
    • 16 attention heads
    • 8192 max sequence length
    • Training optimizations: Flash Attention, Grouped Query Attention (GQA), RoPE embeddings, SwiGLU activations
  • Goal: Train a competitive 1B parameter model that can match or exceed Llama 3.2's performance using distributed consumer hardware instead of traditional GPU clusters. Overall, we're aiming to push the boundaries of what's possible with Apple Silicon and see how performance scales with increasing model size on consumer hardware.

System Architecture

Features

  • Privacy-First: All training happens on your devices, keeping sensitive data under your control
  • Efficient: Optimized for Apple Silicon using MLX, enabling fast training on consumer hardware
  • Distributed: Scale training across multiple Macs for better performance
  • Flexible: Support for various model architectures and training configurations

Introduction to Distributed Training with MLX

This project explores the potential of distributed training on Apple Silicon, specifically targeting the development of large language models. By leveraging MLX's distributed communication framework, we're pushing the boundaries of what's possible with consumer hardware.

The primary goal is ambitious yet practical: train a 1B parameter model using a network of Mac devices that outperforms state-of-the-art results (such as llama 3.2). Traditional approaches to training models of this scale typically require expensive cloud resources or specialized hardware. This implementation demonstrates that with efficient distributed algorithms and Apple's unified architecture, we can achieve comparable results using devices many developers already own.

This framework is designed for ML engineers and researchers interested in:

  • Implementing and optimizing distributed training systems
  • Exploring novel approaches to model parallelism and gradient synchronization
  • Understanding the practical aspects of training large language models
  • Contributing to the advancement of decentralized ML infrastructure

Why MLX for Distributed Training?

After extensive experimentation with various frameworks, MLX emerged as the optimal choice for distributed training on Apple Silicon for several compelling reasons:

  1. Native Silicon Architecture Integration

    • Direct compilation to Metal, maximizing M-series chip performance
    • Seamless utilization of the Neural Engine and unified memory
    • Optimized memory bandwidth and computational throughput
    • Performance that consistently outpaces traditional frameworks on Apple hardware
  2. Advanced Communication Architecture

    • High-efficiency MPI-based inter-device communication
    • Zero-copy gradient synchronization through optimized all-reduce operations
    • Network stack specifically tuned for Apple's hardware ecosystem
    • Minimal overhead in multi-device coordination
  3. Sophisticated Memory Management

    • Leverages unified memory architecture for optimal resource utilization
    • Implements dynamic batch size adjustment based on device capabilities
    • Advanced gradient checkpointing for memory-constrained scenarios
    • Comprehensive monitoring and profiling capabilities

Our research and development focus on several key areas:

  • Scaling transformer architectures to 1B-3B parameters across distributed Mac systems
  • Implementing novel data streaming and caching strategies
  • Exploring hybrid parallelism techniques (data, model, and pipeline)
  • Developing robust distributed training protocols

This project serves as both a practical implementation and a research platform, enabling the ML community to explore distributed training techniques without the traditional barriers to entry. We welcome contributions from engineers and researchers interested in advancing the field of distributed ML training.

Installation

System Requirements

  • macOS Sonoma 14.0+ (Apple Silicon)
  • Python 3.11+
  • MLX 0.20.0+
  • High-speed network connection (10Gbps recommended)
  • SSH access configured between devices

Setup and Installation

# Install system dependencies
xcode-select --install
brew install mpich

# Clone repository
git clone https://github.com/jbarnes850/mlx_distributed
cd mlx_distributed

# Create virtual environment
python3 -m venv .venv
source .venv/bin/activate

# Install dependencies
pip install -e ".[dev]"

# Verify setup
python scripts/verify_setup.py
python scripts/test_network.py

Start Training

# On primary device (e.g., Mac Studio M2 Ultra)
./scripts/start_training.sh --role primary

# On secondary device (e.g., MacBook M3 Max)
./scripts/start_training.sh --role secondary

Monitor Progress

# Open dashboard
open http://localhost:8050

# Watch logs
tail -f logs/training.log

Network Requirements

  • High-speed connection (10Gbps+ recommended)
  • Low latency (<1ms between devices)
  • SSH access configured between devices

Documentation

Implementation Details

Our distributed training implementation follows MLX's recommended practices:

  1. Data Parallelism:

    • Each device maintains a complete model copy
    • Data is sharded across devices
    • Gradients synchronized using mx.distributed.all_sum
    • Weights broadcast periodically for consistency
  2. Memory Management:

    • Dynamic batch sizing based on device capabilities
    • Gradient accumulation for effective larger batches
    • Activation checkpointing for memory efficiency
    • Streaming data loading to manage memory usage
  3. Performance Optimization:

    • Mixed precision training
    • Separate compute/memory streams
    • Flash Attention implementation
    • Grouped Query Attention (GQA)
    • Optimized memory layout
  4. Monitoring and Recovery:

    • Real-time performance dashboard
    • Automatic error recovery
    • Checkpoint management
    • Network health monitoring

For more details on MLX's distributed capabilities, see:

Troubleshooting

Common Issues

  1. Network Communication Errors

    • Verify SSH keys are properly configured between devices
    • Check network bandwidth using scripts/test_network.py
    • Ensure all devices are on the same subnet
    • Try reducing batch_size if experiencing timeouts
  2. Memory Issues

    • Enable gradient checkpointing in config
    • Reduce model size or batch size
    • Monitor memory usage with dashboard
    • Use streaming dataset loading
  3. Performance Problems

    • Verify Metal is properly configured
    • Check CPU/GPU utilization
    • Monitor network bandwidth
    • Adjust number of worker processes
  4. Installation Issues

    • Verify Python version compatibility
    • Check MLX installation
    • Review system requirements

For more detailed troubleshooting:

Performance Tuning

For detailed information about our hardware configuration, training process, and performance optimizations, please see our Performance Tuning Guide. This guide includes:

  • Current hardware specifications and configurations
  • Training time estimates and comparisons
  • Detailed performance optimization strategies
  • Memory management techniques
  • Monitoring and stability measures

Contributing

  1. Fork the repository
  2. Create your feature branch
  3. Commit your changes
  4. Push to the branch
  5. Create a Pull Request

License

MIT License - See LICENSE for details.

About

A privacy-first distributed training framework built on MLX for Apple Silicon, enabling secure and efficient AI model training across multiple devices while preserving data privacy.

Resources

License

Stars

Watchers

Forks

Packages

No packages published