🧠 ModelLib API Specification
Neural Network Layers and Training Framework
Overview
ModelLib provides a comprehensive framework for building, training, and deploying neural networks in STARK. It offers high-level abstractions for common architectures while maintaining performance and flexibility.
Core Features
- Layer Abstractions - Pre-built neural network layers with automatic gradient computation
- Model Composition - Easy model building with sequential and functional APIs
- Training Framework - Built-in optimizers, loss functions, and training loops
- Pre-trained Models - Access to popular model architectures and weights
- Transfer Learning - Simple fine-tuning and feature extraction workflows
- Multi-Framework Support - Interoperability with PyTorch, TensorFlow, and ONNX
Quick Start Example
use std::model::{Sequential, Dense, ReLU, Softmax};
use std::tensor::Tensor;
use std::optim::Adam;
use std::loss::CrossEntropyLoss;
// Define a simple neural network
let model = Sequential::new()
.add(Dense::new(784, 128))
.add(ReLU::new())
.add(Dense::new(128, 64))
.add(ReLU::new())
.add(Dense::new(64, 10))
.add(Softmax::new());
// Set up training
let optimizer = Adam::new(learning_rate: 0.001);
let loss_fn = CrossEntropyLoss::new();
// Training step
fn training_step(
model: &mut Sequential,
optimizer: &mut Adam,
loss_fn: &CrossEntropyLoss,
inputs: Tensor,
targets: Tensor
) -> f32 {
// Forward pass
let predictions = model.forward(inputs);
let loss = loss_fn.compute(predictions, targets);
// Backward pass
let gradients = loss.backward();
optimizer.step(model, gradients);
loss.item()
}
Core Layer Types
Linear Layers
Dense (Fully Connected) Layer
struct Dense {
weight: Tensor,
bias: ?Tensor,
use_bias: bool
}
impl Dense {
// Create new dense layer
fn new(in_features: i32, out_features: i32) -> Self;
fn new_no_bias(in_features: i32, out_features: i32) -> Self;
// Weight initialization
fn with_weight_init(mut self, init: WeightInit) -> Self;
fn with_bias_init(mut self, init: BiasInit) -> Self;
// Forward pass
fn forward(&self, input: Tensor) -> Tensor;
// Parameter access
fn parameters(&self) -> Vec<&Tensor>;
fn named_parameters(&self) -> Vec<(String, &Tensor)>;
// Device management
fn to_device(mut self, device: Device) -> Self;
fn device(&self) -> Device;
}
// Usage example
let dense = Dense::new(512, 256)
.with_weight_init(WeightInit::Xavier)
.with_bias_init(BiasInit::Zero);
let input = Tensor::randn([32, 512]);
let output = dense.forward(input); // Shape: [32, 256]
Convolutional Layers
Conv2D Layer
struct Conv2D {
weight: Tensor, // [out_channels, in_channels, height, width]
bias: ?Tensor,
stride: [i32; 2],
padding: Padding,
dilation: [i32; 2],
groups: i32
}
impl Conv2D {
// Create new conv2d layer
fn new(
in_channels: i32,
out_channels: i32,
kernel_size: [i32; 2]
) -> Self;
// Builder methods
fn with_stride(mut self, stride: [i32; 2]) -> Self;
fn with_padding(mut self, padding: Padding) -> Self;
fn with_dilation(mut self, dilation: [i32; 2]) -> Self;
fn with_groups(mut self, groups: i32) -> Self;
// Forward pass
fn forward(&self, input: Tensor) -> Tensor;
}
enum Padding {
Valid, // No padding
Same, // Output size equals input size
Explicit([i32; 4]) // [left, right, top, bottom]
}
// Usage example
let conv = Conv2D::new(3, 64, [3, 3])
.with_stride([1, 1])
.with_padding(Padding::Same);
let input = Tensor::randn([32, 3, 224, 224]); // Batch of RGB images
let output = conv.forward(input); // Shape: [32, 64, 224, 224]
Pooling Layers
// Max pooling
struct MaxPool2D {
kernel_size: [i32; 2],
stride: [i32; 2],
padding: [i32; 2]
}
impl MaxPool2D {
fn new(kernel_size: [i32; 2]) -> Self;
fn with_stride(mut self, stride: [i32; 2]) -> Self;
fn with_padding(mut self, padding: [i32; 2]) -> Self;
fn forward(&self, input: Tensor) -> Tensor;
}
// Average pooling
struct AvgPool2D {
kernel_size: [i32; 2],
stride: [i32; 2],
padding: [i32; 2]
}
// Adaptive pooling
struct AdaptiveAvgPool2D {
output_size: [i32; 2]
}
impl AdaptiveAvgPool2D {
fn new(output_size: [i32; 2]) -> Self;
fn forward(&self, input: Tensor) -> Tensor;
}
// Usage examples
let maxpool = MaxPool2D::new([2, 2]).with_stride([2, 2]);
let avgpool = AvgPool2D::new([2, 2]).with_stride([2, 2]);
let adaptive = AdaptiveAvgPool2D::new([7, 7]);
Recurrent Layers
LSTM Layer
struct LSTM {
input_size: i32,
hidden_size: i32,
num_layers: i32,
bias: bool,
batch_first: bool,
dropout: f32,
bidirectional: bool
}
impl LSTM {
fn new(input_size: i32, hidden_size: i32) -> Self;
fn with_num_layers(mut self, num_layers: i32) -> Self;
fn with_dropout(mut self, dropout: f32) -> Self;
fn with_bidirectional(mut self, bidirectional: bool) -> Self;
fn with_batch_first(mut self, batch_first: bool) -> Self;
// Forward pass returns (output, (hidden, cell))
fn forward(
&self,
input: Tensor,
hidden: ?LSTMState
) -> (Tensor, LSTMState);
// Initialize hidden state
fn init_hidden(&self, batch_size: i32) -> LSTMState;
}
struct LSTMState {
hidden: Tensor, // [num_layers, batch, hidden_size]
cell: Tensor // [num_layers, batch, hidden_size]
}
// Usage example
let lstm = LSTM::new(256, 512)
.with_num_layers(2)
.with_dropout(0.1)
.with_bidirectional(true)
.with_batch_first(true);
let input = Tensor::randn([32, 100, 256]); // [batch, seq_len, input_size]
let (output, final_state) = lstm.forward(input, None);
// output: [32, 100, 1024] (bidirectional doubles hidden size)
Attention Layers
Multi-Head Self-Attention
struct MultiHeadSelfAttention {
embed_dim: i32,
num_heads: i32,
dropout: f32,
bias: bool
}
impl MultiHeadSelfAttention {
fn new(embed_dim: i32, num_heads: i32) -> Self;
fn with_dropout(mut self, dropout: f32) -> Self;
fn with_bias(mut self, bias: bool) -> Self;
// Forward pass with optional attention mask
fn forward(
&self,
query: Tensor,
key: Tensor,
value: Tensor,
mask: ?Tensor
) -> (Tensor, Tensor); // (output, attention_weights)
// Self-attention (Q=K=V)
fn self_attention(
&self,
input: Tensor,
mask: ?Tensor
) -> (Tensor, Tensor);
}
// Usage example
let attention = MultiHeadSelfAttention::new(512, 8)
.with_dropout(0.1);
let input = Tensor::randn([32, 100, 512]); // [batch, seq_len, embed_dim]
let (output, weights) = attention.self_attention(input, None);
// output: [32, 100, 512], weights: [32, 8, 100, 100]
Activation Functions
// Common activation functions
struct ReLU;
struct LeakyReLU { negative_slope: f32 }
struct ELU { alpha: f32 }
struct Swish;
struct GELU;
struct Tanh;
struct Sigmoid;
struct Softmax { dim: i32 }
struct LogSoftmax { dim: i32 }
// All activations implement the Activation trait
trait Activation {
fn forward(&self, input: Tensor) -> Tensor;
fn backward(&self, grad_output: Tensor, input: Tensor) -> Tensor;
}
// Usage examples
let relu = ReLU;
let leaky_relu = LeakyReLU { negative_slope: 0.01 };
let swish = Swish;
let softmax = Softmax { dim: -1 };
let input = Tensor::randn([32, 1000]);
let activated = relu.forward(input);
let probabilities = softmax.forward(activated);
Model Building APIs
Sequential Model
// Sequential model for simple layer stacking
struct Sequential {
layers: Vec>
}
impl Sequential {
fn new() -> Self;
// Add layers
fn add(mut self, layer: L) -> Self;
fn add_boxed(mut self, layer: Box) -> Self;
// Forward pass
fn forward(&self, input: Tensor) -> Tensor;
// Model introspection
fn layers(&self) -> &[Box];
fn num_parameters(&self) -> usize;
fn summary(&self) -> ModelSummary;
}
// Example: Build a CNN for image classification
let cnn = Sequential::new()
.add(Conv2D::new(3, 32, [3, 3]).with_padding(Padding::Same))
.add(ReLU)
.add(MaxPool2D::new([2, 2]))
.add(Conv2D::new(32, 64, [3, 3]).with_padding(Padding::Same))
.add(ReLU)
.add(MaxPool2D::new([2, 2]))
.add(Conv2D::new(64, 128, [3, 3]).with_padding(Padding::Same))
.add(ReLU)
.add(AdaptiveAvgPool2D::new([1, 1]))
.add(Flatten)
.add(Dense::new(128, 10))
.add(Softmax::new());
Functional Model
// Functional model for complex architectures
struct FunctionalModel {
graph: ComputationGraph,
inputs: Vec,
outputs: Vec
}
impl FunctionalModel {
fn new() -> ModelBuilder;
fn forward(&self, inputs: Vec>) -> Vec>;
fn forward_single(&self, input: Tensor) -> Tensor;
}
struct ModelBuilder {
graph: ComputationGraph
}
impl ModelBuilder {
// Define inputs
fn input(&mut self, shape: &[i32]) -> NodeId;
// Add layers and operations
fn add_layer(&mut self, layer: L, inputs: &[NodeId]) -> NodeId;
fn concat(&mut self, inputs: &[NodeId], dim: i32) -> NodeId;
fn add(&mut self, a: NodeId, b: NodeId) -> NodeId;
fn multiply(&mut self, a: NodeId, b: NodeId) -> NodeId;
// Build final model
fn build(self, outputs: &[NodeId]) -> FunctionalModel;
}
// Example: ResNet-style residual connection
let mut builder = FunctionalModel::new();
let input = builder.input(&[3, 224, 224]);
// Main path
let conv1 = builder.add_layer(Conv2D::new(3, 64, [3, 3]), &[input]);
let relu1 = builder.add_layer(ReLU, &[conv1]);
let conv2 = builder.add_layer(Conv2D::new(64, 64, [3, 3]), &[relu1]);
// Residual connection
let residual = builder.add(input, conv2);
let output = builder.add_layer(ReLU, &[residual]);
let model = builder.build(&[output]);
Pre-trained Models
Model Zoo
use std::model::pretrained;
// Image classification models
let resnet50 = pretrained::resnet50(pretrained: true)?;
let efficientnet_b0 = pretrained::efficientnet_b0(pretrained: true)?;
let vision_transformer = pretrained::vit_base_patch16_224(pretrained: true)?;
// Object detection models
let yolov5 = pretrained::yolov5s(pretrained: true)?;
let faster_rcnn = pretrained::faster_rcnn_resnet50_fpn(pretrained: true)?;
// NLP models
let bert = pretrained::bert_base_uncased(pretrained: true)?;
let gpt2 = pretrained::gpt2(pretrained: true)?;
let t5 = pretrained::t5_base(pretrained: true)?;
// Multimodal models
let clip = pretrained::clip_vit_base_patch32(pretrained: true)?;
// Usage example: Fine-tuning ResNet for custom classification
let mut model = pretrained::resnet50(pretrained: true)?;
// Replace final classification layer
model.replace_classifier(Dense::new(2048, 10)); // 10 custom classes
// Freeze backbone layers for transfer learning
model.freeze_backbone();
// Fine-tune only the classifier
let optimizer = Adam::new(learning_rate: 0.001)
.with_params(model.classifier_parameters());
Model Loading and Saving
// Model serialization
trait Serializable {
fn save(&self, path: &str) -> Result<(), ModelError>;
fn load(path: &str) -> Result where Self: Sized;
// Save/load just weights
fn save_weights(&self, path: &str) -> Result<(), ModelError>;
fn load_weights(&mut self, path: &str) -> Result<(), ModelError>;
// Export to different formats
fn to_onnx(&self, path: &str, input_shapes: &[&[i32]]) -> Result<(), ModelError>;
fn to_torchscript(&self, path: &str) -> Result<(), ModelError>;
}
// Cross-framework loading
struct ModelLoader;
impl ModelLoader {
// Load PyTorch models
fn from_pytorch(path: &str) -> Result, ModelError>;
fn from_pytorch_state_dict(path: &str, architecture: Box) -> Result, ModelError>;
// Load TensorFlow models
fn from_tensorflow(path: &str) -> Result, ModelError>;
fn from_keras(path: &str) -> Result, ModelError>;
// Load ONNX models
fn from_onnx(path: &str) -> Result, ModelError>;
// Load Hugging Face models
fn from_huggingface(model_name: &str, revision: ?&str) -> Result, ModelError>;
}
// Usage examples
// Save STARK model
model.save("my_model.stark")?;
// Load PyTorch model
let pytorch_model = ModelLoader::from_pytorch("resnet50.pth")?;
// Load from Hugging Face Hub
let bert = ModelLoader::from_huggingface("bert-base-uncased", None)?;
// Export to ONNX for deployment
model.to_onnx("model.onnx", &[&[1, 3, 224, 224]])?;
Training Framework
Optimizers
// Optimizer trait
trait Optimizer {
fn step(&mut self, model: &mut dyn Model, gradients: &[Tensor]);
fn zero_grad(&mut self);
fn get_lr(&self) -> f32;
fn set_lr(&mut self, lr: f32);
}
// SGD optimizer
struct SGD {
learning_rate: f32,
momentum: f32,
weight_decay: f32,
nesterov: bool
}
impl SGD {
fn new(learning_rate: f32) -> Self;
fn with_momentum(mut self, momentum: f32) -> Self;
fn with_weight_decay(mut self, weight_decay: f32) -> Self;
fn with_nesterov(mut self, nesterov: bool) -> Self;
}
// Adam optimizer
struct Adam {
learning_rate: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32,
amsgrad: bool
}
impl Adam {
fn new(learning_rate: f32) -> Self;
fn with_betas(mut self, beta1: f32, beta2: f32) -> Self;
fn with_eps(mut self, eps: f32) -> Self;
fn with_weight_decay(mut self, weight_decay: f32) -> Self;
fn with_amsgrad(mut self, amsgrad: bool) -> Self;
}
// AdamW optimizer (Adam with decoupled weight decay)
struct AdamW {
learning_rate: f32,
beta1: f32,
beta2: f32,
eps: f32,
weight_decay: f32
}
// Learning rate schedulers
trait LRScheduler {
fn step(&mut self, optimizer: &mut dyn Optimizer);
fn get_last_lr(&self) -> f32;
}
struct StepLR {
step_size: u32,
gamma: f32,
last_epoch: u32
}
struct ExponentialLR {
gamma: f32,
last_epoch: u32
}
struct CosineAnnealingLR {
t_max: u32,
eta_min: f32,
last_epoch: u32
}
Loss Functions
// Loss function trait
trait Loss {
type Input;
type Target;
fn compute(&self, input: Self::Input, target: Self::Target) -> LossValue;
fn reduction(&self) -> Reduction;
}
enum Reduction {
None, // No reduction
Mean, // Average over all elements
Sum // Sum all elements
}
struct LossValue {
value: f32,
gradient: Tensor
}
// Classification losses
struct CrossEntropyLoss {
reduction: Reduction,
ignore_index: ?i32,
label_smoothing: f32
}
struct NLLLoss {
reduction: Reduction,
ignore_index: ?i32
}
struct BCELoss {
reduction: Reduction
}
struct BCEWithLogitsLoss {
reduction: Reduction,
pos_weight: ?Tensor
}
// Regression losses
struct MSELoss {
reduction: Reduction
}
struct MAELoss {
reduction: Reduction
}
struct SmoothL1Loss {
reduction: Reduction,
beta: f32
}
struct HuberLoss {
reduction: Reduction,
delta: f32
}
// Usage examples
let cross_entropy = CrossEntropyLoss::new()
.with_label_smoothing(0.1);
let mse = MSELoss::new();
let focal_loss = FocalLoss::new()
.with_alpha(0.25)
.with_gamma(2.0);
Training Loop
// High-level training API
struct Trainer {
model: Box,
optimizer: Box,
loss_fn: Box,
scheduler: ?Box,
metrics: Vec>,
device: Device
}
impl Trainer {
fn new(
model: Box,
optimizer: Box,
loss_fn: Box
) -> Self;
fn with_scheduler(mut self, scheduler: Box) -> Self;
fn with_metrics(mut self, metrics: Vec>) -> Self;
fn with_device(mut self, device: Device) -> Self;
// Training methods
async fn fit(
&mut self,
train_loader: DataLoader,
val_loader: ?DataLoader,
epochs: u32
) -> TrainingResult;
async fn train_epoch(&mut self, data_loader: DataLoader) -> EpochMetrics;
async fn validate(&mut self, data_loader: DataLoader) -> ValidationMetrics;
// Callbacks
fn on_epoch_start(&mut self, epoch: u32) -> CallbackAction;
fn on_epoch_end(&mut self, epoch: u32, metrics: &EpochMetrics) -> CallbackAction;
fn on_batch_start(&mut self, batch_idx: u32) -> CallbackAction;
fn on_batch_end(&mut self, batch_idx: u32, loss: f32) -> CallbackAction;
}
enum CallbackAction {
Continue,
EarlyStopping,
ReduceLR,
SaveCheckpoint
}
// Usage example
let mut trainer = Trainer::new(
Box::new(model),
Box::new(Adam::new(0.001)),
Box::new(CrossEntropyLoss::new())
)
.with_scheduler(Box::new(CosineAnnealingLR::new(100)))
.with_metrics(vec![
Box::new(Accuracy::new()),
Box::new(F1Score::new()),
Box::new(TopKAccuracy::new(5))
])
.with_device(Device::CUDA(0));
let result = trainer.fit(train_loader, Some(val_loader), epochs: 100).await?;
Key Benefits
🔧 High-Level Abstractions
Pre-built layers and models with automatic gradient computation and optimization
🔄 Framework Interop
Load and save models from PyTorch, TensorFlow, ONNX, and Hugging Face
⚡ Performance
Optimized implementations with GPU acceleration and memory efficiency
🏗️ Flexible Architecture
Both sequential and functional APIs for simple and complex model architectures
📚 Pre-trained Models
Access to popular architectures and weights for transfer learning
🎯 Training Framework
Complete training pipeline with optimizers, loss functions, and metrics