Code

Model Definitions and Core Snippets

This page highlights the concrete model classes and the key places where they are instantiated and aggregated in benchmark runs.

Models

Baseline / Ensemble CNN Definition

Source: src/lesionshiftai/models/cnn.py. Baseline and each ensemble member use the same ResNet50 backbone with a 1-logit classification head.

class BaselineCNN(nn.Module):
    def __init__(self, pretrained: bool = True) -> None:
        super().__init__()
        weights = ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        self.backbone = resnet50(weights=weights)
        in_features = self.backbone.fc.in_features
        # single logit for BCE
        self.backbone.fc = nn.Linear(in_features, 1)

    def forward(self, x: torch.Tensor):
        return self.backbone(x).squeeze(1)

Models

Vision Transformer Definition

Source: src/lesionshiftai/models/vit.py. The ViT wrapper is architecture-agnostic via model_name and always outputs a single binary logit.

class ViTBinaryClassifier(nn.Module):
    def __init__(
        self,
        model_name: str = "vit_base_patch16_224",
        pretrained: bool = True
    ) -> None:
        super().__init__()
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=1
        )

    def forward(self, x: torch.Tensor):
        return self.backbone(x).squeeze(-1)

Training

Where Model Variants Are Selected

Source: scripts/train_baseline_cnn.py, scripts/train_ensemble_member_cnn.py, and scripts/train_vit.py. This is where run scripts pick CNN vs ViT backbones and specific pretrained ViT variants.

# baseline and ensemble members
model = BaselineCNN(pretrained=True).to(device)

# ViT experiment variant
model = ViTBinaryClassifier(
    model_name="vit_large_patch16_224.augreg_in21k_ft_in1k",
    pretrained=True
).to(device)

Training

Core Fine-Tuning and Checkpoint Loop

Source: scripts/train_vit.py. Each epoch performs train/validation passes, applies warmup+cosine scheduling, and tracks the best checkpoint by validation PR AUC.

optimizer = AdamW(
    model.parameters(),
    lr=cfg.train.lr,
    weight_decay=cfg.train.weight_decay
)
scheduler = _build_scheduler(optimizer, cfg)

for epoch in range(start_epoch, cfg.train.epochs + 1):
    train_metrics = train_one_epoch(...)
    val_metrics, val_preds = evaluate_loader(...)
    scheduler.step()

    ckpt_payload = _build_checkpoint_payload(...)
    torch.save(ckpt_payload, run_dir / "checkpoints" / "last.pt")

    if best_pr_auc == float("-inf") or val_metrics["pr_auc"] > best_pr_auc:
        best_pr_auc = float(val_metrics["pr_auc"])
        torch.save(ckpt_payload, run_dir / "checkpoints" / "best.pt")

Ensemble

How Ensemble Predictions Are Combined

Source: scripts/train_ensemble_member_cnn.py. HAM10000 aggregate predictions are formed by mean malignancy probability across fold members.

test_aggregate_df = (
    all_test_preds_df
    .groupby(["dataset", "sample_id"], as_index=False)
    .agg(
        label=("label", "first"),
        prob_malignant=("prob_malignant", "mean"),
        prob_malignant_std=("prob_malignant", "std"),
        member_predictions=("member_fold", "nunique")
    )
)