Select to view content in your preferred language

Re-Write the arcpy.utils.ArgAdaptor class to use Literal hints on masked args

121
0
3 weeks ago
Status: Open

This is basically a drop in replacement for some really inscrutable code in the utils module that's used to translate string literal arguments to integer flags using an Adaptor implementation (see _mp.constants).

It is possible to use inspect now to directly override __annotations__ for a function and since this class will always be a converting string literals to integers (or other strings) using the __args__ class attribute, it's possible to generate Literal types using that mapping using the __init_subclass__ method to generate them on the fly for implemented Adaptor classes:

 

def __init_subclass__(cls, **kwargs):
    """ Build Literals for the subclass and formats the __args__ dictionary to lowercase """
    # Lowercase the keys for case insensitivity
    cls.__args__ = {k.lower(): v for k, v in cls.__args__.items()}

    # Build the literals for the subclass
    cls.__hints__: ArgAdaptor.HintMap = {
        parameter: Literal[*list(options.keys())]
        for parameter, options in cls.__args__.items()
    }

 

You can then directly update the __annotations__ attribute of the wrapper function:

 

# Apply literal type hints to the adapted function
adapt_arguments.__annotations__.update(adaptor.__hints__)

# Allow manual annotations to override the literal type hints
adapt_arguments.__annotations__.update(masked_function.__annotations__)

 

This will add the option keys as Literal hints for each masked parameter:

 

# Toy example with some basic masking keys
>>> Testing.maskedFunction.__annotations__
>>> ('format', typing.Literal['pdf', 'svg', 'eps'])
>>> ('mode', typing.Literal['read', 'write', 'execute'])
>>> ('filename', <class 'str'>)

 

instead of that __annotation__ attribute only being the hints written directly into the function header.

This allows for Literal hinting to be applied with no extra work for any functions wrapped by ArgAdaptor. It will also update the hints at runtime after changes to the __args__ attribute of an implementation class.

Code:

Spoiler
import functools
from typing import Callable, Sequence, Any, TypeAlias, Literal
import inspect

class ArgAdaptor:
    ValueMap: TypeAlias = dict[str, str | int]
    ArgumentMap: TypeAlias = dict[str, ValueMap]
    HintMap: TypeAlias = dict[str, TypeAlias]
    
    __args__: ArgumentMap = {}
    
    def __init_subclass__(cls, **kwargs):
        """ Build Literals for the subclass and formats the __args__ dictionary to lowercase """
        
        # Lowercase the keys for case insensitivity
        cls.__args__ = {k.lower(): v for k, v in cls.__args__.items()}

        # Build the literals for the subclass
        cls.__hints__: ArgAdaptor.HintMap = {
            parameter: Literal[*list(options.keys())]
            for parameter, options in cls.__args__.items()
        }
    
    @classmethod
    def maskargs(adaptor: 'ArgAdaptor', masked_function: Callable) -> Callable:
        masked_function_signature = inspect.signature(masked_function)
        function_parameters = masked_function_signature.parameters
        adaptors = adaptor.__args__
        
        @functools.wraps(masked_function)
        def adapt_arguments(*args, **kwargs):
            adapted_arguments: dict[str, Any] = {
                parameter.name: value
                for parameter, value in zip(function_parameters.values(), args)
                if parameter.kind in (
                    inspect.Parameter.POSITIONAL_ONLY,
                    inspect.Parameter.POSITIONAL_OR_KEYWORD,
                    inspect.Parameter.VAR_POSITIONAL
                )
            }
            adapted_arguments.update(kwargs)
            
            invalid_args: list[Exception] = []
            for argument, arg_value in adapted_arguments.items():
                if argument not in adaptors:
                    continue
                elif isinstance(arg_value, str):
                    arg_value = arg_value.lower()
                    if arg_value not in adaptors[argument]:
                        invalid_args.append(
                            f'Invalid value for `{argument}`: ' 
                            f"'{arg_value}' "
                            f'(choices are {list(adaptors[argument].keys())})'
                        )
                        continue
                    adapted_arguments[argument] = adaptors[argument][arg_value]
                
                elif isinstance(arg_value, Sequence):
                    arg_values: list[str] = arg_value
                    invalid_values = [
                        arg_val 
                        for arg_val in arg_values 
                        if arg_val not in adaptors[argument]
                    ]
                    if invalid_values:
                        invalid_args.append(
                            f'Invalid value{"s"*(len(invalid_values)>1)} for {argument}:'
                            f'{", ".join(map(str, invalid_values))}'
                            f'(choices are {list(adaptors[argument].keys())})'
                        )
                        continue 
                    adapted_arguments[argument] = [
                        adaptors[argument][arg_val.lower()]
                        for arg_val in arg_values
                    ]
            
            if invalid_args:
                raise ValueError('\n'.join(invalid_args))
            return masked_function(**adapted_arguments)
        
        for att in ('__doc__', '__annotations__', '__esri_toolinfo__'):
            setattr(
                adapt_arguments,
                att,
                (
                    getattr(adapt_arguments, att, None) or
                    getattr(masked_function, att, None) or
                    [
                        f"String::"
                        f"{'|'.join(adaptors[argument].keys())}:"
                        for argument in function_parameters.keys()
                        if argument in adaptors
                    ]
                    if att == '__esri_toolinfo__'
                    else None
                )
            )
        
        adapt_arguments.__annotations__.update(adaptor.__hints__)
        adapt_arguments.__annotations__.update(masked_function.__annotations__)
        return adapt_arguments
    
    @classmethod
    def maskmethods(adaptor: 'ArgAdaptor', other: type) -> None:
        # Grab all non dunder/private methods
        methods_to_mask = {
            method_name: method_object
            for method_name in dir(other)
            
            # Check if the method is callable and not private
            # Use the walrus operator to store the method object
            if callable(method_object := getattr(other, method_name))
            and not method_name.startswith("_")
        }
        
        # Mask the methods using the specified adaptor
        for method_name, method_object in methods_to_mask.items():
            setattr(other, method_name, adaptor.maskargs(method_object))
            print(f'Masked method: {method_name}')

 

This code is not fully compliant with the original implementation as there were some instance specific concatenations and overrides, but I think this is a good starting point for making this adapter class more clear and extensible. This could be an alternative to maintaining a bunch of separate type hints as well.