summaryrefslogtreecommitdiff
path: root/tools/proxyclient/m1n1/fw/common.py
blob: 479e2dfc8d7648f4ff31a8d5c284c4b87316d423 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# SPDX-License-Identifier: MIT

from dataclasses import dataclass
from enum import IntEnum
from m1n1.utils import *
from construct import *

uint8_t = Int8ul
int16_t = Int16sl
uint16_t = Int16ul
int32_t = Int32sl
uint32_t = Int32ul
int64_t = Int64sl
uint64_t = Int64ul

uint = uint32_t
int_ = int32_t
ulong = uint64_t
long_ = int64_t

def Bool(c):
    return ExprAdapter(c, lambda d, ctx: bool(d & 1), lambda d, ctx: int(d))

def SizedArray(count, svar, subcon):
    return Padded(subcon.sizeof() * count, Array(lambda ctx: min(count, ctx.get(svar, ctx._.get(svar))), subcon))

def SizedBytes(count, svar):
    return Lazy(Padded(count, Bytes(lambda ctx: ctx.get(svar) or ctx._.get(svar))))

def UnkBytes(s):
    return Default(HexDump(Bytes(s)), b"\x00" * s)

bool_ = Bool(Int8ul)

class OSObject(Construct):
    TYPE = None

    def _parse(self, stream, context, path, recurse=False):
        tag = stream.read(1).decode("ascii")
        if not recurse and self.TYPE is not None and self.TYPE != tag:
            raise Exception("Object type mismatch")

        if tag == "d":
            count = Int32ul.parse_stream(stream)
            d = {}
            for i in range(count):
                k = self._parse(stream, context, path, True)
                v = self._parse(stream, context, path, True)
                d[k] = v
            return d
        elif tag == "n":
            return Int64ul.parse_stream(stream)
        elif tag == "s":
            length = Int32ul.parse_stream(stream)
            s = stream.read(length).decode("utf-8")
            assert stream.read(1) == b'\0'
            return s
        else:
            raise Exception(f"Unknown object tag {tag!r}")

    def _build(self, obj, stream, context, path):
        assert False

    def _sizeof(self, context, path):
        return None

class OSDictionary(OSObject):
    TYPE = 'd'

class OSSerialize(Construct):
    def _parse(self, stream, context, path, recurse=False):
        hdr = Int32ul.parse_stream(stream)
        if hdr != 0xd3:
            raise Exception("Bad header")

        obj, last = self.parse_obj(stream)
        assert last
        return obj

    def parse_obj(self, stream, level=0):
        # align to 32 bits
        pos = stream.tell()
        if pos & 3:
            stream.read(4 - (pos & 3))

        tag = Int32ul.parse_stream(stream)

        last = bool(tag & 0x80000000)
        otype = (tag >> 24) & 0x1f
        size = tag & 0xffffff

        #print(f"{'  '*level} @{stream.tell():#x} {otype} {last} {size}")

        if otype == 1:
            d = {}
            for i in range(size):
                k, l = self.parse_obj(stream, level + 1)
                assert not l
                v, l = self.parse_obj(stream, level + 1)
                assert l == (i == size - 1)
                d[k] = v
        elif otype == 2:
            d = []
            for i in range(size):
                v, l = self.parse_obj(stream, level + 1)
                assert l == (i == size - 1)
                d.append(v)
        elif otype == 4:
            d = Int64ul.parse_stream(stream)
        elif otype == 9:
            d = stream.read(size).decode("utf-8")
        elif otype == 10:
            d = stream.read(size)
        elif otype == 11:
            d = bool(size)
        else:
            raise Exception(f"Unknown tag {otype}")

        #print(f"{'  '*level}  => {d}")
        return d, last

    def build_obj(self, obj, stream, last=True, level=0):
        tag = 0
        if last:
            tag |= 0x80000000

        if isinstance(obj, dict):
            tag |= (1 << 24) | len(obj)
            Int32ul.build_stream(tag, stream)
            for i, (k, v) in enumerate(obj.items()):
                self.build_obj(k, stream, False, level + 1)
                self.build_obj(v, stream, i == len(obj) - 1, level + 1)
        elif isinstance(obj, list):
            tag |= (2 << 24) | len(obj)
            Int32ul.build_stream(tag, stream)
            for i, v in enumerate(obj):
                self.build_obj(v, stream, i == len(obj) - 1, level + 1)
        elif isinstance(obj, int):
            tag |= (4 << 24) | 64
            Int32ul.build_stream(tag, stream)
            Int64ul.build_stream(obj, stream)
        elif isinstance(obj, str):
            obj = obj.encode("utf-8")
            tag |= (9 << 24) | len(obj)
            Int32ul.build_stream(tag, stream)
            stream.write(obj)
        elif isinstance(obj, bytes):
            tag |= (10 << 24) | len(obj)
            Int32ul.build_stream(tag, stream)
            stream.write(obj)
        elif isinstance(obj, bool):
            tag |= (11 << 24) | int(obj)
            Int32ul.build_stream(tag, stream)
        else:
            raise Exception(f"Cannot encode {obj!r}")

        pos = stream.tell()
        if pos & 3:
            stream.write(bytes(4 - (pos & 3)))

    def _build(self, obj, stream, context, path):
        Int32ul.build_stream(0xd3, stream)
        self.build_obj(obj, stream)

    def _sizeof(self, context, path):
        return None

def string(size):
    return Padded(size, CString("utf8"))