Source code for tactus.tasks.discover_task

"""Discover tasks."""

import contextlib
import importlib
import inspect
import json
import os
import pkgutil
import sys
import types
from pathlib import Path

from ..logs import LoggerHandlers, logger
from ..os_utils import tactusmakedirs
from ..plugin import TactusPluginRegistry, TactusPluginRegistryFromConfig
from ..toolbox import Platform
from .base import Task, _get_name


[docs] def discover_modules(package, what="plugin"): """Discover plugin modules. Args: package (types.ModuleType): Namespace package containing the plugins what (str, optional): String describing what is supposed to be discovered. Defaults to "plugin". Yields: tuple: str: Name of the imported module types.ModuleType: The imported module """ path = package.__path__ prefix = package.__name__ + ".tasks." logger.info("{} search path: {}", what.capitalize(), path) for _finder, mname, _ispkg in pkgutil.iter_modules(path): fullname = prefix + mname logger.info("Loading module {}", fullname) try: mod = importlib.import_module(fullname) except ImportError as exc: logger.error("Could not load {}: {}", fullname, repr(exc)) raise RuntimeError("Failed to load module") from exc yield fullname, mod
def _task_index_file(config): """Defines the task index file. Args: config (ConfigParse): Config Returns: task_index_file (str): Full path to task index_file """ task_index_file_path = Platform(config).get_system_value("casedir") return Path(task_index_file_path) / "tasks_index.json"
[docs] def load_task_index(config): """Load a task index file. Args: config (ConfigParse): Config Returns: known_types(dict): Dict of known tasks, and their location """ task_index_file = _task_index_file(config) if os.path.isfile(task_index_file): logger.info("Read task index from {}", task_index_file) with open(task_index_file, "r", encoding="utf-8") as infile: known_types = json.load(infile) else: logger.info("Create task index file {}", task_index_file) known_types = create_task_index(config) return known_types
[docs] def create_task_index(config): """Create a task index file. Args: config (ConfigParse): Config Returns: known_types(dict): Dict of known tasks, and their location """ task_index_file = _task_index_file(config) reg = TactusPluginRegistryFromConfig(config) known_types = { k: f"{v.__module__}.{v.__qualname__}" for k, v in available_tasks(reg).items() } task_index_file_dir = os.path.dirname(task_index_file) unix_group = config.get("platform.unix_group") tactusmakedirs(task_index_file_dir, unixgroup=unix_group) with open(task_index_file, mode="w", encoding="utf8") as outfile: json.dump(known_types, outfile, indent=True) logger.info("Stored task index in {}", task_index_file) return known_types
[docs] def add_plugins_to_sys_path(reg: TactusPluginRegistry) -> None: """Add plugin root paths to ``sys.path`` so their task modules can be imported. Each plugin whose ``tasks_path`` exists is inserted at the front of ``sys.path`` (if not already present), making the plugin's packages importable by their fully-qualified module names. Args: reg (TactusPluginRegistry): Registry of active tactus plugins. """ for plugin in reg.plugins: if os.path.exists(plugin.tasks_path): plugin_path = str(plugin.path) if plugin_path not in sys.path: sys.path.insert(0, plugin_path)
[docs] def get_task(name, config) -> Task: """Create a `tactus.tasks.Task` object from configuration. Args: name (_type_): _description_ config (_type_): _description_ Returns: Task: The task object with name `name`. The task object has to be a subclass of Task to be retrievable. Raises: NotImplementedError: If task `name` is not amongst the known task names. """ with contextlib.suppress(KeyError): # loglevel may have been overridden, e.g., via ECFLOW UI logger.configure( handlers=LoggerHandlers(default_level=config["general.loglevel"]) ) logger.debug("Logger reset to level {}", config["general.loglevel"]) known_types = load_task_index(config) try: cls = known_types[name.lower()] except KeyError: known_types = create_task_index(config) try: cls = known_types[name.lower()] except KeyError as error: raise NotImplementedError(f'Task "{name}" not implemented') from error if isinstance(cls, str): add_plugins_to_sys_path(TactusPluginRegistryFromConfig(config)) module_path, class_name = cls.rsplit(".", 1) module = importlib.import_module(module_path) cls = getattr(module, class_name) return cls(config)
[docs] def available_tasks(reg: TactusPluginRegistry): """Create a list of available tasks. Args: reg (TactusPluginRegistry): tactus plugin registry Returns: known_types (list): Task objects """ known_types = {} abstract_classes = ["pysurfexbase"] add_plugins_to_sys_path(reg) for plugin in reg.plugins: if os.path.exists(plugin.tasks_path): tasks = types.ModuleType(plugin.name) tasks.__path__ = [str(plugin.tasks_path)] found_types = discover(tasks, Task) for ftype, cls in found_types.items(): if ftype not in abstract_classes: if ftype in known_types: logger.warning("Overriding suite {}", ftype) known_types[ftype] = cls else: logger.warning("Plug-in task {} not found", plugin.tasks_path) return known_types
[docs] def discover(package, base): """Discover task classes. Plugin classes are discovered in a given namespace package, deriving from a given base class. The base class itself is ignored, as are classes imported from another module (based on ``cls.__module__``). Each discovered class is identified by the class name by changing it to lowercase and stripping the name of the base class, if it appears as a suffix. Args: package (types.ModuleType): Namespace package containing the plugins base (type): Base class for the plugins Returns: (dict of str: type): Discovered plugin classes """ what = base.__name__ def pred(x): return inspect.isclass(x) and issubclass(x, base) and x is not base discovered = {} for fullname, mod in discover_modules(package, what=what): for cname, cls in inspect.getmembers(mod, pred): tname = _get_name(cname, cls, what.lower()) if cls.__module__ != fullname: logger.info( "Skipping {} {} imported by {}", what.lower(), tname, fullname ) continue if tname in discovered: logger.warning( "{} type {} is defined more than once", what.capitalize(), tname ) continue discovered[tname] = cls return discovered