Source code for pyflare.sdk.utils.generic_utils

import base64
import json
import os
import re
import pkg_resources

from functools import wraps

from py4j.java_gateway import java_import

from pyflare.sdk.config.read_config import ReadConfig
from pyflare.sdk.utils import pyflare_logger
from pyflare.sdk.config.constants import DEPOT_SECRETS_KV_REGEX, DATAOS_DEFAULT_SECRET_DIRECTORY, S3_ACCESS_KEY_ID, \
    S3_ACCESS_SECRET_KEY, S3_SPARK_CONFS, GCS_AUTH_ACCOUNT_ENABLED, GCS_ACCOUNT_EMAIL, GCS_PROJECT_ID, \
    GCS_ACCOUNT_PRIVATE_KEY, GCS_ACCOUNT_PRIVATE_KEY_ID, AZURE_ACCOUNT_KEY_PREFIX, AZURE_ACCOUNT_KEY, \
    DATAOS_ADDRESS_RESOLVER_REGEX


# import builtins
#
#
# def my_print(*args, **kwargs):
#     # Do something with the arguments
#     # Replace sensitive strings with a placeholder value
#     redacted_text = re.sub('(?i)secret|password|key|abfss|dfs|apikey', '*****', " ".join(str(arg) for arg in args))
#     # Print the redacted text
#     builtins.print(redacted_text)
from pyflare.sdk.utils.pyflare_exceptions import MissingEnvironmentVariable


[docs]def decorate_logger(fn): @wraps(fn) def wrapper(*args, **kwargs): log = pyflare_logger.get_pyflare_logger(name=__name__) log.debug('About to run %s' % fn.__name__) out = fn(*args, **kwargs) log.debug('Done running %s' % fn.__name__) return out return wrapper
[docs]def append_properties(dict1: dict, dict2: dict) -> dict: for key, value in dict2.items(): dict1[key] = value return dict1
[docs]def safe_assignment(val1, val2): """ Returns val1 if val2 is None, else return val2 """ if val2: return val2 return val1
[docs]def get_jars_path(): flare_sdk_jar_path = pkg_resources.resource_filename('pyflare.jars', 'flare_2.12-3.3.1-0.0.14.1-javadoc.jar') heimdall_jar_path = pkg_resources.resource_filename('pyflare.jars', 'heimdall-0.1.9.jar') commons_jar_path = pkg_resources.resource_filename('pyflare.jars', 'commons-0.1.9.jar') spark_jar_path = pkg_resources.resource_filename('pyflare.jars', 'spark-authz-0.1.9.jar') josn4s_jar_path = pkg_resources.resource_filename('pyflare.jars', 'json4s-jackson_2.12-3.6.12.jar') josn4s_jar_path = pkg_resources.resource_filename('pyflare.jars', 'json4s-jackson_2.12-4.0.6.jar') flare_jar_path = pkg_resources.resource_filename('pyflare.jars', 'flare_4.jar') return f"{commons_jar_path},{heimdall_jar_path}, {flare_sdk_jar_path}, {josn4s_jar_path}, {spark_jar_path}"
[docs]def get_abfss_spark_conf(rw_config: ReadConfig): dataset_absolute_path = rw_config.dataset_absolute_path() dataset_auth_token = get_secret_token(rw_config.depot_details) account = rw_config.depot_details.get("connection", {}).get("account", "") endpoint_suffix = dataset_absolute_path.split(account)[1].split("/")[0].strip(". ") dataset_auth_key = "{}.{}.{}".format(AZURE_ACCOUNT_KEY_PREFIX, account, endpoint_suffix) return [(dataset_auth_key, dataset_auth_token)]
[docs]def get_s3_spark_conf(rw_config: ReadConfig): access_key_id = rw_config.depot_details.get("secrets", {}).get("accesskeyid", "") access_key_secret = rw_config.depot_details.get("secrets", {}).get("awssecretaccesskey", "") aws_access_key_id = (S3_ACCESS_KEY_ID, access_key_id) aws_access_key_secret = (S3_ACCESS_SECRET_KEY, access_key_secret) spark_conf = [aws_access_key_id, aws_access_key_secret] spark_conf.extend(S3_SPARK_CONFS) return spark_conf
[docs]def get_gcs_spark_conf(rw_config: ReadConfig): client_email = rw_config.depot_details.get("secrets", {}).get("client_email", "") project_id = rw_config.depot_details.get("secrets", {}).get("project_id", "") private_key = rw_config.depot_details.get("secrets", {}).get("private_key", "") private_key_id = rw_config.depot_details.get("secrets", {}).get("private_key_id", "") private_key_file_path = rw_config.depot_details.get("secrets", {}).get(f"{rw_config.depot_name()}_secrets_file_path", "") return [ # ("spark.hadoop.google.cloud.auth.service.account.json.keyfile", "/etc/dataos/secret/depot.*.json"), (GCS_AUTH_ACCOUNT_ENABLED, "true"), (GCS_ACCOUNT_EMAIL, client_email), (GCS_PROJECT_ID, project_id), (GCS_ACCOUNT_PRIVATE_KEY, private_key_file_path), (GCS_ACCOUNT_PRIVATE_KEY_ID, private_key_id), ]
[docs]def get_secret_token(depot_details) -> str: return depot_details.get("secrets", {}).get(AZURE_ACCOUNT_KEY, "")
[docs]def get_dataset_path(depot_config) -> str: return "{}.{}.{}".format(depot_config.depot_name(), depot_config.collection(), depot_config.dataset_name())
[docs]def decode_base64_string(encoded_string: str, type: str) -> dict: decoded_string = base64.b64decode(encoded_string).decode('utf-8') if type.casefold() == "json": key_value_pairs = json.loads(decoded_string) else: key_value_pairs = re.findall(DEPOT_SECRETS_KV_REGEX, decoded_string) return dict(key_value_pairs)
[docs]def get_secret_file_path() -> str: return DATAOS_DEFAULT_SECRET_DIRECTORY if os.getenv("DATAOS_SECRET_DIR") is None else \ os.getenv("DATAOS_SECRET_DIR").rstrip('/')
[docs]def write_string_to_file(file_path: str, string_data: str, overwrite: bool = True) -> None: log = pyflare_logger.get_pyflare_logger() if not overwrite and os.path.exists(file_path) and os.path.getsize(file_path) > 0: log.info("File exists and is not empty") else: log.info("Creating file at path: %s", file_path) try: with open(file_path, "w") as file: file.write(string_data) log.info(f"Data written successfully to: {file_path}") except Exception as e: log.error(f"Error writing data to the file: {str(e)}")
[docs]def write_dict_to_file(file_path: str, data_dict: dict, overwrite: bool = True) -> None: log = pyflare_logger.get_pyflare_logger() if not overwrite and os.path.exists(file_path) and os.path.getsize(file_path) > 0: log.info("File exists and is not empty") else: log.info("Creating file at path: %s", file_path) try: with open(file_path, "w") as file: json.dump(data_dict, file) log.info(f"Dictionary Data written successfully to: {file_path}") except Exception as e: log.error(f"Error writing data dictionary to the file: {str(e)}")
[docs]def resolve_dataos_address(dataos_address: str) -> dict: matches = re.match(DATAOS_ADDRESS_RESOLVER_REGEX, dataos_address) parsed_address = {} if matches: parsed_address["depot"] = matches.groups()[0] parsed_address["collection"] = matches.groups()[2] parsed_address["dataset"] = matches.groups()[4] return parsed_address
[docs]def get_env_variable(env_variable: str) -> str: value = os.environ.get(env_variable, "") if len(value) < 1: raise MissingEnvironmentVariable(f"{env_variable} is not set") return value
[docs]def enhance_connection_url(connection_url: str, collection: str, dataset: str) -> str: if collection and collection.casefold() != "none": connection_url += f"/{collection}" if dataset: connection_url += f"/{dataset}" return connection_url
[docs]def authorize_user(spark, heimdallClient, apikey): log = pyflare_logger.get_pyflare_logger() response = heimdallClient.getAuthorizeApi().authorize(apikey).execute() if response.isSuccessful(): json_response = response.body() user_id = json_response.getResult().getId() return user_id else: log.error(f"Error: {response.code()}, {response.message()}") return None