Source code for ramble.language.language_helpers

# Copyright 2022-2026 The Ramble Authors
#
# Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
# https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
# <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
# option. This file may not be copied, modified, or distributed
# except according to those terms.

import fnmatch
import functools
from collections import OrderedDict
from typing import Any, List, Optional, Union

from ramble.error import DirectiveError


[docs] def check_definition( single_type, multiple_type, single_arg_name, multiple_arg_name, directive_name ): """ Sanity check definitions before merging or require Args: single_type: Single string for type name multiple_type: List of strings for type names, may contain wildcards multiple_pattern_match: List of strings to match against patterns in multiple_type single_arg_name: String name of the single_type argument in the directive multiple_arg_name: String name of the multiple_type argument in the directive directive_name: Name of the directive requiring a type Returns: List of all type names (Merged if both single_type and multiple_type definitions are valid) """ if single_type and not isinstance(single_type, str): raise DirectiveError( f"Directive {directive_name} was given an invalid type " f"for the {single_arg_name} argument. " f"Type was {type(single_type)}" ) if multiple_type and not isinstance(multiple_type, list): raise DirectiveError( f"Directive {directive_name} was given an invalid type " f"for the {multiple_arg_name} argument. " f"Type was {type(multiple_type)}" )
[docs] def merge_definitions( single_type, multiple_type, multiple_pattern_match, single_arg_name, multiple_arg_name, directive_name, ): """Merge definitions of a type This method will merge two optional definitions of single_type and multiple_type. Args: single_type: Single string for type name multiple_type: List of strings for type names, may contain wildcards multiple_pattern_match: List of strings to match against patterns in multiple_type single_arg_name: String name of the single_type argument in the directive multiple_arg_name: String name of the multiple_type argument in the directive directive_name: Name of the directive requiring a type Returns: List of all type names (Merged if both single_type and multiple_type definitions are valid) """ check_definition( single_type, multiple_type, single_arg_name, multiple_arg_name, directive_name ) merged_types = [] if single_type: merged_types.append(single_type) if multiple_type: merged_types.extend(multiple_type) merged_types_expanded = expand_patterns(merged_types, multiple_pattern_match) return merged_types_expanded
[docs] def require_definition( single_type, multiple_type, multiple_pattern_match, single_arg_name, multiple_arg_name, directive_name, ): """Require at least one definition for a type in a directive This method will validate that single_type / multiple_type are properly defined. It will raise an error if at least one type is not defined, or if either are the incorrect type. Args: single_type: Single string for type name multiple_type: List of strings for type names, may contain wildcards multiple_pattern_match: List of strings to match against patterns in multiple_type single_arg_name: String name of the single_type argument in the directive multiple_arg_name: String name of the multiple_type argument in the directive directive_name: Name of the directive requiring a type Returns: List of all type names (Merged if both single_type and multiple_type definitions are valid) """ if not (single_type or multiple_type): raise DirectiveError( f"Directive {directive_name} requires at least one of " f"{single_arg_name} or {multiple_arg_name} to be defined." ) return merge_definitions( single_type, multiple_type, multiple_pattern_match, single_arg_name, multiple_arg_name, directive_name, )
[docs] def merge_conditions( obj, directive_name: str, single_arg_name: Optional[str] = None, multiple_arg_name: Optional[str] = None, **kwargs, ) -> List[List[str]]: """Merge conditions for a type in a directive, and converts all conditions to when conditions If single/multiple values are provided, this method will validate that they are properly defined, and will merge them into a list of when conditions for each mode. Args: obj: Object instance directive_name: Name of the calling directive single_arg_name: Name of the singular kwarg being required multiple_arg_name: Name of the plural kwarg being required Kwargs: when (list | None): List of when conditions to apply to directive mode (str | None): Modifier mode to be applied as a when condition modes (list(str) | None): List of modifier modes to be applied as when conditions Returns: List of lists of strings, where each inner list is a list of when conditions for a mode. """ single_arg_val = kwargs.get(single_arg_name) if single_arg_name else None multiple_arg_val = kwargs.get(multiple_arg_name) if multiple_arg_name else None base_when_list = build_when_list(kwargs.get("when"), obj, obj.name, directive_name) # Create a list of when conditions for each mode (or one modeless list if no mode) variant_when_lists = [] # If args are modifier modes, convert to when conditions if single_arg_name == "mode" or multiple_arg_name == "modes": all_modes = merge_definitions( single_arg_val, multiple_arg_val, obj.modes, "mode", "modes", directive_name, ) if not all_modes: all_modes = [None] for mode_name in all_modes: if mode_name: mode_variant = f"{obj.name}_mode={mode_name}" variant_when_list = base_when_list + [mode_variant] else: variant_when_list = base_when_list variant_when_lists.append(variant_when_list) else: variant_when_lists.append(base_when_list) return variant_when_lists
[docs] def expand_patterns(merged_types: list, multiple_pattern_match: Union[list, dict]): """Expand wildcard patterns within a list of names This method takes an input list containing wildcard patterns and expands the wildcard with values matching a list of names. Returns a list containing matching names and any inputs with zero matches. If multiple_pattern_match is a dict keyed on 'when', it checks the input against patterns in all 'when' conditions, without evaluating them, and returns a list containing names that match under any when condition, and any inputs with zero matches. Args: merged_types: List of strings for type names, may contain wildcards multiple_pattern_match: List of strings (optional: nested in when_set dict) to match against patterns in merged_types Returns: List of expanded patterns matching the names list plus patterns not found in the names list. """ expanded_patterns = OrderedDict() for input in merged_types: expanded = False if ( multiple_pattern_match and isinstance(multiple_pattern_match, dict) and isinstance(next(iter(multiple_pattern_match)), frozenset) ): for pattern_list in multiple_pattern_match.values(): matched_inputs = fnmatch.filter(pattern_list, input) if matched_inputs: expanded = True for match in matched_inputs: expanded_patterns[match] = "" else: matched_inputs = fnmatch.filter(multiple_pattern_match, input) if matched_inputs: expanded = True for match in matched_inputs: expanded_patterns[match] = "" if not expanded: expanded_patterns[input] = "" return list(expanded_patterns.keys())
[docs] def add_variable_validator(obj, var_name, var_values, when_list, wl_name=None): """Adds a validator to an object to ensure a variable's value is in a list of values.""" validator_name = f"validate_values_for_{var_name}_obj_{obj.name}" predicate = f"'{{{var_name}}}' in {var_values!r}" message = ( f"Value of variable '{var_name}' ('{{{var_name}}}') is not one of the allowed values: " f"{var_values}" ) new_when_list = when_list.copy() if wl_name is not None: new_when_list.extend([f"workload_name={wl_name}"]) when_set = frozenset(new_when_list) if when_set not in obj.validators: obj.validators[when_set] = {} obj.validators[when_set][validator_name] = { "predicate": predicate, "message": message, "fail_on_invalid": True, }
[docs] def build_when_list( when_arg: Optional[Union[str, List[str]]], obj: Any, directive_id: str, directive_name: str, ) -> List[str]: """Construct list of when conditions based on a directives input argument Also, validate that when is passed in with the right type. Args: when_arg (str | list(str)): Single or list of string conditions that were input into the calling directive. obj: A ramble object (i.e. application, modifier, etc..) directive_id (str): Directive identifier. The calling directive can define what is used here, but it should be something that can help users identify where errors from this method originate from. directive_name (str): Name of the calling directive Returns: List of strings, for all of the when conditions. """ when_list = [] if when_arg is not None: if isinstance(when_arg, str): when_arg = [when_arg] elif not isinstance(when_arg, list): if obj == "DirectiveMeta": raise DirectiveError( "DirectiveMeta is unable to process an invalid `when` argument from directive " f"{directive_name} {directive_id}. The `when` argument must be input as a " "string or list." ) else: raise DirectiveError( f"Object {obj.name} calls directive {directive_name} {directive_id} " f"with an invalid `when` argument. The `when` argument must be input as a " "string or list." ) when_list.extend(when_arg) # Enable '@{version}' syntax in `when` clauses if hasattr(obj, "origin_type") and obj.origin_type: for i, w in enumerate(when_list): if w.startswith("@"): when_list[i] = f"{obj.origin_type}_version{w}" return when_list
@functools.lru_cache(maxsize=None) def _parse_when(w_set): from ramble.util.format import when_order variants = {} versions = {} for w_entry in sorted(w_set, key=when_order): for w in w_entry.split(): if "=" in w: name, val = w.split("=", 1) if name in variants and variants[name] != val: return ( None, None, f"variant '{name}' has conflicting values: '{variants[name]}' and '{val}'", ) variants[name] = val elif w.startswith(("+", "~")): name = w[1:] val = "True" if w.startswith("+") else "False" if name in variants and variants[name] != val: return ( None, None, f"variant '{name}' has conflicting values: '{variants[name]}' and '{val}'", ) variants[name] = val elif "@" in w: name, ver = w.split("@", 1) if name in versions and versions[name] != ver: return ( None, None, f"version '{name}' has conflicting values: '{versions[name]}' and '{ver}'", ) versions[name] = ver return variants, versions, None
[docs] def are_when_compatible(when_set1, when_set2): """Determine if two sets of when conditions are compatible Args: when_set1 (list): First set of when conditions when_set2 (list): Second set of when conditions Returns: (bool): True if they are compatible, False otherwise """ if not isinstance(when_set1, frozenset): when_set1 = frozenset(when_set1) if not isinstance(when_set2, frozenset): when_set2 = frozenset(when_set2) v1, ver1, _ = _parse_when(when_set1) v2, ver2, _ = _parse_when(when_set2) if v1 is None or v2 is None: return False for name in v1: if name in v2: if v1[name] != v2[name]: return False for name in ver1: if name in ver2: if ver1[name] != ver2[name]: return False return True
[docs] def is_when_impossible(when_list): """Determine if a single list of when conditions is self-contradictory Args: when_list (list): list of when conditions Returns: (bool, str): True and a message if it is impossible, False and None otherwise """ if when_list is None: return False, None if isinstance(when_list, str): when_list = [when_list] if not isinstance(when_list, frozenset): when_list = frozenset(when_list) _, _, conflict = _parse_when(when_list) if conflict: return True, conflict return False, None