在Apache Spark 2.4中引入了一个新的内置数据源, 图像数据源.用户可以通过DataFrame API加载指定目录的中图像文件,生成一个DataFrame对象.通过该DataFrame对象,用户可以对图像数据进行简单的处理,然后使用MLlib进行特定的训练和分类计算. 本文将介绍图像数据源的实现细节和使用方法.
简单使用先通过一个例子来简单的了解下图像数据源使用方法. 本例设定有一组图像文件存放在阿里云的OSS上, 需要对这组图像加水印,并压缩存储到parquet文件中. 废话不说,先上代码:
// 为了突出重点,代码简化图像格式相关的处理逻辑
def main(args: Array[String]): Unit = {
val conf = new SparkConf().setMaster("local[*]")
val spark = SparkSession.builder()
.config(conf)
.getOrCreate()
val imageDF = spark.read.format("image").load("oss:///path/to/src/dir")
imageDF.select("image.origin", "image.width", "image.height", "image.nChannels", "image.mode", "image.data")
.map(row => {
val origin = row.getAs[String]("origin")
val width = row.getAs[Int]("width")
val height = row.getAs[Int]("height")
val mode = row.getAs[Int]("mode")
val nChannels = row.getAs[Int]("nChannels")
val data = row.getAs[Array[Byte]]("data")
Row(Row(origin, height, width, nChannels, mode,
markWithText(width, height, BufferedImage.TYPE_3BYTE_BGR, data, "EMR")))
}).write.format("parquet").save("oss:///path/to/dst/dir")
}
def markWithText(width: Int, height: Int, imageType: Int, data: Array[Byte], text: String): Array[Byte] = {
val image = new BufferedImage(width, height, imageType)
val raster = image.getData.asInstanceOf[WritableRaster]
val pixels = data.map(_.toInt)
raster.setPixels(0, 0, width, height, pixels)
image.setData(raster)
val buffImg = new BufferedImage(width, height, imageType)
val g = buffImg.createGraphics
g.drawImage(image, 0, 0, null)
g.setColor(Color.red)
g.setFont(new Font("宋体", Font.BOLD, 30))
g.drawString(text, width/2, height/2)
g.dispose()
val buffer = new ByteArrayOutputStream
ImageIO.write(buffImg, "JPG", buffer)
buffer.toByteArray
}
从生成的parquet文件中抽取一条图像二进制数据,保存为本地jpg,效果如下:
图1 左图为原始图像,右图为处理后的图像
你可能注意到两个图像到颜色并不相同,这是因为Spark的图像数据将图像解码为BGR顺序的数据,而示例程序在保存的时候,没有处理这个变换,导致颜色出现了反差.
实现初窥下面我们深入到spark源码中来看一下实现细节.Apache Spark内置图像数据源的实现代码在spark-mllib这个模块中.主要包括两个类:
- org.apache.spark.ml.image.ImageSchema
- org.apache.spark.ml.source.image.ImageFileFormat
其中,ImageSchema定义了图像文件加载为DataFrame的Row的格式和解码方法.ImageFileFormat提供了面向存储层的读写接口.
格式定义一个图像文件被加载为DataFrame后,对应的如下:
val columnSchema = StructType(
StructField("origin", StringType, true) ::
StructField("height", IntegerType, false) ::
StructField("width", IntegerType, false) ::
StructField("nChannels", IntegerType, false) ::
// OpenCV-compatible type: CV_8UC3 in most cases
StructField("mode", IntegerType, false) ::
// Bytes in OpenCV-compatible order: row-wise BGR in most cases
StructField("data", BinaryType, false) :: Nil)
val imageFields: Array[String] = columnSchema.fieldNames
val imageSchema = StructType(StructField("image", columnSchema, true) :: Nil)
如果将该DataFrame打印出来,可以得到如下形式的表:
+--------------------+-----------+------------+---------------+----------+-------------------+
|image.origin |image.width|image.height|image.nChannels|image.mode|image.data |
+--------------------+-----------+------------+---------------+----------+-------------------+
|oss://.../dir/1.jpg |600 |343 |3 |16 |55 45 21 56 ... |
+--------------------+-----------+------------+---------------+----------+-------------------+
其中:
- origin: 原始图像文件的路径
- width: 图像的宽度,单位像素
- height: 图像的高度,单位像素
- nChannels: 图像的通道数, 如常见的RGB位图为通道数为3
- mode: 像素矩阵(data)中元素的数值类型和通道顺序, 与OpenCV的类型兼容
- data: 解码后的像素矩阵
提示: 关于图像的基础支持,可以参考如下文档: Image file reading and writing
加载和解码图像文件通过ImageFileFormat加载为一个Row对象.
// 文件: ImageFileFormat.scala
// 为了简化说明起见,代码有删减和改动
private[image] class ImageFileFormat extends FileFormat with DataSourceRegister {
......
override def prepareWrite(
sparkSession: SparkSession,
job: Job,
options: Map[String, String],
dataSchema: StructType): OutputWriterFactory = {
throw new UnsupportedOperationException("Write is not supported for image data source")
}
override protected def buildReader(
sparkSession: SparkSession,
dataSchema: StructType,
partitionSchema: StructType,
requiredSchema: StructType,
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): (PartitionedFile) => Iterator[InternalRow] = {
......
(file: PartitionedFile) => {
......
val path = new Path(origin)
val stream = fs.open(path)
val bytes = ByteStreams.toByteArray(stream)
val resultOpt = ImageSchema.decode(origin, bytes) // converter.toRow(row))
......
}
}
}
}
从上可以看出:
- 当前的图像数据源实现并不支持保存操作;
- 图像数据的解码工作在ImageSchema中完成.
下面来看一下具体的解码过程:
// 文件: ImageSchema.scala
// 为了简化说明起见,代码有删减和改动
private[spark] def decode(origin: String, bytes: Array[Byte]): Option[Row] = {
// 使用ImageIO加载原始图像数据
val img = ImageIO.read(new ByteArrayInputStream(bytes))
if (img != null) {
// 获取图像的基本属性
val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY
val hasAlpha = img.getColorModel.hasAlpha
val height = img.getHeight
val width = img.getWidth
// ImageIO::ImageType -> OpenCV Type
val (nChannels, mode) = if (isGray) {
(1, ocvTypes("CV_8UC1"))
} else if (hasAlpha) {
(4, ocvTypes("CV_8UC4"))
} else {
(3, ocvTypes("CV_8UC3"))
}
// 解码
val imageSize = height * width * nChannels
// 用于存储解码后的像素矩阵
val decoded = Array.ofDim[Byte](imageSize)
if (isGray) {
// 处理单通道图像
...
} else {
// 处理多通道图像
var offset = 0
for (h
关注
打赏
最近更新
- 深拷贝和浅拷贝的区别(重点)
- 【Vue】走进Vue框架世界
- 【云服务器】项目部署—搭建网站—vue电商后台管理系统
- 【React介绍】 一文带你深入React
- 【React】React组件实例的三大属性之state,props,refs(你学废了吗)
- 【脚手架VueCLI】从零开始,创建一个VUE项目
- 【React】深入理解React组件生命周期----图文详解(含代码)
- 【React】DOM的Diffing算法是什么?以及DOM中key的作用----经典面试题
- 【React】1_使用React脚手架创建项目步骤--------详解(含项目结构说明)
- 【React】2_如何使用react脚手架写一个简单的页面?