[clean] Linter / typing fixes
This commit is contained in:
@@ -3,7 +3,7 @@ import json
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
from typing import Callable
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from ledger_app_clients.ethereum import keychain
|
||||
from ledger_app_clients.ethereum.client import EthAppClient, EIP712FieldType
|
||||
@@ -11,11 +11,16 @@ from ledger_app_clients.ethereum.client import EthAppClient, EIP712FieldType
|
||||
|
||||
# global variables
|
||||
app_client: EthAppClient = None
|
||||
filtering_paths = None
|
||||
current_path = list()
|
||||
sig_ctx = {}
|
||||
filtering_paths: Dict = {}
|
||||
current_path: List[str] = list()
|
||||
sig_ctx: Dict[str, Any] = {}
|
||||
|
||||
autonext_handler: Callable = None
|
||||
|
||||
def default_handler():
|
||||
raise RuntimeError("Uninitialized handler")
|
||||
|
||||
|
||||
autonext_handler: Callable = default_handler
|
||||
|
||||
|
||||
# From a string typename, extract the type and all the array depth
|
||||
@@ -55,29 +60,34 @@ def get_typesize(typename):
|
||||
return (typename, typesize)
|
||||
|
||||
|
||||
|
||||
def parse_int(typesize):
|
||||
return (EIP712FieldType.INT, int(typesize / 8))
|
||||
|
||||
|
||||
def parse_uint(typesize):
|
||||
return (EIP712FieldType.UINT, int(typesize / 8))
|
||||
|
||||
|
||||
def parse_address(typesize):
|
||||
return (EIP712FieldType.ADDRESS, None)
|
||||
|
||||
|
||||
def parse_bool(typesize):
|
||||
return (EIP712FieldType.BOOL, None)
|
||||
|
||||
|
||||
def parse_string(typesize):
|
||||
return (EIP712FieldType.STRING, None)
|
||||
|
||||
|
||||
def parse_bytes(typesize):
|
||||
if typesize != None:
|
||||
if typesize is not None:
|
||||
return (EIP712FieldType.FIX_BYTES, typesize)
|
||||
return (EIP712FieldType.DYN_BYTES, None)
|
||||
|
||||
|
||||
# set functions for each type
|
||||
parsing_type_functions = {};
|
||||
parsing_type_functions = {}
|
||||
parsing_type_functions["int"] = parse_int
|
||||
parsing_type_functions["uint"] = parse_uint
|
||||
parsing_type_functions["address"] = parse_address
|
||||
@@ -86,7 +96,6 @@ parsing_type_functions["string"] = parse_string
|
||||
parsing_type_functions["bytes"] = parse_bytes
|
||||
|
||||
|
||||
|
||||
def send_struct_def_field(typename, keyname):
|
||||
type_enum = None
|
||||
|
||||
@@ -108,7 +117,6 @@ def send_struct_def_field(typename, keyname):
|
||||
return (typename, type_enum, typesize, array_lvls)
|
||||
|
||||
|
||||
|
||||
def encode_integer(value, typesize):
|
||||
data = bytearray()
|
||||
|
||||
@@ -122,9 +130,9 @@ def encode_integer(value, typesize):
|
||||
if value == 0:
|
||||
data.append(0)
|
||||
else:
|
||||
if value < 0: # negative number, send it as unsigned
|
||||
if value < 0: # negative number, send it as unsigned
|
||||
mask = 0
|
||||
for i in range(typesize): # make a mask as big as the typesize
|
||||
for i in range(typesize): # make a mask as big as the typesize
|
||||
mask = (mask << 8) | 0xff
|
||||
value &= mask
|
||||
while value > 0:
|
||||
@@ -133,42 +141,51 @@ def encode_integer(value, typesize):
|
||||
data.reverse()
|
||||
return data
|
||||
|
||||
|
||||
def encode_int(value, typesize):
|
||||
return encode_integer(value, typesize)
|
||||
|
||||
|
||||
def encode_uint(value, typesize):
|
||||
return encode_integer(value, typesize)
|
||||
|
||||
|
||||
def encode_hex_string(value, size):
|
||||
data = bytearray()
|
||||
value = value[2:] # skip 0x
|
||||
value = value[2:] # skip 0x
|
||||
byte_idx = 0
|
||||
while byte_idx < size:
|
||||
data.append(int(value[(byte_idx * 2):(byte_idx * 2 + 2)], 16))
|
||||
byte_idx += 1
|
||||
return data
|
||||
|
||||
|
||||
def encode_address(value, typesize):
|
||||
return encode_hex_string(value, 20)
|
||||
|
||||
|
||||
def encode_bool(value, typesize):
|
||||
return encode_integer(value, typesize)
|
||||
|
||||
|
||||
def encode_string(value, typesize):
|
||||
data = bytearray()
|
||||
for char in value:
|
||||
data.append(ord(char))
|
||||
return data
|
||||
|
||||
|
||||
def encode_bytes_fix(value, typesize):
|
||||
return encode_hex_string(value, typesize)
|
||||
|
||||
|
||||
def encode_bytes_dyn(value, typesize):
|
||||
# length of the value string
|
||||
# - the length of 0x (2)
|
||||
# / by the length of one byte in a hex string (2)
|
||||
return encode_hex_string(value, int((len(value) - 2) / 2))
|
||||
|
||||
|
||||
# set functions for each type
|
||||
encoding_functions = {}
|
||||
encoding_functions[EIP712FieldType.INT] = encode_int
|
||||
@@ -180,7 +197,6 @@ encoding_functions[EIP712FieldType.FIX_BYTES] = encode_bytes_fix
|
||||
encoding_functions[EIP712FieldType.DYN_BYTES] = encode_bytes_dyn
|
||||
|
||||
|
||||
|
||||
def send_struct_impl_field(value, field):
|
||||
# Something wrong happened if this triggers
|
||||
if isinstance(value, list) or (field["enum"] == EIP712FieldType.CUSTOM):
|
||||
@@ -188,7 +204,6 @@ def send_struct_impl_field(value, field):
|
||||
|
||||
data = encoding_functions[field["enum"]](value, field["typesize"])
|
||||
|
||||
|
||||
if filtering_paths:
|
||||
path = ".".join(current_path)
|
||||
if path in filtering_paths.keys():
|
||||
@@ -199,8 +214,7 @@ def send_struct_impl_field(value, field):
|
||||
disable_autonext()
|
||||
|
||||
|
||||
|
||||
def evaluate_field(structs, data, field, lvls_left, new_level = True):
|
||||
def evaluate_field(structs, data, field, lvls_left, new_level=True):
|
||||
array_lvls = field["array_lvls"]
|
||||
|
||||
if new_level:
|
||||
@@ -215,7 +229,7 @@ def evaluate_field(structs, data, field, lvls_left, new_level = True):
|
||||
return False
|
||||
current_path.pop()
|
||||
idx += 1
|
||||
if array_lvls[lvls_left - 1] != None:
|
||||
if array_lvls[lvls_left - 1] is not None:
|
||||
if array_lvls[lvls_left - 1] != idx:
|
||||
print("Mismatch in array size! Got %d, expected %d\n" %
|
||||
(idx, array_lvls[lvls_left - 1]),
|
||||
@@ -232,7 +246,6 @@ def evaluate_field(structs, data, field, lvls_left, new_level = True):
|
||||
return True
|
||||
|
||||
|
||||
|
||||
def send_struct_impl(structs, data, structname):
|
||||
# Check if it is a struct we don't known
|
||||
if structname not in structs.keys():
|
||||
@@ -244,6 +257,7 @@ def send_struct_impl(structs, data, structname):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# ledgerjs doesn't actually sign anything, and instead uses already pre-computed signatures
|
||||
def send_filtering_message_info(display_name: str, filters_count: int):
|
||||
global sig_ctx
|
||||
@@ -262,6 +276,7 @@ def send_filtering_message_info(display_name: str, filters_count: int):
|
||||
enable_autonext()
|
||||
disable_autonext()
|
||||
|
||||
|
||||
# ledgerjs doesn't actually sign anything, and instead uses already pre-computed signatures
|
||||
def send_filtering_show_field(display_name):
|
||||
global sig_ctx
|
||||
@@ -281,12 +296,14 @@ def send_filtering_show_field(display_name):
|
||||
with app_client.eip712_filtering_show_field(display_name, sig):
|
||||
pass
|
||||
|
||||
def read_filtering_file(domain, message, filtering_file_path):
|
||||
|
||||
def read_filtering_file(filtering_file_path: str):
|
||||
data_json = None
|
||||
with open(filtering_file_path) as data:
|
||||
data_json = json.load(data)
|
||||
return data_json
|
||||
|
||||
|
||||
def prepare_filtering(filtr_data, message):
|
||||
global filtering_paths
|
||||
|
||||
@@ -295,12 +312,14 @@ def prepare_filtering(filtr_data, message):
|
||||
else:
|
||||
filtering_paths = {}
|
||||
|
||||
|
||||
def handle_optional_domain_values(domain):
|
||||
if "chainId" not in domain.keys():
|
||||
domain["chainId"] = 0
|
||||
if "verifyingContract" not in domain.keys():
|
||||
domain["verifyingContract"] = "0x0000000000000000000000000000000000000000"
|
||||
|
||||
|
||||
def init_signature_context(types, domain):
|
||||
global sig_ctx
|
||||
|
||||
@@ -314,7 +333,7 @@ def init_signature_context(types, domain):
|
||||
for i in range(8):
|
||||
sig_ctx["chainid"].append(chainid & (0xff << (i * 8)))
|
||||
sig_ctx["chainid"].reverse()
|
||||
schema_str = json.dumps(types).replace(" ","")
|
||||
schema_str = json.dumps(types).replace(" ", "")
|
||||
schema_hash = hashlib.sha224(schema_str.encode())
|
||||
sig_ctx["schema_hash"] = bytearray.fromhex(schema_hash.hexdigest())
|
||||
|
||||
@@ -322,22 +341,24 @@ def init_signature_context(types, domain):
|
||||
def next_timeout(_signum: int, _frame):
|
||||
autonext_handler()
|
||||
|
||||
|
||||
def enable_autonext():
|
||||
seconds = 1/4
|
||||
if app_client._client.firmware.device == 'stax': # Stax Speculos is slow
|
||||
if app_client._client.firmware.device == 'stax': # Stax Speculos is slow
|
||||
interval = seconds * 3
|
||||
else:
|
||||
interval = seconds
|
||||
signal.setitimer(signal.ITIMER_REAL, seconds, interval)
|
||||
|
||||
|
||||
def disable_autonext():
|
||||
signal.setitimer(signal.ITIMER_REAL, 0, 0)
|
||||
|
||||
|
||||
def process_file(aclient: EthAppClient,
|
||||
input_file_path: str,
|
||||
filtering_file_path = None,
|
||||
autonext: Callable = None) -> bool:
|
||||
filtering_file_path: Optional[str] = None,
|
||||
autonext: Optional[Callable] = None) -> bool:
|
||||
global sig_ctx
|
||||
global app_client
|
||||
global autonext_handler
|
||||
@@ -357,7 +378,7 @@ def process_file(aclient: EthAppClient,
|
||||
|
||||
if filtering_file_path:
|
||||
init_signature_context(types, domain)
|
||||
filtr = read_filtering_file(domain, message, filtering_file_path)
|
||||
filtr = read_filtering_file(filtering_file_path)
|
||||
|
||||
# send types definition
|
||||
for key in types.keys():
|
||||
@@ -365,7 +386,7 @@ def process_file(aclient: EthAppClient,
|
||||
pass
|
||||
for f in types[key]:
|
||||
(f["type"], f["enum"], f["typesize"], f["array_lvls"]) = \
|
||||
send_struct_def_field(f["type"], f["name"])
|
||||
send_struct_def_field(f["type"], f["name"])
|
||||
|
||||
if filtering_file_path:
|
||||
with app_client.eip712_filtering_activate():
|
||||
|
||||
Reference in New Issue
Block a user