import hashlib import json import re import signal import sys import copy from typing import Any, Callable, Optional, Union import struct from client import keychain from client.client import EthAppClient, EIP712FieldType from ragger.firmware import Firmware # global variables app_client: EthAppClient = None filtering_paths: dict = {} current_path: list[str] = list() sig_ctx: dict[str, Any] = {} def default_handler(): raise RuntimeError("Uninitialized handler") autonext_handler: Callable = default_handler is_golden_run: bool # From a string typename, extract the type and all the array depth # Input = "uint8[2][][4]" | "bool" # Output = ('uint8', [2, None, 4]) | ('bool', []) def get_array_levels(typename): array_lvls = list() regex = re.compile(r"(.*)\[([0-9]*)\]$") while True: result = regex.search(typename) if not result: break typename = result.group(1) level_size = result.group(2) if len(level_size) == 0: level_size = None else: level_size = int(level_size) array_lvls.insert(0, level_size) return (typename, array_lvls) # From a string typename, extract the type and its size # Input = "uint64" | "string" # Output = ('uint', 64) | ('string', None) def get_typesize(typename): regex = re.compile(r"^(\w+?)(\d*)$") result = regex.search(typename) typename = result.group(1) typesize = result.group(2) if len(typesize) == 0: typesize = None else: typesize = int(typesize) return (typename, typesize) def parse_int(typesize): return (EIP712FieldType.INT, int(typesize / 8)) def parse_uint(typesize): return (EIP712FieldType.UINT, int(typesize / 8)) def parse_address(typesize): return (EIP712FieldType.ADDRESS, None) def parse_bool(typesize): return (EIP712FieldType.BOOL, None) def parse_string(typesize): return (EIP712FieldType.STRING, None) def parse_bytes(typesize): if typesize is not None: return (EIP712FieldType.FIX_BYTES, typesize) return (EIP712FieldType.DYN_BYTES, None) # set functions for each type parsing_type_functions = {} parsing_type_functions["int"] = parse_int parsing_type_functions["uint"] = parse_uint parsing_type_functions["address"] = parse_address parsing_type_functions["bool"] = parse_bool parsing_type_functions["string"] = parse_string parsing_type_functions["bytes"] = parse_bytes def send_struct_def_field(typename, keyname): type_enum = None (typename, array_lvls) = get_array_levels(typename) (typename, typesize) = get_typesize(typename) if typename in parsing_type_functions.keys(): (type_enum, typesize) = parsing_type_functions[typename](typesize) else: type_enum = EIP712FieldType.CUSTOM typesize = None with app_client.eip712_send_struct_def_struct_field(type_enum, typename, typesize, array_lvls, keyname): pass return (typename, type_enum, typesize, array_lvls) def encode_integer(value: Union[str, int], typesize: int) -> bytes: # Some are already represented as integers in the JSON, but most as strings if isinstance(value, str): value = int(value, 0) if value == 0: data = b'\x00' else: # biggest uint type accepted by struct.pack uint64_mask = 0xffffffffffffffff data = struct.pack(">QQQQ", (value >> 192) & uint64_mask, (value >> 128) & uint64_mask, (value >> 64) & uint64_mask, value & uint64_mask) data = data[len(data) - typesize:] data = data.lstrip(b'\x00') return data def encode_int(value: str, typesize: int) -> bytes: return encode_integer(value, typesize) def encode_uint(value: str, typesize: int) -> bytes: return encode_integer(value, typesize) def encode_hex_string(value: str, size: int) -> bytes: assert value.startswith("0x") value = value[2:] if len(value) < (size * 2): value = value.rjust(size * 2, "0") assert len(value) == (size * 2) return bytes.fromhex(value) def encode_address(value: str, typesize: int) -> bytes: return encode_hex_string(value, 20) def encode_bool(value: str, typesize: int) -> bytes: return encode_integer(value, 1) def encode_string(value: str, typesize: int) -> bytes: return value.encode() def encode_bytes_fix(value: str, typesize: int) -> bytes: return encode_hex_string(value, typesize) def encode_bytes_dyn(value: str, typesize: int) -> bytes: # length of the value string # - the length of 0x (2) # / by the length of one byte in a hex string (2) return encode_hex_string(value, int((len(value) - 2) / 2)) # set functions for each type encoding_functions = {} encoding_functions[EIP712FieldType.INT] = encode_int encoding_functions[EIP712FieldType.UINT] = encode_uint encoding_functions[EIP712FieldType.ADDRESS] = encode_address encoding_functions[EIP712FieldType.BOOL] = encode_bool encoding_functions[EIP712FieldType.STRING] = encode_string encoding_functions[EIP712FieldType.FIX_BYTES] = encode_bytes_fix encoding_functions[EIP712FieldType.DYN_BYTES] = encode_bytes_dyn def send_struct_impl_field(value, field): # Something wrong happened if this triggers if isinstance(value, list) or (field["enum"] == EIP712FieldType.CUSTOM): breakpoint() data = encoding_functions[field["enum"]](value, field["typesize"]) if filtering_paths: path = ".".join(current_path) if path in filtering_paths.keys(): if filtering_paths[path]["type"] == "amount_join_token": send_filtering_amount_join_token(filtering_paths[path]["token"]) elif filtering_paths[path]["type"] == "amount_join_value": if "token" in filtering_paths[path].keys(): token = filtering_paths[path]["token"] else: # Permit (ERC-2612) token = 0xff send_filtering_amount_join_value(token, filtering_paths[path]["name"]) elif filtering_paths[path]["type"] == "datetime": send_filtering_datetime(filtering_paths[path]["name"]) elif filtering_paths[path]["type"] == "raw": send_filtering_raw(filtering_paths[path]["name"]) else: assert False with app_client.eip712_send_struct_impl_struct_field(data): enable_autonext() disable_autonext() def evaluate_field(structs, data, field, lvls_left, new_level=True): array_lvls = field["array_lvls"] if new_level: current_path.append(field["name"]) if len(array_lvls) > 0 and lvls_left > 0: with app_client.eip712_send_struct_impl_array(len(data)): pass idx = 0 for subdata in data: current_path.append("[]") if not evaluate_field(structs, subdata, field, lvls_left - 1, False): return False current_path.pop() idx += 1 if array_lvls[lvls_left - 1] is not None: if array_lvls[lvls_left - 1] != idx: print("Mismatch in array size! Got %d, expected %d\n" % (idx, array_lvls[lvls_left - 1]), file=sys.stderr) return False else: if field["enum"] == EIP712FieldType.CUSTOM: if not send_struct_impl(structs, data, field["type"]): return False else: send_struct_impl_field(data, field) if new_level: current_path.pop() return True def send_struct_impl(structs, data, structname): # Check if it is a struct we don't known if structname not in structs.keys(): return False struct = structs[structname] for f in struct: if not evaluate_field(structs, data[f["name"]], f, len(f["array_lvls"])): return False return True def start_signature_payload(ctx: dict, magic: int) -> bytearray: to_sign = bytearray() # magic number so that signature for one type of filter can't possibly be # valid for another, defined in APDU specs to_sign.append(magic) to_sign += ctx["chainid"] to_sign += ctx["caddr"] to_sign += ctx["schema_hash"] return to_sign # ledgerjs doesn't actually sign anything, and instead uses already pre-computed signatures def send_filtering_message_info(display_name: str, filters_count: int): global sig_ctx to_sign = start_signature_payload(sig_ctx, 183) to_sign.append(filters_count) to_sign += display_name.encode() sig = keychain.sign_data(keychain.Key.CAL, to_sign) with app_client.eip712_filtering_message_info(display_name, filters_count, sig): enable_autonext() disable_autonext() def send_filtering_amount_join_token(token_idx: int): global sig_ctx path_str = ".".join(current_path) to_sign = start_signature_payload(sig_ctx, 11) to_sign += path_str.encode() to_sign.append(token_idx) sig = keychain.sign_data(keychain.Key.CAL, to_sign) with app_client.eip712_filtering_amount_join_token(token_idx, sig): pass def send_filtering_amount_join_value(token_idx: int, display_name: str): global sig_ctx path_str = ".".join(current_path) to_sign = start_signature_payload(sig_ctx, 22) to_sign += path_str.encode() to_sign += display_name.encode() to_sign.append(token_idx) sig = keychain.sign_data(keychain.Key.CAL, to_sign) with app_client.eip712_filtering_amount_join_value(token_idx, display_name, sig): pass def send_filtering_datetime(display_name: str): global sig_ctx path_str = ".".join(current_path) to_sign = start_signature_payload(sig_ctx, 33) to_sign += path_str.encode() to_sign += display_name.encode() sig = keychain.sign_data(keychain.Key.CAL, to_sign) with app_client.eip712_filtering_datetime(display_name, sig): pass # ledgerjs doesn't actually sign anything, and instead uses already pre-computed signatures def send_filtering_raw(display_name): global sig_ctx path_str = ".".join(current_path) to_sign = start_signature_payload(sig_ctx, 72) to_sign += path_str.encode() to_sign += display_name.encode() sig = keychain.sign_data(keychain.Key.CAL, to_sign) with app_client.eip712_filtering_raw(display_name, sig): pass def prepare_filtering(filtr_data, message): global filtering_paths if "fields" in filtr_data: filtering_paths = filtr_data["fields"] else: filtering_paths = {} if "tokens" in filtr_data: for token in filtr_data["tokens"]: app_client.provide_token_metadata(token["ticker"], bytes.fromhex(token["addr"][2:]), token["decimals"], token["chain_id"]) def handle_optional_domain_values(domain): if "chainId" not in domain.keys(): domain["chainId"] = 0 if "verifyingContract" not in domain.keys(): domain["verifyingContract"] = "0x0000000000000000000000000000000000000000" def init_signature_context(types, domain): global sig_ctx handle_optional_domain_values(domain) caddr = domain["verifyingContract"] if caddr.startswith("0x"): caddr = caddr[2:] sig_ctx["caddr"] = bytearray.fromhex(caddr) chainid = domain["chainId"] sig_ctx["chainid"] = bytearray() for i in range(8): sig_ctx["chainid"].append(chainid & (0xff << (i * 8))) sig_ctx["chainid"].reverse() schema_str = json.dumps(types).replace(" ", "") schema_hash = hashlib.sha224(schema_str.encode()) sig_ctx["schema_hash"] = bytearray.fromhex(schema_hash.hexdigest()) def next_timeout(_signum: int, _frame): autonext_handler() def enable_autonext(): if app_client._client.firmware in (Firmware.STAX, Firmware.FLEX): delay = 1/3 else: delay = 1/4 # golden run has to be slower to make sure we take good snapshots # and not processing/loading screens if is_golden_run: delay *= 3 signal.setitimer(signal.ITIMER_REAL, delay, delay) def disable_autonext(): signal.setitimer(signal.ITIMER_REAL, 0, 0) def process_data(aclient: EthAppClient, data_json: dict, filters: Optional[dict] = None, autonext: Optional[Callable] = None, golden_run: bool = False) -> bool: global sig_ctx global app_client global autonext_handler global is_golden_run # deepcopy because this function modifies the dict data_json = copy.deepcopy(data_json) app_client = aclient domain_typename = "EIP712Domain" message_typename = data_json["primaryType"] types = data_json["types"] domain = data_json["domain"] message = data_json["message"] if autonext: autonext_handler = autonext signal.signal(signal.SIGALRM, next_timeout) is_golden_run = golden_run if filters: init_signature_context(types, domain) # send types definition for key in types.keys(): with app_client.eip712_send_struct_def_struct_name(key): pass for f in types[key]: (f["type"], f["enum"], f["typesize"], f["array_lvls"]) = \ send_struct_def_field(f["type"], f["name"]) if filters: with app_client.eip712_filtering_activate(): pass prepare_filtering(filters, message) # send domain implementation with app_client.eip712_send_struct_impl_root_struct(domain_typename): enable_autonext() disable_autonext() if not send_struct_impl(types, domain, domain_typename): return False if filters: if filters and "name" in filters: send_filtering_message_info(filters["name"], len(filtering_paths)) else: send_filtering_message_info(domain["name"], len(filtering_paths)) # send message implementation with app_client.eip712_send_struct_impl_root_struct(message_typename): enable_autonext() disable_autonext() if not send_struct_impl(types, message, message_typename): return False return True