Skip to content

models

pytorch_lattice.models.CalibratedLattice

Bases: ConstrainedModule

PyTorch Calibrated Lattice Model.

Creates a torch.nn.Module representing a calibrated lattice model, which will be constructed using the provided model configuration. Note that the model inputs should match the order in which they are defined in the feature_configs.

Attributes:

Name Type Description
All

__init__ arguments.

calibrators

A dictionary that maps feature names to their calibrators.

lattice

The Lattice layer of the model.

output_calibrator

The output NumericalCalibrator calibration layer. This will be None if no output calibration is desired.

Example:

feature_configs = [...]
calibrated_model = CalibratedLattice(feature_configs, ...)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(calibrated_model.parameters(recurse=True), lr=1e-1)

dataset = pyl.utils.data.Dataset(...)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
for epoch in range(100):
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = calibrated_model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        calibrated_model.apply_constraints()
Source code in pytorch_lattice/models/calibrated_lattice.py
class CalibratedLattice(ConstrainedModule):
    """PyTorch Calibrated Lattice Model.

    Creates a `torch.nn.Module` representing a calibrated lattice model, which will be
    constructed using the provided model configuration. Note that the model inputs
    should match the order in which they are defined in the `feature_configs`.

    Attributes:
        All: `__init__` arguments.
        calibrators: A dictionary that maps feature names to their calibrators.
        lattice: The `Lattice` layer of the model.
        output_calibrator: The output `NumericalCalibrator` calibration layer. This
            will be `None` if no output calibration is desired.

    Example:

    ```python
    feature_configs = [...]
    calibrated_model = CalibratedLattice(feature_configs, ...)

    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(calibrated_model.parameters(recurse=True), lr=1e-1)

    dataset = pyl.utils.data.Dataset(...)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
    for epoch in range(100):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = calibrated_model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            calibrated_model.apply_constraints()
    ```
    """

    def __init__(
        self,
        features: list[Union[NumericalFeature, CategoricalFeature]],
        clip_inputs: bool = True,
        output_min: Optional[float] = None,
        output_max: Optional[float] = None,
        kernel_init: LatticeInit = LatticeInit.LINEAR,
        interpolation: Interpolation = Interpolation.HYPERCUBE,
        output_calibration_num_keypoints: Optional[int] = None,
    ) -> None:
        """Initializes an instance of `CalibratedLattice`.

        Args:
            features: A list of numerical and/or categorical feature configs.
            clip_inputs: Whether to restrict inputs to the bounds of lattice.
            output_min: The minimum output value for the model. If `None`, the minimum
                output value will be unbounded.
            output_max: The maximum output value for the model. If `None`, the maximum
                output value will be unbounded.
            kernel_init: the method of initializing kernel weights. If otherwise
                unspecified, will default to `LatticeInit.LINEAR`.
            interpolation: the method of interpolation in the lattice's forward pass.
                If otherwise unspecified, will default to `Interpolation.HYPERCUBE`.
            output_calibration_num_keypoints: The number of keypoints to use for the
                output calibrator. If `None`, no output calibration will be used.

        Raises:
            ValueError: If any feature configs are not `NUMERICAL` or `CATEGORICAL`.
        """
        super().__init__()

        self.features = features
        self.clip_inputs = clip_inputs
        self.output_min = output_min
        self.output_max = output_max
        self.kernel_init = kernel_init
        self.interpolation = interpolation
        self.output_calibration_num_keypoints = output_calibration_num_keypoints
        self.monotonicities = initialize_monotonicities(features)
        self.calibrators = initialize_feature_calibrators(
            features=features,
            output_min=0,
            output_max=[feature.lattice_size - 1 for feature in features],
        )

        self.lattice = Lattice(
            lattice_sizes=[feature.lattice_size for feature in features],
            monotonicities=self.monotonicities,
            clip_inputs=self.clip_inputs,
            output_min=self.output_min,
            output_max=self.output_max,
            interpolation=interpolation,
            kernel_init=kernel_init,
        )

        self.output_calibrator = initialize_output_calibrator(
            output_calibration_num_keypoints=output_calibration_num_keypoints,
            monotonic=not all(m is None for m in self.monotonicities),
            output_min=output_min,
            output_max=output_max,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Runs an input through the network to produce a calibrated lattice output.

        Args:
            x: The input tensor of feature values of shape `(batch_size, num_features)`.

        Returns:
            torch.Tensor of shape `(batch_size, 1)` containing the model output result.
        """
        result = calibrate_and_stack(x, self.calibrators)
        result = self.lattice(result)
        if self.output_calibrator is not None:
            result = self.output_calibrator(result)

        return result

    @torch.no_grad()
    def apply_constraints(self) -> None:
        """Constrains the model into desired constraints specified by the config."""
        for calibrator in self.calibrators.values():
            calibrator.apply_constraints()
        self.lattice.apply_constraints()
        if self.output_calibrator:
            self.output_calibrator.apply_constraints()

    @torch.no_grad()
    def assert_constraints(self, eps: float = 1e-6) -> dict[str, list[str]]:
        """Asserts all layers within model satisfied specified constraints.

        Asserts monotonicity pairs and output bounds for categorical calibrators,
        monotonicity and output bounds for numerical calibrators, and monotonicity and
        weights summing to 1 if weighted_average for linear layer.

        Args:
            eps: the margin of error allowed

        Returns:
            A dict where key is feature_name for calibrators and 'linear' for the linear
            layer, and value is the error messages for each layer. Layers with no error
            messages are not present in the dictionary.
        """
        messages = {}

        for name, calibrator in self.calibrators.items():
            calibrator_messages = calibrator.assert_constraints(eps)
            if calibrator_messages:
                messages[f"{name}_calibrator"] = calibrator_messages
        lattice_messages = self.lattice.assert_constraints(eps)
        if lattice_messages:
            messages["lattice"] = lattice_messages
        if self.output_calibrator:
            output_calibrator_messages = self.output_calibrator.assert_constraints(eps)
            if output_calibrator_messages:
                messages["output_calibrator"] = output_calibrator_messages

        return messages

__init__(features, clip_inputs=True, output_min=None, output_max=None, kernel_init=LatticeInit.LINEAR, interpolation=Interpolation.HYPERCUBE, output_calibration_num_keypoints=None)

Initializes an instance of CalibratedLattice.

Parameters:

Name Type Description Default
features list[Union[NumericalFeature, CategoricalFeature]]

A list of numerical and/or categorical feature configs.

required
clip_inputs bool

Whether to restrict inputs to the bounds of lattice.

True
output_min Optional[float]

The minimum output value for the model. If None, the minimum output value will be unbounded.

None
output_max Optional[float]

The maximum output value for the model. If None, the maximum output value will be unbounded.

None
kernel_init LatticeInit

the method of initializing kernel weights. If otherwise unspecified, will default to LatticeInit.LINEAR.

LINEAR
interpolation Interpolation

the method of interpolation in the lattice's forward pass. If otherwise unspecified, will default to Interpolation.HYPERCUBE.

HYPERCUBE
output_calibration_num_keypoints Optional[int]

The number of keypoints to use for the output calibrator. If None, no output calibration will be used.

None

Raises:

Type Description
ValueError

If any feature configs are not NUMERICAL or CATEGORICAL.

Source code in pytorch_lattice/models/calibrated_lattice.py
def __init__(
    self,
    features: list[Union[NumericalFeature, CategoricalFeature]],
    clip_inputs: bool = True,
    output_min: Optional[float] = None,
    output_max: Optional[float] = None,
    kernel_init: LatticeInit = LatticeInit.LINEAR,
    interpolation: Interpolation = Interpolation.HYPERCUBE,
    output_calibration_num_keypoints: Optional[int] = None,
) -> None:
    """Initializes an instance of `CalibratedLattice`.

    Args:
        features: A list of numerical and/or categorical feature configs.
        clip_inputs: Whether to restrict inputs to the bounds of lattice.
        output_min: The minimum output value for the model. If `None`, the minimum
            output value will be unbounded.
        output_max: The maximum output value for the model. If `None`, the maximum
            output value will be unbounded.
        kernel_init: the method of initializing kernel weights. If otherwise
            unspecified, will default to `LatticeInit.LINEAR`.
        interpolation: the method of interpolation in the lattice's forward pass.
            If otherwise unspecified, will default to `Interpolation.HYPERCUBE`.
        output_calibration_num_keypoints: The number of keypoints to use for the
            output calibrator. If `None`, no output calibration will be used.

    Raises:
        ValueError: If any feature configs are not `NUMERICAL` or `CATEGORICAL`.
    """
    super().__init__()

    self.features = features
    self.clip_inputs = clip_inputs
    self.output_min = output_min
    self.output_max = output_max
    self.kernel_init = kernel_init
    self.interpolation = interpolation
    self.output_calibration_num_keypoints = output_calibration_num_keypoints
    self.monotonicities = initialize_monotonicities(features)
    self.calibrators = initialize_feature_calibrators(
        features=features,
        output_min=0,
        output_max=[feature.lattice_size - 1 for feature in features],
    )

    self.lattice = Lattice(
        lattice_sizes=[feature.lattice_size for feature in features],
        monotonicities=self.monotonicities,
        clip_inputs=self.clip_inputs,
        output_min=self.output_min,
        output_max=self.output_max,
        interpolation=interpolation,
        kernel_init=kernel_init,
    )

    self.output_calibrator = initialize_output_calibrator(
        output_calibration_num_keypoints=output_calibration_num_keypoints,
        monotonic=not all(m is None for m in self.monotonicities),
        output_min=output_min,
        output_max=output_max,
    )

apply_constraints()

Constrains the model into desired constraints specified by the config.

Source code in pytorch_lattice/models/calibrated_lattice.py
@torch.no_grad()
def apply_constraints(self) -> None:
    """Constrains the model into desired constraints specified by the config."""
    for calibrator in self.calibrators.values():
        calibrator.apply_constraints()
    self.lattice.apply_constraints()
    if self.output_calibrator:
        self.output_calibrator.apply_constraints()

assert_constraints(eps=1e-06)

Asserts all layers within model satisfied specified constraints.

Asserts monotonicity pairs and output bounds for categorical calibrators, monotonicity and output bounds for numerical calibrators, and monotonicity and weights summing to 1 if weighted_average for linear layer.

Parameters:

Name Type Description Default
eps float

the margin of error allowed

1e-06

Returns:

Type Description
dict[str, list[str]]

A dict where key is feature_name for calibrators and 'linear' for the linear

dict[str, list[str]]

layer, and value is the error messages for each layer. Layers with no error

dict[str, list[str]]

messages are not present in the dictionary.

Source code in pytorch_lattice/models/calibrated_lattice.py
@torch.no_grad()
def assert_constraints(self, eps: float = 1e-6) -> dict[str, list[str]]:
    """Asserts all layers within model satisfied specified constraints.

    Asserts monotonicity pairs and output bounds for categorical calibrators,
    monotonicity and output bounds for numerical calibrators, and monotonicity and
    weights summing to 1 if weighted_average for linear layer.

    Args:
        eps: the margin of error allowed

    Returns:
        A dict where key is feature_name for calibrators and 'linear' for the linear
        layer, and value is the error messages for each layer. Layers with no error
        messages are not present in the dictionary.
    """
    messages = {}

    for name, calibrator in self.calibrators.items():
        calibrator_messages = calibrator.assert_constraints(eps)
        if calibrator_messages:
            messages[f"{name}_calibrator"] = calibrator_messages
    lattice_messages = self.lattice.assert_constraints(eps)
    if lattice_messages:
        messages["lattice"] = lattice_messages
    if self.output_calibrator:
        output_calibrator_messages = self.output_calibrator.assert_constraints(eps)
        if output_calibrator_messages:
            messages["output_calibrator"] = output_calibrator_messages

    return messages

forward(x)

Runs an input through the network to produce a calibrated lattice output.

Parameters:

Name Type Description Default
x Tensor

The input tensor of feature values of shape (batch_size, num_features).

required

Returns:

Type Description
Tensor

torch.Tensor of shape (batch_size, 1) containing the model output result.

Source code in pytorch_lattice/models/calibrated_lattice.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Runs an input through the network to produce a calibrated lattice output.

    Args:
        x: The input tensor of feature values of shape `(batch_size, num_features)`.

    Returns:
        torch.Tensor of shape `(batch_size, 1)` containing the model output result.
    """
    result = calibrate_and_stack(x, self.calibrators)
    result = self.lattice(result)
    if self.output_calibrator is not None:
        result = self.output_calibrator(result)

    return result

pytorch_lattice.models.CalibratedLinear

Bases: ConstrainedModule

PyTorch Calibrated Linear Model.

Creates a torch.nn.Module representing a calibrated linear model, which will be constructed using the provided model configuration. Note that the model inputs should match the order in which they are defined in the feature_configs.

Attributes:

Name Type Description
All

__init__ arguments.

calibrators

A dictionary that maps feature names to their calibrators.

linear

The Linear layer of the model.

output_calibrator

The output NumericalCalibrator calibration layer. This will be None if no output calibration is desired.

Example:

feature_configs = [...]
calibrated_model = pyl.models.CalibratedLinear(feature_configs, ...)

loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(calibrated_model.parameters(recurse=True), lr=1e-1)

dataset = pyl.utils.data.Dataset(...)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
for epoch in range(100):
    for inputs, labels in dataloader:
        optimizer.zero_grad()
        outputs = calibrated_model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        calibrated_model.apply_constraints()
Source code in pytorch_lattice/models/calibrated_linear.py
class CalibratedLinear(ConstrainedModule):
    """PyTorch Calibrated Linear Model.

    Creates a `torch.nn.Module` representing a calibrated linear model, which will be
    constructed using the provided model configuration. Note that the model inputs
    should match the order in which they are defined in the `feature_configs`.

    Attributes:
        All: `__init__` arguments.
        calibrators: A dictionary that maps feature names to their calibrators.
        linear: The `Linear` layer of the model.
        output_calibrator: The output `NumericalCalibrator` calibration layer. This
            will be `None` if no output calibration is desired.

    Example:

    ```python
    feature_configs = [...]
    calibrated_model = pyl.models.CalibratedLinear(feature_configs, ...)

    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(calibrated_model.parameters(recurse=True), lr=1e-1)

    dataset = pyl.utils.data.Dataset(...)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
    for epoch in range(100):
        for inputs, labels in dataloader:
            optimizer.zero_grad()
            outputs = calibrated_model(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            calibrated_model.apply_constraints()
    ```
    """

    def __init__(
        self,
        features: list[Union[NumericalFeature, CategoricalFeature]],
        output_min: Optional[float] = None,
        output_max: Optional[float] = None,
        use_bias: bool = True,
        output_calibration_num_keypoints: Optional[int] = None,
    ) -> None:
        """Initializes an instance of `CalibratedLinear`.

        Args:
            features: A list of numerical and/or categorical feature configs.
            output_min: The minimum output value for the model. If `None`, the minimum
                output value will be unbounded.
            output_max: The maximum output value for the model. If `None`, the maximum
                output value will be unbounded.
            use_bias: Whether to use a bias term for the linear combination. If any of
                `output_min`, `output_max`, or `output_calibration_num_keypoints` are
                set, a bias term will not be used regardless of the setting here.
            output_calibration_num_keypoints: The number of keypoints to use for the
                output calibrator. If `None`, no output calibration will be used.

        Raises:
            ValueError: If any feature configs are not `NUMERICAL` or `CATEGORICAL`.
        """
        super().__init__()

        self.features = features
        self.output_min = output_min
        self.output_max = output_max
        self.use_bias = use_bias
        self.output_calibration_num_keypoints = output_calibration_num_keypoints
        self.monotonicities = initialize_monotonicities(features)
        self.calibrators = initialize_feature_calibrators(
            features=features, output_min=output_min, output_max=output_max
        )

        self.linear = Linear(
            input_dim=len(features),
            monotonicities=self.monotonicities,
            use_bias=use_bias,
            weighted_average=bool(
                output_min is not None
                or output_max is not None
                or output_calibration_num_keypoints
            ),
        )

        self.output_calibrator = initialize_output_calibrator(
            output_calibration_num_keypoints=output_calibration_num_keypoints,
            monotonic=not all(m is None for m in self.monotonicities),
            output_min=output_min,
            output_max=output_max,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Runs an input through the network to produce a calibrated linear output.

        Args:
            x: The input tensor of feature values of shape `(batch_size, num_features)`.

        Returns:
            torch.Tensor of shape `(batch_size, 1)` containing the model output result.
        """
        result = calibrate_and_stack(x, self.calibrators)
        result = self.linear(result)
        if self.output_calibrator is not None:
            result = self.output_calibrator(result)

        return result

    @torch.no_grad()
    def apply_constraints(self) -> None:
        """Constrains the model into desired constraints specified by the config."""
        for calibrator in self.calibrators.values():
            calibrator.apply_constraints()
        self.linear.apply_constraints()
        if self.output_calibrator:
            self.output_calibrator.apply_constraints()

    @torch.no_grad()
    def assert_constraints(
        self, eps: float = 1e-6
    ) -> Union[list[str], dict[str, list[str]]]:
        """Asserts all layers within model satisfied specified constraints.

        Asserts monotonicity pairs and output bounds for categorical calibrators,
        monotonicity and output bounds for numerical calibrators, and monotonicity and
        weights summing to 1 if weighted_average for linear layer.

        Args:
            eps: the margin of error allowed

        Returns:
            A dict where key is feature_name for calibrators and 'linear' for the linear
            layer, and value is the error messages for each layer. Layers with no error
            messages are not present in the dictionary.
        """
        messages: dict[str, list[str]] = {}

        for name, calibrator in self.calibrators.items():
            calibrator_messages = calibrator.assert_constraints(eps)
            if calibrator_messages:
                messages[f"{name}_calibrator"] = calibrator_messages
        linear_messages = self.linear.assert_constraints(eps)
        if linear_messages:
            messages["linear"] = linear_messages
        if self.output_calibrator:
            output_calibrator_messages = self.output_calibrator.assert_constraints(eps)
            if output_calibrator_messages:
                messages["output_calibrator"] = output_calibrator_messages

        return messages

__init__(features, output_min=None, output_max=None, use_bias=True, output_calibration_num_keypoints=None)

Initializes an instance of CalibratedLinear.

Parameters:

Name Type Description Default
features list[Union[NumericalFeature, CategoricalFeature]]

A list of numerical and/or categorical feature configs.

required
output_min Optional[float]

The minimum output value for the model. If None, the minimum output value will be unbounded.

None
output_max Optional[float]

The maximum output value for the model. If None, the maximum output value will be unbounded.

None
use_bias bool

Whether to use a bias term for the linear combination. If any of output_min, output_max, or output_calibration_num_keypoints are set, a bias term will not be used regardless of the setting here.

True
output_calibration_num_keypoints Optional[int]

The number of keypoints to use for the output calibrator. If None, no output calibration will be used.

None

Raises:

Type Description
ValueError

If any feature configs are not NUMERICAL or CATEGORICAL.

Source code in pytorch_lattice/models/calibrated_linear.py
def __init__(
    self,
    features: list[Union[NumericalFeature, CategoricalFeature]],
    output_min: Optional[float] = None,
    output_max: Optional[float] = None,
    use_bias: bool = True,
    output_calibration_num_keypoints: Optional[int] = None,
) -> None:
    """Initializes an instance of `CalibratedLinear`.

    Args:
        features: A list of numerical and/or categorical feature configs.
        output_min: The minimum output value for the model. If `None`, the minimum
            output value will be unbounded.
        output_max: The maximum output value for the model. If `None`, the maximum
            output value will be unbounded.
        use_bias: Whether to use a bias term for the linear combination. If any of
            `output_min`, `output_max`, or `output_calibration_num_keypoints` are
            set, a bias term will not be used regardless of the setting here.
        output_calibration_num_keypoints: The number of keypoints to use for the
            output calibrator. If `None`, no output calibration will be used.

    Raises:
        ValueError: If any feature configs are not `NUMERICAL` or `CATEGORICAL`.
    """
    super().__init__()

    self.features = features
    self.output_min = output_min
    self.output_max = output_max
    self.use_bias = use_bias
    self.output_calibration_num_keypoints = output_calibration_num_keypoints
    self.monotonicities = initialize_monotonicities(features)
    self.calibrators = initialize_feature_calibrators(
        features=features, output_min=output_min, output_max=output_max
    )

    self.linear = Linear(
        input_dim=len(features),
        monotonicities=self.monotonicities,
        use_bias=use_bias,
        weighted_average=bool(
            output_min is not None
            or output_max is not None
            or output_calibration_num_keypoints
        ),
    )

    self.output_calibrator = initialize_output_calibrator(
        output_calibration_num_keypoints=output_calibration_num_keypoints,
        monotonic=not all(m is None for m in self.monotonicities),
        output_min=output_min,
        output_max=output_max,
    )

apply_constraints()

Constrains the model into desired constraints specified by the config.

Source code in pytorch_lattice/models/calibrated_linear.py
@torch.no_grad()
def apply_constraints(self) -> None:
    """Constrains the model into desired constraints specified by the config."""
    for calibrator in self.calibrators.values():
        calibrator.apply_constraints()
    self.linear.apply_constraints()
    if self.output_calibrator:
        self.output_calibrator.apply_constraints()

assert_constraints(eps=1e-06)

Asserts all layers within model satisfied specified constraints.

Asserts monotonicity pairs and output bounds for categorical calibrators, monotonicity and output bounds for numerical calibrators, and monotonicity and weights summing to 1 if weighted_average for linear layer.

Parameters:

Name Type Description Default
eps float

the margin of error allowed

1e-06

Returns:

Type Description
Union[list[str], dict[str, list[str]]]

A dict where key is feature_name for calibrators and 'linear' for the linear

Union[list[str], dict[str, list[str]]]

layer, and value is the error messages for each layer. Layers with no error

Union[list[str], dict[str, list[str]]]

messages are not present in the dictionary.

Source code in pytorch_lattice/models/calibrated_linear.py
@torch.no_grad()
def assert_constraints(
    self, eps: float = 1e-6
) -> Union[list[str], dict[str, list[str]]]:
    """Asserts all layers within model satisfied specified constraints.

    Asserts monotonicity pairs and output bounds for categorical calibrators,
    monotonicity and output bounds for numerical calibrators, and monotonicity and
    weights summing to 1 if weighted_average for linear layer.

    Args:
        eps: the margin of error allowed

    Returns:
        A dict where key is feature_name for calibrators and 'linear' for the linear
        layer, and value is the error messages for each layer. Layers with no error
        messages are not present in the dictionary.
    """
    messages: dict[str, list[str]] = {}

    for name, calibrator in self.calibrators.items():
        calibrator_messages = calibrator.assert_constraints(eps)
        if calibrator_messages:
            messages[f"{name}_calibrator"] = calibrator_messages
    linear_messages = self.linear.assert_constraints(eps)
    if linear_messages:
        messages["linear"] = linear_messages
    if self.output_calibrator:
        output_calibrator_messages = self.output_calibrator.assert_constraints(eps)
        if output_calibrator_messages:
            messages["output_calibrator"] = output_calibrator_messages

    return messages

forward(x)

Runs an input through the network to produce a calibrated linear output.

Parameters:

Name Type Description Default
x Tensor

The input tensor of feature values of shape (batch_size, num_features).

required

Returns:

Type Description
Tensor

torch.Tensor of shape (batch_size, 1) containing the model output result.

Source code in pytorch_lattice/models/calibrated_linear.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Runs an input through the network to produce a calibrated linear output.

    Args:
        x: The input tensor of feature values of shape `(batch_size, num_features)`.

    Returns:
        torch.Tensor of shape `(batch_size, 1)` containing the model output result.
    """
    result = calibrate_and_stack(x, self.calibrators)
    result = self.linear(result)
    if self.output_calibrator is not None:
        result = self.output_calibrator(result)

    return result

pytorch_lattice.models.features.CategoricalFeature

Feature configuration for categorical features.

Attributes:

Name Type Description
All

__init__ arguments.

category_indices

A dictionary mapping string categories to their index.

monotonicity_index_pairs

A conversion of monotonicity_pairs from string categories to category indices. Only available if monotonicity_pairs are provided.

Source code in pytorch_lattice/models/features.py
class CategoricalFeature:
    """Feature configuration for categorical features.

    Attributes:
        All: `__init__` arguments.
        category_indices: A dictionary mapping string categories to their index.
        monotonicity_index_pairs: A conversion of `monotonicity_pairs` from string
            categories to category indices. Only available if `monotonicity_pairs` are
            provided.
    """

    def __init__(
        self,
        feature_name: str,
        categories: Union[list[int], list[str]],
        missing_input_value: Optional[float] = None,
        monotonicity_pairs: Optional[list[tuple[str, str]]] = None,
        lattice_size: int = 2,
    ) -> None:
        """Initializes a `CategoricalFeatureConfig` instance.

        Args:
            feature_name: The name of the feature. This should match the header for the
                column in the dataset representing this feature.
            categories: The categories that should be used for this feature. Any
                categories not contained will be considered missing or unknown. If you
                expect to have such missing categories, make sure to
            missing_input_value: If provided, this feature's calibrator will learn to
                map all instances of this missing input value to a learned output value.
            monotonicity_pairs: List of pairs of categories `(category_a, category_b)`
                indicating that the calibrator output for `category_b` should be greater
                than or equal to that of `category_a`.
            lattice_size: The default number of keypoints outputted by the calibrator.
                Only used within `Lattice` models.
        """
        self.feature_name = feature_name
        self.categories = categories
        self.missing_input_value = missing_input_value
        self.monotonicity_pairs = monotonicity_pairs
        self.lattice_size = lattice_size

        self.category_indices = {category: i for i, category in enumerate(categories)}
        self.monotonicity_index_pairs = [
            (self.category_indices[a], self.category_indices[b])
            for a, b in monotonicity_pairs or []
        ]

__init__(feature_name, categories, missing_input_value=None, monotonicity_pairs=None, lattice_size=2)

Initializes a CategoricalFeatureConfig instance.

Parameters:

Name Type Description Default
feature_name str

The name of the feature. This should match the header for the column in the dataset representing this feature.

required
categories Union[list[int], list[str]]

The categories that should be used for this feature. Any categories not contained will be considered missing or unknown. If you expect to have such missing categories, make sure to

required
missing_input_value Optional[float]

If provided, this feature's calibrator will learn to map all instances of this missing input value to a learned output value.

None
monotonicity_pairs Optional[list[tuple[str, str]]]

List of pairs of categories (category_a, category_b) indicating that the calibrator output for category_b should be greater than or equal to that of category_a.

None
lattice_size int

The default number of keypoints outputted by the calibrator. Only used within Lattice models.

2
Source code in pytorch_lattice/models/features.py
def __init__(
    self,
    feature_name: str,
    categories: Union[list[int], list[str]],
    missing_input_value: Optional[float] = None,
    monotonicity_pairs: Optional[list[tuple[str, str]]] = None,
    lattice_size: int = 2,
) -> None:
    """Initializes a `CategoricalFeatureConfig` instance.

    Args:
        feature_name: The name of the feature. This should match the header for the
            column in the dataset representing this feature.
        categories: The categories that should be used for this feature. Any
            categories not contained will be considered missing or unknown. If you
            expect to have such missing categories, make sure to
        missing_input_value: If provided, this feature's calibrator will learn to
            map all instances of this missing input value to a learned output value.
        monotonicity_pairs: List of pairs of categories `(category_a, category_b)`
            indicating that the calibrator output for `category_b` should be greater
            than or equal to that of `category_a`.
        lattice_size: The default number of keypoints outputted by the calibrator.
            Only used within `Lattice` models.
    """
    self.feature_name = feature_name
    self.categories = categories
    self.missing_input_value = missing_input_value
    self.monotonicity_pairs = monotonicity_pairs
    self.lattice_size = lattice_size

    self.category_indices = {category: i for i, category in enumerate(categories)}
    self.monotonicity_index_pairs = [
        (self.category_indices[a], self.category_indices[b])
        for a, b in monotonicity_pairs or []
    ]

pytorch_lattice.models.features.NumericalFeature

Feature configuration for numerical features.

Attributes:

Name Type Description
All

__init__ arguments.

input_keypoints

The input keypoints used for this feature's calibrator. These keypoints will be initialized using the given data under the desired input_keypoints_init scheme.

Source code in pytorch_lattice/models/features.py
class NumericalFeature:
    """Feature configuration for numerical features.

    Attributes:
        All: `__init__` arguments.
        input_keypoints: The input keypoints used for this feature's calibrator. These
            keypoints will be initialized using the given `data` under the desired
            `input_keypoints_init` scheme.
    """

    def __init__(
        self,
        feature_name: str,
        data: np.ndarray,
        num_keypoints: int = 5,
        input_keypoints_init: InputKeypointsInit = InputKeypointsInit.QUANTILES,
        missing_input_value: Optional[float] = None,
        monotonicity: Optional[Monotonicity] = None,
        projection_iterations: int = 8,
        lattice_size: int = 2,
    ) -> None:
        """Initializes a `NumericalFeatureConfig` instance.

        Args:
            feature_name: The name of the feature. This should match the header for the
                column in the dataset representing this feature.
            data: Numpy array of float-valued data used for calculating keypoint inputs
                and initializing keypoint outputs.
            num_keypoints: The number of keypoints used by the underlying piece-wise
                linear function of a NumericalCalibrator. There will be
                `num_keypoints - 1` total segments.
            input_keypoints_init: The scheme to use for initializing the input
                keypoints. See `InputKeypointsInit` for more details.
            missing_input_value: If provided, this feature's calibrator will learn to
                map all instances of this missing input value to a learned output value.
            monotonicity: Monotonicity constraint for this feature, if any.
            projection_iterations: Number of times to run Dykstra's projection
                algorithm when applying constraints.
            lattice_size: The default number of keypoints outputted by the
                calibrator. Only used within `Lattice` models.

        Raises:
            ValueError: If `data` contains NaN values.
            ValueError: If `input_keypoints_init` is invalid.
        """
        self.feature_name = feature_name

        if np.isnan(data).any():
            raise ValueError("Data contains NaN values.")

        self.data = data
        self.num_keypoints = num_keypoints
        self.input_keypoints_init = input_keypoints_init
        self.missing_input_value = missing_input_value
        self.monotonicity = monotonicity
        self.projection_iterations = projection_iterations
        self.lattice_size = lattice_size

        sorted_unique_values = np.unique(data)

        if input_keypoints_init == InputKeypointsInit.QUANTILES:
            if sorted_unique_values.size < num_keypoints:
                logging.info(
                    "Observed fewer unique values for feature %s than %d desired "
                    "keypoints. Using the observed %d unique values as keypoints.",
                    feature_name,
                    num_keypoints,
                    sorted_unique_values.size,
                )
                self.input_keypoints = sorted_unique_values
            else:
                quantiles = np.linspace(0.0, 1.0, num=num_keypoints)
                self.input_keypoints = np.quantile(
                    sorted_unique_values, quantiles, method="nearest"
                )
        elif input_keypoints_init == InputKeypointsInit.UNIFORM:
            self.input_keypoints = np.linspace(
                sorted_unique_values[0], sorted_unique_values[-1], num=num_keypoints
            )
        else:
            raise ValueError(f"Unknown input keypoints init: {input_keypoints_init}")

__init__(feature_name, data, num_keypoints=5, input_keypoints_init=InputKeypointsInit.QUANTILES, missing_input_value=None, monotonicity=None, projection_iterations=8, lattice_size=2)

Initializes a NumericalFeatureConfig instance.

Parameters:

Name Type Description Default
feature_name str

The name of the feature. This should match the header for the column in the dataset representing this feature.

required
data ndarray

Numpy array of float-valued data used for calculating keypoint inputs and initializing keypoint outputs.

required
num_keypoints int

The number of keypoints used by the underlying piece-wise linear function of a NumericalCalibrator. There will be num_keypoints - 1 total segments.

5
input_keypoints_init InputKeypointsInit

The scheme to use for initializing the input keypoints. See InputKeypointsInit for more details.

QUANTILES
missing_input_value Optional[float]

If provided, this feature's calibrator will learn to map all instances of this missing input value to a learned output value.

None
monotonicity Optional[Monotonicity]

Monotonicity constraint for this feature, if any.

None
projection_iterations int

Number of times to run Dykstra's projection algorithm when applying constraints.

8
lattice_size int

The default number of keypoints outputted by the calibrator. Only used within Lattice models.

2

Raises:

Type Description
ValueError

If data contains NaN values.

ValueError

If input_keypoints_init is invalid.

Source code in pytorch_lattice/models/features.py
def __init__(
    self,
    feature_name: str,
    data: np.ndarray,
    num_keypoints: int = 5,
    input_keypoints_init: InputKeypointsInit = InputKeypointsInit.QUANTILES,
    missing_input_value: Optional[float] = None,
    monotonicity: Optional[Monotonicity] = None,
    projection_iterations: int = 8,
    lattice_size: int = 2,
) -> None:
    """Initializes a `NumericalFeatureConfig` instance.

    Args:
        feature_name: The name of the feature. This should match the header for the
            column in the dataset representing this feature.
        data: Numpy array of float-valued data used for calculating keypoint inputs
            and initializing keypoint outputs.
        num_keypoints: The number of keypoints used by the underlying piece-wise
            linear function of a NumericalCalibrator. There will be
            `num_keypoints - 1` total segments.
        input_keypoints_init: The scheme to use for initializing the input
            keypoints. See `InputKeypointsInit` for more details.
        missing_input_value: If provided, this feature's calibrator will learn to
            map all instances of this missing input value to a learned output value.
        monotonicity: Monotonicity constraint for this feature, if any.
        projection_iterations: Number of times to run Dykstra's projection
            algorithm when applying constraints.
        lattice_size: The default number of keypoints outputted by the
            calibrator. Only used within `Lattice` models.

    Raises:
        ValueError: If `data` contains NaN values.
        ValueError: If `input_keypoints_init` is invalid.
    """
    self.feature_name = feature_name

    if np.isnan(data).any():
        raise ValueError("Data contains NaN values.")

    self.data = data
    self.num_keypoints = num_keypoints
    self.input_keypoints_init = input_keypoints_init
    self.missing_input_value = missing_input_value
    self.monotonicity = monotonicity
    self.projection_iterations = projection_iterations
    self.lattice_size = lattice_size

    sorted_unique_values = np.unique(data)

    if input_keypoints_init == InputKeypointsInit.QUANTILES:
        if sorted_unique_values.size < num_keypoints:
            logging.info(
                "Observed fewer unique values for feature %s than %d desired "
                "keypoints. Using the observed %d unique values as keypoints.",
                feature_name,
                num_keypoints,
                sorted_unique_values.size,
            )
            self.input_keypoints = sorted_unique_values
        else:
            quantiles = np.linspace(0.0, 1.0, num=num_keypoints)
            self.input_keypoints = np.quantile(
                sorted_unique_values, quantiles, method="nearest"
            )
    elif input_keypoints_init == InputKeypointsInit.UNIFORM:
        self.input_keypoints = np.linspace(
            sorted_unique_values[0], sorted_unique_values[-1], num=num_keypoints
        )
    else:
        raise ValueError(f"Unknown input keypoints init: {input_keypoints_init}")