Source code for pyflare.sdk.writers.iceberg_writer

import ast
import base64
import time
import warnings
import pyspark.sql.functions as F
from pyflare.sdk.config import constants
from pyflare.sdk.config.constants import S3_ICEBERG_FILE_IO
from pyflare.sdk.config.write_config import WriteConfig
from pyflare.sdk.utils import pyflare_logger, generic_utils
from pyflare.sdk.writers.file_writer import FileOutputWriter
from pyspark.sql.readwriter import DataFrameWriterV2

warnings.filterwarnings("ignore", message=".*version-hint.text.*")


[docs]class IcebergOutputWriter(FileOutputWriter): ICEBERG_CONF = '''[ ("spark.sql.catalog.{catalog_name}", "org.apache.iceberg.spark.SparkCatalog"), ("spark.sql.catalog.{catalog_name}.type", "hadoop"), ("spark.sql.catalog.{catalog_name}.warehouse", "{depot_base_path}") ]''' def __init__(self, write_config: WriteConfig): super().__init__(write_config) self.log = pyflare_logger.get_pyflare_logger(name=__name__)
[docs] def write(self, df): if "merge" in self.write_config.extra_options.keys(): depot = self.write_config.depot_details.get("depot") collection = self.write_config.depot_details.get("collection") dataset = self.write_config.depot_details.get("dataset") view_name = f"{depot}_{collection}_{dataset}_{int(time.time() * 1e9)}" df.createOrReplaceTempView(view_name) self.spark.sql(self.__merge_into_query(view_name, depot, collection, dataset)) else: spark_options = self.write_config.spark_options table_properties = self.write_config.extra_options.get("table_properties", {}) io_format = self.write_config.io_format dataset_path = generic_utils.get_dataset_path(self.write_config) # df = self.spark.sql(f"select * from {self.view_name}") df_writer = df.writeTo(dataset_path).using(io_format) if spark_options: df_writer = df_writer.options(**spark_options) self.log.info(f"spark options: {spark_options}") if table_properties: self.log.info(f"table_properties: {table_properties}") df_writer = df_writer.tableProperty(**table_properties) df_writer = self.__process_partition_conf(df_writer) self.__write_mode(df_writer)
[docs] def write_stream(self): pass
[docs] def get_conf(self): return getattr(self, f"_{self.write_config.depot_type()}_{self.write_config.io_format}")()
def _abfss_iceberg(self): dataset_absolute_path = self.write_config.dataset_absolute_path() # depot_base_path = dataset_absolute_path.split(self.write_config.collection())[0] iceberg_conf = ast.literal_eval(self.ICEBERG_CONF.format(catalog_name=self.write_config.depot_name(), depot_base_path=dataset_absolute_path)) iceberg_conf.extend(generic_utils.get_abfss_spark_conf(self.write_config)) return iceberg_conf def _s3_iceberg(self): dataset_absolute_path = self.write_config.dataset_absolute_path() # depot_base_path = dataset_absolute_path.split(self.write_config.collection())[0] iceberg_conf = ast.literal_eval(self.ICEBERG_CONF.format(catalog_name=self.write_config.depot_name(), depot_base_path=dataset_absolute_path)) iceberg_conf.append(S3_ICEBERG_FILE_IO) iceberg_conf.extend(generic_utils.get_s3_spark_conf(self.write_config)) return iceberg_conf def _gcs_iceberg(self): dataset_absolute_path = self.write_config.dataset_absolute_path() # depot_base_path = dataset_absolute_path.split(self.write_config.collection())[0] iceberg_conf = ast.literal_eval(self.ICEBERG_CONF.format(catalog_name=self.write_config.depot_name(), depot_base_path=dataset_absolute_path)) iceberg_conf.extend(generic_utils.get_gcs_spark_conf(self.write_config)) return iceberg_conf def __process_partition_conf(self, df_writer: DataFrameWriterV2) -> DataFrameWriterV2: partition_column_list = [] for temp_dict in self.write_config.extra_options.get("partition", []): partition_scheme: str = temp_dict.get("type", "") partition_column: str = temp_dict.get("column", "") if partition_scheme.casefold() in ["year", "month", "day", "hour"]: self.log.info(f"partition scheme: {partition_scheme}, partition column: {partition_column}") partition_column_list.append(getattr(F, f"{partition_scheme}s")(partition_column)) elif partition_scheme.casefold() == "bucket": bucket_count: int = temp_dict.get("bucket_count", 8) self.log.info( f"partition scheme: {partition_scheme}, partition column: {partition_column}, " f"bucket_count: {bucket_count}") self.log.info(f"F.bucket({bucket_count}, {partition_column}") partition_column_list.append(getattr(F, f"{partition_scheme}")(bucket_count, F.col(partition_column))) elif partition_scheme.casefold() == "identity": self.log.info(f"partition column: {partition_column}") partition_column_list.append(F.col(partition_column)) else: self.log.warn(f"Invalid partition scheme: {partition_scheme}") if partition_column_list: df_writer = df_writer.partitionedBy(*partition_column_list) return df_writer def __write_mode(self, df: DataFrameWriterV2): if self.write_config.mode in ["create", "overwrite", "write"]: df.createOrReplace() elif self.write_config.mode in ['overwriteByPartition']: df.overwritePartitions() else: df.append() def __merge_into_query(self, source_view: str, depot, collection, dataset): merge_clauses = self.write_config.extra_options.get("merge", {}) query = f"MERGE INTO {depot}.{collection}.{dataset} as target \n" query += f"USING (select * from {source_view}) as source \n" query += f"ON {merge_clauses.get('onClause', '')} \n" query += f"{merge_clauses.get('whenClause', '')} \n" return query