Source code for atlas_schema.utils

from __future__ import annotations

from enum import Enum
from typing import TypeVar, Union, cast

import awkward as ak
import dask_awkward as dak

Array = TypeVar("Array", bound=Union[dak.Array, ak.Array])
_E = TypeVar("_E", bound=Enum)


[docs] def isin(element: Array, test_elements: dak.Array | ak.Array, axis: int = -1) -> Array: """ Find test_elements in element. Similar in API as :func:`numpy.isin`. Calculates `element in test_elements`, broadcasting over *element elements only*. Returns a boolean array of the same shape as *element* that is `True` where an element of *element* is in *test_elements* and `False` otherwise. This works by first transforming *test_elements* to an array with one more dimension than the *element*, placing the *test_elements* at *axis*, and then doing a comparison. Args: element (dak.Array or ak.Array): input array of values. test_elements (dak.Array or ak.Array): one-dimensional set of values against which to test each value of *element*. axis (int): the axis along which the comparison is performed Returns: dak.Array or ak.Array: result of comparison for test_elements in *element* Example: >>> import awkward as ak >>> import atlas_schema as ats >>> truth_origins = ak.Array([[1, 2, 3], [4], [5, 6, 7], [1]]) >>> prompt_origins = ak.Array([1, 2, 7]) >>> ats.isin(truth_origins, prompt_origins).to_list() [[True, True, False], [False], [False, False, True], [True]] """ assert test_elements.ndim == 1, "test_elements must be one-dimensional" assert axis >= -1, "axis must be -1 or positive-valued" assert axis < element.ndim + 1, "axis too large for the element" # First, build up the transformation, with slice(None) indicating where to stick the test_elements reshaper: list[None | slice] = [None] * element.ndim axis = element.ndim if axis == -1 else axis reshaper.insert(axis, slice(None)) # Note: reshaper needs to be a tuple for indexing purposes return cast(Array, ak.any(element == test_elements[tuple(reshaper)], axis=-1))