summaryrefslogtreecommitdiff
path: root/tools/proxyclient/m1n1/asm.py
blob: ef1f3af4046982314596a210d22ad02b482fbc19 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# SPDX-License-Identifier: MIT
import os, tempfile, shutil, subprocess

__all__ = ["AsmException", "ARMAsm"]

uname = os.uname()

if uname.sysname == "Darwin":
    DEFAULT_ARCH = "aarch64-linux-gnu-"
    if uname.machine == "arm64":
        TOOLCHAIN = "/opt/homebrew/opt/llvm/bin/"
    else:
        TOOLCHAIN = "/usr/local/opt/llvm/bin/"
    USE_CLANG = "1"
else:
    if uname.machine == "aarch64":
        DEFAULT_ARCH = ""
    else:
        DEFAULT_ARCH = "aarch64-linux-gnu-"
    USE_CLANG = "0"
    TOOLCHAIN = ""

use_clang = os.environ.get("USE_CLANG", USE_CLANG).strip() == "1"
toolchain = os.environ.get("TOOLCHAIN", TOOLCHAIN)

if use_clang:
    CC = toolchain + "clang --target=%ARCH"
    LD = toolchain + "ld.lld"
    OBJCOPY = toolchain + "llvm-objcopy"
    OBJDUMP = toolchain + "llvm-objdump"
    NM = toolchain + "llvm-nm"
else:
    CC = toolchain + "%ARCHgcc"
    LD = toolchain + "%ARCHld"
    OBJCOPY = toolchain + "%ARCHobjcopy"
    OBJDUMP = toolchain + "%ARCHobjdump"
    NM = toolchain + "%ARCHnm"

class AsmException(Exception):
    pass

class BaseAsm(object):
    def __init__(self, source, addr = 0):
        self.source = source
        self._tmp = tempfile.mkdtemp() + os.sep
        self.addr = addr
        self.compile(source)

    def _call(self, program, args):
        subprocess.check_call(program.replace("%ARCH", self.ARCH) + " " + args, shell=True)

    def _get(self, program, args):
        return subprocess.check_output(program.replace("%ARCH", self.ARCH) + " " + args, shell=True).decode("ascii")

    def compile(self, source):
        self.sfile = self._tmp + "b.S"
        with open(self.sfile, "w") as fd:
            fd.write(self.HEADER + "\n")
            fd.write(source + "\n")
            fd.write(self.FOOTER + "\n")

        self.ofile = self._tmp + "b.o"
        self.elffile = self._tmp + "b.elf"
        self.bfile = self._tmp + "b.b"
        self.nfile = self._tmp + "b.n"

        self._call(CC, f"{self.CFLAGS} -c -o {self.ofile} {self.sfile}")
        self._call(LD, f"{self.LDFLAGS} --Ttext={self.addr:#x} -o {self.elffile} {self.ofile}")
        self._call(OBJCOPY, f"-j.text -O binary {self.elffile} {self.bfile}")
        self._call(NM, f"{self.elffile} > {self.nfile}")

        with open(self.bfile, "rb") as fd:
            self.data = fd.read()

        with open(self.nfile) as fd:
            for line in fd:
                line = line.replace("\n", "")
                addr, type, name = line.split()
                addr = int(addr, 16)
                setattr(self, name, addr)
        self.start = self._start
        self.len = len(self.data)
        self.end = self.start + self.len

    def objdump(self):
        self._call(OBJDUMP, f"-rd {self.elffile}")

    def disassemble(self):
        output = self._get(OBJDUMP, f"-zd {self.elffile}")

        for line in output.split("\n"):
            if not line or line.startswith("/"):
                continue
            sl = line.split()
            if not sl or sl[0][-1] != ":":
                continue
            yield line

    def __del__(self):
        if self._tmp:
            shutil.rmtree(self._tmp)
            self._tmp = None

class ARMAsm(BaseAsm):
    ARCH = os.path.join(os.environ.get("ARCH", DEFAULT_ARCH))
    CFLAGS = "-pipe -Wall -march=armv8.4-a"
    LDFLAGS = "-maarch64elf"
    HEADER = """
    .text
    .globl _start
_start:
    """
    FOOTER = """
    .pool
    """

if __name__ == "__main__":
    import sys
    code = """
    ldr x0, =0xDEADBEEF
    b test
    mrs x0, spsel
    svc 1
    %s
test:
    b test
    ret
""" % (" ".join(sys.argv[1:]))
    c = ARMAsm(code, 0x1238)
    c.objdump()
    assert c.start == 0x1238
    if not sys.argv[1:]:
        assert c.test == 0x1248