|
| 1 | +/* |
| 2 | + * Copyright DataStax, Inc. |
| 3 | + * |
| 4 | + * Please see the included license file for details. |
| 5 | + */ |
| 6 | + |
| 7 | +package com.datastax.spark.connector.datasource |
| 8 | + |
| 9 | +import scala.collection.mutable |
| 10 | +import com.datastax.spark.connector._ |
| 11 | +import com.datastax.spark.connector.cluster.DefaultCluster |
| 12 | +import com.datastax.spark.connector.cql.CassandraConnector |
| 13 | +import org.scalatest.BeforeAndAfterEach |
| 14 | +import com.datastax.spark.connector.datasource.CassandraCatalog |
| 15 | +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} |
| 16 | +import com.datastax.spark.connector.cql.CassandraConnector |
| 17 | +import org.apache.spark.sql.SparkSession |
| 18 | + |
| 19 | + |
| 20 | +class CassandraCatalogMetricsSpec extends SparkCassandraITFlatSpecBase with DefaultCluster with BeforeAndAfterEach { |
| 21 | + |
| 22 | + override lazy val conn = CassandraConnector(defaultConf) |
| 23 | + |
| 24 | + override lazy val spark = SparkSession.builder() |
| 25 | + .config(sparkConf |
| 26 | + // Enable Codahale/Dropwizard metrics |
| 27 | + .set("spark.metrics.conf.executor.source.cassandra-connector.class", "org.apache.spark.metrics.CassandraConnectorSource") |
| 28 | + .set("spark.metrics.conf.driver.source.cassandra-connector.class", "org.apache.spark.metrics.CassandraConnectorSource") |
| 29 | + .set("spark.sql.sources.useV1SourceList", "") |
| 30 | + .set("spark.sql.defaultCatalog", "cassandra") |
| 31 | + .set("spark.sql.catalog.cassandra", classOf[CassandraCatalog].getCanonicalName) |
| 32 | + ) |
| 33 | + .withExtensions(new CassandraSparkExtensions).getOrCreate().newSession() |
| 34 | + |
| 35 | + override def beforeClass { |
| 36 | + conn.withSessionDo { session => |
| 37 | + session.execute(s"CREATE KEYSPACE IF NOT EXISTS $ks WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }") |
| 38 | + session.execute(s"CREATE TABLE IF NOT EXISTS $ks.leftjoin (key INT, x INT, PRIMARY KEY (key))") |
| 39 | + for (i <- 1 to 1000 * 10) { |
| 40 | + session.execute(s"INSERT INTO $ks.leftjoin (key, x) values ($i, $i)") |
| 41 | + } |
| 42 | + } |
| 43 | + } |
| 44 | + |
| 45 | + var readRowCount: Long = 0 |
| 46 | + var readByteCount: Long = 0 |
| 47 | + |
| 48 | + it should "update Codahale read metrics for SELECT queries" in { |
| 49 | + val df = spark.sql(s"SELECT x FROM $ks.leftjoin LIMIT 2") |
| 50 | + val metricsRDD = df.queryExecution.toRdd.mapPartitions { iter => |
| 51 | + val tc = org.apache.spark.TaskContext.get() |
| 52 | + val source = org.apache.spark.metrics.MetricsUpdater.getSource(tc) |
| 53 | + Iterator((source.get.readRowMeter.getCount, source.get.readByteMeter.getCount)) |
| 54 | + } |
| 55 | + |
| 56 | + val metrics = metricsRDD.collect() |
| 57 | + readRowCount = metrics.map(_._1).sum - readRowCount |
| 58 | + readByteCount = metrics.map(_._2).sum - readByteCount |
| 59 | + |
| 60 | + assert(readRowCount > 0) |
| 61 | + assert(readByteCount == readRowCount * 4) // 4 bytes per INT result |
| 62 | + } |
| 63 | + |
| 64 | + it should "update Codahale read metrics for COUNT queries" in { |
| 65 | + val df = spark.sql(s"SELECT COUNT(*) FROM $ks.leftjoin") |
| 66 | + val metricsRDD = df.queryExecution.toRdd.mapPartitions { iter => |
| 67 | + val tc = org.apache.spark.TaskContext.get() |
| 68 | + val source = org.apache.spark.metrics.MetricsUpdater.getSource(tc) |
| 69 | + Iterator((source.get.readRowMeter.getCount, source.get.readByteMeter.getCount)) |
| 70 | + } |
| 71 | + |
| 72 | + val metrics = metricsRDD.collect() |
| 73 | + readRowCount = metrics.map(_._1).sum - readRowCount |
| 74 | + readByteCount = metrics.map(_._2).sum - readByteCount |
| 75 | + |
| 76 | + assert(readRowCount > 0) |
| 77 | + assert(readByteCount == readRowCount * 8) // 8 bytes per COUNT result |
| 78 | + } |
| 79 | +} |
0 commit comments