开发者社区 > 博文 > Geomesa-SparkSQL源码篇——GeomesaSparkSQL类
分享
  • 打开微信扫码分享

  • 点击前往QQ分享

  • 点击前往微博分享

  • 点击复制链接

Geomesa-SparkSQL源码篇——GeomesaSparkSQL类

  • 京东城市JUST团队
  • 2021-01-18
  • IP归属:未知
  • 28320浏览

1. Spark SQL的自定义外接数据源的方式

Spark SQL是Spark用来处理结构化数据的一个模块,它提供了一个编程抽象叫做DataFrame并且作为分布式SQL查询引擎的作用。而在真正开发过程当中,有时需要将自己定义的数据源来利用Spark sql进行上层接口的封装。针对这种情况,Spark Sql也有其扩展数据源的接口。

本节以https://www.cnblogs.com/QuestionsZhang/p/10430230.html当中的demo为例,来介绍它的自定义外接数据源的方式,并最终引出geomesa对于spark sql的扩展源码。

1.1 涉及到的API

BaseRelation提供了定义数据结构Schema的方法,类似tuples的数据结构。

TableScan,提供了扫描数据并生成RDD[ROW]的方法。

RelationProvider,拿到参数列表并返回一个BaseRelation。

1.2 代码实现

首先定义relation。

  1. package cn.zj.spark.sql.datasource
  2. import org.apache.hadoop.fs.Path
  3. import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
  4. import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
  5. import org.apache.spark.sql.types.StructType
  6. /**
  7. * Created by rana on 29/9/16.
  8. */
  9. class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {
  10. override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
  11. createRelation(sqlContext, parameters, null)
  12. }
  13. override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = {
  14. val path = parameters.get("path")
  15. path match {
  16. case Some(p) => new CustomDatasourceRelation(sqlContext, p, schema)
  17. case _ => throw new IllegalArgumentException("Path is required for custom-datasource format!!")
  18. }
  19. }
  20. override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String],
  21. data: DataFrame): BaseRelation = {
  22. val path = parameters.getOrElse("path", "./output/") //can throw an exception/error, it's just for this tutorial
  23. val fsPath = new Path(path)
  24. val fs = fsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
  25. mode match {
  26. case SaveMode.Append => sys.error("Append mode is not supported by " + this.getClass.getCanonicalName); sys.exit(1)
  27. case SaveMode.Overwrite => fs.delete(fsPath, true)
  28. case SaveMode.ErrorIfExists => sys.error("Given path: " + path + " already exists!!"); sys.exit(1)
  29. case SaveMode.Ignore => sys.exit()
  30. }
  31. val formatName = parameters.getOrElse("format", "customFormat")
  32. formatName match {
  33. case "customFormat" => saveAsCustomFormat(data, path, mode)
  34. case "json" => saveAsJson(data, path, mode)
  35. case _ => throw new IllegalArgumentException(formatName + " is not supported!!!")
  36. }
  37. createRelation(sqlContext, parameters, data.schema)
  38. }
  39. private def saveAsJson(data : DataFrame, path : String, mode: SaveMode): Unit = {
  40. /**
  41. * Here, I am using the dataframe's Api for storing it as json.
  42. * you can have your own apis and ways for saving!!
  43. */
  44. data.write.mode(mode).json(path)
  45. }
  46. private def saveAsCustomFormat(data : DataFrame, path : String, mode: SaveMode): Unit = {
  47. /**
  48. * Here, I am going to save this as simple text file which has values separated by "|".
  49. * But you can have your own way to store without any restriction.
  50. */
  51. val customFormatRDD = data.rdd.map(row => {
  52. row.toSeq.map(value => value.toString).mkString("|")
  53. })
  54. customFormatRDD.saveAsTextFile(path)
  55. }
  56. }

然后定义Schema以及读取数据代码。

  1. package cn.zj.spark.sql.datasource
  2. import org.apache.spark.rdd.RDD
  3. import org.apache.spark.sql.{Row, SQLContext}
  4. import org.apache.spark.sql.sources._
  5. import org.apache.spark.sql.types._
  6. /**
  7. * Created by rana on 29/9/16.
  8. */
  9. class CustomDatasourceRelation(override val sqlContext : SQLContext, path : String, userSchema : StructType)
  10. extends BaseRelation with TableScan with PrunedScan with PrunedFilteredScan with Serializable {
  11. override def schema: StructType = {
  12. if (userSchema != null) {
  13. userSchema
  14. } else {
  15. StructType(
  16. StructField("id", IntegerType, false) ::
  17. StructField("name", StringType, true) ::
  18. StructField("gender", StringType, true) ::
  19. StructField("salary", LongType, true) ::
  20. StructField("expenses", LongType, true) :: Nil
  21. )
  22. }
  23. }
  24. override def buildScan(): RDD[Row] = {
  25. println("TableScan: buildScan called...")
  26. val schemaFields = schema.fields
  27. // Reading the file's content
  28. val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(f => f._2)
  29. val rows = rdd.map(fileContent => {
  30. val lines = fileContent.split("\n")
  31. val data = lines.map(line => line.split(",").map(word => word.trim).toSeq)
  32. val tmp = data.map(words => words.zipWithIndex.map{
  33. case (value, index) =>
  34. val colName = schemaFields(index).name
  35. Util.castTo(if (colName.equalsIgnoreCase("gender")) {if(value.toInt == 1) "Male" else "Female"} else value,
  36. schemaFields(index).dataType)
  37. })
  38. tmp.map(s => Row.fromSeq(s))
  39. })
  40. rows.flatMap(e => e)
  41. }
  42. override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
  43. println("PrunedScan: buildScan called...")
  44. val schemaFields = schema.fields
  45. // Reading the file's content
  46. val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(f => f._2)
  47. val rows = rdd.map(fileContent => {
  48. val lines = fileContent.split("\n")
  49. val data = lines.map(line => line.split(",").map(word => word.trim).toSeq)
  50. val tmp = data.map(words => words.zipWithIndex.map{
  51. case (value, index) =>
  52. val colName = schemaFields(index).name
  53. val castedValue = Util.castTo(if (colName.equalsIgnoreCase("gender")) {if(value.toInt == 1) "Male" else "Female"} else value,
  54. schemaFields(index).dataType)
  55. if (requiredColumns.contains(colName)) Some(castedValue) else None
  56. })
  57. tmp.map(s => Row.fromSeq(s.filter(_.isDefined).map(value => value.get)))
  58. })
  59. rows.flatMap(e => e)
  60. }
  61. override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
  62. println("PrunedFilterScan: buildScan called...")
  63. println("Filters: ")
  64. filters.foreach(f => println(f.toString))
  65. var customFilters: Map[String, List[CustomFilter]] = Map[String, List[CustomFilter]]()
  66. filters.foreach( f => f match {
  67. case EqualTo(attr, value) =>
  68. println("EqualTo filter is used!!" + "Attribute: " + attr + " Value: " + value)
  69. /**
  70. * as we are implementing only one filter for now, you can think that this below line doesn't mak emuch sense
  71. * because any attribute can be equal to one value at a time. so what's the purpose of storing the same filter
  72. * again if there are.
  73. * but it will be useful when we have more than one filter on the same attribute. Take the below condition
  74. * for example:
  75. * attr > 5 && attr < 10
  76. * so for such cases, it's better to keep a list.
  77. * you can add some more filters in this code and try them. Here, we are implementing only equalTo filter
  78. * for understanding of this concept.
  79. */
  80. customFilters = customFilters ++ Map(attr -> {
  81. customFilters.getOrElse(attr, List[CustomFilter]()) :+ new CustomFilter(attr, value, "equalTo")
  82. })
  83. case _ => println("filter: " + f.toString + " is not implemented by us!!")
  84. })
  85. val schemaFields = schema.fields
  86. // Reading the file's content
  87. val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(f => f._2)
  88. val rows = rdd.map(file => {
  89. val lines = file.split("\n")
  90. val data = lines.map(line => line.split(",").map(word => word.trim).toSeq)
  91. val filteredData = data.map(s => if (customFilters.nonEmpty) {
  92. var includeInResultSet = true
  93. s.zipWithIndex.foreach {
  94. case (value, index) =>
  95. val attr = schemaFields(index).name
  96. val filtersList = customFilters.getOrElse(attr, List())
  97. if (filtersList.nonEmpty) {
  98. if (CustomFilter.applyFilters(filtersList, value, schema)) {
  99. } else {
  100. includeInResultSet = false
  101. }
  102. }
  103. }
  104. if (includeInResultSet) s else Seq()
  105. } else s)
  106. val tmp = filteredData.filter(_.nonEmpty).map(s => s.zipWithIndex.map {
  107. case (value, index) =>
  108. val colName = schemaFields(index).name
  109. val castedValue = Util.castTo(if (colName.equalsIgnoreCase("gender")) {
  110. if (value.toInt == 1) "Male" else "Female"
  111. } else value,
  112. schemaFields(index).dataType)
  113. if (requiredColumns.contains(colName)) Some(castedValue) else None
  114. })
  115. tmp.map(s => Row.fromSeq(s.filter(_.isDefined).map(value => value.get)))
  116. })
  117. rows.flatMap(e => e)
  118. }
  119. }

类型转换类

  1. package cn.zj.spark.sql.datasource
  2. import org.apache.spark.sql.types.{DataType, IntegerType, LongType, StringType}
  3. /**
  4. * Created by rana on 30/9/16.
  5. */
  6. object Util {
  7. def castTo(value : String, dataType : DataType) = {
  8. dataType match {
  9. case _ : IntegerType => value.toInt
  10. case _ : LongType => value.toLong
  11. case _ : StringType => value
  12. }
  13. }
  14. }

1.3 依赖的pom文件配置

  1. <properties>
  2. <maven.compiler.source>1.8</maven.compiler.source>
  3. <maven.compiler.target>1.8</maven.compiler.target>
  4. <scala.version>2.11.8</scala.version>
  5. <spark.version>2.2.0</spark.version>
  6. <!--<hadoop.version>2.6.0-cdh5.7.0</hadoop.version>-->
  7. <!--<hbase.version>1.2.0-cdh5.7.0</hbase.version>-->
  8. <encoding>UTF-8</encoding>
  9. </properties>
  10. <dependencies>
  11. <!-- 导入spark的依赖 -->
  12. <dependency>
  13. <groupId>org.apache.spark</groupId>
  14. <artifactId>spark-core_2.11</artifactId>
  15. <version>${spark.version}</version>
  16. </dependency>
  17. <!-- 导入spark的依赖 -->
  18. <!-- https://mvnrepository.com/artifact/org.apache.spark/spark-sql -->
  19. <dependency>
  20. <groupId>org.apache.spark</groupId>
  21. <artifactId>spark-sql_2.11</artifactId>
  22. <version>2.2.0</version>
  23. </dependency>
  24. </dependencies>

 

1.4 测试代码以及测试文件数据

  1. package cn.zj.spark.sql.datasource
  2. import org.apache.spark.SparkConf
  3. import org.apache.spark.sql.SparkSession
  4. /**
  5. * Created by rana on 29/9/16.
  6. */
  7. object app extends App {
  8. println("Application started...")
  9. val conf = new SparkConf().setAppName("spark-custom-datasource")
  10. val spark = SparkSession.builder().config(conf).master("local").getOrCreate()
  11. val df = spark.sqlContext.read.format("cn.zj.spark.sql.datasource").load("1229practice/data/")
  12. df.createOrReplaceTempView("test")
  13. spark.sql("select * from test where salary = 50000").show()
  14. println("Application Ended...")

  1. 10002, Alice Heady, 0, 20000, 8000
  2. 10003, Jenny Brown, 0, 30000, 120000
  3. 10004, Bob Hayden, 1, 40000, 16000
  4. 10005, Cindy Heady, 0, 50000, 20000
  5. 10006, Doug Brown, 1, 60000, 24000
  6. 10007, Carolina Hayden, 0, 70000, 280000

2. 基本架构

Geomesa在Spark SQL的基础上,利用其注册优化器的机制,将自己已有的基本功能作为外部数据源,注册进Spark SQL的Session当中。而其中主要实现位于GeomesaSparkSQL类当中,基本流程如下:

3. 各部分的功能及其主要参数

名称功能参数名数据
createRelation创建一个Relation对象,传给Spark SQL,作为Relation层parameterGeomesa.feature(sftName)
Hbase.zookeepers(zk集群的ip和host)
Hbase.catalog(catalogName)
Schema表中全量的字段(未作投影)
GeomesaRelation构建Relation对象的内部信息同上同上
buildScan构建查询数据的结构requiredColumns需要进行投影的字段,或者说此处真正需要返回给上层的字段
FiltersSpark SQL原生的过滤条件对象
filtsOpengis的过滤条件对象
getExtracutors进行最终的投影RequiredColumns同上
Schema全量的字段
SparkFilterToCQLFilter将Spark的过滤条件对象转化为CQL的过滤条件对象FiltersSpark SQL原生的过滤条件对象

4. 源码分析

4.1 createRelation

在这个方法当中,需要通过DataStoreFinder来获取相应的geomesa datastore,例如HBaseDataStore。局部变量有两个,前面的sqlContext是Spark sql自身的上下文对象,parameters里面则封装了geomesa相关的配置参数,例如geomesa hbase所需要的catalog、sftname以及zookeeper的host和post。

  1. override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
  2. SQLTypes.init(sqlContext)
  3. // TODO: Need different ways to retrieve sft
  4. // GEOMESA-1643 Add method to lookup SFT to RDD Provider
  5. // Below the details of the Converter RDD Provider and Providers which are backed by GT DSes are leaking through
  6. val ds = DataStoreFinder.getDataStore(parameters)
  7. val sft = if (ds != null) {
  8. try { ds.getSchema(parameters(GEOMESA_SQL_FEATURE)) } finally {
  9. ds.dispose()
  10. }
  11. } else {
  12. if (parameters.contains(GEOMESA_SQL_FEATURE) && parameters.contains("geomesa.sft")) {
  13. SimpleFeatureTypes.createType(parameters(GEOMESA_SQL_FEATURE), parameters("geomesa.sft"))
  14. } else {
  15. SftArgResolver.getArg(SftArgs(parameters(GEOMESA_SQL_FEATURE), parameters(GEOMESA_SQL_FEATURE))) match {
  16. case Right(s) => s
  17. case Left(e) => throw new IllegalArgumentException("Could not resolve simple feature type", e)
  18. }
  19. }
  20. }
  21. logger.trace(s"Creating GeoMesa Relation with sft : $sft")
  22. val schema = sft2StructType(sft)
  23. GeoMesaRelation(sqlContext, sft, schema, parameters)
  24. }

首先第一行中的SQLTypes.init方法引入了很多geomesa自己实现的类以及udf。接着就是对于datastore的初始化,此处程序对于传入的参数进行了判断,如果没有获取到ds,就会根据别的参数来获取对应的sft。最后就是构建全体变量的schema以及创建GeomesaRelation。

4.2 GeomesaRelation

在这个内部类当中,传入了很多查询参数,其中比较重要的为filt参数,这个参数是经过上层优化以后下推到此的opengis的Filter对象,初始化为INClUDE,即为无查询条件的全量搜索。

  1. case class GeoMesaRelation(sqlContext: SQLContext,
  2. sft: SimpleFeatureType,
  3. schema: StructType,
  4. params: Map[String, String],
  5. filt: org.opengis.filter.Filter = org.opengis.filter.Filter.INCLUDE,
  6. props: Option[Seq[String]] = None,
  7. var partitionHints : Seq[Int] = null,
  8. var indexRDD: RDD[GeoCQEngineDataStore] = null,
  9. var partitionedRDD: RDD[(Int, Iterable[SimpleFeature])] = null,
  10. var indexPartRDD: RDD[(Int, GeoCQEngineDataStore)] = null)
  11. extends BaseRelation with PrunedFilteredScan {
  12. val cache: Boolean = Try(params("cache").toBoolean).getOrElse(false)
  13. val indexId: Boolean = Try(params("indexId").toBoolean).getOrElse(false)
  14. val indexGeom: Boolean = Try(params("indexGeom").toBoolean).getOrElse(false)
  15. val numPartitions: Int = Try(params("partitions").toInt).getOrElse(sqlContext.sparkContext.defaultParallelism)
  16. val spatiallyPartition: Boolean = Try(params("spatial").toBoolean).getOrElse(false)
  17. val partitionStrategy: String = Try(params("strategy").toString).getOrElse("EQUAL")
  18. var partitionEnvelopes: List[Envelope] = null
  19. val providedBounds: String = Try(params("bounds").toString).getOrElse(null)
  20. val coverPartition: Boolean = Try(params("cover").toBoolean).getOrElse(false)
  21. // Control partitioning strategies that require a sample of the data
  22. val sampleSize: Int = Try(params("sampleSize").toInt).getOrElse(100)
  23. val thresholdMultiplier: Double = Try(params("threshold").toDouble).getOrElse(0.3)
  24. val initialQuery: String = Try(params("query").toString).getOrElse("INCLUDE")
  25. val geometryOrdinal: Int = sft.indexOf(sft.getGeometryDescriptor.getLocalName)

除此以外,在初始化这个类时,将存在于params当中的参数进行了抽取。例如是否有缓存,索引下标,是否进行空间分片以及分片策略等等。其中比较重要的是geometryOrdinal参数,现阶段geomesa spark sql是不支持没有空间字段的表的查询的,在查询过程当中会报错: Counld not Found SpatialRDDProvider,就是由于此处的geometryOrdinal参数为null。

4.3 buildScan

这个方法是GeomesaRelation类当中构建查询的方法。此时可以看到,spark sql已经将需要的列下推至此,还有geomesa spark sql的过滤条件以及opengis的过滤条件。在这个方法当中,首先将Spark SQL的原生过滤器转化成了CQL的过滤器,并将二者利用and拼接在一起。

接着将需要的列信息当中的__fid__字段去掉。然后调用GeoMesaSpark rdd来进行查询,最后将查询结果当中对应列的信息抽取出来,返回给上层的Spark SQL优化器。

  1. def buildScan(requiredColumns: Array[String],
  2. filters: Array[org.apache.spark.sql.sources.Filter],
  3. filt: org.opengis.filter.Filter,
  4. ctx: SparkContext,
  5. schema: StructType,
  6. params: Map[String, String]): RDD[Row] = {
  7. logger.debug(
  8. s"""Building scan, filt = $filt,
  9. |filters = ${filters.mkString(",")},
  10. |requiredColumns = ${requiredColumns.mkString(",")}""".stripMargin)
  11. val compiledCQL = filters.flatMap(SparkUtils.sparkFilterToCQLFilter).foldLeft[org.opengis.filter.Filter](filt) { (l, r) => ff.and(l, r) }
  12. val requiredAttributes = requiredColumns.filterNot(_ == "__fid__")
  13. val rdd = GeoMesaSpark(params).rdd(
  14. new Configuration(ctx.hadoopConfiguration), ctx, params,
  15. new Query(params(GEOMESA_SQL_FEATURE), compiledCQL, requiredAttributes))
  16. val extractors = SparkUtils.getExtractors(requiredColumns, schema)
  17. val result = rdd.map(SparkUtils.sf2row(schema, _, extractors))
  18. result.asInstanceOf[RDD[Row]]
  19. }

4.4 getExtractor

这个类用来将查询到的SimpleFeature对象根据投影的列信息来进行提取。在此处可以看到一个隐患,就是当程序对requiredColumns进行map操作时,程序根据col的索引来对数据进行提取,从数据结构上来说,这样的查询方式的时间复杂度为O(0),相对来说效率比较高,但是由于有__fid__字段的存在,在一些情况下可能会出现类型不匹配的问题。

  1. def getExtractors(requiredColumns: Array[String], schema: StructType): Array[SimpleFeature => AnyRef] = {
  2. val requiredAttributes = requiredColumns.filterNot(_ == "__fid__")
  3. type EXTRACTOR = SimpleFeature => AnyRef
  4. val IdExtractor: SimpleFeature => AnyRef = sf => sf.getID
  5. requiredColumns.map {
  6. case "__fid__" => IdExtractor
  7. case col =>
  8. val index = requiredAttributes.indexOf(col)
  9. val schemaIndex = schema.fieldIndex(col)
  10. val fieldType = schema.fields(schemaIndex).dataType
  11. if (fieldType == TimestampType) {
  12. sf: SimpleFeature => {
  13. val attr = sf.getAttribute(index)
  14. if (attr == null) { null } else {
  15. new Timestamp(attr.asInstanceOf[Date].getTime)
  16. }
  17. }
  18. } else {
  19. sf: SimpleFeature => sf.getAttribute(index)
  20. }
  21. }
  22. }

4.5 sparkFilterToCQLFilter

最后这个sparkFilterToCQLFilter方法是将spark sql自生的Filter转化为opengis的Filter的一个工具。

  1. def sparkFilterToCQLFilter(filt: org.apache.spark.sql.sources.Filter): Option[org.opengis.filter.Filter] = filt match {
  2. case GreaterThanOrEqual(attribute, v) => Some(ff.greaterOrEqual(ff.property(attribute), ff.literal(v)))
  3. case GreaterThan(attr, v) => Some(ff.greater(ff.property(attr), ff.literal(v)))
  4. case LessThanOrEqual(attr, v) => Some(ff.lessOrEqual(ff.property(attr), ff.literal(v)))
  5. case LessThan(attr, v) => Some(ff.less(ff.property(attr), ff.literal(v)))
  6. case EqualTo(attr, v) if attr == "__fid__" => Some(ff.id(ff.featureId(v.toString)))
  7. case EqualTo(attr, v) => Some(ff.equals(ff.property(attr), ff.literal(v)))
  8. case In(attr, values) if attr == "__fid__" => Some(ff.id(values.map(v => ff.featureId(v.toString)).toSet))
  9. case In(attr, values) =>
  10. Some(values.map(v => ff.equals(ff.property(attr), ff.literal(v))).reduce[org.opengis.filter.Filter]( (l,r) => ff.or(l,r)))
  11. case And(left, right) => Some(ff.and(sparkFilterToCQLFilter(left).get, sparkFilterToCQLFilter(right).get)) // TODO: can these be null
  12. case Or(left, right) => Some(ff.or(sparkFilterToCQLFilter(left).get, sparkFilterToCQLFilter(right).get))
  13. case Not(f) => Some(ff.not(sparkFilterToCQLFilter(f).get))
  14. case StringStartsWith(a, v) => Some(ff.like(ff.property(a), s"$v%"))
  15. case StringEndsWith(a, v) => Some(ff.like(ff.property(a), s"%$v"))
  16. case StringContains(a, v) => Some(ff.like(ff.property(a), s"%$v%"))
  17. case IsNull(attr) => None
  18. case IsNotNull(attr) => None
  19. }


共0条评论