diff --git a/.gitignore b/.gitignore index 706ddde..a1c2435 100755 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ sbt/sbt-launch*.jar target/ .idea/ .idea_modules/ +.DS_Store diff --git a/README.md b/README.md index 189036f..6b052bb 100755 --- a/README.md +++ b/README.md @@ -13,36 +13,37 @@ You can link against this library in your program at the following coordiates: ``` groupId: com.databricks -artifactId: spark-csv_2.10 -version: 1.0.0 +artifactId: spark-csv_2.11 +version: 1.0.3 ``` -The spark-csv assembly jar file can also be added to a Spark using the `--jars` command line option. For example, to include it when starting the spark shell: + +## Using with Apache Spark +This package can be added to Spark using the `--jars` command line option. For example, to include it when starting the spark shell: ``` -$ bin/spark-shell --jars spark-csv-assembly-1.0.0.jar +$ bin/spark-shell --packages com.databricks:spark-csv_2.10:1.0.3 ``` ## Features +This package allows reading CSV files in local or distributed filesystem as [Spark DataFrames](https://spark.apache.org/docs/1.3.0/sql-programming-guide.html). +When reading files the API accepts several options: +* path: location of files. Similar to Spark can accept standard Hadoop globbing expressions. +* header: when set to true the first line of files will be used to name columns and will not be included in data. All types will be assumed string. Default value is false. +* delimiter: by default lines are delimited using ',', but delimiter can be set to any character +* quote: by default the quote character is '"', but can be set to any character. Delimiters inside quotes are ignored +* mode: determines the parsing mode. By default it is PERMISSIVE. Possible values are: + * PERMISSIVE: tries to parse all lines: nulls are inserted for missing tokens and extra tokens are ignored. + * DROPMALFORMED: drops lines which have fewer or more tokens than expected + * FAILFAST: aborts with a RuntimeException if encounters any malformed line + +The package also support saving simple (non-nested) DataFrame. When saving you can specify the delimiter and whether we should generate a header row for the table. See following examples for more details. + These examples use a CSV file available for download [here](https://github.com/databricks/spark-csv/raw/master/src/test/resources/cars.csv): ``` $ wget https://github.com/databricks/spark-csv/raw/master/src/test/resources/cars.csv ``` -### Scala API - -You can use the library by loading the implicits from `com.databricks.spark.csv._`. - -``` -import org.apache.spark.sql.SQLContext - -val sqlContext = new SQLContext(sc) - -import com.databricks.spark.csv._ - -val cars = sqlContext.csvFile("cars.csv") -``` - ### SQL API CSV data can be queried in pure SQL by registering the data as a (temporary) table. @@ -59,18 +60,65 @@ USING com.databricks.spark.csv OPTIONS (path "cars.csv", header "true") ``` +### Scala API +The recommended way to load CSV data is using the load/save functions in SQLContext. + +```scala +import org.apache.spark.sql.SQLContext + +val sqlContext = new SQLContext(sc) +val df = sqlContext.load("com.databricks.spark.csv", Map("path" -> "cars.csv", "header" -> "true")) +df.select("year", "model").save("newcars.csv", "com.databricks.spark.csv") +``` + +You can also use the implicits from `com.databricks.spark.csv._`. + +```scala +import org.apache.spark.sql.SQLContext +import com.databricks.spark.csv._ + +val sqlContext = new SQLContext(sc) + +val cars = sqlContext.csvFile("cars.csv") +cars.select("year", "model").saveAsCsvFile("newcars.tsv") +``` + ### Java API -CSV files can be read using functions in JavaCsvParser. +Similar to Scala, we recommend load/save functions in SQLContext. ```java -import com.databricks.spark.csv.JavaCsvParser; +import org.apache.spark.sql.SQLContext -DataFrame cars = (new JavaCsvParser()).withUseHeader(true).csvFile(sqlContext, "cars.csv"); +SQLContext sqlContext = new SQLContext(sc); + +HashMap options = new HashMap(); +options.put("header", "true"); +options.put("path", "cars.csv"); + +DataFrame df = sqlContext.load("com.databricks.spark.csv", options); +df.select("year", "model").save("newcars.csv", "com.databricks.spark.csv"); +``` +See documentations of load and save for more details. + +In Java (as well as Scala) CSV files can be read using functions in CsvParser. + +```java +import com.databricks.spark.csv.CsvParser; +SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc); + +DataFrame cars = (new CsvParser()).withUseHeader(true).csvFile(sqlContext, "cars.csv"); ``` -### Saving as CSV -You can save your DataFrame using `saveAsCsvFile` function. The function allows you to specify the delimiter and whether we should generate a header row for the table (each header has name `C$i` where `$i` is column index). For example: -```myDataFrame.saveAsCsvFile("/mydir", Map("delimiter" -> "|", "header" -> "true"))``` +### Python API +In Python you can read and save CSV files using load/save functions. + +```python +from pyspark.sql import SQLContext +sqlContext = SQLContext(sc) + +df = sqlContext.load(source="com.databricks.spark.csv", header="true", path = "cars.csv") +df.select("year", "model").save("newcars.csv", "com.databricks.spark.csv") +``` ## Building From Source -This library is built with [SBT](http://www.scala-sbt.org/0.13/docs/Command-Line-Reference.html), which is automatically downloaded by the included shell script. To build a JAR file simply run `sbt/sbt assembly` from the project root. +This library is built with [SBT](http://www.scala-sbt.org/0.13/docs/Command-Line-Reference.html), which is automatically downloaded by the included shell script. To build a JAR file simply run `sbt/sbt package` from the project root. The build configuration includes support for both Scala 2.10 and 2.11. diff --git a/build.sbt b/build.sbt index 6cb34a0..1f511bb 100755 --- a/build.sbt +++ b/build.sbt @@ -1,10 +1,12 @@ name := "spark-csv" -version := "1.0.0" +version := "1.0.3" organization := "com.databricks" -scalaVersion := "2.10.4" +scalaVersion := "2.11.6" + +crossScalaVersions := Seq("2.10.4", "2.11.6") libraryDependencies += "org.apache.spark" %% "spark-sql" % "1.3.0" % "provided" @@ -55,7 +57,6 @@ sparkVersion := "1.3.0" sparkComponents += "sql" -// Enable Junit testing. -// libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test" - libraryDependencies += "org.scalatest" %% "scalatest" % "2.2.1" % "test" + +libraryDependencies += "com.novocode" % "junit-interface" % "0.9" % "test" diff --git a/src/main/java/com/databricks/spark/csv/JavaCsvParser.java b/src/main/java/com/databricks/spark/csv/JavaCsvParser.java deleted file mode 100755 index e8e9d18..0000000 --- a/src/main/java/com/databricks/spark/csv/JavaCsvParser.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright 2014 Databricks - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.databricks.spark.csv; - -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.types.StructType; - -/** - * A collection of static functions for working with CSV files in Spark SQL - */ -public class JavaCsvParser { - - private Boolean useHeader = true; - private Character delimiter = ','; - private Character quote = '"'; - private StructType schema = null; - - public JavaCsvParser withUseHeader(Boolean flag) { - this.useHeader = flag; - return this; - } - - public JavaCsvParser withDelimiter(Character delimiter) { - this.delimiter = delimiter; - return this; - } - - public JavaCsvParser withQuoteChar(Character quote) { - this.quote = quote; - return this; - } - - public JavaCsvParser withSchema(StructType schema) { - this.schema = schema; - return this; - } - - /** Returns a Schema RDD for the given CSV path. */ - public DataFrame csvFile(SQLContext sqlContext, String path) { - CsvRelation relation = new - CsvRelation(path, useHeader, delimiter, quote, schema, sqlContext); - return sqlContext.baseRelationToDataFrame(relation); - } -} diff --git a/src/main/scala/com/databricks/spark/csv/CsvParser.scala b/src/main/scala/com/databricks/spark/csv/CsvParser.scala index e75f3b3..0699484 100644 --- a/src/main/scala/com/databricks/spark/csv/CsvParser.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvParser.scala @@ -18,15 +18,19 @@ package com.databricks.spark.csv import org.apache.spark.sql.{SQLContext, DataFrame} import org.apache.spark.sql.types.StructType +import com.databricks.spark.csv.util.ParseModes + /** * A collection of static functions for working with CSV files in Spark SQL */ class CsvParser { - private var useHeader: Boolean = true + private var useHeader: Boolean = false private var delimiter: Character = ',' private var quote: Character = '"' + private var escape: Character = '\\' private var schema: StructType = null + private var parseMode: String = ParseModes.DEFAULT def withUseHeader(flag: Boolean): CsvParser = { this.useHeader = flag @@ -48,9 +52,27 @@ class CsvParser { this } + def withParseMode(mode: String): CsvParser = { + this.parseMode = mode + this + } + + def withEscape(escapeChar: Character): CsvParser = { + this.escape = escapeChar + this + } + /** Returns a Schema RDD for the given CSV path. */ + @throws[RuntimeException] def csvFile(sqlContext: SQLContext, path: String): DataFrame = { - val relation: CsvRelation = CsvRelation(path, useHeader, delimiter, quote, schema)(sqlContext) + val relation: CsvRelation = CsvRelation( + path, + useHeader, + delimiter, + quote, + escape, + parseMode, + schema)(sqlContext) sqlContext.baseRelationToDataFrame(relation) } diff --git a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala index 7b324d4..2f3a43f 100755 --- a/src/main/scala/com/databricks/spark/csv/CsvRelation.scala +++ b/src/main/scala/com/databricks/spark/csv/CsvRelation.scala @@ -28,17 +28,28 @@ import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan import org.apache.spark.sql.types.{StructType, StructField, StringType} import org.slf4j.LoggerFactory +import com.databricks.spark.csv.util.ParseModes case class CsvRelation protected[spark] ( location: String, useHeader: Boolean, delimiter: Char, quote: Char, + escape: Char, + parseMode: String, userSchema: StructType = null)(@transient val sqlContext: SQLContext) extends BaseRelation with TableScan with InsertableRelation { private val logger = LoggerFactory.getLogger(CsvRelation.getClass) + // Parse mode flags + if (!ParseModes.isValidMode(parseMode)) { + logger.warn(s"$parseMode is not a valid parse mode. Using ${ParseModes.DEFAULT}.") + } + private val failFast = ParseModes.isFailFastMode(parseMode) + private val dropMalformed = ParseModes.isDropMalformedMode(parseMode) + private val permissive = ParseModes.isPermissiveMode(parseMode) + val schema = inferSchema() // By making this a lazy val we keep the RDD around, amortizing the cost of locating splits. @@ -53,6 +64,7 @@ case class CsvRelation protected[spark] ( val csvFormat = CSVFormat.DEFAULT .withDelimiter(delimiter) .withQuote(quote) + .withEscape(escape) .withSkipHeaderRecord(false) .withHeader(fieldNames: _*) @@ -78,6 +90,7 @@ case class CsvRelation protected[spark] ( val csvFormat = CSVFormat.DEFAULT .withDelimiter(delimiter) .withQuote(quote) + .withEscape(escape) .withSkipHeaderRecord(false) val firstRow = CSVParser.parse(firstLine, csvFormat).getRecords.head.toList val header = if (useHeader) { @@ -115,6 +128,7 @@ case class CsvRelation protected[spark] ( projection: MutableProjection, row: GenericMutableRow): Iterator[Row] = { iter.flatMap { line => + var index: Int = 0 try { val records = CSVParser.parse(line, csvFormat).getRecords if (records.isEmpty) { @@ -122,15 +136,25 @@ case class CsvRelation protected[spark] ( None } else { val tokens = records.head - var index = 0 - while (index < schemaFields.length) { - row(index) = tokens.get(index) - index = index + 1 + index = 0 + if (dropMalformed && schemaFields.length != tokens.size) { + logger.warn(s"Dropping malformed line: $line") + None + } else if (failFast && schemaFields.length != tokens.size) { + throw new RuntimeException(s"Malformed line in FAILFAST mode: $line") + } else { + while (index < schemaFields.length) { + row(index) = tokens.get(index) + index = index + 1 + } + Some(projection(row)) } - Some(projection(row)) } } catch { - case NonFatal(e) => + case aiob: ArrayIndexOutOfBoundsException if permissive => + (index until schemaFields.length).foreach(ind => row(ind) = null) + Some(projection(row)) + case NonFatal(e) if !failFast => logger.error(s"Exception while parsing line: $line. ", e) None } diff --git a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala index a0c1ea1..0b18d5a 100755 --- a/src/main/scala/com/databricks/spark/csv/DefaultSource.scala +++ b/src/main/scala/com/databricks/spark/csv/DefaultSource.scala @@ -62,7 +62,16 @@ class DefaultSource throw new Exception("Quotation cannot be more than one character.") } - val useHeader = parameters.getOrElse("header", "true") + val escape = parameters.getOrElse("escape", "\\") + val escapeChar = if (escape.length == 1) { + escape.charAt(0) + } else { + throw new Exception("Escape character cannot be more than one character.") + } + + val parseMode = parameters.getOrElse("mode", "PERMISSIVE") + + val useHeader = parameters.getOrElse("header", "false") val headerFlag = if (useHeader == "true") { true } else if (useHeader == "false") { @@ -71,7 +80,13 @@ class DefaultSource throw new Exception("Header flag can be true or false") } - CsvRelation(path, headerFlag, delimiterChar, quoteChar, schema)(sqlContext) + CsvRelation(path, + headerFlag, + delimiterChar, + quoteChar, + escapeChar, + parseMode, + schema)(sqlContext) } override def createRelation( diff --git a/src/main/scala/com/databricks/spark/csv/package.scala b/src/main/scala/com/databricks/spark/csv/package.scala index 3692ec3..58f503e 100755 --- a/src/main/scala/com/databricks/spark/csv/package.scala +++ b/src/main/scala/com/databricks/spark/csv/package.scala @@ -15,7 +15,10 @@ */ package com.databricks.spark -import org.apache.spark.sql.{SQLContext, DataFrame} +import org.apache.commons.csv.CSVFormat +import org.apache.hadoop.io.compress.CompressionCodec + +import org.apache.spark.sql.{SQLContext, DataFrame, Row} package object csv { @@ -23,52 +26,107 @@ package object csv { * Adds a method, `csvFile`, to SQLContext that allows reading CSV data. */ implicit class CsvContext(sqlContext: SQLContext) { - def csvFile(filePath: String) = { + def csvFile(filePath: String, + useHeader: Boolean = true, + delimiter: Char = ',', + quote: Char = '"', + escape: Char = '\\', + mode: String = "PERMISSIVE") = { val csvRelation = CsvRelation( location = filePath, - useHeader = true, - delimiter = ',', - quote = '"')(sqlContext) + useHeader = useHeader, + delimiter = delimiter, + quote = quote, + escape = escape, + parseMode = mode)(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } - def tsvFile(filePath: String) = { + def tsvFile(filePath: String, useHeader: Boolean = true) = { val csvRelation = CsvRelation( location = filePath, - useHeader = true, + useHeader = useHeader, delimiter = '\t', - quote = '"')(sqlContext) + quote = '"', + escape = '\\', + parseMode = "PERMISSIVE")(sqlContext) sqlContext.baseRelationToDataFrame(csvRelation) } } - + implicit class CsvSchemaRDD(dataFrame: DataFrame) { - def saveAsCsvFile(path: String, parameters: Map[String, String] = Map()): Unit = { + + /** + * Saves DataFrame as csv files. By default uses ',' as delimiter, and includes header line. + */ + def saveAsCsvFile(path: String, parameters: Map[String, String] = Map(), + compressionCodec: Class[_ <: CompressionCodec] = null): Unit = { // TODO(hossein): For nested types, we may want to perform special work val delimiter = parameters.getOrElse("delimiter", ",") + val delimiterChar = if (delimiter.length == 1) { + delimiter.charAt(0) + } else { + throw new Exception("Delimiter cannot be more than one character.") + } + + val escape = parameters.getOrElse("escape", "\\") + val escapeChar = if (escape.length == 1) { + escape.charAt(0) + } else { + throw new Exception("Escape character cannot be more than one character.") + } + + val quoteChar = parameters.get("quote") match { + case Some(s) => { + if (s.length == 1) { + Some(s.charAt(0)) + } else { + throw new Exception("Quotation cannot be more than one character.") + } + } + case None => None + } + + val csvFormatBase = CSVFormat.DEFAULT + .withDelimiter(delimiterChar) + .withEscape(escapeChar) + .withSkipHeaderRecord(false) + .withNullString("null") + + val csvFormat = quoteChar match { + case Some(c) => csvFormatBase.withQuote(c) + case _ => csvFormatBase + } + val generateHeader = parameters.getOrElse("header", "false").toBoolean + //Use format instead of mkString val header = if (generateHeader) { - dataFrame.columns.map(c => s""""$c"""").mkString(delimiter) + csvFormat.format(dataFrame.columns.map(_.asInstanceOf[AnyRef]):_*) } else { "" // There is no need to generate header in this case } val strRDD = dataFrame.rdd.mapPartitions { iter => + new Iterator[String] { var firstRow: Boolean = generateHeader override def hasNext = iter.hasNext override def next: String = { + val row = csvFormat.format(iter.next.toSeq.map(_.asInstanceOf[AnyRef]):_*) if (firstRow) { firstRow = false - header + "\n" + iter.next.mkString(delimiter) + header + "\n" + row } else { - iter.next.mkString(delimiter) + row } } } } - strRDD.saveAsTextFile(path) + compressionCodec match { + case null => strRDD.saveAsTextFile(path) + case codec => strRDD.saveAsTextFile(path, codec) + } } } } diff --git a/src/main/scala/com/databricks/spark/csv/util/ParseModes.scala b/src/main/scala/com/databricks/spark/csv/util/ParseModes.scala new file mode 100644 index 0000000..babad29 --- /dev/null +++ b/src/main/scala/com/databricks/spark/csv/util/ParseModes.scala @@ -0,0 +1,40 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.databricks.spark.csv.util + +private[csv] object ParseModes { + + val PERMISSIVE_MODE = "PERMISSIVE" + val DROP_MALFORMED_MODE = "DROPMALFORMED" + val FAIL_FAST_MODE = "FAILFAST" + + val DEFAULT = PERMISSIVE_MODE + + def isValidMode(mode: String): Boolean = { + mode.toUpperCase match { + case PERMISSIVE_MODE | DROP_MALFORMED_MODE | FAIL_FAST_MODE => true + case _ => false + } + } + + def isDropMalformedMode(mode: String) = mode.toUpperCase == DROP_MALFORMED_MODE + def isFailFastMode(mode: String) = mode.toUpperCase == FAIL_FAST_MODE + def isPermissiveMode(mode: String) = if (isValidMode(mode)) { + mode.toUpperCase == PERMISSIVE_MODE + } else { + true // We default to permissive is the mode string is not valid + } +} diff --git a/src/test/java/com/databricks/spark/csv/JavaCsvSuite.java b/src/test/java/com/databricks/spark/csv/JavaCsvSuite.java new file mode 100644 index 0000000..1e26dbf --- /dev/null +++ b/src/test/java/com/databricks/spark/csv/JavaCsvSuite.java @@ -0,0 +1,63 @@ +package com.databricks.spark.csv; + +import java.io.File; +import java.util.HashMap; +import java.util.Random; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.sql.*; +import org.apache.spark.sql.test.TestSQLContext$; + +public class JavaCsvSuite { + private transient SQLContext sqlContext; + private int numCars = 3; + + String carsFile = "src/test/resources/cars.csv"; + + private String tempDir = "target/test/csvData/"; + + @Before + public void setUp() { + // Trigger static initializer of TestData + sqlContext = TestSQLContext$.MODULE$; + } + + @After + public void tearDown() { + sqlContext = null; + } + + @Test + public void testCsvParser() { + DataFrame df = (new CsvParser()).withUseHeader(true).csvFile(sqlContext, carsFile); + int result = df.select("model").collect().length; + Assert.assertEquals(result, numCars); + } + + @Test + public void testLoad() { + HashMap options = new HashMap(); + options.put("header", "true"); + options.put("path", carsFile); + + DataFrame df = sqlContext.load("com.databricks.spark.csv", options); + int result = df.select("year").collect().length; + Assert.assertEquals(result, numCars); + } + + @Test + public void testSave() { + DataFrame df = (new CsvParser()).withUseHeader(true).csvFile(sqlContext, carsFile); + TestUtils.deleteRecursively(new File(tempDir)); + df.select("year", "model").save(tempDir, "com.databricks.spark.csv"); + + DataFrame newDf = (new CsvParser()).csvFile(sqlContext, tempDir); + int result = newDf.select("C1").collect().length; + Assert.assertEquals(result, numCars); + + } +} diff --git a/src/test/resources/cars-alternative.csv b/src/test/resources/cars-alternative.csv index b7f83c8..2c1285a 100644 --- a/src/test/resources/cars-alternative.csv +++ b/src/test/resources/cars-alternative.csv @@ -2,3 +2,4 @@ year|make|model|comment '2012'|'Tesla'|'S'| 'No comment' 1997|Ford|E350|'Go get one now they are going fast' +2015|Chevy|Volt diff --git a/src/test/resources/cars.csv b/src/test/resources/cars.csv index 86512c1..24d5e11 100644 --- a/src/test/resources/cars.csv +++ b/src/test/resources/cars.csv @@ -1,4 +1,5 @@ year,make,model,comment,blank -"2012","Tesla","S", "No comment", +"2012","Tesla","S","No comment", 1997,Ford,E350,"Go get one now they are going fast", +2015,Chevy,Volt \ No newline at end of file diff --git a/src/test/resources/escape.csv b/src/test/resources/escape.csv new file mode 100644 index 0000000..d9ff81a --- /dev/null +++ b/src/test/resources/escape.csv @@ -0,0 +1,2 @@ +"column" +|"thing \ No newline at end of file diff --git a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala index ace3a44..88db433 100755 --- a/src/test/scala/com/databricks/spark/csv/CsvSuite.scala +++ b/src/test/scala/com/databricks/spark/csv/CsvSuite.scala @@ -17,7 +17,9 @@ package com.databricks.spark.csv import java.io.File +import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.sql.test._ +import org.apache.spark.SparkException import org.apache.spark.sql.types._ import org.scalatest.FunSuite @@ -28,15 +30,18 @@ class CsvSuite extends FunSuite { val carsFile = "src/test/resources/cars.csv" val carsAltFile = "src/test/resources/cars-alternative.csv" val emptyFile = "src/test/resources/empty.csv" + val escapeFile = "src/test/resources/escape.csv" val tempEmptyDir = "target/test/empty/" + val numCars = 3 + test("DSL test") { val results = TestSQLContext .csvFile(carsFile) .select("year") .collect() - assert(results.size === 2) + assert(results.size === numCars) } test("DDL test") { @@ -47,18 +52,72 @@ class CsvSuite extends FunSuite { |OPTIONS (path "$carsFile", header "true") """.stripMargin.replaceAll("\n", " ")) - assert(sql("SELECT year FROM carsTable").collect().size === 2) + assert(sql("SELECT year FROM carsTable").collect().size === numCars) + } + + test("DSL test for DROPMALFORMED parsing mode") { + val results = new CsvParser() + .withParseMode("DROPMALFORMED") + .withUseHeader(true) + .csvFile(TestSQLContext, carsFile) + .select("year") + .collect() + + assert(results.size === numCars - 1) } + test("DSL test for FAILFAST parsing mode") { + val parser = new CsvParser() + .withParseMode("FAILFAST") + .withUseHeader(true) + + val exception = intercept[SparkException]{ + parser.csvFile(TestSQLContext, carsFile) + .select("year") + .collect() + } + + assert(exception.getMessage.contains("Malformed line in FAILFAST mode")) + } + + test("DSL test with alternative delimiter and quote") { val results = new CsvParser() .withDelimiter('|') .withQuoteChar('\'') + .withUseHeader(true) .csvFile(TestSQLContext, carsAltFile) .select("year") .collect() - assert(results.size === 2) + assert(results.size === numCars) + } + + test("DSL test with alternative delimiter and quote using sparkContext.csvFile") { + val results = + TestSQLContext.csvFile(carsAltFile, useHeader = true, delimiter = '|', quote = '\'') + .select("year") + .collect() + + assert(results.size === numCars) + } + + test("Expect parsing error with wrong delimiter settting using sparkContext.csvFile") { + intercept[ org.apache.spark.sql.AnalysisException] { + TestSQLContext.csvFile(carsAltFile, useHeader = true, delimiter = ',', quote = '\'') + .select("year") + .collect() + } + } + + test("Expect wrong parsing results with wrong quote setting using sparkContext.csvFile") { + val results = + TestSQLContext.csvFile(carsAltFile, useHeader = true, delimiter = '|', quote = '"') + .select("year") + .collect() + + assert(results.slice(0, numCars).toSeq.map(_(0).asInstanceOf[String]) == + Seq("'2012'", "1997", "2015")) } test("DDL test with alternative delimiter and quote") { @@ -69,7 +128,7 @@ class CsvSuite extends FunSuite { |OPTIONS (path "$carsAltFile", header "true", quote "'", delimiter "|") """.stripMargin.replaceAll("\n", " ")) - assert(sql("SELECT year FROM carsTable").collect().size === 2) + assert(sql("SELECT year FROM carsTable").collect().size === numCars) } @@ -101,11 +160,12 @@ class CsvSuite extends FunSuite { |OPTIONS (path "$carsFile", header "true") """.stripMargin.replaceAll("\n", " ")) - assert(sql("SELECT makeName FROM carsTable").collect().size === 2) - assert(sql("SELECT avg(yearMade) FROM carsTable group by grp").collect().head(0) === 2004.5) + assert(sql("SELECT makeName FROM carsTable").collect().size === numCars) + assert(sql("SELECT avg(yearMade) FROM carsTable where grp = '' group by grp") + .collect().head(0) === 2004.5) } - test("column names test") { + test("DSL column names test") { val cars = new CsvParser() .withUseHeader(false) .csvFile(TestSQLContext, carsFile) @@ -130,7 +190,7 @@ class CsvSuite extends FunSuite { |OPTIONS (path "$tempEmptyDir", header "false") """.stripMargin.replaceAll("\n", " ")) - assert(sql("SELECT * FROM carsTableIO").collect().size === 3) + assert(sql("SELECT * FROM carsTableIO").collect().size === numCars + 1) assert(sql("SELECT * FROM carsTableEmpty").collect().isEmpty) sql( @@ -138,6 +198,82 @@ class CsvSuite extends FunSuite { |INSERT OVERWRITE TABLE carsTableEmpty |SELECT * FROM carsTableIO """.stripMargin.replaceAll("\n", " ")) - assert(sql("SELECT * FROM carsTableEmpty").collect().size == 3) + assert(sql("SELECT * FROM carsTableEmpty").collect().size == numCars + 1) + } + + test("DSL save") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = TestSQLContext.csvFile(carsFile) + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true")) + + val carsCopy = TestSQLContext.csvFile(copyFilePath + "/") + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet== cars.collect.map(_.toString).toSet) + } + + test("DSL save with a compression codec") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = TestSQLContext.csvFile(carsFile) + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true"), classOf[GzipCodec]) + + val carsCopy = TestSQLContext.csvFile(copyFilePath + "/") + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) + } + + test("DSL save with quoting") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = TestSQLContext.csvFile(carsFile) + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", "quote" -> "\"")) + + val carsCopy = TestSQLContext.csvFile(copyFilePath + "/") + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) + } + + test("DSL save with alternate quoting") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "cars-copy.csv" + + val cars = TestSQLContext.csvFile(carsFile) + cars.saveAsCsvFile(copyFilePath, Map("header" -> "true", "quote" -> "!")) + + val carsCopy = TestSQLContext.csvFile(copyFilePath + "/", quote = '!') + + assert(carsCopy.count == cars.count) + assert(carsCopy.collect.map(_.toString).toSet == cars.collect.map(_.toString).toSet) + } + + test("DSL save with quoting, escaped quote") { + // Create temp directory + TestUtils.deleteRecursively(new File(tempEmptyDir)) + new File(tempEmptyDir).mkdirs() + val copyFilePath = tempEmptyDir + "escape-copy.csv" + + val escape = TestSQLContext.csvFile(escapeFile, escape='|', quote='"') + escape.saveAsCsvFile(copyFilePath, Map("header" -> "true", "quote" -> "\"")) + + val escapeCopy = TestSQLContext.csvFile(copyFilePath + "/") + + assert(escapeCopy.count == escape.count) + assert(escapeCopy.collect.map(_.toString).toSet == escape.collect.map(_.toString).toSet) + assert(escapeCopy.head().getString(0) == "\"thing") } } diff --git a/src/test/scala/com/databricks/spark/csv/TestUtils.scala b/src/test/scala/com/databricks/spark/csv/TestUtils.scala index ac78215..0c32f12 100644 --- a/src/test/scala/com/databricks/spark/csv/TestUtils.scala +++ b/src/test/scala/com/databricks/spark/csv/TestUtils.scala @@ -1,3 +1,18 @@ +/* + * Copyright 2014 Databricks + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package com.databricks.spark.csv import java.io.{File, IOException}