Description
import numpy as np arr_dtype_to_spark_dtypes = [ ("int8", [("b", "array<smallint>")]), ("int16", [("b", "array<smallint>")]), ("int32", [("b", "array<int>")]), ("int64", [("b", "array<bigint>")]), ("float32", [("b", "array<float>")]), ("float64", [("b", "array<double>")]), ] for t, expected_spark_dtypes in arr_dtype_to_spark_dtypes: arr = np.array([1, 2]).astype(t) self.assertEqual( expected_spark_dtypes, self.spark.range(1).select(lit(arr).alias("b")).dtypes ) arr = np.array([1, 2]).astype(np.uint) with self.assertRaisesRegex( TypeError, "The type of array scalar '%s' is not supported" % arr.dtype ): self.spark.range(1).select(lit(arr).alias("b"))
Traceback (most recent call last): File "/Users/s.singh/personal/spark-oss/python/pyspark/sql/tests/test_functions.py", line 1100, in test_ndarray_input expected_spark_dtypes, self.spark.range(1).select(lit(arr).alias("b")).dtypes File "/Users/s.singh/personal/spark-oss/python/pyspark/sql/utils.py", line 332, in wrapped return getattr(functions, f.__name__)(*args, **kwargs) File "/Users/s.singh/personal/spark-oss/python/pyspark/sql/connect/functions.py", line 198, in lit return Column(LiteralExpression._from_value(col)) File "/Users/s.singh/personal/spark-oss/python/pyspark/sql/connect/expressions.py", line 266, in _from_value return LiteralExpression(value=value, dataType=LiteralExpression._infer_type(value)) File "/Users/s.singh/personal/spark-oss/python/pyspark/sql/connect/expressions.py", line 262, in _infer_type raise ValueError(f"Unsupported Data Type {type(value).__name__}") ValueError: Unsupported Data Type ndarray