ImageFolder + DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
tf = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
ds = ImageFolder("data/train", transform=tf)
loader = DataLoader(ds, batch_size=32, shuffle=True, num_workers=4)
Directory layout: one subfolder per class name; ds.classes lists labels.
Models and Weights
import torch
from torchvision.models import resnet50, ResNet50_Weights
weights = ResNet50_Weights.IMAGENET1K_V2
model = resnet50(weights=weights).eval()
preprocess = weights.transforms()
# x: batch from preprocess(PIL) or equivalent
with torch.no_grad():
logits = model(x)
Transforms v2 (brief)
Newer APIs under torchvision.transforms.v2 accept tensors or PIL and support bbox/mask transforms consistently—prefer them for detection/segmentation training when your version includes them.
Takeaways
- Always match preprocessing to the weights you load.
- Use
torchvision.opsfor NMS, ROI align, etc., in detection heads. - Pin
torchandtorchvisionversions that are tested together.
Quick FAQ
torch.cuda.amp), gradient accumulation, or smaller models.collate_fn to DataLoader to pad variable-size images or merge dict targets for detection.