# -*- coding: utf-8 -*-
"""Utility methods that are used across the frkl-suite (https://frkl.io) of tools."""
import copy
import json
import logging
import os
import pprint
import re
import subprocess
from collections import Mapping, OrderedDict
from jinja2 import Environment, Undefined
from ruamel.yaml import YAML
from ruamel.yaml.compat import StringIO
from six import string_types
from .defaults import * # noqa
log = logging.getLogger("frutils")
yaml = YAML()
# utlity methods/classes
[docs]class StringYAML(YAML):
"""Wraps :class:~YAML to be able to dump a string from a yaml object.
More details: http://yaml.readthedocs.io/en/latest/example.html#output-of-dump-as-a-string
Args:
**kwargs (dict): arguments for the underlying :class:~YAML class
"""
def __init__(self, **kwargs):
super(StringYAML, self).__init__(**kwargs)
[docs] def dump(self, data, stream=None, **kw):
inefficient = False
if stream is None:
inefficient = True
stream = StringIO()
YAML.dump(self, data, stream, **kw)
if inefficient:
return stream.getvalue()
[docs]def ordered_load(text):
"""Loads a yaml stream into an OrderedDict
"""
return yaml.load(text)
[docs]def special_dict_to_dict(value):
"""Converts any 'special' dict (like CommentedMap, OrderedDict) to a 'normal' one.
Args:
value (dict): the 'special' dict
Returns:
dict: the 'normal' dict
"""
for k, v in value.items():
if isinstance(v, dict):
value[k] = special_dict_to_dict(v)
return dict(value)
[docs]def list_of_special_dicts_to_list_of_dicts(value):
result = []
for sd in value:
result.append(special_dict_to_dict(sd))
return result
[docs]def is_list_of_strings(input_obj):
"""Helper method to determine whether an object is a list or tuple of only strings (or string_types).
Args:
input_obj (object): the object in question
Returns:
bool: whether or not the object is a list of strings
"""
return (
bool(input_obj)
and isinstance(input_obj, (list, tuple))
and not isinstance(input_obj, string_types)
and all(isinstance(item, string_types) for item in input_obj)
)
[docs]def dict_merge(dct, merge_dct, copy_dct=True):
""" Recursive dict merge. Inspired by :meth:``dict.update()``, instead of
updating only top-level keys, dict_merge recurses down into dicts nested
to an arbitrary depth, updating keys. The ``merge_dct`` is merged into
``dct``.
Copied from: https://gist.github.com/angstwad/bf22d1822c38a92ec0a9
Args:
dct (dict): dict onto which the merge is executed
merge_dct (dict): dct merged into dct
copy_dct (bool): whether to (deep-)copy dct before merging (and leaving it unchanged), or not (default: copy)
Returns:
dict: the merged dict (original or copied)
"""
if copy_dct:
dct = copy.deepcopy(dct)
for k, v in merge_dct.items():
if k in dct and isinstance(dct[k], dict) and isinstance(merge_dct[k], Mapping):
dict_merge(dct[k], merge_dct[k], copy_dct=False)
else:
dct[k] = merge_dct[k]
return dct
[docs]def merge_list_of_dicts(dicts, starting_dict=None):
"""Merges a list of dicts.
Args:
dicts (list): list of dicts to be merged in order
starting_dict (dict): (optional) existing dict where the others are merged into
Returns:
dict: the merged dict (same as starting_dict)
"""
if starting_dict is None:
starting_dict = {}
for d in dicts:
dict_merge(starting_dict, d, copy_dct=False)
return starting_dict
[docs]def is_url_or_abbrev(url, abbrevs=DEFAULT_URL_ABBREVIATIONS_FILE):
if url.startswith("http://") or url.startswith("https://"):
return True
for key in abbrevs.keys():
if url.startswith("{}:".format(key)):
return True
return False
[docs]def get_key_path_value(source_dict, key_path, split_token=".", default_value=None):
"""Queries the source dict tree for the register key, split up using the split_token.
Args:
source_dict (dict): the source dictionary
key_path (str): a key-path, e.g. 'key1.child_key1.test_key'
split_token (str): the character to split the register_key value with
default_value: the default value to return if no key matches
Return:
object: the value for the matching in the tree, or None
"""
temp_dict = source_dict
tokens = key_path.split(split_token)
for key in tokens[0:-1]:
temp_dict = temp_dict.get(key, None)
if temp_dict is None:
return default_value
elif not isinstance(temp_dict, dict):
return default_value
return temp_dict.get(tokens[-1], default_value)
[docs]def add_key_to_dict(target_dict, key_path, value, split_token=".", ordered=True):
"""Add a key into the key path of a dictionary.
Args:
target_dict (dict): the dictionary to add to
key_path (str): the path to the value (e.g. key1.child_key.test)
value: the value to insert
split_token (str): the character to split the key_path with.
ordered (bool): whether to use OrderedDicts instead of 'normal' ones
"""
log.debug("Adding value '{}' as key '{}' to dict.".format(value, key_path))
if ordered:
temp_dict = OrderedDict()
else:
temp_dict = {}
orig_temp_dict = temp_dict
tokens = key_path.split(split_token)
for key in tokens[0:-1]:
if ordered:
temp_dict.setdefault(key, OrderedDict())
else:
temp_dict.setdefault(key, {})
temp_dict = temp_dict[key]
temp_dict[tokens[-1]] = value
dict_merge(target_dict, orig_temp_dict, copy_dct=False)
[docs]def append_key_to_dict(target_dict, key_path, value, split_token="."):
"""Appends a key into the key path of a dictionary.
The value either needs to be a list or tuple, or nonexistent. If this is not the case, this function will throw an exception.
Args:
target_dict (dict): the dictionary to add to
key_path (str): the path to the value (e.g. key1.child_key.test)
value: the value to append
split_token (str): the character to split the key_path with.
"""
log.debug("Adding value '{}' as key '{}' to dict.".format(value, key_path))
temp_dict = copy.deepcopy(target_dict)
orig_temp_dict = temp_dict
tokens = key_path.split(split_token)
for key in tokens[0:-1]:
temp_dict.setdefault(key, {})
temp_dict = temp_dict[key]
old_value = temp_dict.get(tokens[-1], None)
if old_value is None:
temp_dict[tokens[-1]] = [value]
elif isinstance(old_value, (list, tuple)):
old_value.append(value)
else:
raise Exception(
"Value for key_path '{}' is not a list or tuple, can't append.".format(
key_path
)
)
dict_merge(target_dict, orig_temp_dict, copy_dct=False)
[docs]def flatten_lists(lists):
"""Utility method to flatten a list of lists.
This will only flatten the first sublist, any deeper list-structure will be preserved.
Args:
lists (list): a list of lists
Returns:
list: the flattened list
"""
result = [item for sublist in lists for item in sublist]
return result
[docs]def is_templated(text, jinja_delimiter_profile=JINJA_DELIMITER_PROFILES["default"]):
"""Utility method to determine whether a string has template markers in it.
This is pretty simplistic, it only checks whether one of the template marker strings
(e.g. '}}', or '{%') are contained in the text. It doesn't check for matching opening/
closed brackets etc.
Args:
text (str): the text in question
jinja_delimiter_profile (dict): a dictionary with
Returns:
bool: whether the text is templated or not
"""
for key, value in jinja_delimiter_profile.items():
if value in text:
return True
return False
[docs]class IgnoreUndefinedJinjaVariable(Undefined):
def __fail__with_undefined_error(self, *args, **kwargs):
log.debug("Missing jinja var")
return None
[docs]def replace_string(
template_string,
replacement_dict=None,
block_start_string=DEFAULT_BLOCK_START_STRING,
block_end_string=DEFAULT_BLOCK_END_STRING,
variable_start_string=DEFAULT_VARIABLE_START_STRING,
variable_end_string=DEFAULT_VARIABLE_END_STRING,
additional_jinja_extensions=None,
local_env_vars_key=DEFAULT_LOCAL_ENV_VARS_KEY,
ignore_undefined=False
):
"""Replace template markers with values from a replacement dictionary within a string.
Args:
template_string (str): the template string
replacement_dict (dict): the dictionary with the replacement strings
block_start_string (str): the string to indicate a template block start
block_end_string (str): the string to indicate a template block end
variable_start_string (str): the string to indicate a template variable start
variable_end_string (str): the string to indicate a template variable end
additional_jinja_extensions (list): a list of jinja extensions to use
local_env_vars_key (str): the key to use under which to put local environment variables
ignore_undefined (bool): whether to skip replacement when encountering undefined variables (True), or error out (False)
"""
if additional_jinja_extensions is None:
additional_jinja_extensions = []
if replacement_dict is None:
replacement_dict = {}
if local_env_vars_key:
sub_dict = copy.deepcopy({local_env_vars_key: os.environ})
dict_merge(sub_dict, replacement_dict, copy_dct=False)
else:
sub_dict = replacement_dict
trim_blocks = True
if not ignore_undefined:
env = Environment(
extensions=additional_jinja_extensions,
trim_blocks=trim_blocks,
block_start_string=block_start_string,
block_end_string=block_end_string,
variable_start_string=variable_start_string,
variable_end_string=variable_end_string)
else:
env = Environment(
extensions=additional_jinja_extensions,
trim_blocks=trim_blocks,
block_start_string=block_start_string,
block_end_string=block_end_string,
variable_start_string=variable_start_string,
variable_end_string=variable_end_string,
undefined=IgnoreUndefinedJinjaVariable)
result = env.from_string(template_string).render(sub_dict)
return result
[docs]def reindent(s, numSpaces, keep_current=True):
"""Reindents a string.
Args:
s (str): the string
numSpaces (int): the indent
keep_current (bool): keep a potential current indention and add to it
Returns:
str: the indented string
"""
# s = string.split(s, '\n')
s = s.split("\n")
if keep_current:
s = [(numSpaces * " ") + line for line in s]
else:
s = [(numSpaces * " ") + line.lstrip() for line in s]
s = "\n".join(s)
return s
[docs]def readable(python_object, out="raw", safe=True, indent=0):
"""Utility method to print out readable strings from python objects (mostly dicts).
Args:
python_object (obj): the object to print
out (str): the format of the output (available: 'yaml', 'json', 'raw', and 'pformat')
safe (bool): whether to use a 'safe' way of converting to string (if available in the output format type)
indent (int): the indentation (optional)
"""
if out is None:
out = "raw"
if out == "yaml":
if safe:
ryaml = StringYAML(typ="safe")
ryaml.default_flow_style = False
output_string = ryaml.dump(python_object)
# output_string = yaml.safe_dump(
# python_object,
# default_flow_style=False,
# encoding='utf-8',
# allow_unicode=True)
else:
ryaml = StringYAML()
ryaml.default_flow_style = False
output_string = ryaml.dump(python_object)
# output_string = yaml.dump(
# python_object,
# default_flow_style=False,
# encoding='utf-8',
# allow_unicode=True)
elif out == "json":
output_string = json.dumps(python_object, sort_keys=4, indent=4)
elif out == "raw":
output_string = str(python_object)
elif out == "pformat":
output_string = pprint.pformat(python_object)
else:
raise Exception(
"No valid output format provided. Supported: 'yaml', 'json', 'raw', 'pformat'"
)
if indent != 0:
output_string = reindent(output_string, indent)
return output_string
[docs]def readable_raw(python_object, indent=0):
"""Shortcut for using the :func:`readable` method with the 'raw' format."""
return readable(python_object, out="raw", indent=indent)
[docs]def readable_json(python_object, indent=0):
"""Shortcut for using the :func:`readable` method with the 'json' format."""
return readable(python_object, out="json", indent=indent)
[docs]def readable_yaml(python_object, indent=0, safe=True):
"""Shortcut for using the :func:`readable` method with the 'yaml' format."""
result = readable(python_object, out="yaml", indent=indent, safe=safe)
return result
[docs]def ensure_parent_dir(path):
"""Makes sure a parent directory exists.
Args:
path: the path to a file
Returns:
str: the parent dir
"""
parent = os.path.dirname(path)
if not os.path.exists(path):
os.makedirs(parent)
elif not os.path.isdir(os.path.realpath(parent)):
raise Exception("Can't create parent dir, file already exists: {}".format(parent))
return parent
[docs]def calculate_cache_location_for_url(url):
"""Utility method to get a unique path that can be used for caching a download."""
REPL_CHARS = "[^_\-A-Za-z0-9\.]+"
path = re.sub(REPL_CHARS, os.sep, url)
return path
[docs]def can_passwordless_sudo():
"""Checks if the user can use passwordless sudo on this host."""
if os.geteuid() == 0:
return True
FNULL = open(os.devnull, "w")
# use -k to ignore any existing sudo token
p = subprocess.Popen(
"sudo -k -n true",
shell=True,
stdout=FNULL,
stderr=subprocess.STDOUT,
close_fds=True,
)
r = p.wait()
return r == 0