Source code for frutils.frutils

# -*- coding: utf-8 -*-

"""Utility methods that are used across the frkl-suite (https://frkl.io) of tools."""

import copy
import json
import logging
import os
import pprint
import re
import subprocess
from collections import Mapping, OrderedDict

from jinja2 import Environment, Undefined
from ruamel.yaml import YAML
from ruamel.yaml.compat import StringIO
from six import string_types

from .defaults import *  # noqa


log = logging.getLogger("frutils")

yaml = YAML()


# utlity methods/classes
[docs]class StringYAML(YAML): """Wraps :class:~YAML to be able to dump a string from a yaml object. More details: http://yaml.readthedocs.io/en/latest/example.html#output-of-dump-as-a-string Args: **kwargs (dict): arguments for the underlying :class:~YAML class """ def __init__(self, **kwargs): super(StringYAML, self).__init__(**kwargs)
[docs] def dump(self, data, stream=None, **kw): inefficient = False if stream is None: inefficient = True stream = StringIO() YAML.dump(self, data, stream, **kw) if inefficient: return stream.getvalue()
[docs]def ordered_load(text): """Loads a yaml stream into an OrderedDict """ return yaml.load(text)
[docs]def special_dict_to_dict(value): """Converts any 'special' dict (like CommentedMap, OrderedDict) to a 'normal' one. Args: value (dict): the 'special' dict Returns: dict: the 'normal' dict """ for k, v in value.items(): if isinstance(v, dict): value[k] = special_dict_to_dict(v) return dict(value)
[docs]def list_of_special_dicts_to_list_of_dicts(value): result = [] for sd in value: result.append(special_dict_to_dict(sd)) return result
[docs]def is_list_of_strings(input_obj): """Helper method to determine whether an object is a list or tuple of only strings (or string_types). Args: input_obj (object): the object in question Returns: bool: whether or not the object is a list of strings """ return ( bool(input_obj) and isinstance(input_obj, (list, tuple)) and not isinstance(input_obj, string_types) and all(isinstance(item, string_types) for item in input_obj) )
[docs]def dict_merge(dct, merge_dct, copy_dct=True): """ Recursive dict merge. Inspired by :meth:``dict.update()``, instead of updating only top-level keys, dict_merge recurses down into dicts nested to an arbitrary depth, updating keys. The ``merge_dct`` is merged into ``dct``. Copied from: https://gist.github.com/angstwad/bf22d1822c38a92ec0a9 Args: dct (dict): dict onto which the merge is executed merge_dct (dict): dct merged into dct copy_dct (bool): whether to (deep-)copy dct before merging (and leaving it unchanged), or not (default: copy) Returns: dict: the merged dict (original or copied) """ if copy_dct: dct = copy.deepcopy(dct) for k, v in merge_dct.items(): if k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], Mapping): dict_merge(dct[k], merge_dct[k], copy_dct=False) else: dct[k] = merge_dct[k] return dct
[docs]def merge_list_of_dicts(dicts, starting_dict=None): """Merges a list of dicts. Args: dicts (list): list of dicts to be merged in order starting_dict (dict): (optional) existing dict where the others are merged into Returns: dict: the merged dict (same as starting_dict) """ if starting_dict is None: starting_dict = {} for d in dicts: dict_merge(starting_dict, d, copy_dct=False) return starting_dict
[docs]def is_url_or_abbrev(url, abbrevs=DEFAULT_URL_ABBREVIATIONS_FILE): if url.startswith("http://") or url.startswith("https://"): return True for key in abbrevs.keys(): if url.startswith("{}:".format(key)): return True return False
[docs]def get_key_path_value(source_dict, key_path, split_token=".", default_value=None): """Queries the source dict tree for the register key, split up using the split_token. Args: source_dict (dict): the source dictionary key_path (str): a key-path, e.g. 'key1.child_key1.test_key' split_token (str): the character to split the register_key value with default_value: the default value to return if no key matches Return: object: the value for the matching in the tree, or None """ temp_dict = source_dict tokens = key_path.split(split_token) for key in tokens[0:-1]: temp_dict = temp_dict.get(key, None) if temp_dict is None: return default_value elif not isinstance(temp_dict, dict): return default_value return temp_dict.get(tokens[-1], default_value)
[docs]def add_key_to_dict(target_dict, key_path, value, split_token=".", ordered=True): """Add a key into the key path of a dictionary. Args: target_dict (dict): the dictionary to add to key_path (str): the path to the value (e.g. key1.child_key.test) value: the value to insert split_token (str): the character to split the key_path with. ordered (bool): whether to use OrderedDicts instead of 'normal' ones """ log.debug("Adding value '{}' as key '{}' to dict.".format(value, key_path)) if ordered: temp_dict = OrderedDict() else: temp_dict = {} orig_temp_dict = temp_dict tokens = key_path.split(split_token) for key in tokens[0:-1]: if ordered: temp_dict.setdefault(key, OrderedDict()) else: temp_dict.setdefault(key, {}) temp_dict = temp_dict[key] temp_dict[tokens[-1]] = value dict_merge(target_dict, orig_temp_dict, copy_dct=False)
[docs]def append_key_to_dict(target_dict, key_path, value, split_token="."): """Appends a key into the key path of a dictionary. The value either needs to be a list or tuple, or nonexistent. If this is not the case, this function will throw an exception. Args: target_dict (dict): the dictionary to add to key_path (str): the path to the value (e.g. key1.child_key.test) value: the value to append split_token (str): the character to split the key_path with. """ log.debug("Adding value '{}' as key '{}' to dict.".format(value, key_path)) temp_dict = copy.deepcopy(target_dict) orig_temp_dict = temp_dict tokens = key_path.split(split_token) for key in tokens[0:-1]: temp_dict.setdefault(key, {}) temp_dict = temp_dict[key] old_value = temp_dict.get(tokens[-1], None) if old_value is None: temp_dict[tokens[-1]] = [value] elif isinstance(old_value, (list, tuple)): old_value.append(value) else: raise Exception( "Value for key_path '{}' is not a list or tuple, can't append.".format( key_path ) ) dict_merge(target_dict, orig_temp_dict, copy_dct=False)
[docs]def flatten_lists(lists): """Utility method to flatten a list of lists. This will only flatten the first sublist, any deeper list-structure will be preserved. Args: lists (list): a list of lists Returns: list: the flattened list """ result = [item for sublist in lists for item in sublist] return result
[docs]def is_templated(text, jinja_delimiter_profile=JINJA_DELIMITER_PROFILES["default"]): """Utility method to determine whether a string has template markers in it. This is pretty simplistic, it only checks whether one of the template marker strings (e.g. '}}', or '{%') are contained in the text. It doesn't check for matching opening/ closed brackets etc. Args: text (str): the text in question jinja_delimiter_profile (dict): a dictionary with Returns: bool: whether the text is templated or not """ for key, value in jinja_delimiter_profile.items(): if value in text: return True return False
[docs]class IgnoreUndefinedJinjaVariable(Undefined): def __fail__with_undefined_error(self, *args, **kwargs): log.debug("Missing jinja var") return None
[docs]def replace_string( template_string, replacement_dict=None, block_start_string=DEFAULT_BLOCK_START_STRING, block_end_string=DEFAULT_BLOCK_END_STRING, variable_start_string=DEFAULT_VARIABLE_START_STRING, variable_end_string=DEFAULT_VARIABLE_END_STRING, additional_jinja_extensions=None, local_env_vars_key=DEFAULT_LOCAL_ENV_VARS_KEY, ignore_undefined=False ): """Replace template markers with values from a replacement dictionary within a string. Args: template_string (str): the template string replacement_dict (dict): the dictionary with the replacement strings block_start_string (str): the string to indicate a template block start block_end_string (str): the string to indicate a template block end variable_start_string (str): the string to indicate a template variable start variable_end_string (str): the string to indicate a template variable end additional_jinja_extensions (list): a list of jinja extensions to use local_env_vars_key (str): the key to use under which to put local environment variables ignore_undefined (bool): whether to skip replacement when encountering undefined variables (True), or error out (False) """ if additional_jinja_extensions is None: additional_jinja_extensions = [] if replacement_dict is None: replacement_dict = {} if local_env_vars_key: sub_dict = copy.deepcopy({local_env_vars_key: os.environ}) dict_merge(sub_dict, replacement_dict, copy_dct=False) else: sub_dict = replacement_dict trim_blocks = True if not ignore_undefined: env = Environment( extensions=additional_jinja_extensions, trim_blocks=trim_blocks, block_start_string=block_start_string, block_end_string=block_end_string, variable_start_string=variable_start_string, variable_end_string=variable_end_string) else: env = Environment( extensions=additional_jinja_extensions, trim_blocks=trim_blocks, block_start_string=block_start_string, block_end_string=block_end_string, variable_start_string=variable_start_string, variable_end_string=variable_end_string, undefined=IgnoreUndefinedJinjaVariable) result = env.from_string(template_string).render(sub_dict) return result
[docs]def reindent(s, numSpaces, keep_current=True): """Reindents a string. Args: s (str): the string numSpaces (int): the indent keep_current (bool): keep a potential current indention and add to it Returns: str: the indented string """ # s = string.split(s, '\n') s = s.split("\n") if keep_current: s = [(numSpaces * " ") + line for line in s] else: s = [(numSpaces * " ") + line.lstrip() for line in s] s = "\n".join(s) return s
[docs]def readable(python_object, out="raw", safe=True, indent=0): """Utility method to print out readable strings from python objects (mostly dicts). Args: python_object (obj): the object to print out (str): the format of the output (available: 'yaml', 'json', 'raw', and 'pformat') safe (bool): whether to use a 'safe' way of converting to string (if available in the output format type) indent (int): the indentation (optional) """ if out is None: out = "raw" if out == "yaml": if safe: ryaml = StringYAML(typ="safe") ryaml.default_flow_style = False output_string = ryaml.dump(python_object) # output_string = yaml.safe_dump( # python_object, # default_flow_style=False, # encoding='utf-8', # allow_unicode=True) else: ryaml = StringYAML() ryaml.default_flow_style = False output_string = ryaml.dump(python_object) # output_string = yaml.dump( # python_object, # default_flow_style=False, # encoding='utf-8', # allow_unicode=True) elif out == "json": output_string = json.dumps(python_object, sort_keys=4, indent=4) elif out == "raw": output_string = str(python_object) elif out == "pformat": output_string = pprint.pformat(python_object) else: raise Exception( "No valid output format provided. Supported: 'yaml', 'json', 'raw', 'pformat'" ) if indent != 0: output_string = reindent(output_string, indent) return output_string
[docs]def readable_raw(python_object, indent=0): """Shortcut for using the :func:`readable` method with the 'raw' format.""" return readable(python_object, out="raw", indent=indent)
[docs]def readable_json(python_object, indent=0): """Shortcut for using the :func:`readable` method with the 'json' format.""" return readable(python_object, out="json", indent=indent)
[docs]def readable_yaml(python_object, indent=0, safe=True): """Shortcut for using the :func:`readable` method with the 'yaml' format.""" result = readable(python_object, out="yaml", indent=indent, safe=safe) return result
[docs]def readable_pformat(python_object, indent=0): """Shortcut for using the :func:`readable` method with the 'pformat' format.""" return readable(python_object, out="pformat", indent=indent)
[docs]def ensure_parent_dir(path): """Makes sure a parent directory exists. Args: path: the path to a file Returns: str: the parent dir """ parent = os.path.dirname(path) if not os.path.exists(path): os.makedirs(parent) elif not os.path.isdir(os.path.realpath(parent)): raise Exception("Can't create parent dir, file already exists: {}".format(parent)) return parent
[docs]def calculate_cache_location_for_url(url): """Utility method to get a unique path that can be used for caching a download.""" REPL_CHARS = "[^_\-A-Za-z0-9\.]+" path = re.sub(REPL_CHARS, os.sep, url) return path
[docs]def can_passwordless_sudo(): """Checks if the user can use passwordless sudo on this host.""" if os.geteuid() == 0: return True FNULL = open(os.devnull, "w") # use -k to ignore any existing sudo token p = subprocess.Popen( "sudo -k -n true", shell=True, stdout=FNULL, stderr=subprocess.STDOUT, close_fds=True, ) r = p.wait() return r == 0