plots
pytorch_lattice.plots
Plotting functions for PyTorch Lattice calibrated models using matplotlib.
calibrator(model, feature_name)
Plots the calibrator for the given feature and calibrated model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
Union[CalibratedLinear, CalibratedLattice]
|
The calibrated model for which to plot calibrators. |
required |
feature_name |
str
|
The name of the feature for which to plot the calibrator. |
required |
Source code in pytorch_lattice/plots.py
linear_coefficients(model)
Plots the coefficients for the linear layer of a calibrated linear model.