Tutorial: Multiclass Classification
End-to-end walkthrough: classify inputs into N > 2 categories using cross_entropy loss.
What you will build
A model that maps tabular features to one of N discrete classes. KynML detects loss = cross_entropy and switches to the multiclass codegen path: targets become torch.long integers, the final layer uses log_softmax or a raw linear output (CrossEntropyLoss applies softmax internally), and accuracy is computed via argmax.
Prerequisites
pip install kynml
1. Data format
Multiclass labels must be integer class indices starting from 0. Column type in the CSV must be castable to int64. Example — three-class iris-style dataset:
sepal_length,sepal_width,petal_length,petal_width,species
5.1,3.5,1.4,0.2,0
4.9,3.0,1.4,0.2,0
7.0,3.2,4.7,1.4,1
6.3,3.3,6.0,2.5,2
Labels: 0, 1, 2 — not strings, not one-hot vectors.
2. Write the spec
Save as iris_classifier.kyn:
dataset IrisData:
source = csv("data/iris.csv")
target = "species"
split = 0.8
normalize = true
model IrisNet:
input 4
dense 32 relu
dense 16 relu
dense 3 linear
train:
model = IrisNet
data = IrisData
loss = cross_entropy
optimizer = adam(lr=0.001)
epochs = 50
batch = 16
device = auto
evaluate:
metrics = [accuracy]
export:
format = torch
path = "models/iris_classifier.pt"
Critical details
dense 3 linear — the output size must equal the number of classes. activation = linear means no activation function is appended; nn.CrossEntropyLoss expects raw logits, not probabilities.
loss = cross_entropy — triggers the multiclass codegen path. Internally this maps to nn.CrossEntropyLoss().
Do not use sigmoid or softmax on the output layer with cross_entropy — CrossEntropyLoss applies log_softmax internally. Using log_softmax output and nll loss is the manual equivalent.
3. Alternative: log_softmax + nll
If you need explicit probability outputs at inference time, use this pattern instead:
model IrisNet:
input 4
dense 32 relu
dense 16 relu
dense 3 log_softmax
train:
...
loss = nll
log_softmax → nn.LogSoftmax(dim=1). nll → nn.NLLLoss(). The outputs are log-probabilities; apply torch.exp() downstream to get probabilities.
4. Larger model with regularisation
dataset FlowerData:
source = csv("data/flowers.csv")
target = "label"
split = 0.8
normalize = true
num_workers = 2
pin_memory = true
model FlowerClassifier:
input 12
dense 256 gelu
batchnorm
dropout 0.3
dense 128 gelu
batchnorm
dropout 0.2
dense 64 relu
dense 5 linear
train:
model = FlowerClassifier
data = FlowerData
loss = cross_entropy
optimizer = adamw(lr=0.001, weight_decay=0.01)
epochs = 100
batch = 64
device = auto
scheduler = onecycle(max_lr=0.01)
early_stop = early_stop(patience=15)
checkpoint = checkpoint(every_n=10, path="checkpoints/flower.pt", async_save=true)
precision = fp16
evaluate:
metrics = [accuracy]
export:
format = onnx
path = "models/flower_classifier.onnx"
input_shape = [1, 12]
opset = 17
gelu — nn.GELU(). Works well for deeper networks and transformers-derived architectures.
onecycle(max_lr=0.01) — OneCycleLR, a one-cycle policy that ramps up then anneals the learning rate. The scheduler's epochs parameter is populated at codegen time from the epochs field.
precision = fp16 — AMP training. See Speed Guide for GPU-only caveats.
ONNX export — requires input_shape. Here [1, 12] is batch size 1, 12 features. See Export Formats.
5. Run it
.venv/bin/python -m kynml.cli validate iris_classifier.kyn
.venv/bin/python -m kynml.cli train iris_classifier.kyn
Expected output:
Epoch 1/50 - loss: 1.0923
Epoch 2/50 - loss: 0.9871
...
Epoch 50/50 - loss: 0.1034
accuracy: 0.9667
Saved model to /path/to/models/iris_classifier.pt
6. What the generated PyTorch looks like
KynML detects cross_entropy and switches to the multiclass dataset loader, which casts targets to torch.long rather than float32:
IS_MULTICLASS = True
N_CLASSES = 3 # inferred from last dense layer units
def load_dataset() -> tuple[DataLoader, DataLoader]:
df = pd.read_csv(DATASET_PATH)
features = df.drop(columns=[TARGET_COLUMN])
target = df[TARGET_COLUMN]
numeric_features = pd.get_dummies(features, drop_first=False)
x = numeric_features.astype("float32").to_numpy()
y = target.astype("int64").to_numpy() # <-- int64, not float32
x_train, x_test, y_train, y_test = train_test_split(
x, y, train_size=0.8, shuffle=True, random_state=42,
)
...
train_dataset = TensorDataset(
torch.tensor(x_train, dtype=torch.float32),
torch.tensor(y_train, dtype=torch.long), # <-- long, not float32
)
Metric values use argmax on the raw logits:
def _metric_values(predictions: np.ndarray, targets: np.ndarray) -> dict[str, float]:
pred_classes = np.argmax(predictions, axis=-1)
targets_flat = targets.flatten().astype(np.int64)
values["accuracy"] = float(np.mean(pred_classes == targets_flat))
values["mae"] = float(np.mean(np.abs(pred_classes - targets_flat)))
values["mse"] = float(np.mean((pred_classes - targets_flat) ** 2))
values["rmse"] = float(np.sqrt(values["mse"]))
return values
7. Inspect the compiled script before training
.venv/bin/python -m kynml.cli compile iris_classifier.kyn --out generated/iris.py
# then read it
less generated/iris.py
This is the exact code that will run. No hidden magic — edit it freely if you need custom behaviour beyond what the spec exposes.
8. Next steps
- Export for inference: Export Formats
- Optimise throughput: Speed Guide
- Use a HuggingFace dataset as the source: Datasets and Connectors
Troubleshooting
RuntimeError: Expected target size (N, C), got torch.Size([N]) — this happens when you accidentally use loss = cross_entropy with a one-hot encoded target column. KynML expects a single integer label column, not one-hot vectors.
output size mismatch — the input N value in the model block must match the number of features in the CSV after pd.get_dummies. Print numeric_features.columns.tolist() in the generated script to debug.
Class imbalance — KynML does not yet expose weight on CrossEntropyLoss. As a workaround, compile the script and add weight=torch.tensor([...]) to the nn.CrossEntropyLoss() call in the generated file.
IndexError: index out of bounds — your target column contains class indices that exceed N_CLASSES - 1. Check that labels run 0 .. N-1 with no gaps.