From 4432e43fa02f59ea8a287f961681f5ab564a9306 Mon Sep 17 00:00:00 2001 From: Luke Craig Date: Thu, 28 May 2026 23:15:38 -0400 Subject: [PATCH] Add PyGhidra snapshot loader --- README.md | 17 ++ src/dwarffi/dffi.py | 34 ++- src/dwarffi/ghidra.py | 481 +++++++++++++++++++++++++++++ tests/test_pyghidra_loader.py | 548 ++++++++++++++++++++++++++++++++++ 4 files changed, 1079 insertions(+), 1 deletion(-) create mode 100644 src/dwarffi/ghidra.py create mode 100644 tests/test_pyghidra_loader.py diff --git a/README.md b/README.md index b4202a6..90af5f5 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,23 @@ analyzeHeadless /tmp/ghidra-project DffiExport \ The exporter writes `base_types`, `user_types`, `enums`, `typedefs`, `symbols`, and `functions` in the same ISF shape consumed by `DFFI`. +Inside PyGhidra, you can snapshot the currently open program directly without +writing an intermediate JSON file: + +```python +from dwarffi import DFFI + +ffi = DFFI.from_ghidra(currentProgram) +print(ffi.sizeof("Packet")) +print(ffi.get_symbol("global_counter")) +``` + +If `currentProgram` is in the caller's PyGhidra or GhidraScript scope, the +argument can be omitted. The snapshot captures types, symbols, and functions at +call time; call `DFFI.from_ghidra(...)` again after editing Ghidra data types. +Use `types_only=True`, `include_symbols=False`, or `include_functions=False` to +limit exported sections. PyGhidra is not an install dependency for `dwarffi`. + Import an ISF into the active Ghidra program: ```bash diff --git a/src/dwarffi/dffi.py b/src/dwarffi/dffi.py index 028b0f0..c50be6b 100644 --- a/src/dwarffi/dffi.py +++ b/src/dwarffi/dffi.py @@ -114,6 +114,38 @@ def __init__( self.load_isf(item) else: self.load_isf(isf_input) + + @classmethod + def from_ghidra( + cls, + program: Any = None, + *, + include_symbols: bool = True, + include_functions: bool = True, + types_only: bool = False, + source_name: Optional[str] = None, + ) -> "DFFI": + """ + Create a snapshot-backed DFFI instance from a PyGhidra/Ghidra Program. + + If ``program`` is omitted, this looks for ``currentProgram`` in the + caller's PyGhidra or GhidraScript context. Type, symbol, and function + data are captured at call time; later Ghidra edits require calling this + method again. + """ + from .ghidra import current_program_from_context, program_to_isf + + if program is None: + program = current_program_from_context() + return cls( + program_to_isf( + program, + include_symbols=include_symbols, + include_functions=include_functions, + types_only=types_only, + source_name=source_name, + ) + ) def _add_vtypejson(self, source: str, vtype_obj: VtypeJson) -> None: """Internal helper to add a VtypeJson instance to the engine.""" @@ -1456,4 +1488,4 @@ def load_elf_bytes( ) finally: if os.path.exists(tmp_elf_path): - os.remove(tmp_elf_path) \ No newline at end of file + os.remove(tmp_elf_path) diff --git a/src/dwarffi/ghidra.py b/src/dwarffi/ghidra.py new file mode 100644 index 0000000..bb24719 --- /dev/null +++ b/src/dwarffi/ghidra.py @@ -0,0 +1,481 @@ +"""Snapshot Ghidra programs into DFFI's ISF dictionary format. + +This module intentionally avoids importing Ghidra or PyGhidra at import time. +The converter uses the small subset of the Ghidra Java API that is also easy to +fake in unit tests, so regular dwarffi usage does not gain a Ghidra dependency. +""" + +from __future__ import annotations + +import inspect +from typing import Any, Dict, Iterable, List, Optional, Set + +FORMAT_VERSION = "6.2.0" +TOOL_NAME = "pyghidra2isf" +TOOL_VERSION = "0.1.0" + + +def current_program_from_context() -> Any: + """Return a PyGhidra/GhidraScript ``currentProgram`` from the call stack.""" + for frame_info in inspect.stack()[1:]: + frame = frame_info.frame + if "currentProgram" in frame.f_locals: + return frame.f_locals["currentProgram"] + if "currentProgram" in frame.f_globals: + return frame.f_globals["currentProgram"] + raise RuntimeError( + "No active Ghidra program was found. Call DFFI.from_ghidra(currentProgram) " + "from your PyGhidra or GhidraScript context." + ) + + +def program_to_isf( + program: Any, + *, + include_symbols: bool = True, + include_functions: bool = True, + types_only: bool = False, + source_name: Optional[str] = None, +) -> Dict[str, Any]: + """Build an ISF dictionary snapshot from a Ghidra ``Program``.""" + if program is None: + raise ValueError("program must be a Ghidra Program object, not None") + exporter = _ProgramExporter(program, source_name=source_name) + return exporter.export( + include_symbols=include_symbols and not types_only, + include_functions=include_functions and not types_only, + ) + + +class _ProgramExporter: + def __init__(self, program: Any, *, source_name: Optional[str] = None) -> None: + self.program = program + self.source_name = source_name or str(_call(program, "getName", default="ghidra_program")) + language = _call(program, "getLanguage", default=None) + is_big = bool(_call(language, "isBigEndian", default=False)) if language is not None else False + self.endian = "big" if is_big else "little" + self.pointer_size = max(int(_call(program, "getDefaultPointerSize", default=8) or 8), 1) + + self.base_types: Dict[str, Any] = {} + self.user_types: Dict[str, Any] = {} + self.enums: Dict[str, Any] = {} + self.typedefs: Dict[str, Any] = {} + self.symbols: Dict[str, Any] = {} + self.functions: Dict[str, Any] = {} + + def export(self, *, include_symbols: bool, include_functions: bool) -> Dict[str, Any]: + self._export_types() + if include_symbols: + self._export_symbols() + if include_functions: + self._export_functions() + return { + "metadata": self._metadata(), + "base_types": dict(sorted(self.base_types.items())), + "user_types": dict(sorted(self.user_types.items())), + "enums": dict(sorted(self.enums.items())), + "symbols": dict(sorted(self.symbols.items())), + "functions": dict(sorted(self.functions.items())), + "typedefs": dict(sorted(self.typedefs.items())), + } + + def _metadata(self) -> Dict[str, Any]: + source = {"kind": "ghidra_program", "name": self.source_name} + return { + "producer": { + "name": TOOL_NAME, + "version": TOOL_VERSION, + "ghidra_version": _ghidra_version(), + }, + "format": FORMAT_VERSION, + "ghidra": {"types": [source], "symbols": [source]}, + } + + def _export_types(self) -> None: + self._ensure_void() + self._ensure_pointer() + manager = _call(self.program, "getDataTypeManager", default=None) + for data_type in _iter(_call(manager, "getAllDataTypes", default=[])): + if data_type is None or _is_default(data_type): + continue + self._export_data_type(data_type) + + def _export_data_type(self, data_type: Any) -> None: + if _is_structure(data_type): + self._export_composite(data_type, "struct") + elif _is_union(data_type): + self._export_composite(data_type, "union") + elif _is_enum(data_type): + self._export_enum(data_type) + elif _is_typedef(data_type): + self._export_typedef(data_type) + elif _is_pointer(data_type): + self._ensure_pointer() + self._type_ref(_call(data_type, "getDataType", default=None)) + elif _is_array(data_type): + self._type_ref(_call(data_type, "getDataType", default=None)) + elif _is_function_definition(data_type): + return + else: + self._ensure_base(data_type) + + def _export_composite(self, data_type: Any, kind: str) -> None: + name = _type_name(data_type) + if name in self.user_types: + return + + fields: Dict[str, Any] = {} + anonymous_count = 0 + for component in _iter(_call(data_type, "getComponents", default=[])): + field_type = _call(component, "getDataType", default=None) + if field_type is None or _is_default(field_type): + continue + field_name = _call(component, "getFieldName", default=None) + anonymous = field_name is None or str(field_name) == "" + if anonymous: + field_name = f"unnamed_field_{anonymous_count}" + anonymous_count += 1 + + field = { + "type": self._field_type_ref(component), + "offset": max(int(_call(component, "getOffset", default=0) or 0), 0), + } + if anonymous: + field["anonymous"] = True + fields[str(field_name)] = field + + self.user_types[name] = { + "size": max(int(_call(data_type, "getLength", default=0) or 0), 0), + "fields": dict(sorted(fields.items())), + "kind": kind, + } + + def _field_type_ref(self, component: Any) -> Dict[str, Any]: + data_type = _call(component, "getDataType", default=None) + if _is_bitfield(data_type): + return { + "kind": "bitfield", + "bit_length": _first_int(data_type, 0, "getDeclaredBitSize", "getBitSize"), + "bit_position": _first_int(component, 0, "getBitOffset"), + "type": self._type_ref(_call(data_type, "getBaseDataType", default=None)), + } + return self._type_ref(data_type) + + def _export_enum(self, enum_type: Any) -> None: + name = _type_name(enum_type) + if name in self.enums: + return + constants = { + str(const_name): int(_call(enum_type, "getValue", const_name, default=0) or 0) + for const_name in _iter(_call(enum_type, "getNames", default=[])) + } + self.enums[name] = { + "size": max(int(_call(enum_type, "getLength", default=0) or 0), 0), + "base": self._enum_base_name(enum_type, constants), + "constants": dict(sorted(constants.items())), + } + + def _enum_base_name(self, enum_type: Any, constants: Dict[str, int]) -> str: + length = max(int(_call(enum_type, "getLength", default=0) or 0), 0) + signed = any(value < 0 for value in constants.values()) + name = f"{'int' if signed else 'uint'}{length * 8}_t" + self._ensure_synthetic_base(name, length, "int", signed) + return name + + def _export_typedef(self, type_def: Any) -> None: + name = _type_name(type_def) + if name not in self.typedefs: + self.typedefs[name] = self._type_ref(_call(type_def, "getBaseDataType", default=None)) + + def _export_symbols(self) -> None: + symbol_table = _call(self.program, "getSymbolTable", default=None) + listing = _call(self.program, "getListing", default=None) + memory = _call(self.program, "getMemory", default=None) + for symbol in _iter(_call(symbol_table, "getAllSymbols", True, default=[])): + if symbol is None or bool(_call(symbol, "isExternal", default=False)): + continue + if _is_function_symbol(symbol): + continue + address = _call(symbol, "getAddress", default=None) + if address is None: + continue + if memory is not None and not bool(_call(memory, "contains", address, default=True)): + continue + + record: Dict[str, Any] = {"address": _address_offset(address)} + data = _call(listing, "getDataAt", address, default=None) if listing is not None else None + data_type = _call(data, "getDataType", default=None) + if data_type is not None: + record["type"] = self._type_ref(data_type) + self.symbols[str(_call(symbol, "getName", True, default=_call(symbol, "getName", default="")))] = record + + def _export_functions(self) -> None: + listing = _call(self.program, "getListing", default=None) + for function in _iter(_call(listing, "getFunctions", True, default=[])): + if function is None or bool(_call(function, "isExternal", default=False)): + continue + parameters: List[Dict[str, Any]] = [] + for parameter in _iter(_call(function, "getParameters", default=[])): + parameters.append( + { + "name": str(_call(parameter, "getName", default="")), + "type": self._type_ref(_call(parameter, "getDataType", default=None)), + } + ) + self.functions[ + str(_call(function, "getName", True, default=_call(function, "getName", default=""))) + ] = { + "address": _address_offset(_call(function, "getEntryPoint", default=0)), + "return_type": self._type_ref(_call(function, "getReturnType", default=None)), + "parameters": parameters, + } + + def _type_ref(self, data_type: Any) -> Dict[str, Any]: + if data_type is None or _is_void(data_type): + self._ensure_void() + return {"kind": "base", "name": "void"} + if _is_typedef(data_type): + self._export_typedef(data_type) + return {"kind": "typedef", "name": _type_name(data_type)} + if _is_pointer(data_type): + self._ensure_pointer() + return { + "kind": "pointer", + "subtype": self._type_ref(_call(data_type, "getDataType", default=None)), + } + if _is_array(data_type): + return { + "kind": "array", + "count": max(int(_call(data_type, "getNumElements", default=0) or 0), 0), + "subtype": self._type_ref(_call(data_type, "getDataType", default=None)), + } + if _is_structure(data_type): + self._export_composite(data_type, "struct") + return {"kind": "struct", "name": _type_name(data_type)} + if _is_union(data_type): + self._export_composite(data_type, "union") + return {"kind": "union", "name": _type_name(data_type)} + if _is_enum(data_type): + self._export_enum(data_type) + return {"kind": "enum", "name": _type_name(data_type)} + if _is_function_definition(data_type): + return { + "kind": "function", + "return_type": self._type_ref(_call(data_type, "getReturnType", default=None)), + "parameters": [ + { + "name": str(_call(param, "getName", default="")), + "type": self._type_ref(_call(param, "getDataType", default=None)), + } + for param in _iter(_call(data_type, "getArguments", default=[])) + ], + } + + return {"kind": "base", "name": self._ensure_base(data_type)} + + def _ensure_base(self, data_type: Any) -> str: + if data_type is None or _is_void(data_type): + self._ensure_void() + return "void" + name = _type_name(data_type) + size = max(int(_call(data_type, "getLength", default=0) or 0), 0) + if size == 0 and name != "void": + name = "opaque_0" + self._ensure_synthetic_base(name, size, _base_kind(data_type), _is_signed_base(data_type)) + return name + + def _ensure_void(self) -> None: + self._ensure_synthetic_base("void", 0, "void", False) + + def _ensure_pointer(self) -> None: + self._ensure_synthetic_base("pointer", self.pointer_size, "pointer", False) + + def _ensure_synthetic_base(self, name: str, size: int, kind: str, signed: bool) -> None: + if name not in self.base_types: + self.base_types[name] = { + "size": max(size, 0), + "signed": signed, + "kind": kind, + "endian": self.endian, + } + + +def _call(obj: Any, method: str, *args: Any, default: Any = None) -> Any: + if obj is None: + return default + attr = getattr(obj, method, None) + if attr is None: + return default + if not callable(attr): + return attr + try: + return attr(*args) + except TypeError: + if args: + try: + return attr() + except TypeError: + return default + return default + + +def _iter(value: Any) -> Iterable[Any]: + if value is None: + return () + if hasattr(value, "hasNext") and hasattr(value, "next"): + return _java_iterator(value) + return value + + +def _java_iterator(value: Any) -> Iterable[Any]: + while bool(_call(value, "hasNext", default=False)): + yield _call(value, "next", default=None) + + +def _class_names(obj: Any) -> List[str]: + names = [cls.__name__ for cls in type(obj).__mro__] + java_class = _call(obj, "getClass", default=None) + names.extend(_java_class_names(java_class)) + return names + + +def _has_class(obj: Any, name: str) -> bool: + target = name.lower() + return any( + target == cls_name.lower() + or cls_name.lower().endswith(target) + or target in cls_name.lower() + for cls_name in _class_names(obj) + ) + + +def _java_class_names(java_class: Any, seen: Optional[Set[int]] = None) -> List[str]: + if java_class is None: + return [] + if seen is None: + seen = set() + identity = id(java_class) + if identity in seen: + return [] + seen.add(identity) + + names: List[str] = [] + for method in ("getSimpleName", "getName"): + value = _call(java_class, method, default=None) + if value: + text = str(value) + names.append(text) + names.append(text.split(".")[-1]) + + for interface in _iter(_call(java_class, "getInterfaces", default=[])): + names.extend(_java_class_names(interface, seen)) + names.extend(_java_class_names(_call(java_class, "getSuperclass", default=None), seen)) + return names + + +def _is_default(data_type: Any) -> bool: + return _has_class(data_type, "DefaultDataType") or _type_name(data_type).startswith("undefined") + + +def _is_void(data_type: Any) -> bool: + return _has_class(data_type, "VoidDataType") or ( + str(_call(data_type, "getName", default="")).lower() == "void" + and int(_call(data_type, "getLength", default=0) or 0) == 0 + ) + + +def _is_typedef(data_type: Any) -> bool: + return _has_class(data_type, "TypeDef") or _has_class(data_type, "Typedef") + + +def _is_pointer(data_type: Any) -> bool: + return _has_class(data_type, "Pointer") and hasattr(data_type, "getDataType") + + +def _is_array(data_type: Any) -> bool: + return _has_class(data_type, "Array") and hasattr(data_type, "getNumElements") + + +def _is_structure(data_type: Any) -> bool: + return _has_class(data_type, "Structure") and hasattr(data_type, "getComponents") + + +def _is_union(data_type: Any) -> bool: + return _has_class(data_type, "Union") and hasattr(data_type, "getComponents") + + +def _is_enum(data_type: Any) -> bool: + return _has_class(data_type, "Enum") and hasattr(data_type, "getNames") + + +def _is_bitfield(data_type: Any) -> bool: + return _has_class(data_type, "BitFieldDataType") or ( + hasattr(data_type, "getBaseDataType") + and (hasattr(data_type, "getDeclaredBitSize") or hasattr(data_type, "getBitSize")) + ) + + +def _is_function_definition(data_type: Any) -> bool: + return _has_class(data_type, "FunctionDefinition") or ( + hasattr(data_type, "getArguments") and hasattr(data_type, "getReturnType") + ) + + +def _is_function_symbol(symbol: Any) -> bool: + symbol_type = _call(symbol, "getSymbolType", default=None) + return str(symbol_type).split(".")[-1].upper() == "FUNCTION" + + +def _type_name(data_type: Any) -> str: + name = _call(data_type, "getName", default=None) + if name is None or str(name) == "" or str(name).startswith("undefined"): + category = _call(_call(data_type, "getCategoryPath", default=None), "getPath", default="root") + path = _call(data_type, "getPathName", default=str(id(data_type))) + name = f"unnamed_{abs(hash(f'{category}:{path}')):x}" + return str(name) + + +def _base_kind(data_type: Any) -> str: + lower = _type_name(data_type).lower() + if _is_void(data_type): + return "void" + if _has_class(data_type, "BooleanDataType") or "bool" in lower: + return "bool" + if _has_class(data_type, "CharDataType") or lower == "char" or lower.endswith(" char"): + return "char" + if _has_class(data_type, "FloatDataType") or "float" in lower or "double" in lower: + return "float" + return "int" + + +def _is_signed_base(data_type: Any) -> bool: + lower = _type_name(data_type).lower() + if "unsigned" in lower or lower.startswith("u") or lower.startswith("uint") or "byte" in lower: + return False + kind = _base_kind(data_type) + return kind not in ("bool", "void") + + +def _first_int(target: Any, fallback: int, *method_names: str) -> int: + for method_name in method_names: + value = _call(target, method_name, default=None) + if isinstance(value, int) and value >= 0: + return value + return fallback + + +def _address_offset(address: Any) -> int: + offset = _call(address, "getOffset", default=None) + return int(address if offset is None else offset) + + +def _ghidra_version() -> str: + try: + import jpype # type: ignore[import-not-found] + + if jpype.isJVMStarted(): + system = jpype.JClass("java.lang.System") + return str(system.getProperty("application.version", "unknown")) + except Exception: + pass + return "unknown" diff --git a/tests/test_pyghidra_loader.py b/tests/test_pyghidra_loader.py new file mode 100644 index 0000000..a8aa805 --- /dev/null +++ b/tests/test_pyghidra_loader.py @@ -0,0 +1,548 @@ +from __future__ import annotations + +import os +import shutil +import subprocess +import tempfile +from pathlib import Path + +import pytest + +from dwarffi import DFFI +from dwarffi.ghidra import current_program_from_context, program_to_isf + + +class Language: + def __init__(self, *, big_endian: bool = False) -> None: + self._big_endian = big_endian + + def isBigEndian(self) -> bool: + return self._big_endian + + +class BaseType: + def __init__(self, name: str, length: int) -> None: + self._name = name + self._length = length + + def getName(self) -> str: + return self._name + + def getLength(self) -> int: + return self._length + + +class CharDataType(BaseType): + pass + + +class FloatDataType(BaseType): + pass + + +class TypeDef: + def __init__(self, name: str, base: object) -> None: + self._name = name + self._base = base + + def getName(self) -> str: + return self._name + + def getBaseDataType(self) -> object: + return self._base + + +class Pointer: + def __init__(self, subtype: object) -> None: + self._subtype = subtype + + def getName(self) -> str: + return f"{self._subtype.getName()} *" + + def getLength(self) -> int: + return 8 + + def getDataType(self) -> object: + return self._subtype + + +class Array: + def __init__(self, subtype: object, count: int) -> None: + self._subtype = subtype + self._count = count + + def getName(self) -> str: + return f"{self._subtype.getName()}[{self._count}]" + + def getLength(self) -> int: + return self._subtype.getLength() * self._count + + def getDataType(self) -> object: + return self._subtype + + def getNumElements(self) -> int: + return self._count + + +class BitFieldDataType: + def __init__(self, base: object, bit_size: int) -> None: + self._base = base + self._bit_size = bit_size + + def getName(self) -> str: + return f"{self._base.getName()}:{self._bit_size}" + + def getBaseDataType(self) -> object: + return self._base + + def getDeclaredBitSize(self) -> int: + return self._bit_size + + +class Component: + def __init__( + self, + field_name: str | None, + data_type: object, + offset: int, + *, + bit_offset: int = 0, + ) -> None: + self._field_name = field_name + self._data_type = data_type + self._offset = offset + self._bit_offset = bit_offset + + def getFieldName(self) -> str | None: + return self._field_name + + def getDataType(self) -> object: + return self._data_type + + def getOffset(self) -> int: + return self._offset + + def getBitOffset(self) -> int: + return self._bit_offset + + +class Structure: + def __init__(self, name: str, length: int, components: list[Component]) -> None: + self._name = name + self._length = length + self._components = components + + def getName(self) -> str: + return self._name + + def getLength(self) -> int: + return self._length + + def getComponents(self) -> list[Component]: + return self._components + + +class Union: + def __init__(self, name: str, length: int, components: list[Component]) -> None: + self._name = name + self._length = length + self._components = components + + def getName(self) -> str: + return self._name + + def getLength(self) -> int: + return self._length + + def getComponents(self) -> list[Component]: + return self._components + + +class Enum: + def __init__(self, name: str, length: int, constants: dict[str, int]) -> None: + self._name = name + self._length = length + self._constants = constants + + def getName(self) -> str: + return self._name + + def getLength(self) -> int: + return self._length + + def getNames(self) -> list[str]: + return list(self._constants) + + def getValue(self, name: str) -> int: + return self._constants[name] + + +class ParameterDefinition: + def __init__(self, name: str, data_type: object) -> None: + self._name = name + self._data_type = data_type + + def getName(self) -> str: + return self._name + + def getDataType(self) -> object: + return self._data_type + + +class FunctionDefinition: + def __init__(self, return_type: object, args: list[ParameterDefinition]) -> None: + self._return_type = return_type + self._args = args + + def getName(self) -> str: + return "callback" + + def getReturnType(self) -> object: + return self._return_type + + def getArguments(self) -> list[ParameterDefinition]: + return self._args + + +class Address: + def __init__(self, offset: int) -> None: + self._offset = offset + + def getOffset(self) -> int: + return self._offset + + +class Data: + def __init__(self, data_type: object) -> None: + self._data_type = data_type + + def getDataType(self) -> object: + return self._data_type + + +class Symbol: + def __init__(self, name: str, address: int, *, external: bool = False, symbol_type: str = "LABEL") -> None: + self._name = name + self._address = Address(address) + self._external = external + self._symbol_type = symbol_type + + def isExternal(self) -> bool: + return self._external + + def getAddress(self) -> Address: + return self._address + + def getSymbolType(self) -> str: + return self._symbol_type + + def getName(self, include_namespace: bool = False) -> str: + return self._name + + +class Function: + def __init__(self, name: str, address: int, return_type: object, parameters: list[ParameterDefinition]) -> None: + self._name = name + self._address = Address(address) + self._return_type = return_type + self._parameters = parameters + + def isExternal(self) -> bool: + return False + + def getName(self, include_namespace: bool = False) -> str: + return self._name + + def getEntryPoint(self) -> Address: + return self._address + + def getReturnType(self) -> object: + return self._return_type + + def getParameters(self) -> list[ParameterDefinition]: + return self._parameters + + +class DataTypeManager: + def __init__(self, data_types: list[object]) -> None: + self._data_types = data_types + + def getAllDataTypes(self) -> list[object]: + return self._data_types + + +class SymbolTable: + def __init__(self, symbols: list[Symbol]) -> None: + self._symbols = symbols + + def getAllSymbols(self, forward: bool = True) -> list[Symbol]: + return self._symbols + + +class Listing: + def __init__(self, data_by_address: dict[int, Data], functions: list[Function]) -> None: + self._data_by_address = data_by_address + self._functions = functions + + def getDataAt(self, address: Address) -> Data | None: + return self._data_by_address.get(address.getOffset()) + + def getFunctions(self, forward: bool = True) -> list[Function]: + return self._functions + + +class Memory: + def contains(self, address: Address) -> bool: + return address.getOffset() >= 0x1000 + + +class Program: + def __init__( + self, + data_types: list[object], + symbols: list[Symbol], + data_by_address: dict[int, Data], + functions: list[Function], + ) -> None: + self._data_types = data_types + self._symbols = symbols + self._data_by_address = data_by_address + self._functions = functions + + def getName(self) -> str: + return "fake_program" + + def getLanguage(self) -> Language: + return Language(big_endian=False) + + def getDefaultPointerSize(self) -> int: + return 8 + + def getDataTypeManager(self) -> DataTypeManager: + return DataTypeManager(self._data_types) + + def getSymbolTable(self) -> SymbolTable: + return SymbolTable(self._symbols) + + def getListing(self) -> Listing: + return Listing(self._data_by_address, self._functions) + + def getMemory(self) -> Memory: + return Memory() + + +def _fixture_program() -> Program: + u8 = BaseType("uint8_t", 1) + u16 = BaseType("uint16_t", 2) + u32 = BaseType("uint32_t", 4) + int_t = BaseType("int", 4) + my_u32 = TypeDef("my_u32", u32) + inner = Structure( + "Inner", + 4, + [ + Component("a", u16, 0), + Component("b", u8, 2), + ], + ) + value = Union( + "Value", + 4, + [ + Component("word", u32, 0), + Component("bytes", Array(u8, 4), 0), + ], + ) + callback = FunctionDefinition(int_t, [ParameterDefinition("x", int_t)]) + packet = Structure( + "Packet", + 24, + [ + Component("id", my_u32, 0), + Component("inner", inner, 4), + Component("value", value, 8), + Component("flags", BitFieldDataType(u8, 3), 12, bit_offset=1), + Component("next", Pointer(inner), 16), + Component("cb", Pointer(callback), 20), + ], + ) + color = Enum("Color", 4, {"RED": 1, "BLUE": 2}) + add_packet = Function( + "add_packet", + 0x1100, + int_t, + [ParameterDefinition("p", Pointer(packet)), ParameterDefinition("x", int_t)], + ) + return Program( + [u8, u16, u32, int_t, my_u32, inner, value, packet, color], + [ + Symbol("global_counter", 0x4010), + Symbol("external_data", 0x5000, external=True), + Symbol("function_label", 0x1100, symbol_type="FUNCTION"), + ], + {0x4010: Data(int_t)}, + [add_packet], + ) + + +def test_program_to_isf_exports_types_symbols_and_functions() -> None: + isf = program_to_isf(_fixture_program()) + + assert isf["metadata"]["producer"]["name"] == "pyghidra2isf" + assert isf["metadata"]["format"] == "6.2.0" + assert isf["base_types"]["pointer"]["size"] == 8 + assert isf["base_types"]["uint8_t"]["signed"] is False + assert isf["typedefs"]["my_u32"] == {"kind": "base", "name": "uint32_t"} + + packet = isf["user_types"]["Packet"] + assert packet["kind"] == "struct" + assert packet["size"] == 24 + assert packet["fields"]["id"]["type"] == {"kind": "typedef", "name": "my_u32"} + assert packet["fields"]["inner"]["type"] == {"kind": "struct", "name": "Inner"} + assert packet["fields"]["value"]["type"] == {"kind": "union", "name": "Value"} + assert packet["fields"]["flags"]["type"] == { + "kind": "bitfield", + "bit_length": 3, + "bit_position": 1, + "type": {"kind": "base", "name": "uint8_t"}, + } + assert packet["fields"]["next"]["type"] == { + "kind": "pointer", + "subtype": {"kind": "struct", "name": "Inner"}, + } + assert packet["fields"]["cb"]["type"]["subtype"]["kind"] == "function" + + assert isf["user_types"]["Value"]["fields"]["bytes"]["type"] == { + "kind": "array", + "count": 4, + "subtype": {"kind": "base", "name": "uint8_t"}, + } + assert isf["enums"]["Color"]["constants"] == {"BLUE": 2, "RED": 1} + assert isf["symbols"]["global_counter"]["address"] == 0x4010 + assert "external_data" not in isf["symbols"] + assert "function_label" not in isf["symbols"] + assert isf["functions"]["add_packet"]["parameters"][0]["type"] == { + "kind": "pointer", + "subtype": {"kind": "struct", "name": "Packet"}, + } + + +def test_from_ghidra_returns_regular_dffi_instance() -> None: + ffi = DFFI.from_ghidra(_fixture_program()) + + assert ffi.sizeof("Packet") == 24 + assert ffi.sizeof("my_u32") == 4 + assert ffi.typeof("my_u32").name == "uint32_t" + assert ffi.get_symbol("global_counter").address == 0x4010 + assert ffi.get_function("add_packet").address == 0x1100 + + +def test_types_only_omits_symbols_and_functions() -> None: + isf = program_to_isf(_fixture_program(), types_only=True) + + assert isf["user_types"]["Packet"] + assert isf["symbols"] == {} + assert isf["functions"] == {} + + +def test_current_program_from_context_finds_ghidra_script_name() -> None: + currentProgram = _fixture_program() + + assert current_program_from_context() is currentProgram + + +def test_from_ghidra_without_context_has_clear_error() -> None: + with pytest.raises(RuntimeError, match=r"DFFI\.from_ghidra\(currentProgram\)"): + DFFI.from_ghidra() + + +@pytest.mark.skipif( + os.environ.get("DFFI_GHIDRA_TEST") != "1", + reason="set DFFI_GHIDRA_TEST=1 to run the PyGhidra integration test", +) +def test_from_ghidra_with_pyghidra_fixture() -> None: + pyghidra = pytest.importorskip("pyghidra") + gcc = shutil.which("gcc") + if gcc is None: + pytest.skip("gcc is required for the PyGhidra integration fixture") + + with tempfile.TemporaryDirectory() as tmp: + tmp_path = Path(tmp) + source_path = tmp_path / "fixture.c" + binary_path = tmp_path / "fixture" + source_path.write_text( + """ + #include + + typedef uint32_t my_u32; + + struct Packet { + my_u32 id; + uint16_t tag; + uint8_t flags; + }; + + int global_counter = 7; + + int add_packet(struct Packet *p, int x) { + return (int)p->id + x + global_counter; + } + """, + encoding="utf-8", + ) + subprocess.run( + [ + gcc, + "-g", + "-O0", + "-fno-eliminate-unused-debug-types", + "-c", + str(source_path), + "-o", + str(binary_path), + ], + check=True, + text=True, + capture_output=True, + ) + + open_program = getattr(pyghidra, "open_program", None) + if open_program is None: + pytest.skip("pyghidra.open_program is unavailable") + + kwargs = {"analyze": True} + try: + context = open_program(str(binary_path), **kwargs) + except TypeError: + try: + pyghidra.start() + except Exception: + pass + try: + context = open_program(str(binary_path)) + except Exception as exc: + pytest.skip(f"could not open fixture through PyGhidra: {exc}") + except Exception as exc: + pytest.skip(f"could not open fixture through PyGhidra: {exc}") + + with context as opened: + program = opened + if hasattr(opened, "getCurrentProgram"): + program = opened.getCurrentProgram() + elif hasattr(opened, "currentProgram"): + program = opened.currentProgram + + try: + ffi = DFFI.from_ghidra(program) + except Exception as exc: + pytest.skip(f"PyGhidra fixture opened but could not be converted: {exc}") + + assert ffi.sizeof("Packet") == ffi.get_type("Packet").size + assert ffi.sizeof("my_u32") == 4 + assert ffi.get_symbol("global_counter") is not None + assert ffi.get_function("add_packet") is not None + + types_only = DFFI.from_ghidra(program, types_only=True) + assert types_only.sizeof("Packet") == ffi.sizeof("Packet") + assert types_only.symbols == {} + assert types_only.functions == {}