Skip to content

constrained_module

pytorch_lattice.constrained_module.ConstrainedModule

Bases: Module

A base class for constrained implementations of a torch.nn.Module.

Source code in pytorch_lattice/constrained_module.py
class ConstrainedModule(torch.nn.Module):
    """A base class for constrained implementations of a `torch.nn.Module`."""

    @torch.no_grad()
    @abstractmethod
    def apply_constraints(self) -> None:
        """Applies defined constraints to the module."""
        raise NotImplementedError()

    @torch.no_grad()
    @abstractmethod
    def assert_constraints(
        self, eps: float = 1e-6
    ) -> Union[list[str], dict[str, list[str]]]:
        """Asserts that the module satisfied specified constraints."""
        raise NotImplementedError()

apply_constraints() abstractmethod

Applies defined constraints to the module.

Source code in pytorch_lattice/constrained_module.py
@torch.no_grad()
@abstractmethod
def apply_constraints(self) -> None:
    """Applies defined constraints to the module."""
    raise NotImplementedError()

assert_constraints(eps=1e-06) abstractmethod

Asserts that the module satisfied specified constraints.

Source code in pytorch_lattice/constrained_module.py
@torch.no_grad()
@abstractmethod
def assert_constraints(
    self, eps: float = 1e-6
) -> Union[list[str], dict[str, list[str]]]:
    """Asserts that the module satisfied specified constraints."""
    raise NotImplementedError()