Docs Shape Inference

Shape Inference

The KynML compiler includes a typed intermediate representation (IR) with a shape-inference pass that runs automatically as part of compile_to_ir. It catches dimension mismatches, invalid loss/output-unit combinations, and activation footguns before the generated Python script ever runs.


How It Works

Compilation follows this pipeline:

parse_text / parse_file
    → apply_composition   (import resolution + param substitution)
    → validate_program    (AST-level semantic checks)
    → lower_program       (AST → IRModule, inferred=False)
    → run_passes          (→ infer_module, sets inferred=True)

compile_to_ir in kynml/pipeline.py is the canonical entry point that runs all four stages. validate_program alone does not run shape inference — you must call compile_to_ir (or infer_module directly on an IRModule) to get shape errors.

The inference pass (kynml/ir/infer.py) threads a TypeShape through the ops in each graph, left to right:

Op What inference fills in Output shape
InputOp(size=N) out_type = [?,N]:float32 [?,N]
LinearOp(out_features=K) in_features from preceding feature_dim; out_type = [?,K] [?,K]
DropoutOp in_type = out_type = preceding shape unchanged
BatchNorm1dOp num_features from preceding feature_dim (or explicit value) unchanged

The batch axis (?) is always dynamic (represented as None in TypeShape.dims).

After graph inference, the pass reconciles the model's final output shape against the train block's loss function.


What It Catches

Missing input layer

A LinearOp with no preceding InputOp (or any op that sets a shape) raises:

<source> [model M]: dense/linear layer has no preceding shape;
cannot infer in_features. Place an input layer first.

batchnorm with no inferable size

A BatchNorm1dOp placed before any InputOp or LinearOp, and with no explicit size, raises:

<source> [model M]: batchnorm has no feature dim;
place it after an input or dense layer, or specify explicit features.

batchnorm explicit size mismatch

batchnorm 16 after dense 8 relu raises:

<source> [model M]: batchnorm explicit features 16 != incoming feature dim 8

bce with wrong output unit count

bce expects exactly 1 output unit. A final dense 4 sigmoid with loss = bce raises:

<source> [train]: loss 'bce' expects a final dense layer with 1 output
unit (got 4). Adjust the last dense layer.

cross_entropy / nll with fewer than 2 output units

Both multiclass losses require at least 2 output units:

<source> [train]: loss 'cross_entropy' requires a final dense layer with
≥2 output units (got 1).
Add 'dense N linear' as the final layer where N is the class count.

Warnings (Not Errors)

The following are emitted as warnings on IRModule.warnings — they do not raise and do not block codegen. They surface as severity="warning" entries via kynml.lsp.diagnostics.diagnose.

cross_entropy + softmax

cross_entropy applies log-softmax internally. A final softmax activation double-applies it:

<source> [train]: cross_entropy applies log-softmax internally;
a final 'softmax' activation double-applies it.
Consider using 'linear' activation with cross_entropy.

cross_entropy + log_softmax

<source> [train]: cross_entropy with 'log_softmax' is a footgun
(double log-softmax). Use 'nll' loss with log_softmax,
or use cross_entropy with 'linear' activation.

Loss / Target Type Reconciliation

After shape inference, n_classes and target_type are written onto IRTrain:

Loss family n_classes target_type dtype Notes
cross_entropy, nll final output feature dim int64 Target tensor is torch.long
bce float32, shape [?,1]
mse, mae, l1, huber float32, shape [?,K] where K = output dim

The PyTorch backend reads n_classes and target_type directly from the IR. It does not re-derive them from the AST.


Validate vs Compile Timing

Check When it fires Entry point
Loss keyword valid validate_program kynml validate, kynml compile, kynml train
Epochs/batch positive validate_program same
in_features auto-fill infer_module compile_to_ir, kynml compile, kynml train
bce unit count infer_module same
cross_entropy/nll unit count infer_module same
batchnorm feature mismatch infer_module same
Activation footgun warnings infer_module same (warnings only)

kynml validate runs only validate_program — shape checks are not run. Use kynml compile or compile_to_ir to get shape errors.


Working with the IR Directly

from kynml.parser import parse_text
from kynml.semantic import validate_program
from kynml.ir.builder import lower_program
from kynml.ir.infer import infer_module

src = """
dataset D:
    source = csv("data/x.csv")
    target = "y"

model M:
    input 10
    dense 64 relu
    dense 1 linear

train:
    model = M
    data  = D
    loss  = mse
    optimizer = adam(lr=0.001)
    epochs = 10
    batch  = 32
"""

prog = parse_text(src)
validate_program(prog)
module = lower_program(prog)
inferred = infer_module(module)

graph = inferred.graph("M")
print(graph.input_type)   # [?,10]:float32
print(graph.output_type)  # [?,1]:float32

for op in graph.ops:
    print(op)
# InputOp(size=10, out_type=TypeShape((None, 10), ...))
# LinearOp(in_features=10, out_features=64, activation='relu', ...)
# LinearOp(in_features=64, out_features=1,  activation='linear', ...)

Or use the single-call front door:

from kynml.parser import parse_text
from kynml.pipeline import compile_to_ir

module = compile_to_ir(parse_text(src))
assert module.inferred

TypeShape Format

TypeShape is printed as [dims]:dtype. The batch axis is ? (dynamic):

[?,10]:float32   # input layer with 10 features
[?,64]:float32   # after dense 64
[?,3]:float32    # after dense 3 (3-class output)

feature_dim returns the last (static) dimension — the width that LinearOp and BatchNorm1dOp care about.


See Also