Just in time for halloween, I’m trying to do some ...
# questions
i
Just in time for halloween, I’m trying to do some
Jinja
black magic 🧙 High level: I want to add a global variable to
globals_dict
in
settings.py
and use it in a loop in my catalog. See thread for an example. Any ideas?
settings.py
Copy code
fruit_list = ["apple", "banana", "pear"]

CONFIG_LOADER_CLASS = TemplatedConfigLoader
CONFIG_LOADER_ARGS = {
    "globals_pattern": "*globals.yml",
    "globals_dict": {"fruit": fruit_list},
}
catalog.yml
Copy code
{% for fruit in ${fruit_list} %}

{{ fruit }}.raw_data:
    type: pandas.CSVDataset
    filepath: data/01_raw/{{ fruit }}.csv

{% endfor %}
This results in:
Copy code
ScannerError: while scanning for the next token
found character that cannot start any token
Of course, the issue is the
${fruit_list}
. Is there some other syntax I’m unaware of?
f
I believe it is not possible to do that out-of-the-box because jinja2 render is before the global dict replacement. I would love to see that functionality
i
I think you’re right! Remembering back, I don’t think this is the first time I’ve tried something like this. Do you know where the yaml gets read in by chance? I was looking to extend the class that does this and try to add this functionality myself but couldn’t really find that piece of code.
f
The
TemplatedConfigLoader
calls
_get_config_from_patterns
which uses _load_config_file. There you can see a call of
anyconfig.load(yml, ac_template=True).items()
You could modify
TemplatedConfigLoader
to load the global variable first and use them as a context dictionary, so you can call
anyconfig_load(yml, ac_template=True, ac_context=context)
, where context is a dictionary with your global variables. Now you will be able to write:
catalog.yml
Copy code
{% for fruit in fruit_list %}

{{ fruit }}.raw_data:
    type: pandas.CSVDataset
    filepath: data/01_raw/{{ fruit }}.csv

{% endfor %}
parameters_globals.yml
Copy code
fruit_list:
  - apple
  - banana
  - pear
Feel free to use this
CustomTemplatedConfigLoader
. I just added the param
ac_context
to some functions, which uses the dictionary with the global parameters
Copy code
import logging
from pathlib import Path
from typing import AbstractSet, Any, Dict, Iterable, List, Set

from kedro.config import BadConfigException, MissingConfigException
from kedro.config.common import _check_duplicate_keys, _lookup_config_filepaths
from kedro.config.templated_config import TemplatedConfigLoader, _format_object
from yaml.parser import ParserError

_config_logger = logging.getLogger(__name__)


def _get_config_from_patterns(
    conf_paths: Iterable[str],
    patterns: Iterable[str] = None,
    ac_template: bool = False,
    ac_context: Dict[str, Any] = None,
) -> Dict[str, Any]:
    """Recursively scan for configuration files, load and merge them, and
    return them in the form of a config dictionary.

    Args:
        conf_paths: List of configuration paths to directories
        patterns: Glob patterns to match. Files, which names match
            any of the specified patterns, will be processed.
        ac_template: Boolean flag to indicate whether to use the `ac_template`
            argument of the ``anyconfig.load`` method. Used in the context of
            `_load_config_file` function.
        ac_context: Dict[str, Any] Mapping object presents context to
            instantiate template. Argument of the ``anyconfig.load`` method

    Raises:
        ValueError: If 2 or more configuration files inside the same
            config path (or its subdirectories) contain the same
            top-level key.
        MissingConfigException: If no configuration files exist within
            a specified config path.
        BadConfigException: If configuration is poorly formatted and
            cannot be loaded.

    Returns:
        Dict[str, Any]:  A Python dictionary with the combined
            configuration from all configuration files. **Note:** any keys
            that start with `_` will be ignored.
    """

    if not patterns:
        raise ValueError(
            "'patterns' must contain at least one glob "
            "pattern to match config filenames against."
        )

    config = {}  # type: Dict[str, Any]
    processed_files = set()  # type: Set[Path]

    for conf_path in conf_paths:
        if not Path(conf_path).is_dir():
            raise ValueError(
                f"Given configuration path either does not exist "
                f"or is not a valid directory: {conf_path}"
            )

        config_filepaths = _lookup_config_filepaths(
            Path(conf_path), patterns, processed_files, _config_logger
        )
        new_conf = _load_configs(
            config_filepaths=config_filepaths,
            ac_template=ac_template,
            ac_context=ac_context,
        )

        common_keys = config.keys() & new_conf.keys()
        if common_keys:
            sorted_keys = ", ".join(sorted(common_keys))
            msg = (
                "Config from path '%s' will override the following "
                "existing top-level config keys: %s"
            )
            <http://_config_logger.info|_config_logger.info>(msg, conf_path, sorted_keys)

        config.update(new_conf)
        processed_files |= set(config_filepaths)

    if not processed_files:
        raise MissingConfigException(
            f"No files found in {conf_paths} matching the glob "
            f"pattern(s): {patterns}"
        )
    return config


def _load_config_file(
    config_file: Path, ac_template: bool = False, ac_context: Dict[str, Any] = None
) -> Dict[str, Any]:
    """Load an individual config file using `anyconfig` as a backend.

    Args:
        config_file: Path to a config file to process.
        ac_template: Boolean flag to indicate whether to use the `ac_template`
            argument of the ``anyconfig.load`` method.
        ac_context: Dict[str, Any] Mapping object presents context to
            instantiate template. Argument of the ``anyconfig.load`` method

    Raises:
        BadConfigException: If configuration is poorly formatted and
            cannot be loaded.
        ParserError: If file is invalid and cannot be parsed.

    Returns:
        Parsed configuration.
    """
    # for performance reasons
    import anyconfig  # pylint: disable=import-outside-toplevel

    try:
        # Default to UTF-8, which is Python 3 default encoding, to decode the file
        with open(config_file, encoding="utf8") as yml:
            _config_logger.debug("Loading config file: '%s'", config_file)
            return {
                k: v
                for k, v in anyconfig.load(
                    yml, ac_template=ac_template, ac_context=ac_context
                ).items()
                if not k.startswith("_")
            }
    except AttributeError as exc:
        raise BadConfigException(f"Couldn't load config file: {config_file}") from exc

    except ParserError as exc:
        assert exc.problem_mark is not None
        line = exc.problem_mark.line
        cursor = exc.problem_mark.column
        raise ParserError(
            f"Invalid YAML file {config_file}, unable to read line {line}, position {cursor}."
        ) from exc


def _load_configs(
    config_filepaths: List[Path], ac_template: bool, ac_context: Dict[str, Any] = None
) -> Dict[str, Any]:
    """Recursively load all configuration files, which satisfy
    a given list of glob patterns from a specific path.

    Args:
        config_filepaths: Configuration files sorted in the order of precedence.
        ac_template: Boolean flag to indicate whether to use the `ac_template`
            argument of the ``anyconfig.load`` method. Used in the context of
            `_load_config_file` function.
        ac_context: Dict[str, Any] Mapping object presents context to
            instantiate template. Argument of the ``anyconfig.load`` method

    Raises:
        ValueError: If 2 or more configuration files contain the same key(s).
        BadConfigException: If configuration is poorly formatted and
            cannot be loaded.

    Returns:
        Resulting configuration dictionary.

    """

    aggregate_config = {}
    seen_file_to_keys = {}  # type: Dict[Path, AbstractSet[str]]

    for config_filepath in config_filepaths:
        single_config = _load_config_file(
            config_filepath, ac_template=ac_template, ac_context=ac_context
        )
        _check_duplicate_keys(seen_file_to_keys, config_filepath, single_config)
        seen_file_to_keys[config_filepath] = single_config.keys()
        aggregate_config.update(single_config)

    return aggregate_config


class CustomTemplatedConfigLoader(TemplatedConfigLoader):
    """Extension of the ``TemplatedConfigLoader`` class that allows the use
    of global variables when loading the config using anyconfig.load as the
    ac_context parameter.
    """

    def get(self, *patterns: str) -> Dict[str, Any]:
        config_raw = _get_config_from_patterns(
            conf_paths=self.conf_paths,
            patterns=patterns,
            ac_template=True,
            ac_context=self._config_mapping,
        )

        return _format_object(config_raw, self._config_mapping)
i
Cool! Thanks