summaryrefslogtreecommitdiffstats
path: root/src/borg/helpers/msgpack.py
blob: 197d2debd8a9fcb27dff7a600d0a9bc493c7a393 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
from .datastruct import StableDict
from ..constants import *  # NOQA

# wrapping msgpack ---------------------------------------------------------------------------------------------------
#
# due to the planned breaking api changes in upstream msgpack, we wrap it the way we need it -
# to avoid having lots of clutter in the calling code. see tickets #968 and #3632.
#
# Packing
# -------
# use_bin_type = False is needed to generate the old msgpack format (not msgpack 2.0 spec) as borg always did.
# unicode_errors = None is needed because usage of it is deprecated
#
# Unpacking
# ---------
# raw = True is needed to unpack the old msgpack format to bytes (not str, about the decoding see item.pyx).
# unicode_errors = None is needed because usage of it is deprecated

from msgpack import Packer as mp_Packer
from msgpack import packb as mp_packb
from msgpack import pack as mp_pack
from msgpack import Unpacker as mp_Unpacker
from msgpack import unpackb as mp_unpackb
from msgpack import unpack as mp_unpack
from msgpack import version as mp_version

from msgpack import ExtType
from msgpack import OutOfData


version = mp_version

_post_100 = version >= (1, 0, 0)


class PackException(Exception):
    """Exception while msgpack packing"""


class UnpackException(Exception):
    """Exception while msgpack unpacking"""


class Packer(mp_Packer):
    def __init__(self, *, default=None, unicode_errors=None,
                 use_single_float=False, autoreset=True, use_bin_type=False,
                 strict_types=False):
        assert unicode_errors is None
        super().__init__(default=default, unicode_errors=unicode_errors,
                         use_single_float=use_single_float, autoreset=autoreset, use_bin_type=use_bin_type,
                         strict_types=strict_types)

    def pack(self, obj):
        try:
            return super().pack(obj)
        except Exception as e:
            raise PackException(e)


def packb(o, *, use_bin_type=False, unicode_errors=None, **kwargs):
    assert unicode_errors is None
    try:
        return mp_packb(o, use_bin_type=use_bin_type, unicode_errors=unicode_errors, **kwargs)
    except Exception as e:
        raise PackException(e)


def pack(o, stream, *, use_bin_type=False, unicode_errors=None, **kwargs):
    assert unicode_errors is None
    try:
        return mp_pack(o, stream, use_bin_type=use_bin_type, unicode_errors=unicode_errors, **kwargs)
    except Exception as e:
        raise PackException(e)


# Note: after requiring msgpack >= 0.6.1 we can remove the max_*_len args and
#       rely on msgpack auto-computing DoS-safe max values from len(data) for
#       unpack(data) or from max_buffer_len for Unpacker(max_buffer_len=N).
#       maybe we can also use that to simplify get_limited_unpacker().

class Unpacker(mp_Unpacker):
    def __init__(self, file_like=None, *, read_size=0, use_list=True, raw=True,
                 object_hook=None, object_pairs_hook=None, list_hook=None,
                 unicode_errors=None, max_buffer_size=0,
                 ext_hook=ExtType,
                 strict_map_key=False,
                 max_str_len=2147483647,  # 2**32-1
                 max_bin_len=2147483647,
                 max_array_len=2147483647,
                 max_map_len=2147483647,
                 max_ext_len=2147483647):
        assert raw is True
        assert unicode_errors is None
        kw = dict(file_like=file_like, read_size=read_size, use_list=use_list, raw=raw,
                  object_hook=object_hook, object_pairs_hook=object_pairs_hook, list_hook=list_hook,
                  unicode_errors=unicode_errors, max_buffer_size=max_buffer_size,
                  ext_hook=ext_hook,
                  max_str_len=max_str_len,
                  max_bin_len=max_bin_len,
                  max_array_len=max_array_len,
                  max_map_len=max_map_len,
                  max_ext_len=max_ext_len)
        if _post_100:
            kw['strict_map_key'] = strict_map_key
        super().__init__(**kw)

    def unpack(self):
        try:
            return super().unpack()
        except OutOfData:
            raise
        except Exception as e:
            raise UnpackException(e)

    def __next__(self):
        try:
            return super().__next__()
        except StopIteration:
            raise
        except Exception as e:
            raise UnpackException(e)

    next = __next__


def unpackb(packed, *, raw=True, unicode_errors=None,
            strict_map_key=False,
            max_str_len=2147483647,  # 2**32-1
            max_bin_len=2147483647,
            max_array_len=2147483647,
            max_map_len=2147483647,
            max_ext_len=2147483647,
            **kwargs):
    assert unicode_errors is None
    try:
        kw = dict(raw=raw, unicode_errors=unicode_errors,
                  max_str_len=max_str_len,
                  max_bin_len=max_bin_len,
                  max_array_len=max_array_len,
                  max_map_len=max_map_len,
                  max_ext_len=max_ext_len)
        kw.update(kwargs)
        if _post_100:
            kw['strict_map_key'] = strict_map_key
        return mp_unpackb(packed, **kw)
    except Exception as e:
        raise UnpackException(e)


def unpack(stream, *, raw=True, unicode_errors=None,
           strict_map_key=False,
           max_str_len=2147483647,  # 2**32-1
           max_bin_len=2147483647,
           max_array_len=2147483647,
           max_map_len=2147483647,
           max_ext_len=2147483647,
           **kwargs):
    assert unicode_errors is None
    try:
        kw = dict(raw=raw, unicode_errors=unicode_errors,
                  max_str_len=max_str_len,
                  max_bin_len=max_bin_len,
                  max_array_len=max_array_len,
                  max_map_len=max_map_len,
                  max_ext_len=max_ext_len)
        kw.update(kwargs)
        if _post_100:
            kw['strict_map_key'] = strict_map_key
        return mp_unpack(stream, **kw)
    except Exception as e:
        raise UnpackException(e)


# msgpacking related utilities -----------------------------------------------

def is_slow_msgpack():
    import msgpack
    import msgpack.fallback
    return msgpack.Packer is msgpack.fallback.Packer


def is_supported_msgpack():
    # DO NOT CHANGE OR REMOVE! See also requirements and comments in setup.py.
    import msgpack
    return (0, 5, 6) <= msgpack.version <= (1, 0, 7) and \
           msgpack.version not in [(1, 0, 1), ]  # < add bad releases here to deny list


def get_limited_unpacker(kind):
    """return a limited Unpacker because we should not trust msgpack data received from remote"""
    args = dict(use_list=False,  # return tuples, not lists
                max_bin_len=0,  # not used
                max_ext_len=0,  # not used
                max_buffer_size=3 * max(BUFSIZE, MAX_OBJECT_SIZE),
                max_str_len=MAX_OBJECT_SIZE,  # a chunk or other repo object
                )
    if kind == 'server':
        args.update(dict(max_array_len=100,  # misc. cmd tuples
                         max_map_len=100,  # misc. cmd dicts
                         ))
    elif kind == 'client':
        args.update(dict(max_array_len=LIST_SCAN_LIMIT,  # result list from repo.list() / .scan()
                         max_map_len=100,  # misc. result dicts
                         ))
    elif kind == 'manifest':
        args.update(dict(use_list=True,  # default value
                         max_array_len=100,  # ITEM_KEYS ~= 22
                         max_map_len=MAX_ARCHIVES,  # list of archives
                         max_str_len=255,  # archive name
                         object_hook=StableDict,
                         ))
    elif kind == 'archive':
        args.update(dict(use_list=True,  # default value
                         max_map_len=100,  # ARCHIVE_KEYS ~= 20
                         max_str_len=10000,  # comment
                         object_hook=StableDict,
                         ))
    elif kind == 'key':
        args.update(dict(use_list=True,  # default value
                         max_array_len=0,  # not used
                         max_map_len=10,  # EncryptedKey dict
                         max_str_len=4000,  # inner key data
                         object_hook=StableDict,
                         ))
    else:
        raise ValueError('kind must be "server", "client", "manifest" or "key"')
    return Unpacker(**args)


def bigint_to_int(mtime):
    """Convert bytearray to int
    """
    if isinstance(mtime, bytes):
        return int.from_bytes(mtime, 'little', signed=True)
    return mtime


def int_to_bigint(value):
    """Convert integers larger than 64 bits to bytearray

    Smaller integers are left alone
    """
    if value.bit_length() > 63:
        return value.to_bytes((value.bit_length() + 9) // 8, 'little', signed=True)
    return value