Code
Model Definitions and Core Snippets
This page highlights the concrete model classes and the key places where they are instantiated and aggregated in benchmark runs.
Models
Baseline / Ensemble CNN Definition
Source: src/lesionshiftai/models/cnn.py. Baseline and each ensemble member use the same ResNet50 backbone with a 1-logit classification head.
class BaselineCNN(nn.Module):
def __init__(self, pretrained: bool = True) -> None:
super().__init__()
weights = ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
self.backbone = resnet50(weights=weights)
in_features = self.backbone.fc.in_features
# single logit for BCE
self.backbone.fc = nn.Linear(in_features, 1)
def forward(self, x: torch.Tensor):
return self.backbone(x).squeeze(1)Models
Vision Transformer Definition
Source: src/lesionshiftai/models/vit.py. The ViT wrapper is architecture-agnostic via model_name and always outputs a single binary logit.
class ViTBinaryClassifier(nn.Module):
def __init__(
self,
model_name: str = "vit_base_patch16_224",
pretrained: bool = True
) -> None:
super().__init__()
self.backbone = timm.create_model(
model_name,
pretrained=pretrained,
num_classes=1
)
def forward(self, x: torch.Tensor):
return self.backbone(x).squeeze(-1)Training
Where Model Variants Are Selected
Source: scripts/train_baseline_cnn.py, scripts/train_ensemble_member_cnn.py, and scripts/train_vit.py. This is where run scripts pick CNN vs ViT backbones and specific pretrained ViT variants.
# baseline and ensemble members
model = BaselineCNN(pretrained=True).to(device)
# ViT experiment variant
model = ViTBinaryClassifier(
model_name="vit_large_patch16_224.augreg_in21k_ft_in1k",
pretrained=True
).to(device)Training
Core Fine-Tuning and Checkpoint Loop
Source: scripts/train_vit.py. Each epoch performs train/validation passes, applies warmup+cosine scheduling, and tracks the best checkpoint by validation PR AUC.
optimizer = AdamW(
model.parameters(),
lr=cfg.train.lr,
weight_decay=cfg.train.weight_decay
)
scheduler = _build_scheduler(optimizer, cfg)
for epoch in range(start_epoch, cfg.train.epochs + 1):
train_metrics = train_one_epoch(...)
val_metrics, val_preds = evaluate_loader(...)
scheduler.step()
ckpt_payload = _build_checkpoint_payload(...)
torch.save(ckpt_payload, run_dir / "checkpoints" / "last.pt")
if best_pr_auc == float("-inf") or val_metrics["pr_auc"] > best_pr_auc:
best_pr_auc = float(val_metrics["pr_auc"])
torch.save(ckpt_payload, run_dir / "checkpoints" / "best.pt")Ensemble
How Ensemble Predictions Are Combined
Source: scripts/train_ensemble_member_cnn.py. HAM10000 aggregate predictions are formed by mean malignancy probability across fold members.
test_aggregate_df = (
all_test_preds_df
.groupby(["dataset", "sample_id"], as_index=False)
.agg(
label=("label", "first"),
prob_malignant=("prob_malignant", "mean"),
prob_malignant_std=("prob_malignant", "std"),
member_predictions=("member_fold", "nunique")
)
)