Overview

ModelLib provides a comprehensive framework for building, training, and deploying neural networks in STARK. It offers high-level abstractions for common architectures while maintaining performance and flexibility.

Core Features

  • Layer Abstractions - Pre-built neural network layers with automatic gradient computation
  • Model Composition - Easy model building with sequential and functional APIs
  • Training Framework - Built-in optimizers, loss functions, and training loops
  • Pre-trained Models - Access to popular model architectures and weights
  • Transfer Learning - Simple fine-tuning and feature extraction workflows
  • Multi-Framework Support - Interoperability with PyTorch, TensorFlow, and ONNX

Quick Start Example

use std::model::{Sequential, Dense, ReLU, Softmax};
use std::tensor::Tensor;
use std::optim::Adam;
use std::loss::CrossEntropyLoss;

// Define a simple neural network
let model = Sequential::new()
    .add(Dense::new(784, 128))
    .add(ReLU::new())
    .add(Dense::new(128, 64))
    .add(ReLU::new())
    .add(Dense::new(64, 10))
    .add(Softmax::new());

// Set up training
let optimizer = Adam::new(learning_rate: 0.001);
let loss_fn = CrossEntropyLoss::new();

// Training step
fn training_step(
    model: &mut Sequential,
    optimizer: &mut Adam,
    loss_fn: &CrossEntropyLoss,
    inputs: Tensor,
    targets: Tensor
) -> f32 {
    // Forward pass
    let predictions = model.forward(inputs);
    let loss = loss_fn.compute(predictions, targets);
    
    // Backward pass
    let gradients = loss.backward();
    optimizer.step(model, gradients);
    
    loss.item()
}

Core Layer Types

Linear Layers

Dense (Fully Connected) Layer

struct Dense {
    weight: Tensor,
    bias: ?Tensor,
    use_bias: bool
}

impl Dense {
    // Create new dense layer
    fn new(in_features: i32, out_features: i32) -> Self;
    fn new_no_bias(in_features: i32, out_features: i32) -> Self;
    
    // Weight initialization
    fn with_weight_init(mut self, init: WeightInit) -> Self;
    fn with_bias_init(mut self, init: BiasInit) -> Self;
    
    // Forward pass
    fn forward(&self, input: Tensor) -> Tensor;
    
    // Parameter access
    fn parameters(&self) -> Vec<&Tensor>;
    fn named_parameters(&self) -> Vec<(String, &Tensor)>;
    
    // Device management
    fn to_device(mut self, device: Device) -> Self;
    fn device(&self) -> Device;
}

// Usage example
let dense = Dense::new(512, 256)
    .with_weight_init(WeightInit::Xavier)
    .with_bias_init(BiasInit::Zero);

let input = Tensor::randn([32, 512]);
let output = dense.forward(input);  // Shape: [32, 256]

Convolutional Layers

Conv2D Layer

struct Conv2D {
    weight: Tensor,  // [out_channels, in_channels, height, width]
    bias: ?Tensor,
    stride: [i32; 2],
    padding: Padding,
    dilation: [i32; 2],
    groups: i32
}

impl Conv2D {
    // Create new conv2d layer
    fn new(
        in_channels: i32,
        out_channels: i32,
        kernel_size: [i32; 2]
    ) -> Self;
    
    // Builder methods
    fn with_stride(mut self, stride: [i32; 2]) -> Self;
    fn with_padding(mut self, padding: Padding) -> Self;
    fn with_dilation(mut self, dilation: [i32; 2]) -> Self;
    fn with_groups(mut self, groups: i32) -> Self;
    
    // Forward pass
    fn forward(&self, input: Tensor) -> Tensor;
}

enum Padding {
    Valid,                    // No padding
    Same,                     // Output size equals input size
    Explicit([i32; 4])       // [left, right, top, bottom]
}

// Usage example
let conv = Conv2D::new(3, 64, [3, 3])
    .with_stride([1, 1])
    .with_padding(Padding::Same);

let input = Tensor::randn([32, 3, 224, 224]);  // Batch of RGB images
let output = conv.forward(input);              // Shape: [32, 64, 224, 224]

Pooling Layers

// Max pooling
struct MaxPool2D {
    kernel_size: [i32; 2],
    stride: [i32; 2],
    padding: [i32; 2]
}

impl MaxPool2D {
    fn new(kernel_size: [i32; 2]) -> Self;
    fn with_stride(mut self, stride: [i32; 2]) -> Self;
    fn with_padding(mut self, padding: [i32; 2]) -> Self;
    
    fn forward(&self, input: Tensor) -> Tensor;
}

// Average pooling
struct AvgPool2D {
    kernel_size: [i32; 2],
    stride: [i32; 2],
    padding: [i32; 2]
}

// Adaptive pooling
struct AdaptiveAvgPool2D {
    output_size: [i32; 2]
}

impl AdaptiveAvgPool2D {
    fn new(output_size: [i32; 2]) -> Self;
    fn forward(&self, input: Tensor) -> Tensor;
}

// Usage examples
let maxpool = MaxPool2D::new([2, 2]).with_stride([2, 2]);
let avgpool = AvgPool2D::new([2, 2]).with_stride([2, 2]);
let adaptive = AdaptiveAvgPool2D::new([7, 7]);

Recurrent Layers

LSTM Layer

struct LSTM {
    input_size: i32,
    hidden_size: i32,
    num_layers: i32,
    bias: bool,
    batch_first: bool,
    dropout: f32,
    bidirectional: bool
}

impl LSTM {
    fn new(input_size: i32, hidden_size: i32) -> Self;
    fn with_num_layers(mut self, num_layers: i32) -> Self;
    fn with_dropout(mut self, dropout: f32) -> Self;
    fn with_bidirectional(mut self, bidirectional: bool) -> Self;
    fn with_batch_first(mut self, batch_first: bool) -> Self;
    
    // Forward pass returns (output, (hidden, cell))
    fn forward(
        &self, 
        input: Tensor,
        hidden: ?LSTMState
    ) -> (Tensor, LSTMState);
    
    // Initialize hidden state
    fn init_hidden(&self, batch_size: i32) -> LSTMState;
}

struct LSTMState {
    hidden: Tensor,  // [num_layers, batch, hidden_size]
    cell: Tensor     // [num_layers, batch, hidden_size]
}

// Usage example
let lstm = LSTM::new(256, 512)
    .with_num_layers(2)
    .with_dropout(0.1)
    .with_bidirectional(true)
    .with_batch_first(true);

let input = Tensor::randn([32, 100, 256]);  // [batch, seq_len, input_size]
let (output, final_state) = lstm.forward(input, None);
// output: [32, 100, 1024] (bidirectional doubles hidden size)

Attention Layers

Multi-Head Self-Attention

struct MultiHeadSelfAttention {
    embed_dim: i32,
    num_heads: i32,
    dropout: f32,
    bias: bool
}

impl MultiHeadSelfAttention {
    fn new(embed_dim: i32, num_heads: i32) -> Self;
    fn with_dropout(mut self, dropout: f32) -> Self;
    fn with_bias(mut self, bias: bool) -> Self;
    
    // Forward pass with optional attention mask
    fn forward(
        &self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        mask: ?Tensor
    ) -> (Tensor, Tensor);  // (output, attention_weights)
    
    // Self-attention (Q=K=V)
    fn self_attention(
        &self,
        input: Tensor,
        mask: ?Tensor
    ) -> (Tensor, Tensor);
}

// Usage example
let attention = MultiHeadSelfAttention::new(512, 8)
    .with_dropout(0.1);

let input = Tensor::randn([32, 100, 512]);  // [batch, seq_len, embed_dim]
let (output, weights) = attention.self_attention(input, None);
// output: [32, 100, 512], weights: [32, 8, 100, 100]

Activation Functions

// Common activation functions
struct ReLU;
struct LeakyReLU { negative_slope: f32 }
struct ELU { alpha: f32 }
struct Swish; 
struct GELU;
struct Tanh;
struct Sigmoid;
struct Softmax { dim: i32 }
struct LogSoftmax { dim: i32 }

// All activations implement the Activation trait
trait Activation {
    fn forward(&self, input: Tensor) -> Tensor;
    fn backward(&self, grad_output: Tensor, input: Tensor) -> Tensor;
}

// Usage examples
let relu = ReLU;
let leaky_relu = LeakyReLU { negative_slope: 0.01 };
let swish = Swish;
let softmax = Softmax { dim: -1 };

let input = Tensor::randn([32, 1000]);
let activated = relu.forward(input);
let probabilities = softmax.forward(activated);

Model Building APIs

Sequential Model

// Sequential model for simple layer stacking
struct Sequential {
    layers: Vec>
}

impl Sequential {
    fn new() -> Self;
    
    // Add layers
    fn add(mut self, layer: L) -> Self;
    fn add_boxed(mut self, layer: Box) -> Self;
    
    // Forward pass
    fn forward(&self, input: Tensor) -> Tensor;
    
    // Model introspection
    fn layers(&self) -> &[Box];
    fn num_parameters(&self) -> usize;
    fn summary(&self) -> ModelSummary;
}

// Example: Build a CNN for image classification
let cnn = Sequential::new()
    .add(Conv2D::new(3, 32, [3, 3]).with_padding(Padding::Same))
    .add(ReLU)
    .add(MaxPool2D::new([2, 2]))
    .add(Conv2D::new(32, 64, [3, 3]).with_padding(Padding::Same))
    .add(ReLU)
    .add(MaxPool2D::new([2, 2]))
    .add(Conv2D::new(64, 128, [3, 3]).with_padding(Padding::Same))
    .add(ReLU)
    .add(AdaptiveAvgPool2D::new([1, 1]))
    .add(Flatten)
    .add(Dense::new(128, 10))
    .add(Softmax::new());

Functional Model

// Functional model for complex architectures
struct FunctionalModel {
    graph: ComputationGraph,
    inputs: Vec,
    outputs: Vec
}

impl FunctionalModel {
    fn new() -> ModelBuilder;
    
    fn forward(&self, inputs: Vec>) -> Vec>;
    fn forward_single(&self, input: Tensor) -> Tensor;
}

struct ModelBuilder {
    graph: ComputationGraph
}

impl ModelBuilder {
    // Define inputs
    fn input(&mut self, shape: &[i32]) -> NodeId;
    
    // Add layers and operations
    fn add_layer(&mut self, layer: L, inputs: &[NodeId]) -> NodeId;
    fn concat(&mut self, inputs: &[NodeId], dim: i32) -> NodeId;
    fn add(&mut self, a: NodeId, b: NodeId) -> NodeId;
    fn multiply(&mut self, a: NodeId, b: NodeId) -> NodeId;
    
    // Build final model
    fn build(self, outputs: &[NodeId]) -> FunctionalModel;
}

// Example: ResNet-style residual connection
let mut builder = FunctionalModel::new();

let input = builder.input(&[3, 224, 224]);

// Main path
let conv1 = builder.add_layer(Conv2D::new(3, 64, [3, 3]), &[input]);
let relu1 = builder.add_layer(ReLU, &[conv1]);
let conv2 = builder.add_layer(Conv2D::new(64, 64, [3, 3]), &[relu1]);

// Residual connection
let residual = builder.add(input, conv2);
let output = builder.add_layer(ReLU, &[residual]);

let model = builder.build(&[output]);

Pre-trained Models

Model Zoo

use std::model::pretrained;

// Image classification models
let resnet50 = pretrained::resnet50(pretrained: true)?;
let efficientnet_b0 = pretrained::efficientnet_b0(pretrained: true)?;
let vision_transformer = pretrained::vit_base_patch16_224(pretrained: true)?;

// Object detection models
let yolov5 = pretrained::yolov5s(pretrained: true)?;
let faster_rcnn = pretrained::faster_rcnn_resnet50_fpn(pretrained: true)?;

// NLP models
let bert = pretrained::bert_base_uncased(pretrained: true)?;
let gpt2 = pretrained::gpt2(pretrained: true)?;
let t5 = pretrained::t5_base(pretrained: true)?;

// Multimodal models
let clip = pretrained::clip_vit_base_patch32(pretrained: true)?;

// Usage example: Fine-tuning ResNet for custom classification
let mut model = pretrained::resnet50(pretrained: true)?;

// Replace final classification layer
model.replace_classifier(Dense::new(2048, 10));  // 10 custom classes

// Freeze backbone layers for transfer learning
model.freeze_backbone();

// Fine-tune only the classifier
let optimizer = Adam::new(learning_rate: 0.001)
    .with_params(model.classifier_parameters());

Model Loading and Saving

// Model serialization
trait Serializable {
    fn save(&self, path: &str) -> Result<(), ModelError>;
    fn load(path: &str) -> Result where Self: Sized;
    
    // Save/load just weights
    fn save_weights(&self, path: &str) -> Result<(), ModelError>;
    fn load_weights(&mut self, path: &str) -> Result<(), ModelError>;
    
    // Export to different formats
    fn to_onnx(&self, path: &str, input_shapes: &[&[i32]]) -> Result<(), ModelError>;
    fn to_torchscript(&self, path: &str) -> Result<(), ModelError>;
}

// Cross-framework loading
struct ModelLoader;

impl ModelLoader {
    // Load PyTorch models
    fn from_pytorch(path: &str) -> Result, ModelError>;
    fn from_pytorch_state_dict(path: &str, architecture: Box) -> Result, ModelError>;
    
    // Load TensorFlow models
    fn from_tensorflow(path: &str) -> Result, ModelError>;
    fn from_keras(path: &str) -> Result, ModelError>;
    
    // Load ONNX models
    fn from_onnx(path: &str) -> Result, ModelError>;
    
    // Load Hugging Face models
    fn from_huggingface(model_name: &str, revision: ?&str) -> Result, ModelError>;
}

// Usage examples
// Save STARK model
model.save("my_model.stark")?;

// Load PyTorch model
let pytorch_model = ModelLoader::from_pytorch("resnet50.pth")?;

// Load from Hugging Face Hub
let bert = ModelLoader::from_huggingface("bert-base-uncased", None)?;

// Export to ONNX for deployment
model.to_onnx("model.onnx", &[&[1, 3, 224, 224]])?;

Training Framework

Optimizers

// Optimizer trait
trait Optimizer {
    fn step(&mut self, model: &mut dyn Model, gradients: &[Tensor]);
    fn zero_grad(&mut self);
    fn get_lr(&self) -> f32;
    fn set_lr(&mut self, lr: f32);
}

// SGD optimizer
struct SGD {
    learning_rate: f32,
    momentum: f32,
    weight_decay: f32,
    nesterov: bool
}

impl SGD {
    fn new(learning_rate: f32) -> Self;
    fn with_momentum(mut self, momentum: f32) -> Self;
    fn with_weight_decay(mut self, weight_decay: f32) -> Self;
    fn with_nesterov(mut self, nesterov: bool) -> Self;
}

// Adam optimizer
struct Adam {
    learning_rate: f32,
    beta1: f32,
    beta2: f32,
    eps: f32,
    weight_decay: f32,
    amsgrad: bool
}

impl Adam {
    fn new(learning_rate: f32) -> Self;
    fn with_betas(mut self, beta1: f32, beta2: f32) -> Self;
    fn with_eps(mut self, eps: f32) -> Self;
    fn with_weight_decay(mut self, weight_decay: f32) -> Self;
    fn with_amsgrad(mut self, amsgrad: bool) -> Self;
}

// AdamW optimizer (Adam with decoupled weight decay)
struct AdamW {
    learning_rate: f32,
    beta1: f32,
    beta2: f32,
    eps: f32,
    weight_decay: f32
}

// Learning rate schedulers
trait LRScheduler {
    fn step(&mut self, optimizer: &mut dyn Optimizer);
    fn get_last_lr(&self) -> f32;
}

struct StepLR {
    step_size: u32,
    gamma: f32,
    last_epoch: u32
}

struct ExponentialLR {
    gamma: f32,
    last_epoch: u32
}

struct CosineAnnealingLR {
    t_max: u32,
    eta_min: f32,
    last_epoch: u32
}

Loss Functions

// Loss function trait
trait Loss {
    type Input;
    type Target;
    
    fn compute(&self, input: Self::Input, target: Self::Target) -> LossValue;
    fn reduction(&self) -> Reduction;
}

enum Reduction {
    None,      // No reduction
    Mean,      // Average over all elements
    Sum        // Sum all elements
}

struct LossValue {
    value: f32,
    gradient: Tensor
}

// Classification losses
struct CrossEntropyLoss {
    reduction: Reduction,
    ignore_index: ?i32,
    label_smoothing: f32
}

struct NLLLoss {
    reduction: Reduction,
    ignore_index: ?i32
}

struct BCELoss {
    reduction: Reduction
}

struct BCEWithLogitsLoss {
    reduction: Reduction,
    pos_weight: ?Tensor
}

// Regression losses
struct MSELoss {
    reduction: Reduction
}

struct MAELoss {
    reduction: Reduction
}

struct SmoothL1Loss {
    reduction: Reduction,
    beta: f32
}

struct HuberLoss {
    reduction: Reduction,
    delta: f32
}

// Usage examples
let cross_entropy = CrossEntropyLoss::new()
    .with_label_smoothing(0.1);

let mse = MSELoss::new();

let focal_loss = FocalLoss::new()
    .with_alpha(0.25)
    .with_gamma(2.0);

Training Loop

// High-level training API
struct Trainer {
    model: Box,
    optimizer: Box,
    loss_fn: Box,
    scheduler: ?Box,
    metrics: Vec>,
    device: Device
}

impl Trainer {
    fn new(
        model: Box,
        optimizer: Box,
        loss_fn: Box
    ) -> Self;
    
    fn with_scheduler(mut self, scheduler: Box) -> Self;
    fn with_metrics(mut self, metrics: Vec>) -> Self;
    fn with_device(mut self, device: Device) -> Self;
    
    // Training methods
    async fn fit(
        &mut self,
        train_loader: DataLoader,
        val_loader: ?DataLoader,
        epochs: u32
    ) -> TrainingResult;
    
    async fn train_epoch(&mut self, data_loader: DataLoader) -> EpochMetrics;
    async fn validate(&mut self, data_loader: DataLoader) -> ValidationMetrics;
    
    // Callbacks
    fn on_epoch_start(&mut self, epoch: u32) -> CallbackAction;
    fn on_epoch_end(&mut self, epoch: u32, metrics: &EpochMetrics) -> CallbackAction;
    fn on_batch_start(&mut self, batch_idx: u32) -> CallbackAction;
    fn on_batch_end(&mut self, batch_idx: u32, loss: f32) -> CallbackAction;
}

enum CallbackAction {
    Continue,
    EarlyStopping,
    ReduceLR,
    SaveCheckpoint
}

// Usage example
let mut trainer = Trainer::new(
    Box::new(model),
    Box::new(Adam::new(0.001)),
    Box::new(CrossEntropyLoss::new())
)
.with_scheduler(Box::new(CosineAnnealingLR::new(100)))
.with_metrics(vec![
    Box::new(Accuracy::new()),
    Box::new(F1Score::new()),
    Box::new(TopKAccuracy::new(5))
])
.with_device(Device::CUDA(0));

let result = trainer.fit(train_loader, Some(val_loader), epochs: 100).await?;

Key Benefits

🔧 High-Level Abstractions

Pre-built layers and models with automatic gradient computation and optimization

🔄 Framework Interop

Load and save models from PyTorch, TensorFlow, ONNX, and Hugging Face

⚡ Performance

Optimized implementations with GPU acceleration and memory efficiency

🏗️ Flexible Architecture

Both sequential and functional APIs for simple and complex model architectures

📚 Pre-trained Models

Access to popular architectures and weights for transfer learning

🎯 Training Framework

Complete training pipeline with optimizers, loss functions, and metrics