One model. Every language. Zero friction.

The Universal Deep Learning Bridge

Shrew is a runtime and DSL that decouples deep learning from any single technology. Define your model once in .sw, then train and deploy it from Python, Rust, JavaScript, C++, or any language in your stack.

PythonRustC/C++JavaScriptWASMJava
Get Started
Features

The runtime powering the bridge

A Rust-powered runtime with tensors, autograd, GPU acceleration, and a JIT compiler. The engine behind the .sw DSL — so your models run everywhere without being tied to any single technology.

Reverse-Mode Autograd

Eager automatic differentiation — every op records its graph, backward() does topological sort + chain rule. Gradient paths cover matmul, reshape, transpose, affine, cat, and 30+ ops.

Full Neural Network Stack

Linear, Conv1d/2d, RNN, LSTM, GRU, MultiHeadAttention, TransformerBlock, BatchNorm2d, LayerNorm, GroupNorm, RMSNorm, Embedding, Dropout, and 7 loss functions.

CUDA GPU Backend

NVIDIA GPU backend via cudarc with cuBLAS matmul, custom PTX kernels, memory pool with allocation reuse, and mixed-precision training (F16/BF16 ↔ F32).

SIMD-Accelerated CPU

gemm-accelerated matmul (AVX2/AVX-512/FMA), rayon parallel ops, contiguous fast-paths bypassing stride calculations. Same performance as hand-tuned C.

.sw Intermediate Representation

Declarative model specification format with lexer, parser, AST, Graph IR, shape inference, and optimization passes (DCE, CSE, constant folding, operator fusion).

JIT Compiler

JitExecutor compiles IR graphs into flat instruction tapes with pre-allocated memory slots and value lifetime tracking. No re-interpretation at runtime.

Optimizers & Schedulers

SGD (momentum, weight decay), Adam, AdamW, RAdam, RMSProp. LR schedulers: StepLR, CosineAnnealing, CosineWarmup, ReduceLROnPlateau. Gradient clipping + EMA.

Quantization & ONNX

INT8/INT4 post-training quantization (symmetric/asymmetric, per-tensor/per-channel). ONNX export/import (opset 17) with zero-dependency protobuf.

Training & Distributed

Trainer with validation, early stopping, metric tracking. DataParallel, PipelineParallel, MixedPrecisionTrainer with dynamic loss scaling.

Code Examples

See it in action

From tensor operations and autograd to training transformers and quantizing models. Real API examples from the framework.

main.rs
use shrew::prelude::*;

fn main() -> shrew::Result<()> {
    let dev = CpuDevice;

    // Broadcasting: [3,1] + [1,2] → [3,2]
    let a = CpuTensor::from_f64_slice(
        &[1.0, 2.0, 3.0], (3, 1), DType::F64, &dev
    )?;
    let b = CpuTensor::from_f64_slice(
        &[10.0, 20.0], (1, 2), DType::F64, &dev
    )?;
    let c = a.add(&b)?;

    // Reverse-mode autograd
    let w = CpuTensor::randn((3, 3), DType::F64, &dev)?
        .set_variable();
    let x = CpuTensor::randn((2, 3), DType::F64, &dev)?;
    let loss = x.matmul(&w)?.sum_all()?;
    let grads = loss.backward()?;
    let dw = grads.get(&w).unwrap();  // ∂loss/∂w

    Ok(())
}
Performance

Native Rust speed, SIMD-accelerated

gemm provides AVX2/AVX-512/FMA matmul, rayon parallelizes batched ops. Contiguous fast-paths bypass stride calculations. CUDA backend with cuBLAS for GPU acceleration.

7

DType support

60+

Tensor ops

10

Crates

CPU Benchmark — release mode (AMD/Intel)

matmul 256×256~370 µs
matmul 512×512~3.5 ms
add 1M elements~1.3 ms
linear fwd [64,512]×[512,512]+bias~3.9 ms
Architecture

Modular by design

Each concern is its own crate — independently swappable and extendable. The same Tensor<B> code runs on CPU and GPU without changes.

.sw File

Declarative model spec (architecture, config, training)

shrew-ir

Lexer → Parser → AST → Graph IR → Validate → Shape Infer → Optimize

JIT Compiler

Flat instruction tape, pre-allocated memory slots, value lifetime tracking

CPU

SIMD + rayon

CUDA

cuBLAS + PTX

Python

PyO3 + NumPy

shrew-core

Tensor<B>, Shape, DType, Layout, Backend trait, reverse-mode autograd, dynamic symbolic shapes

shrew-ir

.sw format: lexer, parser, AST, Graph IR, lowering, validation, shape inference, optimization passes

shrew-cpu

CPU backend: SIMD matmul via gemm (AVX2/AVX-512/FMA), parallel ops via rayon, broadcasting

shrew-cuda

NVIDIA GPU backend: cuBLAS matmul, custom PTX kernels, memory pool, mixed-precision F16/BF16

shrew-nn

Linear, Conv1d/2d, RNN/LSTM/GRU, MultiHeadAttention, Transformer, BatchNorm, LayerNorm, losses

shrew-optim

SGD, Adam, AdamW, RAdam, RMSProp, LR schedulers, gradient clipping, EMA

shrew-data

Dataset trait, DataLoader, MNIST, image transforms, async prefetch loader

shrew

Facade crate: executor, JIT compiler, trainer, distributed training, quantization, ONNX, profiling

shrew-python

Python bindings via PyO3 with NumPy interop (from_numpy, to_numpy)

Serialization Formats

.shrew

Native binary checkpoints

Safetensors

HuggingFace compatible

ONNX

Opset 17, zero deps

Roadmap

From runtime to universal bridge

v0.1.0 shipped with tensors, autograd, CUDA GPU, neural networks, optimizers, JIT, quantization, ONNX, and Python bindings. Next up: JS/WASM and C/C++ bindings.

Phase 0Project scaffolding
Done
Phase 1Core tensor system + CPU backend
Done
Phase 2Autograd — reverse-mode AD
Done
Phase 3NN layers + Optimizers
Done
Phase 3.5CNN building blocks (Conv1d/2d, MaxPool, AvgPool)
Done
Phase 4Broadcasting, MHA, TransformerBlock
Done
Phase 5.sw IR format — full pipeline
Done
Phase 6Executor, JIT Compiler & Trainer
Done
Phase 7CUDA GPU backend (cuBLAS + PTX kernels)
Done
Phase 8Data loading, MNIST examples
Done
Phase 8.5LR schedulers, checkpoints, char-GPT
Done
Phase 8.6GEMM matmul, comparisons, parallelism
Done
Phase 9Python bindings (PyO3 + NumPy)
Done
Phase 9.5Quantization (INT8/INT4), ONNX, distributed training
Done
Phase 10JS/WASM & C/C++ bindings
Planned

Documentation

Guides for the .sw DSL syntax, language bindings, API references, and tutorials are coming soon. In the meantime, explore the examples in the repository.

Coming Soon

Start building with Shrew

Add Shrew to your Rust project, use the Python bindings, or define models in .sw files. One runtime, every language.

MNIST with MLP

cargo run -p mnist-example

MNIST with CNN

cargo run -p mnist-cnn-example

Char-level GPT

cargo run --release -p char-gpt-example

Python bindings

pip install shrew-python