Ask AI

You are viewing an unreleased or outdated version of the documentation

Source code for dagster_polars.io_managers.base

import sys
from abc import abstractmethod
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    Literal,
    Mapping,
    Optional,
    Tuple,
    Union,
    cast,
    get_args,
    get_origin,
    overload,
)

import polars as pl
from dagster import (
    ConfigurableIOManager,
    EnvVar,
    InitResourceContext,
    InputContext,
    MetadataValue,
    OutputContext,
    UPathIOManager,
    _check as check,
)
from dagster._core.storage.upath_io_manager import is_dict_type
from pydantic import PrivateAttr
from pydantic.fields import Field

from dagster_polars.io_managers.utils import get_polars_metadata
from dagster_polars.types import (
    DataFramePartitions,
    DataFramePartitionsWithMetadata,
    LazyFramePartitions,
    LazyFramePartitionsWithMetadata,
    LazyFrameWithMetadata,
    StorageMetadata,
)

if TYPE_CHECKING:
    from upath import UPath

POLARS_EAGER_FRAME_ANNOTATIONS = [
    pl.DataFrame,
    Optional[pl.DataFrame],
    # common default types
    Any,
    type(None),
    None,
    # multiple partitions
    Dict[str, pl.DataFrame],
    Mapping[str, pl.DataFrame],
    DataFramePartitions,
    # DataFrame + metadata
    Tuple[pl.DataFrame, StorageMetadata],
    Optional[Tuple[pl.DataFrame, StorageMetadata]],
    # multiple partitions + metadata
    DataFramePartitionsWithMetadata,
]

POLARS_LAZY_FRAME_ANNOTATIONS = [
    pl.LazyFrame,
    Optional[pl.LazyFrame],
    # multiple partitions
    Dict[str, pl.LazyFrame],
    Mapping[str, pl.LazyFrame],
    LazyFramePartitions,
    # LazyFrame + metadata
    Tuple[pl.LazyFrame, StorageMetadata],
    Optional[Tuple[pl.LazyFrame, StorageMetadata]],
    # multiple partitions + metadata
    LazyFramePartitionsWithMetadata,
]


if sys.version_info >= (3, 9):
    POLARS_EAGER_FRAME_ANNOTATIONS.append(dict[str, pl.DataFrame])
    POLARS_EAGER_FRAME_ANNOTATIONS.append(dict[str, Optional[pl.DataFrame]])

    POLARS_LAZY_FRAME_ANNOTATIONS.append(dict[str, pl.LazyFrame])
    POLARS_LAZY_FRAME_ANNOTATIONS.append(dict[str, Optional[pl.LazyFrame]])


def annotation_is_typing_optional(annotation) -> bool:
    return get_origin(annotation) == Union and type(None) in get_args(annotation)


def annotation_is_tuple(annotation) -> bool:
    return get_origin(annotation) in (Tuple, tuple)


def annotation_for_multiple_partitions(annotation) -> bool:
    if not annotation_is_typing_optional(annotation):
        return annotation_is_tuple(annotation) and get_origin(annotation) in (dict, Dict, Mapping)
    else:
        inner_annotation = get_args(annotation)[0]
        return annotation_is_tuple(inner_annotation) and get_origin(inner_annotation) in (
            dict,
            Dict,
            Mapping,
        )


def annotation_is_tuple_with_metadata(annotation) -> bool:
    if annotation_is_typing_optional(annotation):
        annotation = get_args(annotation)[0]

    return annotation_is_tuple(annotation) and get_origin(get_args(annotation)[1]) in [
        dict,
        Dict,
        Mapping,
    ]


def annotation_for_storage_metadata(annotation) -> bool:
    # first unwrap the Optional type
    if annotation_is_typing_optional(annotation):
        annotation = get_args(annotation)[0]

    if not annotation_for_multiple_partitions(annotation):
        return annotation_is_tuple_with_metadata(annotation)
    else:
        # unwrap the partitions
        annotation = get_args(annotation)[1]
        return annotation_is_tuple_with_metadata(annotation)


def _process_env_vars(config: Mapping[str, Any]) -> Dict[str, Any]:
    out = {}
    for key, value in config.items():
        if isinstance(value, dict) and len(value) == 1 and next(iter(value.keys())) == "env":
            out[key] = EnvVar(next(iter(value.values()))).get_value()
        else:
            out[key] = value
    return out


[docs]class BasePolarsUPathIOManager(ConfigurableIOManager, UPathIOManager): """Base class for `dagster-polars` IOManagers. Doesn't define a specific storage format. To implement a specific storage format (parquet, csv, etc), inherit from this class and implement the `write_df_to_path`, `sink_df_to_path` and `scan_df_from_path` methods. Features: - All the features of :py:class:`~dagster.UPathIOManager` - works with local and remote filesystems (like S3), supports loading multiple partitions with respect to :py:class:`~dagster.PartitionMapping`, and more - loads the correct type - `polars.DataFrame`, `polars.LazyFrame`, or other types defined in :py:mod:`dagster_polars.types` - based on the input type annotation (or `dagster.DagsterType`'s `typing_type`) - can sink lazy `pl.LazyFrame` DataFrames - handles `Nones` with `Optional` types by skipping loading missing inputs or saving `None` outputs - logs various metadata about the DataFrame - size, schema, sample, stats, ... - the `"columns"` input metadata value can be used to select a subset of columns to load """ base_dir: Optional[str] = Field(default=None, description="Base directory for storing files.") cloud_storage_options: Optional[Mapping[str, Any]] = Field( default=None, description="Storage authentication for cloud object store", alias="storage_options", ) _base_path = PrivateAttr() def setup_for_execution(self, context: InitResourceContext) -> None: from upath import UPath sp = ( _process_env_vars(self.cloud_storage_options) if self.cloud_storage_options is not None else {} ) self._base_path = ( UPath(self.base_dir, **sp) if self.base_dir is not None else UPath(check.not_none(context.instance).storage_directory()) ) @abstractmethod def write_df_to_path( self, context: OutputContext, df: pl.DataFrame, path: "UPath", metadata: Optional[StorageMetadata] = None, ): ... @abstractmethod def sink_df_to_path( self, context: OutputContext, df: pl.LazyFrame, path: "UPath", metadata: Optional[StorageMetadata] = None, ): ... @overload @abstractmethod def scan_df_from_path( self, path: "UPath", context: InputContext, with_metadata: Literal[None, False] ) -> pl.LazyFrame: ... @overload @abstractmethod def scan_df_from_path( self, path: "UPath", context: InputContext, with_metadata: Literal[True] ) -> LazyFrameWithMetadata: ... @abstractmethod def scan_df_from_path( self, path: "UPath", context: InputContext, with_metadata: Optional[bool] = False ) -> Union[pl.LazyFrame, LazyFrameWithMetadata]: ... # tmp fix until https://github.com/dagster-io/dagster/pull/19294 is merged def load_input(self, context: InputContext) -> Union[Any, Dict[str, Any]]: # If no asset key, we are dealing with an op output which is always non-partitioned if not context.has_asset_key or not context.has_asset_partitions: path = self._get_path(context) return self._load_single_input(path, context) else: asset_partition_keys = context.asset_partition_keys if len(asset_partition_keys) == 0: return None elif len(asset_partition_keys) == 1: paths = self._get_paths_for_partitions(context) check.invariant(len(paths) == 1, f"Expected 1 path, but got {len(paths)}") path = next(iter(paths.values())) backcompat_paths = self._get_multipartition_backcompat_paths(context) backcompat_path = ( None if not backcompat_paths else next(iter(backcompat_paths.values())) ) return self._load_partition_from_path( context=context, partition_key=asset_partition_keys[0], path=path, backcompat_path=backcompat_path, ) else: # we are dealing with multiple partitions of an asset type_annotation = context.dagster_type.typing_type if type_annotation != Any and not is_dict_type(type_annotation): check.failed( "Loading an input that corresponds to multiple partitions, but the" " type annotation on the op input is not a dict, Dict, Mapping, or" f" Any: is '{type_annotation}'." ) return self._load_multiple_inputs(context) def dump_to_path( self, context: OutputContext, obj: Union[ pl.DataFrame, Optional[pl.DataFrame], Tuple[pl.DataFrame, Dict[str, Any]], pl.LazyFrame, Optional[pl.LazyFrame], Tuple[pl.LazyFrame, Dict[str, Any]], ], path: "UPath", partition_key: Optional[str] = None, ): typing_type = context.dagster_type.typing_type if annotation_is_typing_optional(typing_type) and ( obj is None or annotation_for_storage_metadata(typing_type) and obj[0] is None ): context.log.warning(self.get_optional_output_none_log_message(context, path)) return else: assert obj is not None, "output should not be None if it's type is not Optional" if not annotation_for_storage_metadata(typing_type): if typing_type in POLARS_EAGER_FRAME_ANNOTATIONS: obj = cast(pl.DataFrame, obj) df = obj self.write_df_to_path(context=context, df=df, path=path) elif typing_type in POLARS_LAZY_FRAME_ANNOTATIONS: obj = cast(pl.LazyFrame, obj) df = obj self.sink_df_to_path(context=context, df=df, path=path) else: raise NotImplementedError( f"dump_df_to_path for {typing_type} is not implemented" ) else: if not annotation_is_typing_optional(typing_type): frame_type = get_args(typing_type)[0] else: frame_type = get_args(get_args(typing_type)[0])[0] if frame_type in POLARS_EAGER_FRAME_ANNOTATIONS: obj = cast(Tuple[pl.DataFrame, Dict[str, Any]], obj) df, metadata = obj self.write_df_to_path(context=context, df=df, path=path, metadata=metadata) elif frame_type in POLARS_LAZY_FRAME_ANNOTATIONS: obj = cast(Tuple[pl.LazyFrame, Dict[str, Any]], obj) df, metadata = obj self.sink_df_to_path(context=context, df=df, path=path, metadata=metadata) else: raise NotImplementedError( f"dump_df_to_path for {typing_type} is not implemented" ) def load_from_path( self, context: InputContext, path: "UPath", partition_key: Optional[str] = None ) -> Union[ pl.DataFrame, pl.LazyFrame, Tuple[pl.DataFrame, Dict[str, Any]], Tuple[pl.LazyFrame, Dict[str, Any]], None, ]: if annotation_is_typing_optional(context.dagster_type.typing_type) and not path.exists(): context.log.warning(self.get_missing_optional_input_log_message(context, path)) return None assert context.metadata is not None metadata: Optional[StorageMetadata] = None return_storage_metadata = annotation_for_storage_metadata(context.dagster_type.typing_type) if not return_storage_metadata: ldf = self.scan_df_from_path(path=path, context=context) # type: ignore else: ldf, metadata = self.scan_df_from_path(path=path, context=context, with_metadata=True) columns = context.metadata.get("columns") if columns is not None: context.log.debug(f"Loading {columns=}") ldf = ldf.select(columns) if ( context.upstream_output is not None and context.upstream_output.asset_info is not None and context.upstream_output.asset_info.partitions_def is not None and context.upstream_output.metadata is not None and partition_key is not None ): partition_by = context.upstream_output.metadata.get("partition_by") # we can only support automatically filtering by 1 column # otherwise we would have been dealing with a multi-partition key # which is not straightforward to filter by if partition_by is not None and isinstance(partition_by, str): context.log.debug(f"Filtering by {partition_by}=={partition_key}") ldf = ldf.filter(pl.col(partition_by) == partition_key) if context.dagster_type.typing_type in POLARS_EAGER_FRAME_ANNOTATIONS: if not return_storage_metadata: return ldf.collect() else: assert metadata is not None return ldf.collect(), metadata elif context.dagster_type.typing_type in POLARS_LAZY_FRAME_ANNOTATIONS: if not return_storage_metadata: return ldf else: assert metadata is not None return ldf, metadata else: raise NotImplementedError( f"Can't load object for type annotation {context.dagster_type.typing_type}" ) def get_metadata( self, context: OutputContext, obj: Union[pl.DataFrame, pl.LazyFrame, None] ) -> Dict[str, MetadataValue]: if obj is None: return {"missing": MetadataValue.bool(True)} else: if annotation_for_storage_metadata(context.dagster_type.typing_type): df = obj[0] else: df = obj return ( get_polars_metadata(context, df) if df is not None else {"missing": MetadataValue.bool(True)} ) def get_path_for_partition( self, context: Union[InputContext, OutputContext], path: "UPath", partition: str ) -> "UPath": """Method for accessing the path for a given partition. Override this method if you want to use a different partitioning scheme (for example, if the saving function handles partitioning instead). The extension will be added later. :param context: :param path: asset path before partitioning :param partition: formatted partition key :return: """ return path / partition def get_missing_optional_input_log_message(self, context: InputContext, path: "UPath") -> str: return f"Optional input {context.name} at {path} doesn't exist in the filesystem and won't be loaded!" def get_optional_output_none_log_message(self, context: OutputContext, path: "UPath") -> str: return f"The object for the optional output {context.name} is None, so it won't be saved to {path}!" # this method is overridden because the default one does not pass partition_key to load_from_path def _load_partition_from_path( self, context: InputContext, partition_key: str, path: "UPath", backcompat_path: Optional["UPath"] = None, ) -> Any: """1. Try to load the partition from the normal path. 2. If it was not found, try to load it from the backcompat path. 3. If allow_missing_partitions metadata is True, skip the partition if it was not found in any of the paths. Otherwise, raise an error. Args: context (InputContext): IOManager Input context partition_key (str): the partition key corresponding to the partition being loaded path (UPath): The path to the partition. backcompat_path (Optional[UPath]): The path to the partition in the backcompat location. Returns: Any: The object loaded from the partition. """ allow_missing_partitions = ( context.metadata.get("allow_missing_partitions", False) if context.metadata is not None else False ) try: context.log.debug(self.get_loading_input_partition_log_message(path, partition_key)) obj = self.load_from_path(context=context, path=path, partition_key=partition_key) return obj except FileNotFoundError as e: if backcompat_path is not None: try: obj = self.load_from_path( context=context, path=path, partition_key=partition_key ) context.log.debug( f"File not found at {path}. Loaded instead from backcompat path:" f" {backcompat_path}" ) return obj except FileNotFoundError as e: if allow_missing_partitions: context.log.warning(self.get_missing_partition_log_message(partition_key)) return None else: raise e if allow_missing_partitions: context.log.warning(self.get_missing_partition_log_message(partition_key)) return None else: raise e