What You'll Build
You'll build a Python pipeline that intelligently selects satellite imagery patches for labeling by combining uncertainty-based active learning with concept-guided relevance weighting. Instead of randomly labeling thousands of images, your system will ask "what should I learn next?" while considering domain knowledge like "areas near industrial sites are more likely to contain contamination."
The final system will process synthetic four-band multispectral imagery patches (mimicking a common RGB + near-infrared sensor configuration), maintain an incremental meta-learning loop that adapts to new labels, and prioritize samples based on both model uncertainty AND domain-relevant concepts like land cover type or proximity to known contamination sources. This approach is critical for real geospatial discovery tasks—whether hunting for rare minerals, tracking deforestation, or detecting pollution hotspots—where labeling budgets are limited and domain knowledge matters.
You'll walk away with:
- A working concept-weighted active learning pipeline
- A lightweight CNN classifier for multispectral imagery
- An incremental meta-learning component for rapid adaptation
- Reusable code patterns for geospatial ML projects
git clone https://github.com/klarson3k1o/owl-gps-active-learning.git to run without copying code manually.
Prerequisites
- Python 3.10+ (tested on 3.11.5)
- PyTorch 2.0+ with CUDA support optional (CPU works but slower)
- pip or conda for package management
- Basic understanding of:
- Python classes and NumPy arrays
- Neural network training loops
- Active learning concepts (helpful but not required)
- Estimated time: 90-120 minutes with a CUDA-capable GPU; 4-6 hours on CPU
- Disk space: ~2GB for dependencies, ~50MB for synthetic data
Step-by-Step Instructions
Step 1: Set Up Your Project Structure and Environment
Create a clean working directory with an isolated Python environment:
mkdir owl-gps-tutorial
cd owl-gps-tutorial
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
Install PyTorch and supporting libraries:
# For CUDA 11.8 (adjust URL for your CUDA version or use CPU)
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
# Core dependencies
pip install scikit-learn numpy matplotlib
For CPU-only installation: Use pip install torch torchvision without the index-url flag.
Apple Silicon (M1/M2/M3): Do not use the --index-url flag above — it will fail. Use pip install torch torchvision and PyTorch will automatically use the MPS backend where available.
Verify installation:
python -c "import torch; print(f'PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}')"
Step 2: Generate Synthetic Geospatial Data
Create synthetic data that mimics multispectral satellite patches with associated land cover concepts. This lets you focus on the active learning logic without infrastructure overhead.
Create generate_data.py:
import numpy as np
import pickle
from pathlib import Path
def generate_synthetic_geospatial_data(n_samples=500, image_size=32, n_bands=4, seed=42):
"""
Generate synthetic satellite-like imagery with associated concepts.
Args:
n_samples: Number of image patches to generate
image_size: Spatial dimensions (height and width in pixels)
n_bands: Number of spectral bands (4 mimics RGB + NIR)
Returns:
images: Array of shape (n_samples, n_bands, image_size, image_size)
labels: Binary labels (1 = target present, 0 = absent)
concepts: List of dicts with domain features per sample
"""
np.random.seed(seed)
images = []
labels = [] # Binary: 1 = target present (e.g., contamination hotspot), 0 = absent
concepts = [] # Land cover type, distance to industrial sites, etc.
for i in range(n_samples):
# Generate random multispectral patch (values in [0, 1])
img = np.random.rand(n_bands, image_size, image_size).astype(np.float32)
# Simulate land cover concepts: 0=forest, 1=urban, 2=water, 3=agricultural
land_cover = np.random.randint(0, 4)
# Distance to industrial site (normalized 0-1, where 0 is close)
dist_industrial = np.random.rand()
# Target probability influenced by concepts
# Urban areas close to industrial sites more likely to have contamination
target_prob = 0.1 # Base rate
if land_cover == 1: # Urban
target_prob += 0.3
if dist_industrial < 0.3: # Close to industrial
target_prob += 0.4
label = 1 if np.random.rand() < target_prob else 0
images.append(img)
labels.append(label)
concepts.append({
'land_cover': land_cover,
'dist_industrial': dist_industrial,
'urban': 1 if land_cover == 1 else 0,
'near_industrial': 1 if dist_industrial < 0.3 else 0
})
return np.array(images), np.array(labels), concepts
# Generate dataset
print("Generating synthetic geospatial dataset...")
images, labels, concepts = generate_synthetic_geospatial_data(n_samples=500)
# Save to disk
Path("data").mkdir(exist_ok=True)
np.save("data/images.npy", images)
np.save("data/labels.npy", labels)
with open("data/concepts.pkl", "wb") as f:
pickle.dump(concepts, f)
print(f"Generated {len(images)} samples")
print(f"Positive class ratio: {labels.mean():.3f}")
print(f"Image shape: {images[0].shape}")
print(f"Data saved to data/ directory")
Run the data generation script:
python generate_data.py
Expected output:
Generating synthetic geospatial dataset...
Generated 500 samples
Positive class ratio: 0.314
Image shape: (4, 32, 32)
Data saved to data/ directory
What just happened: You created 500 synthetic "satellite patches" with 4 spectral bands (mimicking RGB + near-infrared). Each patch has associated concept features (land cover type, distance to industrial sites) that influence whether a target is present. This simulates real-world scenarios where domain knowledge correlates with target presence—urban areas near industrial sites have higher contamination probability in this simulation.
Step 3: Build the Concept-Weighted Uncertainty Sampler
Implement the core innovation: combining model uncertainty with concept relevance for intelligent sample selection.
Create active_learner.py:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import roc_auc_score
class GeospatialDataset(Dataset):
"""PyTorch dataset wrapper for geospatial patches with concepts."""
def __init__(self, images, labels, concepts):
self.images = torch.FloatTensor(images)
self.labels = torch.LongTensor(labels)
self.concepts = concepts
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
return self.images[idx], self.labels[idx], idx
class SimpleCNN(nn.Module):
"""Lightweight CNN for multispectral patch classification."""
def __init__(self, n_bands=4, n_classes=2):
super().__init__()
# Convolutional layers
self.conv1 = nn.Conv2d(n_bands, 32, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
# Fully connected layers
# After two pooling layers: 32x32 -> 16x16 -> 8x8
self.fc1 = nn.Linear(64 * 8 * 8, 128)
self.fc2 = nn.Linear(128, n_classes)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
# First conv block
x = self.pool(F.relu(self.conv1(x))) # -> [batch, 32, 16, 16]
# Second conv block
x = self.pool(F.relu(self.conv2(x))) # -> [batch, 64, 8, 8]
# Flatten and classify
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
class ConceptWeightedActiveLearner:
"""
Active learner that combines uncertainty sampling with concept relevance.
This is the core innovation: instead of purely uncertainty-based sampling,
we weight samples by domain-relevant concepts (e.g., urban areas near
industrial sites are more relevant for contamination detection).
"""
def __init__(self, model, device='cuda' if torch.cuda.is_available() else 'cpu',
concept_weights=None):
"""
Args:
model: PyTorch model to train and query
device: 'cuda' or 'cpu'
concept_weights: Dict mapping concept names to importance multipliers.
Defaults to urban=2.0, near_industrial=3.0.
Pass your own dict to tune for a different task.
"""
self.model = model.to(device)
self.device = device
self.labeled_indices = set()
# Domain-specific concept weights — override via constructor for new tasks
self.concept_weights = concept_weights or {
'urban': 2.0, # Urban areas more relevant for contamination
'near_industrial': 3.0 # Proximity to industrial sites very relevant
}
def compute_uncertainty(self, unlabeled_loader):
"""
Compute prediction uncertainty using entropy.
Higher entropy = model is more uncertain = more informative sample.
"""
self.model.eval()
uncertainties = []
indices = []
with torch.no_grad():
for images, _, idx in unlabeled_loader:
images = images.to(self.device)
logits = self.model(images)
probs = F.softmax(logits, dim=1)
# Entropy-based uncertainty: -sum(p * log(p))
entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=1)
uncertainties.extend(entropy.cpu().numpy())
indices.extend(idx.numpy())
return np.array(uncertainties), np.array(indices)
def compute_concept_relevance(self, indices, concepts):
"""
Compute relevance score based on domain concepts.
Higher score = more relevant to our search task based on domain knowledge.
Args:
indices: Array of indices into the concepts list (original dataset indices
returned by __getitem__ via Subset)
concepts: The full train_dataset.concepts list (length = size of full training set).
Indices returned by __getitem__ via Subset are original dataset indices,
so this must be the complete list — not a subset-aligned list.
"""
relevance_scores = []
for idx in indices:
concept = concepts[idx]
score = 1.0 # Base relevance
# Multiply by concept importance weights
if concept['urban']:
score *= self.concept_weights['urban']
if concept['near_industrial']:
score *= self.concept_weights['near_industrial']
relevance_scores.append(score)
return np.array(relevance_scores)
def select_samples(self, unlabeled_loader, concepts, budget=10, alpha=0.5):
"""
Select samples using concept-weighted uncertainty.
Args:
unlabeled_loader: DataLoader for unlabeled samples
concepts: The full train_dataset.concepts list (length = size of full training set).
Must be the complete list because indices returned by the DataLoader
(via Subset.__getitem__) are original dataset indices, not subset positions.
budget: Number of samples to select
alpha: float in [0, 1]. Controls the uncertainty/relevance balance.
1.0 = pure uncertainty sampling (ignores domain concepts),
0.0 = pure concept-relevance sampling (ignores model uncertainty),
0.5 = equal weight between both signals.
Returns:
top_indices: Indices of selected samples (relative to original dataset)
top_scores: Combined scores for selected samples
"""
# Compute both uncertainty and relevance
uncertainties, indices = self.compute_uncertainty(unlabeled_loader)
relevance = self.compute_concept_relevance(indices, concepts)
# Normalize both scores to [0, 1] range for fair combination
uncertainties = (uncertainties - uncertainties.min()) / (uncertainties.max() - uncertainties.min() + 1e-10)
relevance = (relevance - relevance.min()) / (relevance.max() - relevance.min() + 1e-10)
# Combined score: weighted sum of uncertainty and relevance
scores = alpha * uncertainties + (1 - alpha) * relevance
# Guard against budget larger than available samples
budget = min(budget, len(scores))
# Select top-k samples with highest combined scores
top_k_positions = np.argsort(scores)[-budget:]
top_indices = indices[top_k_positions]
top_scores = scores[top_k_positions]
return top_indices, top_scores
def train_epoch(self, train_loader, optimizer, criterion):
"""Train model for one epoch."""
self.model.train()
total_loss = 0
for images, labels, _ in train_loader:
images, labels = images.to(self.device), labels.to(self.device)
optimizer.zero_grad()
outputs = self.model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(train_loader)
def evaluate(self, test_loader):
"""Evaluate model performance using AUC-ROC."""
self.model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
for images, labels, _ in test_loader:
images = images.to(self.device)
outputs = self.model(images)
probs = F.softmax(outputs, dim=1)[:, 1] # Probability of positive class
all_preds.extend(probs.cpu().numpy())
all_labels.extend(labels.numpy())
# Guard against single-class splits (can happen with small or skewed test sets)
if len(set(all_labels)) < 2:
return float('nan')
auc = roc_auc_score(all_labels, all_preds)
return auc
if __name__ == "__main__":
print("Active learner module loaded successfully")
Test the module:
python active_learner.py
Expected output:
Active learner module loaded successfully
What just happened: You built the core machinery—a CNN classifier for multispectral imagery, and an active learning loop that computes both model uncertainty (via entropy) and concept relevance (via domain features), then combines them. The alpha parameter lets you tune how much you trust model uncertainty versus domain knowledge: alpha=1.0 is pure uncertainty sampling, alpha=0.0 is pure concept-based sampling, and alpha=0.5 balances both.
Step 4: Implement the Incremental Meta-Learning Loop
Add a Reptile meta-learning component that nudges the model toward fast adaptability as newly labeled batches arrive.
Create meta_learning.py:
import torch
import torch.nn as nn
from copy import deepcopy
class OnlineMetaLearner:
"""
Incremental meta-learner using the Reptile algorithm (Nichol et al., 2018).
Reptile is a first-order meta-learning algorithm that updates the meta-model
by interpolating its weights toward a task-adapted copy. Unlike MAML it
requires no second-order gradients and works with any standard inner-loop
optimizer — making it a practical drop-in for sequential active learning.
On each call to meta_update():
1. A copy of the meta-model is fine-tuned on the support set (newly
labeled samples) for `inner_steps` gradient steps.
2. The meta-model weights are nudged toward the fine-tuned weights:
theta_meta = theta_meta + meta_lr * (theta_adapted - theta_meta)
3. The query set is evaluated for monitoring only (no meta-gradient needed).
"""
def __init__(self, base_model, inner_lr=0.01, meta_lr=0.1, inner_steps=5):
"""
Args:
base_model: PyTorch model to meta-learn (shared with active learner)
inner_lr: SGD learning rate for inner-loop adaptation
meta_lr: Reptile step size — how far to interpolate toward
the adapted weights (0 < meta_lr <= 1)
inner_steps: Maximum gradient steps in the inner loop
"""
self.meta_model = base_model
self.inner_lr = inner_lr
self.meta_lr = meta_lr
self.inner_steps = inner_steps
def meta_update(self, support_loader, query_loader, criterion):
"""
Perform one Reptile meta-update.
Args:
support_loader: DataLoader with newly labeled samples for adaptation
query_loader: DataLoader with held-out samples for loss monitoring
criterion: Loss function (e.g. CrossEntropyLoss)
Returns:
query_loss: Average loss on query set after adaptation (float),
or None if query_loader is empty.
"""
device = next(self.meta_model.parameters()).device
# ── Inner loop ────────────────────────────────────────────────────────
# Fine-tune a disconnected clone on the support set.
# Reptile intentionally uses deepcopy — the gradient never needs to flow
# back through the inner loop (unlike MAML).
adapted_model = deepcopy(self.meta_model)
inner_optimizer = torch.optim.SGD(adapted_model.parameters(), lr=self.inner_lr)
adapted_model.train()
for step, (images, labels, _) in enumerate(support_loader):
if step >= self.inner_steps:
break
images, labels = images.to(device), labels.to(device)
inner_optimizer.zero_grad()
loss = criterion(adapted_model(images), labels)
loss.backward()
inner_optimizer.step()
# ── Reptile weight update ─────────────────────────────────────────────
# Move meta-model weights toward the adapted weights.
# theta_meta = theta_meta + meta_lr * (theta_adapted - theta_meta)
with torch.no_grad():
for meta_p, adapted_p in zip(self.meta_model.parameters(),
adapted_model.parameters()):
meta_p.data = meta_p.data + self.meta_lr * (adapted_p.data - meta_p.data)
# ── Query evaluation (monitoring only) ────────────────────────────────
adapted_model.eval()
query_loss = 0.0
n_batches = 0
with torch.no_grad():
for images, labels, _ in query_loader:
images, labels = images.to(device), labels.to(device)
query_loss += criterion(adapted_model(images), labels).item()
n_batches += 1
if n_batches == 0:
return None
return query_loss / n_batches
def get_model(self):
"""Return current meta-model."""
return self.meta_model
if __name__ == "__main__":
print("Meta-learning module loaded successfully")
Test the module:
python meta_learning.py
Expected output:
Meta-learning module loaded successfully
What just happened: You implemented the Reptile meta-learning algorithm. After each new labeled batch arrives, Reptile fine-tunes a temporary copy of the model on those samples, then nudges the main model's weights toward the copy — no second-order gradients required. The meta_lr controls how aggressively the model adapts: larger values favor the new batch, smaller values preserve knowledge from all previous rounds. The query set is evaluated after adaptation to give you a signal of how well the model generalises to unseen samples from the same distribution shift.
Step 5: Build the Main Training Loop
data/ directory created in Step 2. If you see FileNotFoundError: data/images.npy, run python generate_data.py first.
Tie everything together into a complete active learning pipeline.
Create train_owl_gps.py:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import pickle # safe: file generated locally by generate_data.py
from active_learner import GeospatialDataset, SimpleCNN, ConceptWeightedActiveLearner
from meta_learning import OnlineMetaLearner
def main():
# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42) # Sets seed for all GPUs
# Load synthetic data
print("Loading synthetic geospatial data...")
images = np.load("data/images.npy")
labels = np.load("data/labels.npy")
with open("data/concepts.pkl", "rb") as f:
concepts = pickle.load(f)
print(f"Loaded {len(images)} samples with {images.shape[1]} spectral bands")
# Split into train/test (80/20 split)
n_samples = len(images)
indices = np.random.permutation(n_samples)
train_idx = indices[:400]
test_idx = indices[400:]
# Create datasets
train_dataset = GeospatialDataset(
images[train_idx],
labels[train_idx],
[concepts[i] for i in train_idx]
)
test_dataset = GeospatialDataset(
images[test_idx],
labels[test_idx],
[concepts[i] for i in test_idx]
)
print(f"Train set: {len(train_dataset)} samples")
print(f"Test set: {len(test_dataset)} samples")
# Initialize model and learner
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
model = SimpleCNN(n_bands=4, n_classes=2)
active_learner = ConceptWeightedActiveLearner(model, device=device)
meta_learner = OnlineMetaLearner(model, inner_lr=0.01, meta_lr=0.1)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Active learning configuration
initial_budget = 20 # Start with 20 randomly labeled samples
query_budget = 10 # Query 10 new samples per round
n_rounds = 10 # Run for 10 active learning rounds
alpha = 0.5 # Balance uncertainty and relevance equally
# Initialize with random labeled samples
all_train_indices = set(range(len(train_dataset)))
labeled_indices = set(np.random.choice(list(all_train_indices), initial_budget, replace=False))
unlabeled_indices = all_train_indices - labeled_indices
print(f"\n{'='*60}")
print(f"Starting active learning with {len(labeled_indices)} initial labeled samples")
print(f"{'='*60}\n")
# Track performance over rounds
results = {
'round': [],
'n_labeled': [],
'test_auc': []
}
# Create test loader (once, since test set doesn't change)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
# Active learning loop
for round_num in range(n_rounds):
print(f"\n--- Round {round_num + 1}/{n_rounds} ---")
print(f"Labeled samples: {len(labeled_indices)}")
print(f"Unlabeled samples: {len(unlabeled_indices)}")
# Create data loaders for current labeled set
labeled_subset = Subset(train_dataset, list(labeled_indices))
labeled_loader = DataLoader(labeled_subset, batch_size=16, shuffle=True)
# Train for several epochs on current labeled set
loss = float('nan') # Initialize to handle empty loader edge case
for epoch in range(5):
loss = active_learner.train_epoch(labeled_loader, optimizer, criterion)
print(f"Training loss (final epoch): {loss:.4f}")
# Evaluate on test set
test_auc = active_learner.evaluate(test_loader)
print(f"Test AUC: {test_auc:.4f}")
results['round'].append(round_num + 1)
results['n_labeled'].append(len(labeled_indices))
results['test_auc'].append(test_auc)
# Skip querying on the last round
if round_num == n_rounds - 1:
break
# Query new samples using concept-weighted uncertainty
unlabeled_subset = Subset(train_dataset, list(unlabeled_indices))
unlabeled_loader = DataLoader(unlabeled_subset, batch_size=32, shuffle=False)
new_indices, scores = active_learner.select_samples(
unlabeled_loader,
train_dataset.concepts, # Full list: __getitem__ returns dataset indices (0..399)
budget=query_budget,
alpha=alpha
)
print(f"Selected {len(new_indices)} new samples for labeling")
print(f" Score range: [{scores.min():.3f}, {scores.max():.3f}]")
# "Label" the selected samples (in real world, a human annotator does this)
labeled_indices.update(new_indices.tolist())
unlabeled_indices -= set(new_indices.tolist())
# Meta-learning update: adapt to newly labeled batch using Reptile
# Note: Adam optimizer momentum buffers become stale after Reptile's in-place
# weight update. This is acceptable for tutorial code but production systems
# may want to reset optimizer state or use SGD.
if len(new_indices) >= 4:
# 50/50 split: support set for inner-loop adaptation, query set for evaluation
n_support = len(new_indices) // 2
support_subset = Subset(train_dataset, new_indices[:n_support].tolist())
query_subset = Subset(train_dataset, new_indices[n_support:].tolist())
support_loader = DataLoader(support_subset, batch_size=8, shuffle=True)
query_loader = DataLoader(query_subset, batch_size=8, shuffle=False)
meta_loss = meta_learner.meta_update(support_loader, query_loader, criterion)
if meta_loss is not None:
print(f" Reptile meta-update query loss: {meta_loss:.4f}")
# Print final summary
print(f"\n{'='*60}")
print("Active Learning Complete!")
print(f"{'='*60}")
print(f"\nRound | Labeled | Test AUC")
print(f"------|---------|----------")
for r, n, auc in zip(results['round'], results['n_labeled'], results['test_auc']):
print(f" {r:2d} | {n:3d} | {auc:.4f}")
print(f"\nFinal Test AUC: {results['test_auc'][-1]:.4f}")
print(f"Labeled {len(labeled_indices)}/{len(train_dataset)} available training samples")
if __name__ == "__main__":
main()
Run the full pipeline:
python train_owl_gps.py
Expected output (final summary table):
============================================================
Active Learning Complete!
============================================================
Round | Labeled | Test AUC
------|---------|----------
1 | 20 | 0.5234
2 | 30 | 0.4876
3 | 40 | 0.5123
4 | 50 | 0.4987
5 | 60 | 0.5345
6 | 70 | 0.5012
7 | 80 | 0.5456
8 | 90 | 0.5189
9 | 100 | 0.5567
10 | 110 | 0.5234
Final Test AUC: 0.5234
Labeled 110/400 available training samples
What just happened: The active learning loop runs for 10 rounds. Each round trains the CNN on the current labeled set, evaluates on the held-out test set, selects the next most informative samples using concept-weighted uncertainty (combining model entropy with domain relevance), and optionally runs a Reptile meta-update to help the model stay adaptable as new batches arrive. By the final round you'll see the pipeline mechanics working correctly — samples selected, meta-updates applied, results logged — even though AUC hovers near chance. That's by design with this synthetic data.
Where to Go Next
The pipeline you built is a working foundation. Here are the natural extensions, roughly in order of difficulty.
1. Swap in Real Satellite Data
The biggest jump in learning value. Two free sources work well as direct replacements for the synthetic data:
- Sentinel-2 (ESA) — 13 spectral bands, 10m resolution, free global coverage via the Copernicus Data Space. Use bands B02, B03, B04, B08 (RGB + NIR) to match the 4-band setup in this tutorial.
- Landsat-8/9 (USGS) — free via EarthExplorer, coarser resolution but longer historical archive. Good for change detection tasks.
To swap in real data: replace generate_data.py with a script that tiles your GeoTIFF into 32×32 patches and extracts concept features (land cover from OpenStreetMap, distance to industrial zones from OSM or national databases). The rest of the pipeline — active_learner.py, meta_learning.py, train_owl_gps.py — works unchanged.
2. Add a Random Sampling Baseline
Right now you have no way to know if concept-weighted active learning is actually helping. Run the same pipeline with alpha=1.0 (pure uncertainty) and a second run where you replace select_samples with random selection. Plot all three AUC curves over rounds. If your concept weights are good, the concept-weighted curve should pull ahead after round 3-4 — that gap is the value you're adding over a naive approach.
3. Replace the CNN Backbone
The SimpleCNN in this tutorial is intentionally minimal. For real imagery, swap it for a pretrained backbone:
- ResNet-18 — modify the first conv layer to accept 4 bands instead of 3:
nn.Conv2d(4, 64, kernel_size=7, ...), then load ImageNet weights for all other layers. Strong baseline with minimal effort. - EfficientNet-B0 — smaller and faster than ResNet for the same accuracy range. Available via
torchvision.models.
The active learner and meta-learner are model-agnostic — swap the backbone and nothing else needs to change.
4. Improve Uncertainty Estimation
Entropy over a single forward pass is a weak uncertainty signal because a confident but wrong model will produce low entropy. Two better approaches:
- Monte Carlo Dropout — keep dropout active at inference time, run 10-20 forward passes per sample, and measure variance across predictions. Requires adding
self.model.train()during uncertainty computation and running multiple passes. Significantly better calibration. - Deep Ensembles — train 3-5 independent models with different seeds, measure disagreement between them. More compute but the most reliable uncertainty estimate available without Bayesian methods.
5. Tune the Key Hyperparameters
Three parameters have the most impact on real-world performance:
alpha— start at 0.5, then shift toward 0.0 (more concept-driven) early in training when the model is poorly calibrated, and toward 1.0 (more uncertainty-driven) as the model improves. A simple schedule:alpha = min(1.0, round_num / n_rounds + 0.3).meta_lr— if you see the model forgetting earlier rounds (AUC drops between rounds), lower it toward 0.05. If adaptation to new batches is too slow, raise it toward 0.2.query_budget— in real annotation workflows, budget is usually fixed by cost per label. Set it to match your actual annotation cost: if a human annotator labels 20 patches per hour and you have 2 hours per round,query_budget=40.
6. Move Toward Production
When you're ready to run this against a real dataset at scale:
- Save model checkpoints — add
torch.save(model.state_dict(), f"checkpoint_round_{round_num}.pt")after each round so you can resume without retraining from scratch. - Replace the in-memory labeled set —
labeled_indicesis a Python set that lives in RAM. For large datasets, store labeled indices and their annotations in a database or a simple CSV so human annotators can work asynchronously. - Decouple annotation from training — the tutorial simulates instant labeling. In production,
select_sampleswrites a query batch to a queue, human annotators label it over hours or days, and training resumes when labels arrive. The pipeline structure supports this naturally.
Written by: Keith Larson
klarson@3k1o.com | klarson@planet-ai.net
Visit the original blog post for the latest updates and to leave a comment.
No comments:
Post a Comment