Skip to content

Commit fa6fca0

Browse files
committed
Add a couple of integration tests for DSV2 Codahale metrics
1 parent 6ad4d1c commit fa6fca0

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright DataStax, Inc.
3+
*
4+
* Please see the included license file for details.
5+
*/
6+
7+
package com.datastax.spark.connector.datasource
8+
9+
import scala.collection.mutable
10+
import com.datastax.spark.connector._
11+
import com.datastax.spark.connector.cluster.DefaultCluster
12+
import com.datastax.spark.connector.cql.CassandraConnector
13+
import org.scalatest.BeforeAndAfterEach
14+
import com.datastax.spark.connector.datasource.CassandraCatalog
15+
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
16+
import com.datastax.spark.connector.cql.CassandraConnector
17+
import org.apache.spark.sql.SparkSession
18+
19+
20+
class CassandraCatalogMetricsSpec extends SparkCassandraITFlatSpecBase with DefaultCluster with BeforeAndAfterEach {
21+
22+
override lazy val conn = CassandraConnector(defaultConf)
23+
24+
override lazy val spark = SparkSession.builder()
25+
.config(sparkConf
26+
// Enable Codahale/Dropwizard metrics
27+
.set("spark.metrics.conf.executor.source.cassandra-connector.class", "org.apache.spark.metrics.CassandraConnectorSource")
28+
.set("spark.metrics.conf.driver.source.cassandra-connector.class", "org.apache.spark.metrics.CassandraConnectorSource")
29+
.set("spark.sql.sources.useV1SourceList", "")
30+
.set("spark.sql.defaultCatalog", "cassandra")
31+
.set("spark.sql.catalog.cassandra", classOf[CassandraCatalog].getCanonicalName)
32+
)
33+
.withExtensions(new CassandraSparkExtensions).getOrCreate().newSession()
34+
35+
override def beforeClass {
36+
conn.withSessionDo { session =>
37+
session.execute(s"CREATE KEYSPACE IF NOT EXISTS $ks WITH REPLICATION = { 'class': 'SimpleStrategy', 'replication_factor': 1 }")
38+
session.execute(s"CREATE TABLE IF NOT EXISTS $ks.leftjoin (key INT, x INT, PRIMARY KEY (key))")
39+
for (i <- 1 to 1000 * 10) {
40+
session.execute(s"INSERT INTO $ks.leftjoin (key, x) values ($i, $i)")
41+
}
42+
}
43+
}
44+
45+
var readRowCount: Long = 0
46+
var readByteCount: Long = 0
47+
48+
it should "update Codahale read metrics for SELECT queries" in {
49+
val df = spark.sql(s"SELECT x FROM $ks.leftjoin LIMIT 2")
50+
val metricsRDD = df.queryExecution.toRdd.mapPartitions { iter =>
51+
val tc = org.apache.spark.TaskContext.get()
52+
val source = org.apache.spark.metrics.MetricsUpdater.getSource(tc)
53+
Iterator((source.get.readRowMeter.getCount, source.get.readByteMeter.getCount))
54+
}
55+
56+
val metrics = metricsRDD.collect()
57+
readRowCount = metrics.map(_._1).sum - readRowCount
58+
readByteCount = metrics.map(_._2).sum - readByteCount
59+
60+
assert(readRowCount > 0)
61+
assert(readByteCount == readRowCount * 4) // 4 bytes per INT result
62+
}
63+
64+
it should "update Codahale read metrics for COUNT queries" in {
65+
val df = spark.sql(s"SELECT COUNT(*) FROM $ks.leftjoin")
66+
val metricsRDD = df.queryExecution.toRdd.mapPartitions { iter =>
67+
val tc = org.apache.spark.TaskContext.get()
68+
val source = org.apache.spark.metrics.MetricsUpdater.getSource(tc)
69+
Iterator((source.get.readRowMeter.getCount, source.get.readByteMeter.getCount))
70+
}
71+
72+
val metrics = metricsRDD.collect()
73+
readRowCount = metrics.map(_._1).sum - readRowCount
74+
readByteCount = metrics.map(_._2).sum - readByteCount
75+
76+
assert(readRowCount > 0)
77+
assert(readByteCount == readRowCount * 8) // 8 bytes per COUNT result
78+
}
79+
}

0 commit comments

Comments
 (0)