Source code for atlas_schema.schema

from __future__ import annotations

import warnings
from collections.abc import KeysView, ValuesView
from typing import Any, ClassVar

from coffea.nanoevents.schemas.base import BaseSchema, zip_forms

from atlas_schema.typing_compat import Behavior, Self


[docs] class NtupleSchema(BaseSchema): # type: ignore[misc] """Ntuple schema builder The Ntuple schema is built from all branches found in the supplied file, based on the naming pattern of the branches. The following additional arrays are constructed: - n/a """ __dask_capable__ = True warn_missing_crossrefs = True error_missing_event_ids = False event_ids_data: ClassVar[set[str]] = { "lumiBlock", "averageInteractionsPerCrossing", "actualInteractionsPerCrossing", "dataTakingYear", } event_ids_mc: ClassVar[set[str]] = { "mcChannelNumber", "runNumber", "eventNumber", "mcEventWeights", } event_ids: ClassVar[set[str]] = {*event_ids_data, *event_ids_mc} mixins: ClassVar[dict[str, str]] = { "el": "Electron", "jet": "Jet", "met": "MissingET", "mu": "Muon", "pass": "Pass", "ph": "Photon", "trigPassed": "Trigger", "weight": "Weight", } # These are stored as length-1 vectors unnecessarily singletons: ClassVar[set[str]] = set() docstrings: ClassVar[dict[str, str]] = { "charge": "charge", "eta": "pseudorapidity", "met": "missing transverse energy [MeV]", "mass": "invariant mass [MeV]", "pt": "transverse momentum [MeV]", "phi": "azimuthal angle", }
[docs] def __init__(self, base_form: dict[str, Any], version: str = "latest"): super().__init__(base_form) self._version = version if version == "latest": pass else: pass self._form["fields"], self._form["contents"] = self._build_collections( self._form["fields"], self._form["contents"] ) self._form["parameters"]["metadata"]["version"] = self._version
@classmethod def v1(cls, base_form: dict[str, Any]) -> Self: """Build the NtupleEvents For example, one can use ``NanoEventsFactory.from_root("file.root", schemaclass=NtupleSchema.v1)`` to ensure NanoAODv7 compatibility. """ return cls(base_form, version="1") def _build_collections( self, field_names: list[str], input_contents: list[Any] ) -> tuple[KeysView[str], ValuesView[dict[str, Any]]]: branch_forms = dict(zip(field_names, input_contents)) # parse into high-level records (collections, list collections, and singletons) collections = {k.split("_")[0] for k in branch_forms} collections -= self.event_ids collections -= set(self.singletons) # rename needed because easyjet breaks the AMG assumptions # https://gitlab.cern.ch/easyjet/easyjet/-/issues/246 for k in list(branch_forms): if "NOSYS" not in k: continue branch_forms[k.replace("_NOSYS", "") + "_NOSYS"] = branch_forms.pop(k) # these are collections with systematic variations subcollections = { k.split("__")[0].split("_", 1)[1].replace("_NOSYS", "") for k in branch_forms if "NOSYS" in k } # Check the presence of the event_ids missing_event_ids = [ event_id for event_id in self.event_ids if event_id not in branch_forms ] if len(missing_event_ids) > 0: if self.error_missing_event_ids: msg = f"There are missing event ID fields: {missing_event_ids} \n\n\ The event ID fields {self.event_ids} are necessary to perform sub-run identification \ (e.g. for corrections and sub-dividing data during different detector conditions),\ to cross-validate MC and Data (i.e. matching events for comparison), and to generate event displays. \ It's advised to never drop these branches from the dataformat.\n\n\ This error can be demoted to a warning by setting the class level variable error_missing_event_ids to False." raise RuntimeError(msg) warnings.warn( f"Missing event_ids : {missing_event_ids}", RuntimeWarning, stacklevel=2, ) output = {} # first, register singletons (event-level, others) for name in {*self.event_ids, *self.singletons}: if name in missing_event_ids: continue output[name] = branch_forms[name] # next, go through and start grouping up collections for name in collections: mixin = self.mixins.get(name, "NanoCollection") content = {} used = set() for subname in subcollections: prefix = f"{name}_{subname}_" used.update({k for k in branch_forms if k.startswith(prefix)}) subcontent = { k[len(prefix) :]: branch_forms[k] for k in branch_forms if k.startswith(prefix) } if subcontent: # create the nominal version content[subname] = branch_forms[f"{prefix}NOSYS"] # create a collection of the systematic variations for the given variable content[f"{subname}_syst"] = zip_forms( subcontent, f"{name}_syst", record_name="NanoCollection" ) content.update( { k[len(name) + 1 :]: branch_forms[k] for k in branch_forms if k.startswith(name + "_") and k not in used } ) if not used and not content: warnings.warn( f"I identified a branch that likely does not have any leaves: '{name}'. I will treat this as a 'singleton'. To suppress this warning next time, please define your singletons explicitly.", RuntimeWarning, stacklevel=2, ) self.singletons.add(name) output[name] = branch_forms[name] else: output[name] = zip_forms(content, name, record_name=mixin) output[name].setdefault("parameters", {}) output[name]["parameters"].update({"collection_name": name}) if output[name]["class"] == "ListOffsetArray": parameters = output[name]["content"]["fields"] contents = output[name]["content"]["contents"] elif output[name]["class"] == "RecordArray": parameters = output[name]["fields"] contents = output[name]["contents"] elif output[name]["class"] == "NumpyArray": # these are singletons that we just pass through continue else: msg = f"Unhandled class {output[name]['class']}" raise RuntimeError(msg) # update docstrings as needed # NB: must be before flattening for easier logic for index, parameter in enumerate(parameters): if "parameters" not in contents[index]: continue parsed_name = parameter.replace("_NOSYS", "") contents[index]["parameters"]["__doc__"] = self.docstrings.get( parsed_name, contents[index]["parameters"].get( "__doc__", "no docstring available" ), ) return output.keys(), output.values() @classmethod def behavior(cls) -> Behavior: """Behaviors necessary to implement this schema""" from atlas_schema.methods import behavior as roaster return roaster