Ask AI

You are viewing an unreleased or outdated version of the documentation

Source code for dagster_polars.io_managers.delta

import json
from enum import Enum
from pprint import pformat
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union, overload

import dagster._check as check
import polars as pl
from dagster import InputContext, MetadataValue, OutputContext
from dagster._annotations import experimental
from dagster._core.storage.upath_io_manager import is_dict_type

from dagster_polars.io_managers.base import BasePolarsUPathIOManager
from dagster_polars.types import DataFrameWithMetadata, LazyFrameWithMetadata, StorageMetadata

try:
    from deltalake import DeltaTable
    from deltalake.exceptions import TableNotFoundError
except ImportError as e:
    raise ImportError("Install 'dagster-polars[deltalake]' to use DeltaLake functionality") from e

if TYPE_CHECKING:
    from upath import UPath


DAGSTER_POLARS_STORAGE_METADATA_SUBDIR = ".dagster_polars_metadata"

SINGLE_LOADING_TYPES = (pl.DataFrame, pl.LazyFrame, LazyFrameWithMetadata, DataFrameWithMetadata)


class DeltaWriteMode(str, Enum):
    error = "error"
    append = "append"
    overwrite = "overwrite"
    ignore = "ignore"


[docs]@experimental class PolarsDeltaIOManager(BasePolarsUPathIOManager): """Implements writing and reading DeltaLake tables. Features: - All features provided by :py:class:`~dagster_polars.BasePolarsUPathIOManager`. - All read/write options can be set via corresponding metadata or config parameters (metadata takes precedence). - Supports native DeltaLake partitioning by storing different asset partitions in the same DeltaLake table. To enable this behavior, set the `partition_by` metadata value or config parameter (it's passed to `delta_write_options` of `pl.DataFrame.write_delta`). Automatically filters loaded partitions, unless `MultiPartitionsDefinition` is used. With `MultiPartitionsDefinition` you are responsible for filtering the partitions in the downstream asset, as it's non-trivial to do so in the IOManager. When loading all available asset partitions, the whole table can be loaded in one go by using type annotations like `pl.DataFrame` and `pl.LazyFrame`. - Supports writing/reading custom metadata to/from `.dagster_polars_metadata/<version>.json` file in the DeltaLake table directory. Install `dagster-polars[delta]` to use this IOManager. Examples: .. code-block:: python from dagster import asset from dagster_polars import PolarsDeltaIOManager import polars as pl @asset( io_manager_key="polars_delta_io_manager", key_prefix=["my_dataset"] ) def my_asset() -> pl.DataFrame: # data will be stored at <base_dir>/my_dataset/my_asset.delta ... defs = Definitions( assets=[my_table], resources={ "polars_parquet_io_manager": PolarsDeltaIOManager(base_dir="s3://my-bucket/my-dir") } ) Appending to a DeltaLake table: .. code-block:: python @asset( io_manager_key="polars_delta_io_manager", metadata={ "mode": "append" }, ) def my_table() -> pl.DataFrame: ... Using native DeltaLake partitioning by storing different asset partitions in the same DeltaLake table: .. code-block:: python from dagster import AssetExecutionContext, DailyPartitionedDefinition from dagster_polars import LazyFramePartitions @asset( io_manager_key="polars_delta_io_manager", metadata={ "partition_by": "partition_col" }, partitions_def=... ) def upstream(context: AssetExecutionContext) -> pl.DataFrame: df = ... # add partition to the DataFrame df = df.with_columns(pl.lit(context.partition_key).alias("partition_col")) return df @asset def downstream(upstream: LazyFramePartitions) -> pl.DataFrame: # concat LazyFrames, filter required partitions and call .collect() ... """ extension: str = ".delta" mode: DeltaWriteMode = DeltaWriteMode.overwrite.value # type: ignore overwrite_schema: bool = False version: Optional[int] = None # tmp fix until UPathIOManager supports this: added special handling for loading all partitions of an asset def load_input(self, context: InputContext) -> Union[Any, Dict[str, Any]]: # If no asset key, we are dealing with an op output which is always non-partitioned if not context.has_asset_key or not context.has_asset_partitions: path = self._get_path(context) return self._load_single_input(path, context) else: asset_partition_keys = context.asset_partition_keys if len(asset_partition_keys) == 0: return None elif len(asset_partition_keys) == 1: paths = self._get_paths_for_partitions(context) check.invariant(len(paths) == 1, f"Expected 1 path, but got {len(paths)}") path = next(iter(paths.values())) backcompat_paths = self._get_multipartition_backcompat_paths(context) backcompat_path = ( None if not backcompat_paths else next(iter(backcompat_paths.values())) ) return self._load_partition_from_path( context=context, partition_key=asset_partition_keys[0], path=path, backcompat_path=backcompat_path, ) else: # we are dealing with multiple partitions of an asset type_annotation = context.dagster_type.typing_type if type_annotation == Any or is_dict_type(type_annotation): return self._load_multiple_inputs(context) # special case of loading the whole DeltaLake table at once # when using AllPartitionMappings and native DeltaLake partitioning elif ( context.upstream_output is not None and context.upstream_output.metadata is not None and context.upstream_output.metadata.get("partition_by") is not None and type_annotation in SINGLE_LOADING_TYPES and context.upstream_output.asset_info is not None and context.upstream_output.asset_info.partitions_def is not None and set(asset_partition_keys) == set( context.upstream_output.asset_info.partitions_def.get_partition_keys( dynamic_partitions_store=context.instance ) ) ): # load all partitions at once return self.load_from_path( context=context, path=self.get_path_for_partition( context=context, partition=asset_partition_keys[0], # 0 would work, path=self._get_paths_for_partitions(context)[ asset_partition_keys[0] ], # 0 would work, ), partition_key=None, ) else: check.failed( "Loading an input that corresponds to multiple partitions, but the" f" type annotation on the op input is not a dict, Dict, Mapping, one of {SINGLE_LOADING_TYPES}," " or Any: is '{type_annotation}'." ) def sink_df_to_path( self, context: OutputContext, df: pl.LazyFrame, path: "UPath", metadata: Optional[StorageMetadata] = None, ): context_metadata = context.metadata or {} streaming = context_metadata.get("streaming", False) return self.write_df_to_path(context, df.collect(streaming=streaming), path, metadata) def write_df_to_path( self, context: OutputContext, df: pl.DataFrame, path: "UPath", metadata: Optional[StorageMetadata] = None, # why is metadata passed ): context_metadata = context.metadata or {} delta_write_options = context_metadata.get( "delta_write_options" ) # This needs to be gone and just only key value on the metadata if context.has_asset_partitions: delta_write_options = delta_write_options or {} partition_by = context_metadata.get( "partition_by" ) # this could be wrong, you could have partition_by in delta_write_options and in the metadata if partition_by is not None: assert ( context.partition_key is not None ), 'can\'t set "partition_by" for an asset without partitions' delta_write_options["partition_by"] = partition_by delta_write_options["partition_filters"] = [ (partition_by, "=", context.partition_key) ] if delta_write_options is not None: context.log.debug(f"Writing with delta_write_options: {pformat(delta_write_options)}") storage_options = self.storage_options try: dt = DeltaTable(str(path), storage_options=storage_options) except TableNotFoundError: dt = str(path) df.write_delta( dt, mode=context_metadata.get("mode") or self.mode.value, overwrite_schema=context_metadata.get("overwrite_schema") or self.overwrite_schema, storage_options=storage_options, delta_write_options=delta_write_options, ) if isinstance(dt, DeltaTable): current_version = dt.version() else: current_version = DeltaTable( str(path), storage_options=storage_options, without_files=True ).version() context.add_output_metadata({"version": current_version}) if metadata is not None: metadata_path = self.get_storage_metadata_path(path, current_version) metadata_path.parent.mkdir(parents=True, exist_ok=True) metadata_path.write_text(json.dumps(metadata)) @overload def scan_df_from_path( self, path: "UPath", context: InputContext, with_metadata: Literal[None, False] ) -> pl.LazyFrame: ... @overload def scan_df_from_path( self, path: "UPath", context: InputContext, with_metadata: Literal[True] ) -> LazyFrameWithMetadata: ... def scan_df_from_path( self, path: "UPath", context: InputContext, with_metadata: Optional[bool] = False, ) -> Union[pl.LazyFrame, LazyFrameWithMetadata]: context_metadata = context.metadata or {} version = self.get_delta_version_to_load(path, context) context.log.debug(f"Reading Delta table with version: {version}") ldf = pl.scan_delta( str(path), version=version, delta_table_options=context_metadata.get("delta_table_options"), pyarrow_options=context_metadata.get("pyarrow_options"), storage_options=self.storage_options, ) if with_metadata: version = self.get_delta_version_to_load(path, context) metadata_path = self.get_storage_metadata_path(path, version) if metadata_path.exists(): metadata = json.loads(metadata_path.read_text()) else: metadata = {} return ldf, metadata else: return ldf def get_path_for_partition( self, context: Union[InputContext, OutputContext], path: "UPath", partition: str ) -> "UPath": if isinstance(context, InputContext): if ( context.upstream_output is not None and context.upstream_output.metadata is not None and context.upstream_output.metadata.get("partition_by") is not None ): # upstream asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself return path if isinstance(context, OutputContext): if context.metadata is not None and context.metadata.get("partition_by") is not None: # this asset has "partition_by" metadata set, so partitioning for it is handled by DeltaLake itself return path return path / partition # partitioning is handled by the IOManager def get_metadata( self, context: OutputContext, obj: Union[pl.DataFrame, pl.LazyFrame, None] ) -> Dict[str, MetadataValue]: context_metadata = context.metadata or {} metadata = super().get_metadata(context, obj) if context.has_asset_partitions: partition_by = context_metadata.get("partition_by") if partition_by is not None: metadata["partition_by"] = partition_by if context_metadata.get("mode") == "append": # modify the medatata to reflect the fact that we are appending to the table if context.has_asset_partitions: # paths = self._get_paths_for_partitions(context) # assert len(paths) == 1 # path = list(paths.values())[0] # FIXME: what to about row_count metadata do if we are appending to a partitioned table? # we should not be using the full table length, # but it's unclear how to get the length of the partition we are appending to pass else: metadata["append_row_count"] = metadata["row_count"] path = self._get_path(context) # we need to get row_count from the full table metadata["row_count"] = MetadataValue.int( DeltaTable(str(path), storage_options=self.storage_options) .to_pyarrow_dataset() .count_rows() ) return metadata def get_delta_version_to_load(self, path: "UPath", context: InputContext) -> int: context_metadata = context.metadata or {} version_from_metadata = context_metadata.get("version") version_from_config = self.version version: Optional[int] = None if version_from_metadata is not None and version_from_config is not None: context.log.warning( f"Both version from metadata ({version_from_metadata}) " f"and config ({version_from_config}) are set. Using version from metadata." ) version = int(version_from_metadata) elif version_from_metadata is None and version_from_config is not None: version = int(version_from_config) elif version_from_metadata is not None and version_from_config is None: version = int(version_from_metadata) if version is None: return DeltaTable( str(path), storage_options=self.storage_options, without_files=True ).version() else: return version def get_storage_metadata_path(self, path: "UPath", version: int) -> "UPath": return path / DAGSTER_POLARS_STORAGE_METADATA_SUBDIR / f"{version}.json"