Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[Log] -> CometLog,
classOf[Log2] -> CometLog2,
classOf[Log10] -> CometLog10,
classOf[Logarithm] -> CometLogarithm,
classOf[Multiply] -> CometMultiply,
classOf[Pow] -> CometScalarFunction("pow"),
classOf[Rand] -> CometRand,
Expand Down
16 changes: 15 additions & 1 deletion spark/src/main/scala/org/apache/comet/serde/math.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

package org.apache.comet.serde

import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Tan, Unhex}
import org.apache.spark.sql.catalyst.expressions.{Abs, Atan2, Attribute, Ceil, CheckOverflow, Expression, Floor, Hex, If, LessThanOrEqual, Literal, Log, Log10, Log2, Logarithm, Tan, Unhex}
import org.apache.spark.sql.types.{DecimalType, NumericType}

import org.apache.comet.CometSparkSessionExtensions.withInfo
Expand Down Expand Up @@ -138,6 +138,20 @@ object CometLog2 extends CometExpressionSerde[Log2] with MathExprBase {
}
}

object CometLogarithm extends CometExpressionSerde[Logarithm] with MathExprBase {
override def convert(
expr: Logarithm,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
// Spark's Logarithm(left=base, right=value) returns null when result is NaN,
// which happens when base <= 0 or value <= 0. Apply nullIfNegative to both.
val leftExpr = exprToProtoInternal(nullIfNegative(expr.left), inputs, binding)
val rightExpr = exprToProtoInternal(nullIfNegative(expr.right), inputs, binding)
val optExpr = scalarFunctionExprToProto("log", leftExpr, rightExpr)
optExprWithInfo(optExpr, expr, expr.left, expr.right)
}
}

object CometHex extends CometExpressionSerde[Hex] with MathExprBase {
override def convert(
expr: Hex,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
-- under the License.

-- ConfigMatrix: parquet.enable.dictionary=false,true
-- Config: spark.comet.expression.Tan.allowIncompatible=true

statement
CREATE TABLE test_tan(d double) USING parquet
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class CometSqlFileTestSuite extends CometTestBase with AdaptiveSparkPlanHelper {
case SparkAnswerOnly =>
checkSparkAnswer(sql)
case WithTolerance(tol) =>
checkSparkAnswerWithTolerance(sql, tol)
checkSparkAnswerAndOperatorWithTolerance(sql, tol)
case ExpectFallback(reason) =>
checkSparkAnswerAndFallbackReason(sql, reason)
case Ignore(reason) =>
Expand Down
11 changes: 11 additions & 0 deletions spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,17 @@ abstract class CometTestBase
internalCheckSparkAnswer(df, assertCometNative = false, withTol = Some(absTol))
}

/**
* Check that the query returns the correct results when Comet is enabled and that Comet
* replaced all possible operators. Use the provided `tol` when comparing floating-point
* results.
*/
protected def checkSparkAnswerAndOperatorWithTolerance(
query: String,
absTol: Double = 1e-6): (SparkPlan, SparkPlan) = {
internalCheckSparkAnswer(sql(query), assertCometNative = true, withTol = Some(absTol))
}

/**
* Check that the query returns the correct results when Comet is enabled and that Comet
* replaced all possible operators except for those specified in the excluded list.
Expand Down
Loading