Skip to content

layers

pytorch_lattice.layers.CategoricalCalibrator

Bases: ConstrainedModule

A categorical calibrator.

This module takes an input of shape (batch_size, 1) and calibrates it by mapping a given category to its learned output value. The output will have the same shape as the input.

Attributes:

Name Type Description
All

__init__ arguments.

kernel

torch.nn.Parameter that stores the categorical mapping weights.

Example:

inputs = torch.tensor(...)  # shape: (batch_size, 1)
calibrator = CategoricalCalibrator(
    num_categories=5,
    missing_input_value=-1,
    output_min=0.0
    output_max=1.0,
    monotonicity_pairs=[(0, 1), (1, 2)],
    kernel_init=CateegoricalCalibratorInit.UNIFORM,
)
outputs = calibrator(inputs)

Source code in pytorch_lattice/layers/categorical_calibrator.py
class CategoricalCalibrator(ConstrainedModule):
    """A categorical calibrator.

    This module takes an input of shape `(batch_size, 1)` and calibrates it by mapping a
    given category to its learned output value. The output will have the same shape as
    the input.

    Attributes:
        All: `__init__` arguments.
        kernel: `torch.nn.Parameter` that stores the categorical mapping weights.

    Example:
    ```python
    inputs = torch.tensor(...)  # shape: (batch_size, 1)
    calibrator = CategoricalCalibrator(
        num_categories=5,
        missing_input_value=-1,
        output_min=0.0
        output_max=1.0,
        monotonicity_pairs=[(0, 1), (1, 2)],
        kernel_init=CateegoricalCalibratorInit.UNIFORM,
    )
    outputs = calibrator(inputs)
    ```
    """

    def __init__(
        self,
        num_categories: int,
        missing_input_value: Optional[float] = None,
        output_min: Optional[float] = None,
        output_max: Optional[float] = None,
        monotonicity_pairs: Optional[list[tuple[int, int]]] = None,
        kernel_init: CategoricalCalibratorInit = CategoricalCalibratorInit.UNIFORM,
    ) -> None:
        """Initializes an instance of `CategoricalCalibrator`.

        Args:
            num_categories: The number of known categories.
            missing_input_value: If provided, the calibrator will learn to map all
                instances of this missing input value to a learned output value just
                the same as it does for known categories. Note that `num_categories`
                will be one greater to include this missing category.
            output_min: Minimum output value. If `None`, the minimum output value will
                be unbounded.
            output_max: Maximum output value. If `None`, the maximum output value will
                be unbounded.
            monotonicity_pairs: List of pairs of indices `(i,j)` indicating that the
                calibrator output for index `j` should be greater than or equal to that
                of index `i`.
            kernel_init: Initialization scheme to use for the kernel.

        Raises:
            ValueError: If `monotonicity_pairs` is cyclic.
            ValueError: If `kernel_init` is invalid.
        """
        super().__init__()

        self.num_categories = (
            num_categories + 1 if missing_input_value is not None else num_categories
        )
        self.missing_input_value = missing_input_value
        self.output_min = output_min
        self.output_max = output_max
        self.monotonicity_pairs = monotonicity_pairs
        if monotonicity_pairs:
            self._monotonicity_graph = defaultdict(list)
            self._reverse_monotonicity_graph = defaultdict(list)
            for i, j in monotonicity_pairs:
                self._monotonicity_graph[i].append(j)
                self._reverse_monotonicity_graph[j].append(i)
            try:
                self._monotonically_sorted_indices = [
                    *TopologicalSorter(self._reverse_monotonicity_graph).static_order()
                ]
            except CycleError as exc:
                raise ValueError("monotonicity_pairs is cyclic") from exc
        self.kernel_init = kernel_init

        self.kernel = torch.nn.Parameter(torch.Tensor(self.num_categories, 1).double())
        if kernel_init == CategoricalCalibratorInit.CONSTANT:
            if output_min is not None and output_max is not None:
                init_value = (output_min + output_max) / 2
            elif output_min is not None:
                init_value = output_min
            elif output_max is not None:
                init_value = output_max
            else:
                init_value = 0.0
            torch.nn.init.constant_(self.kernel, init_value)
        elif kernel_init == CategoricalCalibratorInit.UNIFORM:
            if output_min is not None and output_max is not None:
                low, high = output_min, output_max
            elif output_min is None and output_max is not None:
                low, high = output_max - 0.05, output_max
            elif output_min is not None and output_max is None:
                low, high = output_min, output_min + 0.05
            else:
                low, high = -0.05, 0.05
            torch.nn.init.uniform_(self.kernel, low, high)
        else:
            raise ValueError(f"Unknown kernel init: {kernel_init}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Calibrates categorical inputs through a learned mapping.

        Args:
            x: The input tensor of category indices of shape `(batch_size, 1)`.

        Returns:
            torch.Tensor of shape `(batch_size, 1)` containing calibrated input values.
        """
        if self.missing_input_value is not None:
            missing_category_tensor = torch.zeros_like(x) + (self.num_categories - 1)
            x = torch.where(x == self.missing_input_value, missing_category_tensor, x)
        # TODO: test if using torch.gather is faster than one-hot matmul.
        one_hot = torch.nn.functional.one_hot(
            torch.squeeze(x, -1).long(), num_classes=self.num_categories
        ).double()
        return torch.mm(one_hot, self.kernel)

    @torch.no_grad()
    def apply_constraints(self) -> None:
        """Projects kernel into desired constraints."""
        projected_kernel_data = self.kernel.data
        if self.monotonicity_pairs:
            projected_kernel_data = self._approximately_project_monotonicity_pairs(
                projected_kernel_data
            )
        if self.output_min is not None:
            projected_kernel_data = torch.maximum(
                projected_kernel_data, torch.tensor(self.output_min)
            )
        if self.output_max is not None:
            projected_kernel_data = torch.minimum(
                projected_kernel_data, torch.tensor(self.output_max)
            )
        self.kernel.data = projected_kernel_data

    @torch.no_grad()
    def assert_constraints(self, eps: float = 1e-6) -> list[str]:
        """Asserts that layer satisfies specified constraints.

        This checks that weights at the indexes of monotonicity pairs are in the correct
        order and that the output is within bounds.

        Args:
            eps: the margin of error allowed

        Returns:
            A list of messages describing violated constraints including violated
            monotonicity pairs. If no constraints  violated, the list will be empty.
        """
        weights = torch.squeeze(self.kernel.data)
        messages = []

        if self.output_max is not None and torch.max(weights) > self.output_max + eps:
            messages.append("Max weight greater than output_max.")
        if self.output_min is not None and torch.min(weights) < self.output_min - eps:
            messages.append("Min weight less than output_min.")

        if self.monotonicity_pairs:
            violation_indices = [
                (i, j)
                for (i, j) in self.monotonicity_pairs
                if weights[i] - weights[j] > eps
            ]
            if violation_indices:
                messages.append(f"Monotonicity violated at: {str(violation_indices)}.")

        return messages

    @torch.no_grad()
    def keypoints_inputs(self) -> torch.Tensor:
        """Returns a tensor of keypoint inputs (category indices)."""
        if self.missing_input_value is not None:
            return torch.cat(
                (
                    torch.arange(self.num_categories - 1),
                    torch.tensor([self.missing_input_value]),
                ),
                0,
            )
        return torch.arange(self.num_categories)

    @torch.no_grad()
    def keypoints_outputs(self) -> torch.Tensor:
        """Returns a tensor of keypoint outputs."""
        return torch.squeeze(self.kernel.data, -1)

    ################################################################################
    ############################## PRIVATE METHODS #################################
    ################################################################################

    def _approximately_project_monotonicity_pairs(self, kernel_data) -> torch.Tensor:
        """Projects kernel such that the monotonicity pairs are satisfied.

        The kernel will be projected such that `kernel_data[i] <= kernel_data[j]`. This
        results in calibrated outputs that adhere to the desired constraints.

        Args:
            kernel_data: The tensor of shape `(self.num_categories, 1)` to be projected
                into the constraints specified by `self.monotonicity pairs`.

        Returns:
            Projected kernel data. To prevent the kernel from drifting in one direction,
            the data returned is the average of the min/max and max/min projections.
        """
        projected_kernel_data = torch.unbind(kernel_data, 0)

        def project(data, monotonicity_graph, step, minimum):
            projected_data = list(data)
            sorted_indices = self._monotonically_sorted_indices
            if minimum:
                sorted_indices = sorted_indices[::-1]
            for i in sorted_indices:
                if i in monotonicity_graph:
                    projection = projected_data[i]
                    for j in monotonicity_graph[i]:
                        if minimum:
                            projection = torch.minimum(projection, projected_data[j])
                        else:
                            projection = torch.maximum(projection, projected_data[j])
                        if step == 1.0:
                            projected_data[i] = projection
                        else:
                            projected_data[i] = (
                                step * projection + (1 - step) * projected_data[i]
                            )
            return projected_data

        projected_kernel_min_max = project(
            projected_kernel_data, self._monotonicity_graph, 0.5, minimum=True
        )
        projected_kernel_min_max = project(
            projected_kernel_min_max,
            self._reverse_monotonicity_graph,
            1.0,
            minimum=False,
        )
        projected_kernel_min_max = torch.stack(projected_kernel_min_max)

        projected_kernel_max_min = project(
            projected_kernel_data, self._reverse_monotonicity_graph, 0.5, minimum=False
        )
        projected_kernel_max_min = project(
            projected_kernel_max_min, self._monotonicity_graph, 1.0, minimum=True
        )
        projected_kernel_max_min = torch.stack(projected_kernel_max_min)

        return (projected_kernel_min_max + projected_kernel_max_min) / 2

__init__(num_categories, missing_input_value=None, output_min=None, output_max=None, monotonicity_pairs=None, kernel_init=CategoricalCalibratorInit.UNIFORM)

Initializes an instance of CategoricalCalibrator.

Parameters:

Name Type Description Default
num_categories int

The number of known categories.

required
missing_input_value Optional[float]

If provided, the calibrator will learn to map all instances of this missing input value to a learned output value just the same as it does for known categories. Note that num_categories will be one greater to include this missing category.

None
output_min Optional[float]

Minimum output value. If None, the minimum output value will be unbounded.

None
output_max Optional[float]

Maximum output value. If None, the maximum output value will be unbounded.

None
monotonicity_pairs Optional[list[tuple[int, int]]]

List of pairs of indices (i,j) indicating that the calibrator output for index j should be greater than or equal to that of index i.

None
kernel_init CategoricalCalibratorInit

Initialization scheme to use for the kernel.

UNIFORM

Raises:

Type Description
ValueError

If monotonicity_pairs is cyclic.

ValueError

If kernel_init is invalid.

Source code in pytorch_lattice/layers/categorical_calibrator.py
def __init__(
    self,
    num_categories: int,
    missing_input_value: Optional[float] = None,
    output_min: Optional[float] = None,
    output_max: Optional[float] = None,
    monotonicity_pairs: Optional[list[tuple[int, int]]] = None,
    kernel_init: CategoricalCalibratorInit = CategoricalCalibratorInit.UNIFORM,
) -> None:
    """Initializes an instance of `CategoricalCalibrator`.

    Args:
        num_categories: The number of known categories.
        missing_input_value: If provided, the calibrator will learn to map all
            instances of this missing input value to a learned output value just
            the same as it does for known categories. Note that `num_categories`
            will be one greater to include this missing category.
        output_min: Minimum output value. If `None`, the minimum output value will
            be unbounded.
        output_max: Maximum output value. If `None`, the maximum output value will
            be unbounded.
        monotonicity_pairs: List of pairs of indices `(i,j)` indicating that the
            calibrator output for index `j` should be greater than or equal to that
            of index `i`.
        kernel_init: Initialization scheme to use for the kernel.

    Raises:
        ValueError: If `monotonicity_pairs` is cyclic.
        ValueError: If `kernel_init` is invalid.
    """
    super().__init__()

    self.num_categories = (
        num_categories + 1 if missing_input_value is not None else num_categories
    )
    self.missing_input_value = missing_input_value
    self.output_min = output_min
    self.output_max = output_max
    self.monotonicity_pairs = monotonicity_pairs
    if monotonicity_pairs:
        self._monotonicity_graph = defaultdict(list)
        self._reverse_monotonicity_graph = defaultdict(list)
        for i, j in monotonicity_pairs:
            self._monotonicity_graph[i].append(j)
            self._reverse_monotonicity_graph[j].append(i)
        try:
            self._monotonically_sorted_indices = [
                *TopologicalSorter(self._reverse_monotonicity_graph).static_order()
            ]
        except CycleError as exc:
            raise ValueError("monotonicity_pairs is cyclic") from exc
    self.kernel_init = kernel_init

    self.kernel = torch.nn.Parameter(torch.Tensor(self.num_categories, 1).double())
    if kernel_init == CategoricalCalibratorInit.CONSTANT:
        if output_min is not None and output_max is not None:
            init_value = (output_min + output_max) / 2
        elif output_min is not None:
            init_value = output_min
        elif output_max is not None:
            init_value = output_max
        else:
            init_value = 0.0
        torch.nn.init.constant_(self.kernel, init_value)
    elif kernel_init == CategoricalCalibratorInit.UNIFORM:
        if output_min is not None and output_max is not None:
            low, high = output_min, output_max
        elif output_min is None and output_max is not None:
            low, high = output_max - 0.05, output_max
        elif output_min is not None and output_max is None:
            low, high = output_min, output_min + 0.05
        else:
            low, high = -0.05, 0.05
        torch.nn.init.uniform_(self.kernel, low, high)
    else:
        raise ValueError(f"Unknown kernel init: {kernel_init}")

apply_constraints()

Projects kernel into desired constraints.

Source code in pytorch_lattice/layers/categorical_calibrator.py
@torch.no_grad()
def apply_constraints(self) -> None:
    """Projects kernel into desired constraints."""
    projected_kernel_data = self.kernel.data
    if self.monotonicity_pairs:
        projected_kernel_data = self._approximately_project_monotonicity_pairs(
            projected_kernel_data
        )
    if self.output_min is not None:
        projected_kernel_data = torch.maximum(
            projected_kernel_data, torch.tensor(self.output_min)
        )
    if self.output_max is not None:
        projected_kernel_data = torch.minimum(
            projected_kernel_data, torch.tensor(self.output_max)
        )
    self.kernel.data = projected_kernel_data

assert_constraints(eps=1e-06)

Asserts that layer satisfies specified constraints.

This checks that weights at the indexes of monotonicity pairs are in the correct order and that the output is within bounds.

Parameters:

Name Type Description Default
eps float

the margin of error allowed

1e-06

Returns:

Type Description
list[str]

A list of messages describing violated constraints including violated

list[str]

monotonicity pairs. If no constraints violated, the list will be empty.

Source code in pytorch_lattice/layers/categorical_calibrator.py
@torch.no_grad()
def assert_constraints(self, eps: float = 1e-6) -> list[str]:
    """Asserts that layer satisfies specified constraints.

    This checks that weights at the indexes of monotonicity pairs are in the correct
    order and that the output is within bounds.

    Args:
        eps: the margin of error allowed

    Returns:
        A list of messages describing violated constraints including violated
        monotonicity pairs. If no constraints  violated, the list will be empty.
    """
    weights = torch.squeeze(self.kernel.data)
    messages = []

    if self.output_max is not None and torch.max(weights) > self.output_max + eps:
        messages.append("Max weight greater than output_max.")
    if self.output_min is not None and torch.min(weights) < self.output_min - eps:
        messages.append("Min weight less than output_min.")

    if self.monotonicity_pairs:
        violation_indices = [
            (i, j)
            for (i, j) in self.monotonicity_pairs
            if weights[i] - weights[j] > eps
        ]
        if violation_indices:
            messages.append(f"Monotonicity violated at: {str(violation_indices)}.")

    return messages

forward(x)

Calibrates categorical inputs through a learned mapping.

Parameters:

Name Type Description Default
x Tensor

The input tensor of category indices of shape (batch_size, 1).

required

Returns:

Type Description
Tensor

torch.Tensor of shape (batch_size, 1) containing calibrated input values.

Source code in pytorch_lattice/layers/categorical_calibrator.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Calibrates categorical inputs through a learned mapping.

    Args:
        x: The input tensor of category indices of shape `(batch_size, 1)`.

    Returns:
        torch.Tensor of shape `(batch_size, 1)` containing calibrated input values.
    """
    if self.missing_input_value is not None:
        missing_category_tensor = torch.zeros_like(x) + (self.num_categories - 1)
        x = torch.where(x == self.missing_input_value, missing_category_tensor, x)
    # TODO: test if using torch.gather is faster than one-hot matmul.
    one_hot = torch.nn.functional.one_hot(
        torch.squeeze(x, -1).long(), num_classes=self.num_categories
    ).double()
    return torch.mm(one_hot, self.kernel)

keypoints_inputs()

Returns a tensor of keypoint inputs (category indices).

Source code in pytorch_lattice/layers/categorical_calibrator.py
@torch.no_grad()
def keypoints_inputs(self) -> torch.Tensor:
    """Returns a tensor of keypoint inputs (category indices)."""
    if self.missing_input_value is not None:
        return torch.cat(
            (
                torch.arange(self.num_categories - 1),
                torch.tensor([self.missing_input_value]),
            ),
            0,
        )
    return torch.arange(self.num_categories)

keypoints_outputs()

Returns a tensor of keypoint outputs.

Source code in pytorch_lattice/layers/categorical_calibrator.py
@torch.no_grad()
def keypoints_outputs(self) -> torch.Tensor:
    """Returns a tensor of keypoint outputs."""
    return torch.squeeze(self.kernel.data, -1)

pytorch_lattice.layers.Lattice

Bases: ConstrainedModule

A Lattice Module.

Layer performs interpolation using one of 'units' d-dimensional lattices with arbitrary number of keypoints per dimension. Each lattice vertex has a trainable weight, and input is considered to be a d-dimensional point within the lattice.

Attributes:

Name Type Description
All

__init__ arguments.

kernel

torch.nn.Parameter of shape (prod(lattice_sizes), units) which stores weights at each vertex of lattice.

Example:

lattice_sizes = [2, 2, 4, 3]
inputs=torch.tensor(...) # shape: (batch_size, len(lattice_sizes))
lattice=Lattice(
    lattice_sizes,
    clip_inputs=True,
    interpolation=Interpolation.HYPERCUBE,
    units=1,
)
outputs = Lattice(inputs)

Source code in pytorch_lattice/layers/lattice.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
class Lattice(ConstrainedModule):
    """A Lattice Module.

    Layer performs interpolation using one of 'units' d-dimensional lattices with
    arbitrary number of keypoints per dimension. Each lattice vertex has a trainable
    weight, and input is considered to be a d-dimensional point within the lattice.

    Attributes:
        All: `__init__` arguments.
        kernel: `torch.nn.Parameter` of shape `(prod(lattice_sizes), units)` which
            stores
            weights at each vertex of lattice.

    Example:
    ```python
    lattice_sizes = [2, 2, 4, 3]
    inputs=torch.tensor(...) # shape: (batch_size, len(lattice_sizes))
    lattice=Lattice(
        lattice_sizes,
        clip_inputs=True,
        interpolation=Interpolation.HYPERCUBE,
        units=1,
    )
    outputs = Lattice(inputs)
    ```
    """

    def __init__(
        self,
        lattice_sizes: Union[list[int], tuple[int]],
        output_min: Optional[float] = None,
        output_max: Optional[float] = None,
        kernel_init: LatticeInit = LatticeInit.LINEAR,
        monotonicities: Optional[list[Optional[Monotonicity]]] = None,
        clip_inputs: bool = True,
        interpolation: Interpolation = Interpolation.HYPERCUBE,
        units: int = 1,
    ) -> None:
        """Initializes an instance of 'Lattice'.

        Args:
            lattice_sizes: List or tuple of size of lattice along each dimension.
            output_min: Minimum output value for weights at vertices of lattice.
            output_max: Maximum output value for weights at vertices of lattice.
            kernel_init: Initialization scheme to use for the kernel.
            monotonicities: `None` or list of `NONE` or
                `Monotonicity.INCREASING` of length `len(lattice_sizes)` specifying
                monotonicity of each feature of lattice. A monotonically decreasing
                 feature should use `Monotonicity.INCREASING` in the lattice layer but
                `Monotonicity.DECREASING` in the calibrator.
            clip_inputs: Whether input points should be clipped to the range of lattice.
            interpolation: Interpolation scheme for a given input.
            units: Dimensionality of weights stored at each vertex of lattice.

        Raises:
            ValueError: if `kernel_init` is invalid.
            NotImplementedError: Random monotonic initialization not yet implemented.
        """
        super().__init__()

        self.lattice_sizes = list(lattice_sizes)
        self.output_min = output_min
        self.output_max = output_max
        self.kernel_init = kernel_init
        self.clip_inputs = clip_inputs
        self.interpolation = interpolation
        self.units = units

        if monotonicities is not None:
            self.monotonicities = monotonicities
        else:
            self.monotonicities = [None] * len(lattice_sizes)

        if output_min is not None and output_max is not None:
            output_init_min, output_init_max = output_min, output_max
        elif output_min is not None:
            output_init_min, output_init_max = output_min, output_min + 4.0
        elif output_max is not None:
            output_init_min, output_init_max = output_max - 4.0, output_max
        else:
            output_init_min, output_init_max = -2.0, 2.0
        self._output_init_min, self._output_init_max = output_init_min, output_init_max

        @torch.no_grad()
        def initialize_kernel() -> torch.Tensor:
            if self.kernel_init == LatticeInit.LINEAR:
                return self._linear_initializer()
            if self.kernel_init == LatticeInit.RANDOM_MONOTONIC:
                raise NotImplementedError(
                    "Random monotonic initialization not yet implemented."
                )
            raise ValueError(f"Unknown kernel init: {self.kernel_init}")

        self.kernel = torch.nn.Parameter(initialize_kernel())

    def forward(self, x: Union[torch.Tensor, list[torch.Tensor]]) -> torch.Tensor:
        """Calculates interpolation from input, using method of self.interpolation.

        Args:
            x: input tensor. If `units == 1`, tensor of shape:
                `(batch_size, ..., len(lattice_size))` or list of `len(lattice_sizes)`
                tensors of same shape: `(batch_size, ..., 1)`. If `units > 1`, tensor of
                shape `(batch_size, ..., units, len(lattice_sizes))` or list of
                `len(lattice_sizes)` tensors OF same shape `(batch_size, ..., units, 1)`

        Returns:
            torch.Tensor of shape `(batch_size, ..., units)` containing interpolated
            values.

        Raises:
            ValueError: If the type of interpolation is unknown.
        """
        x = [xi.double() for xi in x] if isinstance(x, list) else x.double()
        if self.interpolation == Interpolation.HYPERCUBE:
            return self._compute_hypercube_interpolation(x)
        if self.interpolation == Interpolation.SIMPLEX:
            return self._compute_simplex_interpolation(x)
        raise ValueError(f"Unknown interpolation type: {self.interpolation}")

    @torch.no_grad()
    def apply_constraints(self) -> None:
        """Aggregate function for enforcing constraints of lattice."""
        weights = self.kernel.clone()

        if self._count_non_zeros(self.monotonicities):
            lattice_sizes = self.lattice_sizes
            monotonicities = self.monotonicities
            if self.units > 1:
                lattice_sizes = lattice_sizes + [int(self.units)]
                if self.monotonicities:
                    monotonicities = monotonicities + [None]

            weights = weights.reshape(*lattice_sizes)
            weights = self._approximately_project_monotonicity(
                weights, lattice_sizes, monotonicities
            )

        if self.output_min is not None:
            weights = torch.clamp_min(weights, self.output_min)
        if self.output_max is not None:
            weights = torch.clamp_max(weights, self.output_max)

        self.kernel.data = weights.view(-1, self.units)

    @torch.no_grad()
    def assert_constraints(self, eps: float = 1e-6) -> list[str]:
        """Asserts that layer satisfies specified constraints.

        This checks that weights follow monotonicity and bounds constraints.

        Args:
            eps: the margin of error allowed

        Returns:
            A list of dicts describing violated constraints including indices of
            monotonicity violations. If no constraints violated, the list will be empty.
        """
        messages = []
        lattice_sizes = self.lattice_sizes
        monotonicities = self.monotonicities
        weights = self.kernel.data.clone()

        if weights.shape[1] > 1:
            lattice_sizes = lattice_sizes + [int(weights.shape[1])]
            if monotonicities:
                monotonicities = monotonicities + [None]

        # Reshape weights to match lattice sizes
        weights = weights.reshape(*lattice_sizes)

        for i in range(len(monotonicities or [])):
            if monotonicities[i] != Monotonicity.INCREASING:
                continue
            weights_layers = torch.unbind(weights, dim=i)

            for j in range(1, len(weights_layers)):
                diff = torch.min(weights_layers[j] - weights_layers[j - 1])
                if diff.item() < -eps:
                    messages.append(f"Monotonicity violated at feature index {i}.")

        if self.output_max is not None and torch.max(weights) > self.output_max + eps:
            messages.append("Max weight greater than output_max.")
        if self.output_min is not None and torch.min(weights) < self.output_min - eps:
            messages.append("Min weight less than output_min.")

        return messages

    ################################################################################
    ############################## PRIVATE METHODS #################################
    ################################################################################

    def _linear_initializer(self) -> torch.Tensor:
        """Creates initial weights tensor for linear initialization.

        Args:
            monotonicities: monotonicity constraints of lattice, enforced in
                initialization.

        Returns:
            `torch.Tensor` of shape `(prod(lattice_sizes), units)`
        """
        monotonicities = self.monotonicities[:]

        if monotonicities is None:
            monotonicities = [None] * len(self.lattice_sizes)

        num_constraint_dims = self._count_non_zeros(monotonicities)
        if num_constraint_dims == 0:
            monotonicities = [Monotonicity.INCREASING] * len(self.lattice_sizes)
            num_constraint_dims = len(self.lattice_sizes)

        dim_range = (
            float(self._output_init_max - self._output_init_min) / num_constraint_dims
        )
        one_d_weights = []

        for monotonicity, dim_size in zip(monotonicities, self.lattice_sizes):
            if monotonicity is not None:
                one_d = np.linspace(start=0.0, stop=dim_range, num=dim_size)
            else:
                one_d = np.array([0.0] * dim_size)

            one_d_weights.append(torch.tensor(one_d, dtype=torch.double).unsqueeze(0))

        weights = self._batch_outer_operation(one_d_weights, operation=torch.add)
        weights = (weights + self._output_init_min).view(-1, 1)
        if self.units > 1:
            weights = weights.repeat(1, self.units)

        return weights

    @staticmethod
    def _count_non_zeros(*iterables) -> int:
        """Returns total number of non 0/None enum elements in given iterables.

        Args:
            *iterables: Any number of the value `None` or iterables of `None` or
                `Monotonicity` enum values.
        """
        result = 0
        for iterable in iterables:
            if iterable is not None:
                for element in iterable:
                    if element is not None:
                        result += 1
        return result

    def _compute_simplex_interpolation(
        self, inputs: Union[torch.Tensor, list[torch.Tensor]]
    ) -> torch.Tensor:
        """Evaluates a lattice using simplex interpolation.

        Each `d`-dimensional unit hypercube of the lattice can be partitioned into `d!`
        disjoint simplices with `d+1` vertices. `S` is the unique simplex which contains
        input point `P`, and `S` has vertices `ABCD...`. For any vertex such as `A`, a
        new simplex `S'` can be created using the vertices `PBCD...`. The weight of `A`
        within the interpolation is then `vol(S')/vol(S)`. This process is repeated
        for every vertex in `S`, and the resulting values are summed.

        This interpolation can be computed in `O(D log(D))` time because it is only
        necessary to compute the volume of the simplex containing input point `P`. For
        context, the unit hypercube can be partitioned into `d!` simplices by starting
        at `(0,0,...,0)` and incrementing `0` to `1` dimension-by-dimensionuntil one
        reaches `(1,1,...,1)`. There are `d!` possible paths from `(0,0,...,0)` to
        `(1,1,...,1)`, which account for the number of unique, disjoint simplices
        created by the method. There are `d` steps for each possible path where each
        step comprises the vertices of one simplex. Thus, one can find the containing
        simplex for input `P` by argsorting the coordinates of `P` in descending order
        and pathing along said order. To compute the intepolation weights simply take
        the deltas from `[1, desc_sort(P_coords), 0]`.

        Args:
            inputs: input tensor. If `units == 1`, tensor of shape:
                `(batch_size, ..., len(lattice_size))` or list of `len(lattice_sizes)`
                tensors of same shape: `(batch_size, ..., 1)`. If `units > 1`, tensor of
                shape `(batch_size, ..., units, len(lattice_sizes))` or list of
                `len(lattice_sizes)` tensors of same shape `(batch_size, ..., units, 1)`

        Returns:
            `torch.Tensor` of shape `(batch_size, ..., units)` containing interpolated
            values.
        """
        if isinstance(inputs, list):
            inputs = torch.cat(inputs, dim=-1)

        if self.clip_inputs:
            inputs = self._clip_onto_lattice_range(inputs)

        lattice_rank = len(self.lattice_sizes)
        input_dim = len(inputs.shape)
        all_size_2 = all(size == 2 for size in self.lattice_sizes)

        # Strides are the index shift (with respect to flattened kernel data) of each
        # dimension, which can be used in a dot product with multi-dimensional
        # coordinates to give an index for the flattened lattice weights.
        # Ex): for lattice_sizes = [4, 3, 2], we get strides = [6, 2, 1]: when looking
        # at lattice coords (i, j, k) and kernel data flattened into 1-D, incrementing i
        # corresponds to a shift of 6 in flattened kernel data, j corresponds to a shift
        # of 2, and k corresponds to a shift of 1. Consequently, we can do
        # (coords * strides) for any coordinates to obtain the flattened index.
        strides = torch.tensor(
            np.cumprod([1] + self.lattice_sizes[::-1][:-1])[::-1].copy()
        )
        if not all_size_2:
            lower_corner_coordinates = inputs.int()
            lower_corner_coordinates = torch.min(
                lower_corner_coordinates, torch.tensor(self.lattice_sizes) - 2
            )
            inputs = inputs - lower_corner_coordinates.float()

        sorted_indices = torch.argsort(inputs, descending=True)
        sorted_inputs = torch.sort(inputs, descending=True).values

        # Pad the 1 and 0 onto the ends of sorted coordinates and compute deltas.
        no_padding_dims = [(0, 0)] * (input_dim - 1)
        flat_no_padding = [item for sublist in no_padding_dims for item in sublist]
        sorted_inputs_padded_left = torch.nn.functional.pad(
            sorted_inputs, [1, 0] + flat_no_padding, value=1.0
        )
        sorted_inputs_padded_right = torch.nn.functional.pad(
            sorted_inputs, [0, 1] + flat_no_padding, value=0.0
        )
        weights = sorted_inputs_padded_left - sorted_inputs_padded_right

        # Use strides to find indices of simplex vertices in flattened form.
        sorted_strides = torch.gather(strides, 0, sorted_indices.view(-1)).view(
            sorted_indices.shape
        )
        if all_size_2:
            corner_offset_and_sorted_strides = torch.nn.functional.pad(
                sorted_strides, [1, 0] + flat_no_padding
            )
        else:
            lower_corner_offset = (lower_corner_coordinates * strides).sum(
                dim=-1, keepdim=True
            )
            corner_offset_and_sorted_strides = torch.cat(
                [lower_corner_offset, sorted_strides], dim=-1
            )
        indices = torch.cumsum(corner_offset_and_sorted_strides, dim=-1)

        # Get kernel data from corresponding simplex vertices.
        if self.units == 1:
            gathered_params = torch.index_select(
                self.kernel.view(-1), 0, indices.view(-1)
            ).view(indices.shape)
        else:
            unit_offset = torch.tensor(
                [[i] * (lattice_rank + 1) for i in range(self.units)]
            )
            flat_indices = indices * self.units + unit_offset
            gathered_params = torch.index_select(
                self.kernel.view(-1), 0, flat_indices.view(-1)
            ).view(indices.shape)

        return (gathered_params * weights).sum(dim=-1, keepdim=self.units == 1)

    def _compute_hypercube_interpolation(
        self,
        inputs: Union[torch.Tensor, list[torch.Tensor]],
    ) -> torch.Tensor:
        """Performs hypercube interpolation using the surrounding unit hypercube.

        Args:
            inputs: input tensor. If `units == 1`, tensor of shape:
                `(batch_size, ..., len(lattice_size))` or list of `len(lattice_sizes)`
                tensors of same shape: `(batch_size, ..., 1)`. If `units > 1`, tensor of
                shape `(batch_size, ..., units, len(lattice_sizes))` or list of
                `len(lattice_sizes)` tensors of same shape `(batch_size, ..., units, 1)`

        Returns:
            `torch.Tensor` of shape `(batch_size, ..., units)` containing interpolated
            value(s).
        """
        interpolation_weights = self._compute_hypercube_interpolation_weights(
            inputs=inputs, clip_inputs=self.clip_inputs
        )
        if self.units == 1:
            return torch.matmul(interpolation_weights, self.kernel)

        return torch.sum(interpolation_weights * self.kernel.t(), dim=-1)

    def _compute_hypercube_interpolation_weights(
        self, inputs: Union[torch.Tensor, list[torch.Tensor]], clip_inputs: bool = True
    ) -> torch.Tensor:
        """Computes weights for hypercube lattice interpolation.

        For each n-dim unit in "inputs," the weights matrix will generate the weights
        corresponding to the unit's location within its surrounding hypercube. These
        weights can then be multiplied by the lattice layer's kernel to compute the
        actual hypercube interpolation. Specifically, the outer product of the set
        `(1-x_i, x_i)` for all x_i in input unit x calculates the weights for each
        vertex in the surrounding hypercube, and every other vertex in the lattice is
        set to zero since it is not used. In addition, for consecutive dimensions of
        equal size in the lattice, broadcasting is used to speed up calculations.

        Args:
            inputs: torch.Tensor of shape `(batch_size, ..., len(lattice_sizes)` or list
                of `len(lattice_sizes)` tensors of same shape `(batch_size, ..., 1)`
            clip_inputs: Boolean to determine whether input values outside lattice
                bounds should be clipped to the min or max supported values.

        Returns:
            `torch.Tensor` of shape `(batch_size, ..., prod(lattice_sizes))` containing
            the weights which can be matrix multiplied with the kernel to perform
            hypercube interpolation.
        """
        if isinstance(inputs, list):
            input_dtype = inputs[0].dtype
        else:
            input_dtype = inputs.dtype

        # Special case: 2^d lattice with input passed in as a single tensor
        if all(size == 2 for size in self.lattice_sizes) and not isinstance(
            inputs, list
        ):
            w = torch.stack([(1.0 - inputs), inputs], dim=-1)
            if clip_inputs:
                w = torch.clamp(w, min=0, max=1)
            one_d_interpolation_weights = list(torch.unbind(w, dim=-2))
            return self._batch_outer_operation(one_d_interpolation_weights)

        if clip_inputs:
            inputs = self._clip_onto_lattice_range(inputs)

        # Set up buckets of consecutive equal dimensions for broadcasting later
        dim_keypoints = {}
        for dim_size in set(self.lattice_sizes):
            dim_keypoints[dim_size] = torch.tensor(
                list(range(dim_size)), dtype=input_dtype
            )
        bucketized_inputs = self._bucketize_consecutive_equal_dims(inputs)
        one_d_interpolation_weights = []

        for tensor, bucket_size, dim_size in bucketized_inputs:
            if bucket_size > 1:
                tensor = torch.unsqueeze(tensor, dim=-1)
            distance = torch.abs(tensor - dim_keypoints[dim_size])
            weights = 1.0 - torch.minimum(
                distance, torch.tensor(1.0, dtype=distance.dtype)
            )
            if bucket_size == 1:
                one_d_interpolation_weights.append(weights)
            else:
                one_d_interpolation_weights.extend(torch.unbind(weights, dim=-2))

        return self._batch_outer_operation(one_d_interpolation_weights)

    @staticmethod
    def _batch_outer_operation(
        list_of_tensors: list[torch.Tensor],
        operation: Optional[Callable] = None,
    ) -> torch.Tensor:
        """Computes the flattened outer product of a list of tensors.

        Args:
            list_of_tensors: List of tensors of same shape `(batch_size, ..., k[i])`
                where everything except `k_i` matches.
            operation: A torch operation which supports broadcasting to be applied. If
                `None` is provided, this will apply `torch.mul` for the first several
                tensors and `torch.matmul` for the remaining tensors.

        Returns:
            `torch.Tensor` of shape `(batch_size, ..., k_i * k_j * ...)` containing a
            flattened version of the outer product.
        """
        if len(list_of_tensors) == 1:
            return list_of_tensors[0]

        result = torch.unsqueeze(list_of_tensors[0], dim=-1)

        for i, tensor in enumerate(list_of_tensors[1:]):
            if not operation:
                op = torch.mul if i < 6 else torch.matmul
            else:
                op = operation

            result = op(result, torch.unsqueeze(tensor, dim=-2))
            shape = [-1] + [int(size) for size in result.shape[1:]]
            new_shape = shape[:-2] + [shape[-2] * shape[-1]]
            if i < len(list_of_tensors) - 2:
                new_shape.append(1)
            result = torch.reshape(result, new_shape)

        return result

    @overload
    def _clip_onto_lattice_range(self, inputs: torch.Tensor) -> torch.Tensor:
        ...

    @overload
    def _clip_onto_lattice_range(
        self, inputs: list[torch.Tensor]
    ) -> list[torch.Tensor]:
        ...

    def _clip_onto_lattice_range(
        self,
        inputs: Union[torch.Tensor, list[torch.Tensor]],
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
        """Clips inputs onto valid input range for given lattice_sizes.

        Args:
            inputs: `inputs` argument of `_compute_interpolation_weights()`.

        Returns:
            `torch.Tensor` of shape `inputs` with values within range
            `[0, dim_size - 1]`.
        """
        clipped_inputs: Union[torch.Tensor, list[torch.Tensor]]
        if not isinstance(inputs, list):
            upper_bounds = torch.tensor(
                [dim_size - 1.0 for dim_size in self.lattice_sizes]
            ).double()
            clipped_inputs = torch.clamp(
                inputs, min=torch.zeros_like(upper_bounds), max=upper_bounds
            )
        else:
            dim_upper_bounds = {}
            for dim_size in set(self.lattice_sizes):
                dim_upper_bounds[dim_size] = torch.tensor(
                    dim_size - 1.0, dtype=inputs[0].dtype
                )
            dim_lower_bound = torch.zeros(1, dtype=inputs[0].dtype)

            clipped_inputs = [
                torch.clamp(
                    one_d_input, min=dim_lower_bound, max=dim_upper_bounds[dim_size]
                )
                for one_d_input, dim_size in zip(inputs, self.lattice_sizes)
            ]

        return clipped_inputs

    def _bucketize_consecutive_equal_dims(
        self,
        inputs: Union[torch.Tensor, list[torch.Tensor]],
    ) -> Iterator[tuple[torch.Tensor, int, int]]:
        """Creates buckets of equal sized dimensions for broadcasting ops.

        Args:
            inputs: `inputs` argument of `_compute_interpolation_weights()`.

        Returns:
            An `Iterable` containing `(torch.Tensor, int, int)` where the tensor
            contains individual values from "inputs" corresponding to its bucket, the
            first `int` is bucket size, and the second `int` is size of the dimension of
            the bucket.
        """
        if not isinstance(inputs, list):
            bucket_sizes = []
            bucket_dim_sizes = []
            current_size = 1
            for i in range(1, len(self.lattice_sizes)):
                if self.lattice_sizes[i] != self.lattice_sizes[i - 1]:
                    bucket_sizes.append(current_size)
                    bucket_dim_sizes.append(self.lattice_sizes[i - 1])
                    current_size = 1
                else:
                    current_size += 1
            bucket_sizes.append(current_size)
            bucket_dim_sizes.append(self.lattice_sizes[-1])
            inputs = torch.split(inputs, split_size_or_sections=bucket_sizes, dim=-1)
        else:
            bucket_sizes = [1] * len(self.lattice_sizes)
            bucket_dim_sizes = self.lattice_sizes

        return zip(inputs, bucket_sizes, bucket_dim_sizes)

    def _approximately_project_monotonicity(
        self,
        weights: torch.Tensor,
        lattice_sizes: list[int],
        monotonicities: list[Optional[Monotonicity]],
    ) -> torch.Tensor:
        """Projects weights of lattice to meet monotonicity constraints.

        Note that this projection is an approximation which guarantees monotonicity
        constraints but is not an exact projection with respect to the L2 norm.

        Algorithm:
        1. `max_projection`: For each vertex V in the lattice, the weight is adjusted to
        be the maximum of all weights of vertices X such that X has all coordinates
        less than or equal to V in monotonic dimensions.

        2. `half_projection`: We adjust the weights to be the average of the original
        weights and the `max_projection` weights.

        3. `min_projection`: For each vertex V in the lattice, the weight is adjusted
        based on the `half_projection` to be the minimum of all weights of vertices X
        such that V has all coordinates less than or equal to X in monotonic dimensions.

        This algorithm ensures that weights conform to the monotonicity constraints
        while getting closer to a true projection by adjusting both up/downwards.

        Args:
            weights: `torch.Tensor` of kernel data reshaped into `(lattice_sizes)` if
                `units == 1` or `(lattice_sizes, units)` if `units > 1`.
            lattice_sizes: List of size of each dimension of lattice, but for
                `units > 1`, `units` is appended to the end for computation purposes.
            monotonicities: List of `None` or `Monotonicity.INCREASING`
                of length `len(lattice_sizes)` for `units == 1` or
                `len(lattice_sizes)+1` if `units > 1` specifying monotonicity of each
                feature of lattice.

        Returns:
            `torch.Tensor` of shape `self.kernel` with updated weights which meet
            monotonicity constraints.
        """
        max_projection = weights
        for dim in range(len(lattice_sizes)):
            if monotonicities[dim] is None:
                continue
            layers = list(torch.unbind(max_projection, dim))
            for i in range(1, len(layers)):
                layers[i] = torch.max(layers[i], layers[i - 1])
            max_projection = torch.stack(layers, dim)

        half_projection = (weights + max_projection) / 2.0

        min_projection = half_projection
        for dim in range(len(lattice_sizes)):
            if monotonicities[dim] is None:
                continue
            layers = list(torch.unbind(min_projection, dim))
            for i in range(len(layers) - 2, -1, -1):
                # Compute cumulative minimum in reverse order
                layers[i] = torch.min(layers[i], layers[i + 1])
            min_projection = torch.stack(layers, dim)

        return min_projection

__init__(lattice_sizes, output_min=None, output_max=None, kernel_init=LatticeInit.LINEAR, monotonicities=None, clip_inputs=True, interpolation=Interpolation.HYPERCUBE, units=1)

Initializes an instance of 'Lattice'.

Parameters:

Name Type Description Default
lattice_sizes Union[list[int], tuple[int]]

List or tuple of size of lattice along each dimension.

required
output_min Optional[float]

Minimum output value for weights at vertices of lattice.

None
output_max Optional[float]

Maximum output value for weights at vertices of lattice.

None
kernel_init LatticeInit

Initialization scheme to use for the kernel.

LINEAR
monotonicities Optional[list[Optional[Monotonicity]]]

None or list of NONE or Monotonicity.INCREASING of length len(lattice_sizes) specifying monotonicity of each feature of lattice. A monotonically decreasing feature should use Monotonicity.INCREASING in the lattice layer but Monotonicity.DECREASING in the calibrator.

None
clip_inputs bool

Whether input points should be clipped to the range of lattice.

True
interpolation Interpolation

Interpolation scheme for a given input.

HYPERCUBE
units int

Dimensionality of weights stored at each vertex of lattice.

1

Raises:

Type Description
ValueError

if kernel_init is invalid.

NotImplementedError

Random monotonic initialization not yet implemented.

Source code in pytorch_lattice/layers/lattice.py
def __init__(
    self,
    lattice_sizes: Union[list[int], tuple[int]],
    output_min: Optional[float] = None,
    output_max: Optional[float] = None,
    kernel_init: LatticeInit = LatticeInit.LINEAR,
    monotonicities: Optional[list[Optional[Monotonicity]]] = None,
    clip_inputs: bool = True,
    interpolation: Interpolation = Interpolation.HYPERCUBE,
    units: int = 1,
) -> None:
    """Initializes an instance of 'Lattice'.

    Args:
        lattice_sizes: List or tuple of size of lattice along each dimension.
        output_min: Minimum output value for weights at vertices of lattice.
        output_max: Maximum output value for weights at vertices of lattice.
        kernel_init: Initialization scheme to use for the kernel.
        monotonicities: `None` or list of `NONE` or
            `Monotonicity.INCREASING` of length `len(lattice_sizes)` specifying
            monotonicity of each feature of lattice. A monotonically decreasing
             feature should use `Monotonicity.INCREASING` in the lattice layer but
            `Monotonicity.DECREASING` in the calibrator.
        clip_inputs: Whether input points should be clipped to the range of lattice.
        interpolation: Interpolation scheme for a given input.
        units: Dimensionality of weights stored at each vertex of lattice.

    Raises:
        ValueError: if `kernel_init` is invalid.
        NotImplementedError: Random monotonic initialization not yet implemented.
    """
    super().__init__()

    self.lattice_sizes = list(lattice_sizes)
    self.output_min = output_min
    self.output_max = output_max
    self.kernel_init = kernel_init
    self.clip_inputs = clip_inputs
    self.interpolation = interpolation
    self.units = units

    if monotonicities is not None:
        self.monotonicities = monotonicities
    else:
        self.monotonicities = [None] * len(lattice_sizes)

    if output_min is not None and output_max is not None:
        output_init_min, output_init_max = output_min, output_max
    elif output_min is not None:
        output_init_min, output_init_max = output_min, output_min + 4.0
    elif output_max is not None:
        output_init_min, output_init_max = output_max - 4.0, output_max
    else:
        output_init_min, output_init_max = -2.0, 2.0
    self._output_init_min, self._output_init_max = output_init_min, output_init_max

    @torch.no_grad()
    def initialize_kernel() -> torch.Tensor:
        if self.kernel_init == LatticeInit.LINEAR:
            return self._linear_initializer()
        if self.kernel_init == LatticeInit.RANDOM_MONOTONIC:
            raise NotImplementedError(
                "Random monotonic initialization not yet implemented."
            )
        raise ValueError(f"Unknown kernel init: {self.kernel_init}")

    self.kernel = torch.nn.Parameter(initialize_kernel())

apply_constraints()

Aggregate function for enforcing constraints of lattice.

Source code in pytorch_lattice/layers/lattice.py
@torch.no_grad()
def apply_constraints(self) -> None:
    """Aggregate function for enforcing constraints of lattice."""
    weights = self.kernel.clone()

    if self._count_non_zeros(self.monotonicities):
        lattice_sizes = self.lattice_sizes
        monotonicities = self.monotonicities
        if self.units > 1:
            lattice_sizes = lattice_sizes + [int(self.units)]
            if self.monotonicities:
                monotonicities = monotonicities + [None]

        weights = weights.reshape(*lattice_sizes)
        weights = self._approximately_project_monotonicity(
            weights, lattice_sizes, monotonicities
        )

    if self.output_min is not None:
        weights = torch.clamp_min(weights, self.output_min)
    if self.output_max is not None:
        weights = torch.clamp_max(weights, self.output_max)

    self.kernel.data = weights.view(-1, self.units)

assert_constraints(eps=1e-06)

Asserts that layer satisfies specified constraints.

This checks that weights follow monotonicity and bounds constraints.

Parameters:

Name Type Description Default
eps float

the margin of error allowed

1e-06

Returns:

Type Description
list[str]

A list of dicts describing violated constraints including indices of

list[str]

monotonicity violations. If no constraints violated, the list will be empty.

Source code in pytorch_lattice/layers/lattice.py
@torch.no_grad()
def assert_constraints(self, eps: float = 1e-6) -> list[str]:
    """Asserts that layer satisfies specified constraints.

    This checks that weights follow monotonicity and bounds constraints.

    Args:
        eps: the margin of error allowed

    Returns:
        A list of dicts describing violated constraints including indices of
        monotonicity violations. If no constraints violated, the list will be empty.
    """
    messages = []
    lattice_sizes = self.lattice_sizes
    monotonicities = self.monotonicities
    weights = self.kernel.data.clone()

    if weights.shape[1] > 1:
        lattice_sizes = lattice_sizes + [int(weights.shape[1])]
        if monotonicities:
            monotonicities = monotonicities + [None]

    # Reshape weights to match lattice sizes
    weights = weights.reshape(*lattice_sizes)

    for i in range(len(monotonicities or [])):
        if monotonicities[i] != Monotonicity.INCREASING:
            continue
        weights_layers = torch.unbind(weights, dim=i)

        for j in range(1, len(weights_layers)):
            diff = torch.min(weights_layers[j] - weights_layers[j - 1])
            if diff.item() < -eps:
                messages.append(f"Monotonicity violated at feature index {i}.")

    if self.output_max is not None and torch.max(weights) > self.output_max + eps:
        messages.append("Max weight greater than output_max.")
    if self.output_min is not None and torch.min(weights) < self.output_min - eps:
        messages.append("Min weight less than output_min.")

    return messages

forward(x)

Calculates interpolation from input, using method of self.interpolation.

Parameters:

Name Type Description Default
x Union[Tensor, list[Tensor]]

input tensor. If units == 1, tensor of shape: (batch_size, ..., len(lattice_size)) or list of len(lattice_sizes) tensors of same shape: (batch_size, ..., 1). If units > 1, tensor of shape (batch_size, ..., units, len(lattice_sizes)) or list of len(lattice_sizes) tensors OF same shape (batch_size, ..., units, 1)

required

Returns:

Type Description
Tensor

torch.Tensor of shape (batch_size, ..., units) containing interpolated

Tensor

values.

Raises:

Type Description
ValueError

If the type of interpolation is unknown.

Source code in pytorch_lattice/layers/lattice.py
def forward(self, x: Union[torch.Tensor, list[torch.Tensor]]) -> torch.Tensor:
    """Calculates interpolation from input, using method of self.interpolation.

    Args:
        x: input tensor. If `units == 1`, tensor of shape:
            `(batch_size, ..., len(lattice_size))` or list of `len(lattice_sizes)`
            tensors of same shape: `(batch_size, ..., 1)`. If `units > 1`, tensor of
            shape `(batch_size, ..., units, len(lattice_sizes))` or list of
            `len(lattice_sizes)` tensors OF same shape `(batch_size, ..., units, 1)`

    Returns:
        torch.Tensor of shape `(batch_size, ..., units)` containing interpolated
        values.

    Raises:
        ValueError: If the type of interpolation is unknown.
    """
    x = [xi.double() for xi in x] if isinstance(x, list) else x.double()
    if self.interpolation == Interpolation.HYPERCUBE:
        return self._compute_hypercube_interpolation(x)
    if self.interpolation == Interpolation.SIMPLEX:
        return self._compute_simplex_interpolation(x)
    raise ValueError(f"Unknown interpolation type: {self.interpolation}")

pytorch_lattice.layers.Linear

Bases: ConstrainedModule

A constrained linear module.

This module takes an input of shape (batch_size, input_dim) and applied a linear transformation. The output will have the same shape as the input.

Attributes:

Name Type Description
All

__init__ arguments.

kernel

torch.nn.Parameter that stores the linear combination weighting.

bias

torch.nn.Parameter that stores the bias term. Only available is use_bias is true.

Example:

input_dim = 3
inputs = torch.tensor(...)  # shape: (batch_size, input_dim)
linear = Linear(
    input_dim,
    monotonicities=[
        None,
        Monotonicity.INCREASING,
        Monotonicity.DECREASING
    ],
    use_bias=False,
    weighted_average=True,
)
outputs = linear(inputs)

Source code in pytorch_lattice/layers/linear.py
class Linear(ConstrainedModule):
    """A constrained linear module.

    This module takes an input of shape `(batch_size, input_dim)` and applied a linear
    transformation. The output will have the same shape as the input.

    Attributes:
        All: `__init__` arguments.
        kernel: `torch.nn.Parameter` that stores the linear combination weighting.
        bias: `torch.nn.Parameter` that stores the bias term. Only available is
            `use_bias` is true.

    Example:
    ```python
    input_dim = 3
    inputs = torch.tensor(...)  # shape: (batch_size, input_dim)
    linear = Linear(
        input_dim,
        monotonicities=[
            None,
            Monotonicity.INCREASING,
            Monotonicity.DECREASING
        ],
        use_bias=False,
        weighted_average=True,
    )
    outputs = linear(inputs)
    ```
    """

    def __init__(
        self,
        input_dim: int,
        monotonicities: Optional[list[Optional[Monotonicity]]] = None,
        use_bias: bool = True,
        weighted_average: bool = False,
    ) -> None:
        """Initializes an instance of `Linear`.

        Args:
            input_dim: The number of inputs that will be combined.
            monotonicities: If provided, specifies the monotonicity of each input
                dimension.
            use_bias: Whether to use a bias term for the linear combination.
            weighted_average: Whether to make the output a weighted average i.e. all
                coefficients are positive and add up to a total of 1.0. No bias term
                will be used, and `use_bias` will be set to false regardless of the
                original value. `monotonicities` will also be set to increasing for all
                input dimensions to ensure that all coefficients are positive.

        Raises:
            ValueError: If monotonicities does not have length input_dim (if provided).
        """
        super().__init__()

        self.input_dim = input_dim
        if monotonicities and len(monotonicities) != input_dim:
            raise ValueError("Monotonicities, if provided, must have length input_dim.")
        self.monotonicities = (
            monotonicities
            if not weighted_average
            else [Monotonicity.INCREASING] * input_dim
        )
        self.use_bias = use_bias if not weighted_average else False
        self.weighted_average = weighted_average

        self.kernel = torch.nn.Parameter(torch.Tensor(input_dim, 1).double())
        torch.nn.init.constant_(self.kernel, 1.0 / input_dim)
        if use_bias:
            self.bias = torch.nn.Parameter(torch.Tensor(1).double())
            torch.nn.init.constant_(self.bias, 0.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Transforms inputs using a linear combination.

        Args:
            x: The input tensor of shape `(batch_size, input_dim)`.

        Returns:
            torch.Tensor of shape `(batch_size, 1)` containing transformed input values.
        """
        result = torch.mm(x, self.kernel)
        if self.use_bias:
            result += self.bias
        return result

    @torch.no_grad()
    def apply_constraints(self) -> None:
        """Projects kernel into desired constraints."""
        projected_kernel_data = self.kernel.data

        if self.monotonicities:
            if Monotonicity.INCREASING in self.monotonicities:
                increasing_mask = torch.tensor(
                    [
                        [0.0] if m == Monotonicity.INCREASING else [1.0]
                        for m in self.monotonicities
                    ]
                )
                projected_kernel_data = torch.maximum(
                    projected_kernel_data, projected_kernel_data * increasing_mask
                )
            if Monotonicity.DECREASING in self.monotonicities:
                decreasing_mask = torch.tensor(
                    [
                        [0.0] if m == Monotonicity.DECREASING else [1.0]
                        for m in self.monotonicities
                    ]
                )
                projected_kernel_data = torch.minimum(
                    projected_kernel_data, projected_kernel_data * decreasing_mask
                )

        if self.weighted_average:
            norm = torch.norm(projected_kernel_data, 1)
            norm = torch.where(norm < 1e-8, 1.0, norm)
            projected_kernel_data /= norm

        self.kernel.data = projected_kernel_data

    @torch.no_grad()
    def assert_constraints(self, eps: float = 1e-6) -> list[str]:
        """Asserts that layer satisfies specified constraints.

        This checks that decreasing monotonicity corresponds to negative weights,
        increasing monotonicity corresponds to positive weights, and weights sum to 1
        for weighted_average=True.

        Args:
            eps: the margin of error allowed

        Returns:
            A list of messages describing violated constraints. If no constraints
            violated, the list will be empty.
        """
        messages = []

        if self.weighted_average:
            total_weight = torch.sum(self.kernel.data)
            if torch.abs(total_weight - 1.0) > eps:
                messages.append("Weights do not sum to 1.")

        if self.monotonicities:
            monotonicities_constant = torch.tensor(
                [
                    1
                    if m == Monotonicity.INCREASING
                    else -1
                    if m == Monotonicity.DECREASING
                    else 0
                    for m in self.monotonicities
                ],
                device=self.kernel.device,
                dtype=self.kernel.dtype,
            ).view(-1, 1)

            violated_monotonicities = (self.kernel * monotonicities_constant) < -eps
            violation_indices = torch.where(violated_monotonicities)
            if violation_indices[0].numel() > 0:
                messages.append(
                    f"Monotonicity violated at: {violation_indices[0].tolist()}"
                )

        return messages

__init__(input_dim, monotonicities=None, use_bias=True, weighted_average=False)

Initializes an instance of Linear.

Parameters:

Name Type Description Default
input_dim int

The number of inputs that will be combined.

required
monotonicities Optional[list[Optional[Monotonicity]]]

If provided, specifies the monotonicity of each input dimension.

None
use_bias bool

Whether to use a bias term for the linear combination.

True
weighted_average bool

Whether to make the output a weighted average i.e. all coefficients are positive and add up to a total of 1.0. No bias term will be used, and use_bias will be set to false regardless of the original value. monotonicities will also be set to increasing for all input dimensions to ensure that all coefficients are positive.

False

Raises:

Type Description
ValueError

If monotonicities does not have length input_dim (if provided).

Source code in pytorch_lattice/layers/linear.py
def __init__(
    self,
    input_dim: int,
    monotonicities: Optional[list[Optional[Monotonicity]]] = None,
    use_bias: bool = True,
    weighted_average: bool = False,
) -> None:
    """Initializes an instance of `Linear`.

    Args:
        input_dim: The number of inputs that will be combined.
        monotonicities: If provided, specifies the monotonicity of each input
            dimension.
        use_bias: Whether to use a bias term for the linear combination.
        weighted_average: Whether to make the output a weighted average i.e. all
            coefficients are positive and add up to a total of 1.0. No bias term
            will be used, and `use_bias` will be set to false regardless of the
            original value. `monotonicities` will also be set to increasing for all
            input dimensions to ensure that all coefficients are positive.

    Raises:
        ValueError: If monotonicities does not have length input_dim (if provided).
    """
    super().__init__()

    self.input_dim = input_dim
    if monotonicities and len(monotonicities) != input_dim:
        raise ValueError("Monotonicities, if provided, must have length input_dim.")
    self.monotonicities = (
        monotonicities
        if not weighted_average
        else [Monotonicity.INCREASING] * input_dim
    )
    self.use_bias = use_bias if not weighted_average else False
    self.weighted_average = weighted_average

    self.kernel = torch.nn.Parameter(torch.Tensor(input_dim, 1).double())
    torch.nn.init.constant_(self.kernel, 1.0 / input_dim)
    if use_bias:
        self.bias = torch.nn.Parameter(torch.Tensor(1).double())
        torch.nn.init.constant_(self.bias, 0.0)

apply_constraints()

Projects kernel into desired constraints.

Source code in pytorch_lattice/layers/linear.py
@torch.no_grad()
def apply_constraints(self) -> None:
    """Projects kernel into desired constraints."""
    projected_kernel_data = self.kernel.data

    if self.monotonicities:
        if Monotonicity.INCREASING in self.monotonicities:
            increasing_mask = torch.tensor(
                [
                    [0.0] if m == Monotonicity.INCREASING else [1.0]
                    for m in self.monotonicities
                ]
            )
            projected_kernel_data = torch.maximum(
                projected_kernel_data, projected_kernel_data * increasing_mask
            )
        if Monotonicity.DECREASING in self.monotonicities:
            decreasing_mask = torch.tensor(
                [
                    [0.0] if m == Monotonicity.DECREASING else [1.0]
                    for m in self.monotonicities
                ]
            )
            projected_kernel_data = torch.minimum(
                projected_kernel_data, projected_kernel_data * decreasing_mask
            )

    if self.weighted_average:
        norm = torch.norm(projected_kernel_data, 1)
        norm = torch.where(norm < 1e-8, 1.0, norm)
        projected_kernel_data /= norm

    self.kernel.data = projected_kernel_data

assert_constraints(eps=1e-06)

Asserts that layer satisfies specified constraints.

This checks that decreasing monotonicity corresponds to negative weights, increasing monotonicity corresponds to positive weights, and weights sum to 1 for weighted_average=True.

Parameters:

Name Type Description Default
eps float

the margin of error allowed

1e-06

Returns:

Type Description
list[str]

A list of messages describing violated constraints. If no constraints

list[str]

violated, the list will be empty.

Source code in pytorch_lattice/layers/linear.py
@torch.no_grad()
def assert_constraints(self, eps: float = 1e-6) -> list[str]:
    """Asserts that layer satisfies specified constraints.

    This checks that decreasing monotonicity corresponds to negative weights,
    increasing monotonicity corresponds to positive weights, and weights sum to 1
    for weighted_average=True.

    Args:
        eps: the margin of error allowed

    Returns:
        A list of messages describing violated constraints. If no constraints
        violated, the list will be empty.
    """
    messages = []

    if self.weighted_average:
        total_weight = torch.sum(self.kernel.data)
        if torch.abs(total_weight - 1.0) > eps:
            messages.append("Weights do not sum to 1.")

    if self.monotonicities:
        monotonicities_constant = torch.tensor(
            [
                1
                if m == Monotonicity.INCREASING
                else -1
                if m == Monotonicity.DECREASING
                else 0
                for m in self.monotonicities
            ],
            device=self.kernel.device,
            dtype=self.kernel.dtype,
        ).view(-1, 1)

        violated_monotonicities = (self.kernel * monotonicities_constant) < -eps
        violation_indices = torch.where(violated_monotonicities)
        if violation_indices[0].numel() > 0:
            messages.append(
                f"Monotonicity violated at: {violation_indices[0].tolist()}"
            )

    return messages

forward(x)

Transforms inputs using a linear combination.

Parameters:

Name Type Description Default
x Tensor

The input tensor of shape (batch_size, input_dim).

required

Returns:

Type Description
Tensor

torch.Tensor of shape (batch_size, 1) containing transformed input values.

Source code in pytorch_lattice/layers/linear.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Transforms inputs using a linear combination.

    Args:
        x: The input tensor of shape `(batch_size, input_dim)`.

    Returns:
        torch.Tensor of shape `(batch_size, 1)` containing transformed input values.
    """
    result = torch.mm(x, self.kernel)
    if self.use_bias:
        result += self.bias
    return result

pytorch_lattice.layers.NumericalCalibrator

Bases: ConstrainedModule

A numerical calibrator.

This module takes an input of shape (batch_size, 1) and calibrates it using a piece-wise linear function that conforms to any provided constraints. The output will have the same shape as the input.

Attributes:

Name Type Description
All

__init__ arguments.

kernel

torch.nn.Parameter that stores the piece-wise linear function weights.

missing_output

torch.nn.Parameter that stores the output learned for any missing inputs. Only available if missing_input_value is provided.

Example:

inputs = torch.tensor(...)  # shape: (batch_size, 1)
calibrator = NumericalCalibrator(
    input_keypoints=np.linspace(1., 5., num=5),
    output_min=0.0,
    output_max=1.0,
    monotonicity=Monotonicity.INCREASING,
    kernel_init=NumericalCalibratorInit.EQUAL_HEIGHTS,
)
outputs = calibrator(inputs)

Source code in pytorch_lattice/layers/numerical_calibrator.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
class NumericalCalibrator(ConstrainedModule):
    """A numerical calibrator.

    This module takes an input of shape `(batch_size, 1)` and calibrates it using a
    piece-wise linear function that conforms to any provided constraints. The output
    will have the same shape as the input.

    Attributes:
        All: `__init__` arguments.
        kernel: `torch.nn.Parameter` that stores the piece-wise linear function weights.
        missing_output: `torch.nn.Parameter` that stores the output learned for any
            missing inputs. Only available if `missing_input_value` is provided.

    Example:
    ```python
    inputs = torch.tensor(...)  # shape: (batch_size, 1)
    calibrator = NumericalCalibrator(
        input_keypoints=np.linspace(1., 5., num=5),
        output_min=0.0,
        output_max=1.0,
        monotonicity=Monotonicity.INCREASING,
        kernel_init=NumericalCalibratorInit.EQUAL_HEIGHTS,
    )
    outputs = calibrator(inputs)
    ```
    """

    def __init__(
        self,
        input_keypoints: np.ndarray,
        missing_input_value: Optional[float] = None,
        output_min: Optional[float] = None,
        output_max: Optional[float] = None,
        monotonicity: Optional[Monotonicity] = None,
        kernel_init: NumericalCalibratorInit = NumericalCalibratorInit.EQUAL_HEIGHTS,
        projection_iterations: int = 8,
        input_keypoints_type: InputKeypointsType = InputKeypointsType.FIXED,
    ) -> None:
        """Initializes an instance of `NumericalCalibrator`.

        Args:
            input_keypoints: Ordered list of float-valued keypoints for the underlying
                piece-wise linear function.
            missing_input_value: If provided, the calibrator will learn to map all
                instances of this missing input value to a learned output value.
            output_min: Minimum output value. If `None`, the minimum output value will
                be unbounded.
            output_max: Maximum output value. If `None`, the maximum output value will
                be unbounded.
            monotonicity: Monotonicity constraint for the underlying piece-wise linear
                function.
            kernel_init: Initialization scheme to use for the kernel.
            projection_iterations: Number of times to run Dykstra's projection
                algorithm when applying constraints.
            input_keypoints_type: `InputKeypointType` of either `FIXED` or `LEARNED`. If
                `LEARNED`, keypoints other than the first or last will follow
                `input_keypoints` for initialization but adapt during training.

        Raises:
            ValueError: If `kernel_init` is invalid.
        """
        super().__init__()

        self.input_keypoints = input_keypoints
        self.missing_input_value = missing_input_value
        self.output_min = output_min
        self.output_max = output_max
        self.monotonicity = monotonicity
        self.kernel_init = kernel_init
        self.projection_iterations = projection_iterations
        self.input_keypoints_type = input_keypoints_type

        # Determine default output initialization values if bounds are not fully set.
        if output_min is not None and output_max is not None:
            output_init_min, output_init_max = output_min, output_max
        elif output_min is not None:
            output_init_min, output_init_max = output_min, output_min + 4.0
        elif output_max is not None:
            output_init_min, output_init_max = output_max - 4.0, output_max
        else:
            output_init_min, output_init_max = -2.0, 2.0
        self._output_init_min, self._output_init_max = output_init_min, output_init_max

        self._interpolation_keypoints = torch.from_numpy(input_keypoints[:-1])
        self._lengths = torch.from_numpy(input_keypoints[1:] - input_keypoints[:-1])
        if self.input_keypoints_type == InputKeypointsType.LEARNED:
            self._keypoint_min = input_keypoints[0]
            self._keypoint_range = input_keypoints[-1] - input_keypoints[0]
            initial_logits = torch.from_numpy(
                np.log(
                    (input_keypoints[1:] - input_keypoints[:-1]) / self._keypoint_range
                )
            ).double()
            self._interpolation_logits = torch.nn.Parameter(initial_logits)

        # First row of the kernel represents the bias. The remaining rows represent
        # the y-value delta compared to the previous point i.e. the segment heights.
        @torch.no_grad()
        def initialize_kernel() -> torch.Tensor:
            output_init_range = self._output_init_max - self._output_init_min
            if kernel_init == NumericalCalibratorInit.EQUAL_HEIGHTS:
                num_segments = self._interpolation_keypoints.size()[0]
                segment_height = output_init_range / num_segments
                heights = torch.tensor([[segment_height]] * num_segments)
            elif kernel_init == NumericalCalibratorInit.EQUAL_SLOPES:
                heights = (
                    self._lengths * output_init_range / torch.sum(self._lengths)
                )[:, None]
            else:
                raise ValueError(f"Unknown kernel init: {self.kernel_init}")

            if monotonicity == Monotonicity.DECREASING:
                bias = torch.tensor([[self._output_init_max]])
                heights = -heights
            else:
                bias = torch.tensor([[self._output_init_min]])
            return torch.cat((bias, heights), 0).double()

        self.kernel = torch.nn.Parameter(initialize_kernel())

        if missing_input_value:
            self.missing_output = torch.nn.Parameter(torch.Tensor(1))
            torch.nn.init.constant_(
                self.missing_output,
                (self._output_init_min + self._output_init_max) / 2.0,
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Calibrates numerical inputs through piece-wise linear interpolation.

        Args:
            x: The input tensor of shape `(batch_size, 1)`.

        Returns:
            torch.Tensor of shape `(batch_size, 1)` containing calibrated input values.
        """
        if self.input_keypoints_type == InputKeypointsType.LEARNED:
            softmaxed_logits = torch.nn.functional.softmax(
                self._interpolation_logits, dim=-1
            )
            self._lengths = softmaxed_logits * self._keypoint_range
            interior_keypoints = (
                torch.cumsum(self._lengths, dim=-1) + self._keypoint_min
            )
            self._interpolation_keypoints = torch.cat(
                [torch.tensor([self._keypoint_min]), interior_keypoints[:-1]]
            )

        interpolation_weights = (x - self._interpolation_keypoints) / self._lengths
        interpolation_weights = torch.minimum(interpolation_weights, torch.tensor(1.0))
        interpolation_weights = torch.maximum(interpolation_weights, torch.tensor(0.0))
        interpolation_weights = torch.cat(
            (torch.ones_like(x), interpolation_weights), -1
        )
        result = torch.mm(interpolation_weights, self.kernel)

        if self.missing_input_value is not None:
            missing_mask = torch.eq(x, self.missing_input_value).long()
            result = missing_mask * self.missing_output + (1.0 - missing_mask) * result

        return result

    @torch.no_grad()
    def apply_constraints(self) -> None:
        """Jointly projects kernel into desired constraints.

        Uses Dykstra's alternating projection algorithm to jointly project onto all
        given constraints. This algorithm projects with respect to the L2 norm, but it
        approached the norm from the "wrong" side. To ensure that all constraints are
        strictly met, we do final approximate projections that project strictly into the
        feasible space, but this is not an exact projection with respect to the L2 norm.
        Enough iterations make the impact of this approximation negligible.
        """
        constrain_bounds = self.output_min is not None or self.output_max is not None
        constrain_monotonicity = self.monotonicity is not None
        num_constraints = sum([constrain_bounds, constrain_monotonicity])

        # We do nothing to the weights in this case
        if num_constraints == 0:
            return

        original_bias, original_heights = self.kernel.data[0:1], self.kernel.data[1:]
        previous_bias_delta: dict[str, torch.Tensor] = defaultdict(
            lambda: torch.zeros_like(original_bias)
        )
        previous_heights_delta: dict[str, torch.Tensor] = defaultdict(
            lambda: torch.zeros_like(original_heights)
        )

        def apply_bound_constraints(bias, heights):
            previous_bias = bias - previous_bias_delta["BOUNDS"]
            previous_heights = heights - previous_heights_delta["BOUNDS"]
            if constrain_monotonicity:
                bias, heights = self._project_monotonic_bounds(
                    previous_bias, previous_heights
                )
            else:
                bias, heights = self._approximately_project_bounds_only(
                    previous_bias, previous_heights
                )
            previous_bias_delta["BOUNDS"] = bias - previous_bias
            previous_heights_delta["BOUNDS"] = heights - previous_heights
            return bias, heights

        def apply_monotonicity_constraints(heights):
            previous_heights = heights - previous_bias_delta["MONOTONICITY"]
            heights = self._project_monotonicity(previous_heights)
            previous_heights_delta["MONOTONICITY"] = heights - previous_heights
            return heights

        def apply_dykstras_projection(bias, heights):
            if constrain_bounds:
                bias, heights = apply_bound_constraints(bias, heights)
            if constrain_monotonicity:
                heights = apply_monotonicity_constraints(heights)
            return bias, heights

        def finalize_constraints(bias, heights):
            if constrain_monotonicity:
                heights = self._project_monotonicity(heights)
            if constrain_bounds:
                if constrain_monotonicity:
                    bias, heights = self._squeeze_by_scaling(bias, heights)
                else:
                    bias, heights = self._approximately_project_bounds_only(
                        bias, heights
                    )
            return bias, heights

        projected_bias, projected_heights = apply_dykstras_projection(
            original_bias, original_heights
        )
        if num_constraints > 1:
            for _ in range(self.projection_iterations - 1):
                projected_bias, projected_heights = apply_dykstras_projection(
                    projected_bias, projected_heights
                )
            projected_bias, projected_heights = finalize_constraints(
                projected_bias, projected_heights
            )

        self.kernel.data = torch.cat((projected_bias, projected_heights), 0)

    @torch.no_grad()
    def assert_constraints(self, eps: float = 1e-6) -> list[str]:
        """Asserts that layer satisfies specified constraints.

        This checks that weights follow monotonicity constraints and that the output is
        within bounds.

        Args:
            eps: the margin of error allowed

        Returns:
            A list of messages describing violated constraints including indices of
            monotonicity violations. If no constraints violated, the list will be empty.
        """
        weights = torch.squeeze(self.kernel.data)
        messages = []

        if (
            self.output_max is not None
            and torch.max(self.keypoints_outputs()) > self.output_max + eps
        ):
            messages.append("Max weight greater than output_max.")
        if (
            self.output_min is not None
            and torch.min(self.keypoints_outputs()) < self.output_min - eps
        ):
            messages.append("Min weight less than output_min.")

        diffs = weights[1:]
        violation_indices = []

        if self.monotonicity == Monotonicity.INCREASING:
            violation_indices = (diffs < -eps).nonzero().tolist()
        elif self.monotonicity == Monotonicity.DECREASING:
            violation_indices = (diffs > eps).nonzero().tolist()

        violation_indices = [(i[0], i[0] + 1) for i in violation_indices]
        if violation_indices:
            messages.append(f"Monotonicity violated at: {str(violation_indices)}.")

        return messages

    @torch.no_grad()
    def keypoints_inputs(self) -> torch.Tensor:
        """Returns tensor of keypoint inputs."""
        return torch.cat(
            (
                self._interpolation_keypoints,
                self._interpolation_keypoints[-1:] + self._lengths[-1:],
            ),
            0,
        )

    @torch.no_grad()
    def keypoints_outputs(self) -> torch.Tensor:
        """Returns tensor of keypoint outputs."""
        return torch.cumsum(self.kernel.data, 0).T[0]

    ################################################################################
    ############################## PRIVATE METHODS #################################
    ################################################################################

    def _project_monotonic_bounds(
        self, bias: torch.Tensor, heights: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Projects bias and heights into bounds considering monotonicity.

        For computation simplification in the case of decreasing monotonicity, we mirror
        bias and heights and swap-mirror the output bounds. After doing the standard
        projection with resepct to increasing monotonicity, we then mirror everything
        back to get the correct projection.

        Args:
            bias: The bias of the underlying piece-wise linear function.
            heights: The heights of each segment of the underlying piece-wise linear
                function.

        Returns:
            A tuple containing the projected bias and projected heights.
        """
        output_min, output_max = self.output_min, self.output_max
        decreasing = self.monotonicity == Monotonicity.DECREASING
        if decreasing:
            bias, heights = -bias, -heights
            output_min = None if self.output_max is None else -1 * self.output_max
            output_max = None if self.output_min is None else -1 * self.output_min
        if output_max is not None:
            num_heights = heights.size()[0]
            output_max_diffs = output_max - (bias + torch.sum(heights, 0))
            bias_delta = output_max_diffs / (num_heights + 1)
            bias_delta = torch.minimum(bias_delta, torch.tensor(0.0))
            if output_min is not None:
                bias = torch.maximum(bias + bias_delta, torch.tensor(output_min))
                heights_delta = output_max_diffs / num_heights
            else:
                bias += bias_delta
                heights_delta = bias_delta
            heights += torch.minimum(heights_delta, torch.tensor(0.0))
        elif output_min is not None:
            bias = torch.maximum(bias, torch.tensor(output_min))
        if decreasing:
            bias, heights = -bias, -heights
        return bias, heights

    def _approximately_project_bounds_only(
        self, bias: torch.Tensor, heights: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Projects bias and heights without considering monotonicity.

        It is worth noting that this projection is an approximation and is not an exact
        projection with respect to the L2 norm; however, it is sufficiently accurate and
        efficient in practice for non-monotonic functions.

        Args:
            bias: The bias of the underlying piece-wise linear function.
            heights: The heights of each segment of the underlying piece-wise linear
                function.

        Returns:
            A tuple containing the projected bias and projected heights.
        """
        sums = torch.cumsum(torch.cat((bias, heights), 0), 0)
        if self.output_min is not None:
            sums = torch.maximum(sums, torch.tensor(self.output_min))
        if self.output_max is not None:
            sums = torch.minimum(sums, torch.tensor(self.output_max))
        bias = sums[0:1]
        heights = sums[1:] - sums[:-1]
        return bias, heights

    def _project_monotonicity(self, heights: torch.Tensor) -> torch.Tensor:
        """Returns bias and heights projected into desired monotonicity constraints."""
        if self.monotonicity == Monotonicity.INCREASING:
            return torch.maximum(heights, torch.tensor(0.0))
        if self.monotonicity == Monotonicity.DECREASING:
            return torch.minimum(heights, torch.tensor(0.0))
        return heights

    def _squeeze_by_scaling(
        self, bias: torch.Tensor, heights: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """Squeezes monotonic calibrators by scaling them into bound constraints.

        It is worth noting that this is not an exact projection with respect to the L2
        norm; however, it maintains convexity, which projection by shift does not.

        Args:
            bias: The bias of the underlying piece-wise linear function.
            heights: The heights of each segment of the underlying piece-wise linear
                function.

        Returns:
            A tuple containing the projected bias and projected heights.
        """
        decreasing = self.monotonicity == Monotonicity.DECREASING
        output_max = self.output_max
        if decreasing:
            if self.output_min is None:
                return bias, heights
            bias, heights = -bias, -heights
            output_max = None if self.output_min is None else -1 * self.output_min
        if output_max is None:
            return bias, heights
        delta = output_max - bias
        scaling_factor = torch.where(
            delta > 0.0001, torch.sum(heights, 0) / delta, torch.ones_like(delta)
        )
        heights /= torch.maximum(scaling_factor, torch.tensor(1.0))
        if decreasing:
            bias, heights = -bias, -heights
        return bias, heights

__init__(input_keypoints, missing_input_value=None, output_min=None, output_max=None, monotonicity=None, kernel_init=NumericalCalibratorInit.EQUAL_HEIGHTS, projection_iterations=8, input_keypoints_type=InputKeypointsType.FIXED)

Initializes an instance of NumericalCalibrator.

Parameters:

Name Type Description Default
input_keypoints ndarray

Ordered list of float-valued keypoints for the underlying piece-wise linear function.

required
missing_input_value Optional[float]

If provided, the calibrator will learn to map all instances of this missing input value to a learned output value.

None
output_min Optional[float]

Minimum output value. If None, the minimum output value will be unbounded.

None
output_max Optional[float]

Maximum output value. If None, the maximum output value will be unbounded.

None
monotonicity Optional[Monotonicity]

Monotonicity constraint for the underlying piece-wise linear function.

None
kernel_init NumericalCalibratorInit

Initialization scheme to use for the kernel.

EQUAL_HEIGHTS
projection_iterations int

Number of times to run Dykstra's projection algorithm when applying constraints.

8
input_keypoints_type InputKeypointsType

InputKeypointType of either FIXED or LEARNED. If LEARNED, keypoints other than the first or last will follow input_keypoints for initialization but adapt during training.

FIXED

Raises:

Type Description
ValueError

If kernel_init is invalid.

Source code in pytorch_lattice/layers/numerical_calibrator.py
def __init__(
    self,
    input_keypoints: np.ndarray,
    missing_input_value: Optional[float] = None,
    output_min: Optional[float] = None,
    output_max: Optional[float] = None,
    monotonicity: Optional[Monotonicity] = None,
    kernel_init: NumericalCalibratorInit = NumericalCalibratorInit.EQUAL_HEIGHTS,
    projection_iterations: int = 8,
    input_keypoints_type: InputKeypointsType = InputKeypointsType.FIXED,
) -> None:
    """Initializes an instance of `NumericalCalibrator`.

    Args:
        input_keypoints: Ordered list of float-valued keypoints for the underlying
            piece-wise linear function.
        missing_input_value: If provided, the calibrator will learn to map all
            instances of this missing input value to a learned output value.
        output_min: Minimum output value. If `None`, the minimum output value will
            be unbounded.
        output_max: Maximum output value. If `None`, the maximum output value will
            be unbounded.
        monotonicity: Monotonicity constraint for the underlying piece-wise linear
            function.
        kernel_init: Initialization scheme to use for the kernel.
        projection_iterations: Number of times to run Dykstra's projection
            algorithm when applying constraints.
        input_keypoints_type: `InputKeypointType` of either `FIXED` or `LEARNED`. If
            `LEARNED`, keypoints other than the first or last will follow
            `input_keypoints` for initialization but adapt during training.

    Raises:
        ValueError: If `kernel_init` is invalid.
    """
    super().__init__()

    self.input_keypoints = input_keypoints
    self.missing_input_value = missing_input_value
    self.output_min = output_min
    self.output_max = output_max
    self.monotonicity = monotonicity
    self.kernel_init = kernel_init
    self.projection_iterations = projection_iterations
    self.input_keypoints_type = input_keypoints_type

    # Determine default output initialization values if bounds are not fully set.
    if output_min is not None and output_max is not None:
        output_init_min, output_init_max = output_min, output_max
    elif output_min is not None:
        output_init_min, output_init_max = output_min, output_min + 4.0
    elif output_max is not None:
        output_init_min, output_init_max = output_max - 4.0, output_max
    else:
        output_init_min, output_init_max = -2.0, 2.0
    self._output_init_min, self._output_init_max = output_init_min, output_init_max

    self._interpolation_keypoints = torch.from_numpy(input_keypoints[:-1])
    self._lengths = torch.from_numpy(input_keypoints[1:] - input_keypoints[:-1])
    if self.input_keypoints_type == InputKeypointsType.LEARNED:
        self._keypoint_min = input_keypoints[0]
        self._keypoint_range = input_keypoints[-1] - input_keypoints[0]
        initial_logits = torch.from_numpy(
            np.log(
                (input_keypoints[1:] - input_keypoints[:-1]) / self._keypoint_range
            )
        ).double()
        self._interpolation_logits = torch.nn.Parameter(initial_logits)

    # First row of the kernel represents the bias. The remaining rows represent
    # the y-value delta compared to the previous point i.e. the segment heights.
    @torch.no_grad()
    def initialize_kernel() -> torch.Tensor:
        output_init_range = self._output_init_max - self._output_init_min
        if kernel_init == NumericalCalibratorInit.EQUAL_HEIGHTS:
            num_segments = self._interpolation_keypoints.size()[0]
            segment_height = output_init_range / num_segments
            heights = torch.tensor([[segment_height]] * num_segments)
        elif kernel_init == NumericalCalibratorInit.EQUAL_SLOPES:
            heights = (
                self._lengths * output_init_range / torch.sum(self._lengths)
            )[:, None]
        else:
            raise ValueError(f"Unknown kernel init: {self.kernel_init}")

        if monotonicity == Monotonicity.DECREASING:
            bias = torch.tensor([[self._output_init_max]])
            heights = -heights
        else:
            bias = torch.tensor([[self._output_init_min]])
        return torch.cat((bias, heights), 0).double()

    self.kernel = torch.nn.Parameter(initialize_kernel())

    if missing_input_value:
        self.missing_output = torch.nn.Parameter(torch.Tensor(1))
        torch.nn.init.constant_(
            self.missing_output,
            (self._output_init_min + self._output_init_max) / 2.0,
        )

apply_constraints()

Jointly projects kernel into desired constraints.

Uses Dykstra's alternating projection algorithm to jointly project onto all given constraints. This algorithm projects with respect to the L2 norm, but it approached the norm from the "wrong" side. To ensure that all constraints are strictly met, we do final approximate projections that project strictly into the feasible space, but this is not an exact projection with respect to the L2 norm. Enough iterations make the impact of this approximation negligible.

Source code in pytorch_lattice/layers/numerical_calibrator.py
@torch.no_grad()
def apply_constraints(self) -> None:
    """Jointly projects kernel into desired constraints.

    Uses Dykstra's alternating projection algorithm to jointly project onto all
    given constraints. This algorithm projects with respect to the L2 norm, but it
    approached the norm from the "wrong" side. To ensure that all constraints are
    strictly met, we do final approximate projections that project strictly into the
    feasible space, but this is not an exact projection with respect to the L2 norm.
    Enough iterations make the impact of this approximation negligible.
    """
    constrain_bounds = self.output_min is not None or self.output_max is not None
    constrain_monotonicity = self.monotonicity is not None
    num_constraints = sum([constrain_bounds, constrain_monotonicity])

    # We do nothing to the weights in this case
    if num_constraints == 0:
        return

    original_bias, original_heights = self.kernel.data[0:1], self.kernel.data[1:]
    previous_bias_delta: dict[str, torch.Tensor] = defaultdict(
        lambda: torch.zeros_like(original_bias)
    )
    previous_heights_delta: dict[str, torch.Tensor] = defaultdict(
        lambda: torch.zeros_like(original_heights)
    )

    def apply_bound_constraints(bias, heights):
        previous_bias = bias - previous_bias_delta["BOUNDS"]
        previous_heights = heights - previous_heights_delta["BOUNDS"]
        if constrain_monotonicity:
            bias, heights = self._project_monotonic_bounds(
                previous_bias, previous_heights
            )
        else:
            bias, heights = self._approximately_project_bounds_only(
                previous_bias, previous_heights
            )
        previous_bias_delta["BOUNDS"] = bias - previous_bias
        previous_heights_delta["BOUNDS"] = heights - previous_heights
        return bias, heights

    def apply_monotonicity_constraints(heights):
        previous_heights = heights - previous_bias_delta["MONOTONICITY"]
        heights = self._project_monotonicity(previous_heights)
        previous_heights_delta["MONOTONICITY"] = heights - previous_heights
        return heights

    def apply_dykstras_projection(bias, heights):
        if constrain_bounds:
            bias, heights = apply_bound_constraints(bias, heights)
        if constrain_monotonicity:
            heights = apply_monotonicity_constraints(heights)
        return bias, heights

    def finalize_constraints(bias, heights):
        if constrain_monotonicity:
            heights = self._project_monotonicity(heights)
        if constrain_bounds:
            if constrain_monotonicity:
                bias, heights = self._squeeze_by_scaling(bias, heights)
            else:
                bias, heights = self._approximately_project_bounds_only(
                    bias, heights
                )
        return bias, heights

    projected_bias, projected_heights = apply_dykstras_projection(
        original_bias, original_heights
    )
    if num_constraints > 1:
        for _ in range(self.projection_iterations - 1):
            projected_bias, projected_heights = apply_dykstras_projection(
                projected_bias, projected_heights
            )
        projected_bias, projected_heights = finalize_constraints(
            projected_bias, projected_heights
        )

    self.kernel.data = torch.cat((projected_bias, projected_heights), 0)

assert_constraints(eps=1e-06)

Asserts that layer satisfies specified constraints.

This checks that weights follow monotonicity constraints and that the output is within bounds.

Parameters:

Name Type Description Default
eps float

the margin of error allowed

1e-06

Returns:

Type Description
list[str]

A list of messages describing violated constraints including indices of

list[str]

monotonicity violations. If no constraints violated, the list will be empty.

Source code in pytorch_lattice/layers/numerical_calibrator.py
@torch.no_grad()
def assert_constraints(self, eps: float = 1e-6) -> list[str]:
    """Asserts that layer satisfies specified constraints.

    This checks that weights follow monotonicity constraints and that the output is
    within bounds.

    Args:
        eps: the margin of error allowed

    Returns:
        A list of messages describing violated constraints including indices of
        monotonicity violations. If no constraints violated, the list will be empty.
    """
    weights = torch.squeeze(self.kernel.data)
    messages = []

    if (
        self.output_max is not None
        and torch.max(self.keypoints_outputs()) > self.output_max + eps
    ):
        messages.append("Max weight greater than output_max.")
    if (
        self.output_min is not None
        and torch.min(self.keypoints_outputs()) < self.output_min - eps
    ):
        messages.append("Min weight less than output_min.")

    diffs = weights[1:]
    violation_indices = []

    if self.monotonicity == Monotonicity.INCREASING:
        violation_indices = (diffs < -eps).nonzero().tolist()
    elif self.monotonicity == Monotonicity.DECREASING:
        violation_indices = (diffs > eps).nonzero().tolist()

    violation_indices = [(i[0], i[0] + 1) for i in violation_indices]
    if violation_indices:
        messages.append(f"Monotonicity violated at: {str(violation_indices)}.")

    return messages

forward(x)

Calibrates numerical inputs through piece-wise linear interpolation.

Parameters:

Name Type Description Default
x Tensor

The input tensor of shape (batch_size, 1).

required

Returns:

Type Description
Tensor

torch.Tensor of shape (batch_size, 1) containing calibrated input values.

Source code in pytorch_lattice/layers/numerical_calibrator.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Calibrates numerical inputs through piece-wise linear interpolation.

    Args:
        x: The input tensor of shape `(batch_size, 1)`.

    Returns:
        torch.Tensor of shape `(batch_size, 1)` containing calibrated input values.
    """
    if self.input_keypoints_type == InputKeypointsType.LEARNED:
        softmaxed_logits = torch.nn.functional.softmax(
            self._interpolation_logits, dim=-1
        )
        self._lengths = softmaxed_logits * self._keypoint_range
        interior_keypoints = (
            torch.cumsum(self._lengths, dim=-1) + self._keypoint_min
        )
        self._interpolation_keypoints = torch.cat(
            [torch.tensor([self._keypoint_min]), interior_keypoints[:-1]]
        )

    interpolation_weights = (x - self._interpolation_keypoints) / self._lengths
    interpolation_weights = torch.minimum(interpolation_weights, torch.tensor(1.0))
    interpolation_weights = torch.maximum(interpolation_weights, torch.tensor(0.0))
    interpolation_weights = torch.cat(
        (torch.ones_like(x), interpolation_weights), -1
    )
    result = torch.mm(interpolation_weights, self.kernel)

    if self.missing_input_value is not None:
        missing_mask = torch.eq(x, self.missing_input_value).long()
        result = missing_mask * self.missing_output + (1.0 - missing_mask) * result

    return result

keypoints_inputs()

Returns tensor of keypoint inputs.

Source code in pytorch_lattice/layers/numerical_calibrator.py
@torch.no_grad()
def keypoints_inputs(self) -> torch.Tensor:
    """Returns tensor of keypoint inputs."""
    return torch.cat(
        (
            self._interpolation_keypoints,
            self._interpolation_keypoints[-1:] + self._lengths[-1:],
        ),
        0,
    )

keypoints_outputs()

Returns tensor of keypoint outputs.

Source code in pytorch_lattice/layers/numerical_calibrator.py
@torch.no_grad()
def keypoints_outputs(self) -> torch.Tensor:
    """Returns tensor of keypoint outputs."""
    return torch.cumsum(self.kernel.data, 0).T[0]