# Copyright 2022-2026 The Ramble Authors
#
# Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
# https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
# <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
# option. This file may not be copied, modified, or distributed
# except according to those terms.
import argparse
import inspect
from ramble.util.logger import logger
from spack.util.pattern import Args
__all__ = ["add_common_arguments", "allows_unknown_args", "validate_unknown_args"]
#: dictionary of argument-generating functions, keyed by name
_arguments = {}
def arg(fn):
"""Decorator for a function that generates a common argument.
This ensures that argument bunches are created lazily. Decorate
argument-generating functions below with @arg so that
``add_common_arguments()`` can find them.
"""
_arguments[fn.__name__] = fn
return fn
[docs]
def add_common_arguments(parser, list_of_arguments):
"""Extend a parser with extra arguments
Args:
parser: parser to be extended
list_of_arguments: arguments to be added to the parser
"""
for argument in list_of_arguments:
if argument not in _arguments:
message = 'Trying to add non existing argument "{0}" to a command'
raise KeyError(message.format(argument))
x = _arguments[argument]()
parser.add_argument(*x.flags, **x.kwargs)
[docs]
def allows_unknown_args(command):
"""Implements really simple argument injection for unknown arguments.
Commands may add an optional argument called "unknown args" to
indicate they can handle unknown args. This checks that the
command allows `unknown_args` as an input argument.
"""
info = dict(inspect.getmembers(command))
varnames = info["__code__"].co_varnames
argcount = info["__code__"].co_argcount
return argcount >= 2 and "unknown_args" in varnames
[docs]
def validate_unknown_args(command, unknown_args):
"""Validate command allows unknown arguments when they are passed in"""
if allows_unknown_args(command):
return
elif unknown_args:
logger.die(f'unrecognized arguments: {" ".join(unknown_args)}')
@arg
def yes_to_all():
return Args(
"-y",
"--yes-to-all",
action="store_true",
dest="yes_to_all",
help='assume "yes" is the answer to every confirmation request',
)
@arg
def tags():
return Args("-t", "--tags", action="append", help="filter a package query by tags")
@arg
def application():
return Args("application", help="application name")
@arg
def workspace():
return Args("workspace", help="workspace name")
@arg
def specs():
return Args("specs", nargs=argparse.REMAINDER, help="one or more workload specs")
@arg
def obj_type():
from ramble.repository import OBJECT_NAMES, default_type
return Args(
"--type",
default=f"{default_type.name}",
help=f"type of objects. Defaults to '{default_type.name}'. "
f"Allowed types are {', '.join(OBJECT_NAMES)}",
)
@arg
def repo_type():
from ramble.repository import OBJECT_NAMES, default_type
return Args(
"-t",
"--type",
default="any",
help=f"type of repositories to manage. Defaults to '{default_type.name}'. "
f"Allowed types are {', '.join(OBJECT_NAMES)}, or any",
)
@arg
def phases():
return Args(
"--phases",
dest="phases",
nargs="+",
default=["*"],
help="select phases to execute when performing setup. " + "Phase names support globbing",
required=False,
)
@arg
def include_phase_dependencies():
return Args(
"--include-phase-dependencies",
dest="include_phase_dependencies",
action="store_true",
help="if set, phase dependencies are automatically added to "
"the list of executed phases",
required=False,
)
@arg
def profile_phases():
return Args(
"--profile-phase",
nargs="+",
action="append",
default=None,
dest="profile_phases",
help="phases to be profiled by line_profiler",
required=False,
)
@arg
def profile_phase_output():
return Args(
"--profile-phase-output",
default=None,
dest="profile_phase_output",
help="file path to save the phase line_profiler output",
required=False,
)
@arg
def where():
return Args(
"--where",
dest="where",
nargs="+",
action="append",
help="inclusive filter on experiments where the provided logical statement is True",
required=False,
)
@arg
def exclude_where():
return Args(
"--exclude-where",
dest="exclude_where",
nargs="+",
action="append",
help="exclusive filter experiments where the provided logical statement is True",
required=False,
)
@arg
def filter_tags():
return Args(
"--filter-tags",
action="append",
nargs="+",
help="filter experiments to only those that include the provided tags",
required=False,
)
@arg
def no_checksum():
return Args(
"-n",
"--no-checksum",
action="store_true",
default=False,
help="do not use checksums to verify downloaded files (unsafe)",
)