File size: 4,143 Bytes
3de7bf6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Buffer List Mixin."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


import torch
from torch import nn


class BufferListMixin(nn.Module):
    """Buffer List Mixin.

    This mixin is used to allow registering a list of tensors as buffers in a pytorch module.

    Example:
        >>> class MyModule(BufferListMixin, nn.Module):
        ...     def __init__(self):
        ...         super().__init__()
        ...         tensor_list = [torch.ones(3) * i for i in range(3)]
        ...         self.register_buffer_list("my_buffer_list", tensor_list)
        >>> module = MyModule()
        >>> # The buffer list can be accessed as a regular attribute
        >>> module.my_buffer_list
        [
            tensor([0., 0., 0.]),
            tensor([1., 1., 1.]),
            tensor([2., 2., 2.])
        ]
        >>> # We can update the buffer list at any time
        >>> new_tensor_list = [torch.ones(3) * i + 10 for i in range(3)]
        >>> module.register_buffer_list("my_buffer_list", new_tensor_list)
        >>> module.my_buffer_list
        [
            tensor([10., 10., 10.]),
            tensor([11., 11., 11.]),
            tensor([12., 12., 12.])
        ]
        >>> # Move to GPU. Since the tensors are registered as buffers, device placement is handled automatically
        >>> module.cuda()
        >>> module.my_buffer_list
        [
            tensor([10., 10., 10.], device='cuda:0'),
            tensor([11., 11., 11.], device='cuda:0'),
            tensor([12., 12., 12.], device='cuda:0')
        ]
    """

    def register_buffer_list(self, name: str, values: list[torch.Tensor], persistent: bool = True, **kwargs) -> None:
        """Register a list of tensors as buffers in a pytorch module.

        Each tensor is registered as a buffer with the name `_name_i` where `i` is the index of the tensor in the list.
        To update and retrieve the list of tensors, we dynamically assign a descriptor attribute to the class.

        Args:
            name (str): Name of the buffer list.
            values (list[torch.Tensor]): List of tensors to register as buffers.
            persistent (bool, optional): Whether the buffers should be saved as part of the module state_dict.
                Defaults to True.
            **kwargs: Additional keyword arguments to pass to `torch.nn.Module.register_buffer`.
        """
        for i, value in enumerate(values):
            self.register_buffer(f"_{name}_{i}", value, persistent=persistent, **kwargs)

        setattr(BufferListMixin, name, BufferListDescriptor(name, len(values)))


class BufferListDescriptor:
    """Buffer List Descriptor.

    This descriptor is used to allow registering a list of tensors as buffers in a pytorch module.

    Args:
        name (str): Name of the buffer list.
        length (int): Length of the buffer list.
    """

    def __init__(self, name: str, length: int) -> None:
        self.name = name
        self.length = length

    def __get__(self, instance: object, object_type: type | None = None) -> list[torch.Tensor]:
        """Get the list of tensors.

        Each element of the buffer list is stored as a buffer with the name `name_i` where `i` is the index of the
        element in the list. We use list comprehension to retrieve the list of tensors.

        Args:
            instance (object): Instance of the class.
            object_type (Any, optional): Type of the class. Defaults to None.

        Returns:
            list[torch.Tensor]: Contents of the buffer list.
        """
        del object_type
        return [getattr(instance, f"_{self.name}_{i}") for i in range(self.length)]

    def __set__(self, instance: object, values: list[torch.Tensor]) -> None:
        """Set the list of tensors.

        Assigns a new list of tensors to the buffer list by updating the individual buffer attributes.

        Args:
            instance (object): Instance of the class.
            values (list[torch.Tensor]): List of tensors to set.
        """
        for i, value in enumerate(values):
            setattr(instance, f"_{self.name}_{i}", value)