The Universal Deep Learning Bridge
Shrew is a runtime and DSL that decouples deep learning from any single technology. Define your model once in .sw, then train and deploy it from Python, Rust, JavaScript, C++, or any language in your stack.
The runtime powering the bridge
A Rust-powered runtime with tensors, autograd, GPU acceleration, and a JIT compiler. The engine behind the .sw DSL — so your models run everywhere without being tied to any single technology.
Reverse-Mode Autograd
Eager automatic differentiation — every op records its graph, backward() does topological sort + chain rule. Gradient paths cover matmul, reshape, transpose, affine, cat, and 30+ ops.
Full Neural Network Stack
Linear, Conv1d/2d, RNN, LSTM, GRU, MultiHeadAttention, TransformerBlock, BatchNorm2d, LayerNorm, GroupNorm, RMSNorm, Embedding, Dropout, and 7 loss functions.
CUDA GPU Backend
NVIDIA GPU backend via cudarc with cuBLAS matmul, custom PTX kernels, memory pool with allocation reuse, and mixed-precision training (F16/BF16 ↔ F32).
SIMD-Accelerated CPU
gemm-accelerated matmul (AVX2/AVX-512/FMA), rayon parallel ops, contiguous fast-paths bypassing stride calculations. Same performance as hand-tuned C.
.sw Intermediate Representation
Declarative model specification format with lexer, parser, AST, Graph IR, shape inference, and optimization passes (DCE, CSE, constant folding, operator fusion).
JIT Compiler
JitExecutor compiles IR graphs into flat instruction tapes with pre-allocated memory slots and value lifetime tracking. No re-interpretation at runtime.
Optimizers & Schedulers
SGD (momentum, weight decay), Adam, AdamW, RAdam, RMSProp. LR schedulers: StepLR, CosineAnnealing, CosineWarmup, ReduceLROnPlateau. Gradient clipping + EMA.
Quantization & ONNX
INT8/INT4 post-training quantization (symmetric/asymmetric, per-tensor/per-channel). ONNX export/import (opset 17) with zero-dependency protobuf.
Training & Distributed
Trainer with validation, early stopping, metric tracking. DataParallel, PipelineParallel, MixedPrecisionTrainer with dynamic loss scaling.
See it in action
From tensor operations and autograd to training transformers and quantizing models. Real API examples from the framework.
use shrew::prelude::*;
fn main() -> shrew::Result<()> {
let dev = CpuDevice;
// Broadcasting: [3,1] + [1,2] → [3,2]
let a = CpuTensor::from_f64_slice(
&[1.0, 2.0, 3.0], (3, 1), DType::F64, &dev
)?;
let b = CpuTensor::from_f64_slice(
&[10.0, 20.0], (1, 2), DType::F64, &dev
)?;
let c = a.add(&b)?;
// Reverse-mode autograd
let w = CpuTensor::randn((3, 3), DType::F64, &dev)?
.set_variable();
let x = CpuTensor::randn((2, 3), DType::F64, &dev)?;
let loss = x.matmul(&w)?.sum_all()?;
let grads = loss.backward()?;
let dw = grads.get(&w).unwrap(); // ∂loss/∂w
Ok(())
}Native Rust speed, SIMD-accelerated
gemm provides AVX2/AVX-512/FMA matmul, rayon parallelizes batched ops. Contiguous fast-paths bypass stride calculations. CUDA backend with cuBLAS for GPU acceleration.
7
DType support
60+
Tensor ops
10
Crates
CPU Benchmark — release mode (AMD/Intel)
Modular by design
Each concern is its own crate — independently swappable and extendable. The same Tensor<B> code runs on CPU and GPU without changes.
.sw File
Declarative model spec (architecture, config, training)
shrew-ir
Lexer → Parser → AST → Graph IR → Validate → Shape Infer → Optimize
JIT Compiler
Flat instruction tape, pre-allocated memory slots, value lifetime tracking
CPU
SIMD + rayon
CUDA
cuBLAS + PTX
Python
PyO3 + NumPy
shrew-core
Tensor<B>, Shape, DType, Layout, Backend trait, reverse-mode autograd, dynamic symbolic shapes
shrew-ir
.sw format: lexer, parser, AST, Graph IR, lowering, validation, shape inference, optimization passes
shrew-cpu
CPU backend: SIMD matmul via gemm (AVX2/AVX-512/FMA), parallel ops via rayon, broadcasting
shrew-cuda
NVIDIA GPU backend: cuBLAS matmul, custom PTX kernels, memory pool, mixed-precision F16/BF16
shrew-nn
Linear, Conv1d/2d, RNN/LSTM/GRU, MultiHeadAttention, Transformer, BatchNorm, LayerNorm, losses
shrew-optim
SGD, Adam, AdamW, RAdam, RMSProp, LR schedulers, gradient clipping, EMA
shrew-data
Dataset trait, DataLoader, MNIST, image transforms, async prefetch loader
shrew
Facade crate: executor, JIT compiler, trainer, distributed training, quantization, ONNX, profiling
shrew-python
Python bindings via PyO3 with NumPy interop (from_numpy, to_numpy)
Serialization Formats
.shrew
Native binary checkpoints
Safetensors
HuggingFace compatible
ONNX
Opset 17, zero deps
From runtime to universal bridge
v0.1.0 shipped with tensors, autograd, CUDA GPU, neural networks, optimizers, JIT, quantization, ONNX, and Python bindings. Next up: JS/WASM and C/C++ bindings.
Documentation
Guides for the .sw DSL syntax, language bindings, API references, and tutorials are coming soon. In the meantime, explore the examples in the repository.
Start building with Shrew
Add Shrew to your Rust project, use the Python bindings, or define models in .sw files. One runtime, every language.
MNIST with MLP
cargo run -p mnist-example
MNIST with CNN
cargo run -p mnist-cnn-example
Char-level GPT
cargo run --release -p char-gpt-example
Python bindings
pip install shrew-python