diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala index 72c04022..aa79dcad 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala @@ -19,9 +19,10 @@ package za.co.absa.cobrix.spark.cobol.utils import com.fasterxml.jackson.databind.ObjectMapper import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, SparkSession} +import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession} import za.co.absa.cobrix.cobol.internal.Logging import za.co.absa.cobrix.spark.cobol.parameters.MetadataFields.MAX_ELEMENTS import za.co.absa.cobrix.spark.cobol.utils.impl.HofsWrapper.transform @@ -550,5 +551,48 @@ object SparkUtils extends Logging { fields.toList } + /** + * A UDF that receives the entire record as a [[Row]] and returns a + * human-readable string representation of its contents. + * + * Usage (after columns are combined into a struct): + * {{{ + * df.withColumn("record_dump", printRowUdf(struct(df.columns.map(col): _*))) + * }}} + */ + val printRowUdf: UserDefinedFunction = udf { row: Row => + def rowToString(r: Row): String = { + val schema = r.schema + val fields = schema.fields.zipWithIndex.map { case (field, idx) => + val value = if (r.isNullAt(idx)) { + "null" + } else { + r.get(idx) match { + case nestedRow: Row => + s"{${rowToString(nestedRow)}}" + case seq: Seq[_] => + val items = seq.map { + case nestedRow: Row => s"{${rowToString(nestedRow)}}" + case other => String.valueOf(other) + } + s"[${items.mkString(", ")}]" + case other => + String.valueOf(other) + } + } + s"${field.name}=$value" + } + fields.mkString(", ") + } + + if (row == null) { + null + } else { + val result = rowToString(row) + // Side-effect: print to stdout so the content is visible during tests + //println(s"[printRowUdf] $result") + result + } + } } diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/BasicRecordCombiner.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/BasicRecordCombiner.scala deleted file mode 100644 index fdc8554a..00000000 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/BasicRecordCombiner.scala +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Copyright 2018 ABSA Group Limited - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package za.co.absa.cobrix.spark.cobol.writer - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import za.co.absa.cobrix.cobol.parser.Copybook -import za.co.absa.cobrix.cobol.parser.ast.datatype.{Decimal, Integral} -import za.co.absa.cobrix.cobol.parser.ast.{Group, Primitive, Statement} -import za.co.absa.cobrix.cobol.parser.recordformats.RecordFormat -import za.co.absa.cobrix.cobol.reader.parameters.ReaderParameters -import za.co.absa.cobrix.cobol.reader.schema.CobolSchema - -class BasicRecordCombiner extends RecordCombiner { - - import BasicRecordCombiner._ - - override def combine(df: DataFrame, cobolSchema: CobolSchema, readerParameters: ReaderParameters): RDD[Array[Byte]] = { - val ast = getAst(cobolSchema) - val copybookFields = ast.children.filter { - case f if f.redefines.nonEmpty => false - case p: Primitive => !p.isFiller - case g: Group => !g.isFiller - case _ => true - } - - validateSchema(df, copybookFields.toSeq) - - val cobolFields = copybookFields.map(_.asInstanceOf[Primitive]) - val sparkFields = df.schema.fields.map(_.name.toLowerCase) - - cobolFields.foreach(cobolField => - if (cobolField.encode.isEmpty) { - val fieldDefinition = getFieldDefinition(cobolField) - throw new IllegalArgumentException(s"Field '${cobolField.name}' does not have an encoding defined in the copybook. " + - s"'PIC $fieldDefinition' is not yet supported.") - } - ) - - val sparkFieldPositions = cobolFields.zipWithIndex.map { case (cobolField, idx) => - val fieldName = cobolField.name.toLowerCase - val position = sparkFields.indexOf(fieldName) - - if (position < 0) { - throw new IllegalArgumentException(s"Field '${cobolField.name}' from the copybook is not found in the DataFrame schema.") - } - - (idx, position) - } - - val hasRdw = readerParameters.recordFormat == RecordFormat.VariableLength - val isRdwBigEndian = readerParameters.isRdwBigEndian - val adjustment1 = if (readerParameters.isRdwPartRecLength) 4 else 0 - val adjustment2 = readerParameters.rdwAdjustment - - val size = if (hasRdw) { - cobolSchema.getRecordSize + 4 - } else { - cobolSchema.getRecordSize - } - - val startOffset = if (hasRdw) 4 else 0 - - val recordLengthLong = cobolSchema.getRecordSize.toLong + adjustment1.toLong + adjustment2.toLong - if (recordLengthLong < 0) { - throw new IllegalArgumentException( - s"Invalid RDW length $recordLengthLong. Check 'is_rdw_part_of_record_length' and 'rdw_adjustment'." - ) - } - if (isRdwBigEndian && recordLengthLong > 0xFFFFL) { - throw new IllegalArgumentException( - s"RDW length $recordLengthLong exceeds 65535 and cannot be encoded in big-endian mode." - ) - } - if (!isRdwBigEndian && recordLengthLong > Int.MaxValue.toLong) { - throw new IllegalArgumentException( - s"RDW length $recordLengthLong exceeds ${Int.MaxValue} and cannot be encoded safely." - ) - } - val recordLength = recordLengthLong.toInt - - df.rdd.map { row => - val ar = new Array[Byte](size) - - if (hasRdw) { - if (isRdwBigEndian) { - ar(0) = ((recordLength >> 8) & 0xFF).toByte - ar(1) = (recordLength & 0xFF).toByte - // The last two bytes are reserved and defined by IBM as binary zeros on all platforms. - ar(2) = 0 - ar(3) = 0 - } else { - ar(0) = (recordLength & 0xFF).toByte - ar(1) = ((recordLength >> 8) & 0xFF).toByte - // This is non-standard. But so are little-endian RDW headers. - // As an advantage, it has no effect for small records but adds support for big records (> 64KB). - ar(2) = ((recordLength >> 16) & 0xFF).toByte - ar(3) = ((recordLength >> 24) & 0xFF).toByte - } - } - - sparkFieldPositions.foreach { case (cobolIdx, sparkIdx) => - if (!row.isNullAt(sparkIdx)) { - val fieldStr = row.get(sparkIdx) - val cobolField = cobolFields(cobolIdx) - Copybook.setPrimitiveField(cobolField, ar, fieldStr, startOffset) - } - } - - ar - } - } - - private def validateSchema(df: DataFrame, copybookFields: Seq[Statement]): Unit = { - val dfFields = df.schema.fields.map(_.name.toLowerCase).toSet - - val notFoundFields = copybookFields.flatMap { field => - if (dfFields.contains(field.name.toLowerCase)) { - None - } else { - Some(field.name) - } - } - - if (notFoundFields.nonEmpty) { - throw new IllegalArgumentException(s"The following fields from the copybook are not found in the DataFrame: ${notFoundFields.mkString(", ")}") - } - - val unsupportedDataTypeFields = copybookFields.filter { field => - field.isInstanceOf[Group] || - (field.isInstanceOf[Primitive] && field.asInstanceOf[Primitive].occurs.isDefined) || - field.redefines.nonEmpty - } - - if (unsupportedDataTypeFields.nonEmpty) { - throw new IllegalArgumentException(s"The following fields from the copybook are not supported by the 'spark-cobol' at the moment: " + - s"${unsupportedDataTypeFields.map(_.name).mkString(", ")}. Only primitive fields without redefines and occurs are supported.") - } - } - - private def getAst(cobolSchema: CobolSchema): Group = { - val rootAst = cobolSchema.copybook.ast - - if (rootAst.children.length == 1 && rootAst.children.head.isInstanceOf[Group]) { - rootAst.children.head.asInstanceOf[Group] - } else { - rootAst - } - } -} - -object BasicRecordCombiner { - def getFieldDefinition(field: Primitive): String = { - val pic = field.dataType.originalPic.getOrElse(field.dataType.pic) - - val usage = field.dataType match { - case dt: Integral => dt.compact.map(_.toString).getOrElse("USAGE IS DISPLAY") - case dt: Decimal => dt.compact.map(_.toString).getOrElse("USAGE IS DISPLAY") - case _ => "" - } - - s"$pic $usage".trim - } -} diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/NestedRecordCombiner.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/NestedRecordCombiner.scala new file mode 100644 index 00000000..44adf6d5 --- /dev/null +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/NestedRecordCombiner.scala @@ -0,0 +1,314 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.spark.cobol.writer + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{ArrayType, StructType} +import org.apache.spark.sql.{DataFrame, Row} +import org.slf4j.LoggerFactory +import za.co.absa.cobrix.cobol.parser.Copybook +import za.co.absa.cobrix.cobol.parser.ast.datatype.{Decimal, Integral} +import za.co.absa.cobrix.cobol.parser.ast.{Group, Primitive} +import za.co.absa.cobrix.cobol.parser.recordformats.RecordFormat +import za.co.absa.cobrix.cobol.reader.parameters.ReaderParameters +import za.co.absa.cobrix.cobol.reader.schema.CobolSchema +import za.co.absa.cobrix.spark.cobol.writer.WriterAst._ + +import scala.collection.mutable + +class NestedRecordCombiner extends RecordCombiner { + + import NestedRecordCombiner._ + + override def combine(df: DataFrame, cobolSchema: CobolSchema, readerParameters: ReaderParameters): RDD[Array[Byte]] = { + val hasRdw = readerParameters.recordFormat == RecordFormat.VariableLength + val isRdwBigEndian = readerParameters.isRdwBigEndian + val adjustment1 = if (readerParameters.isRdwPartRecLength) 4 else 0 + val adjustment2 = readerParameters.rdwAdjustment + + val size = if (hasRdw) { + cobolSchema.getRecordSize + 4 + } else { + cobolSchema.getRecordSize + } + + val startOffset = if (hasRdw) 4 else 0 + + val recordLengthLong = cobolSchema.getRecordSize.toLong + adjustment1.toLong + adjustment2.toLong + if (recordLengthLong < 0) { + throw new IllegalArgumentException( + s"Invalid RDW length $recordLengthLong. Check 'is_rdw_part_of_record_length' and 'rdw_adjustment'." + ) + } + if (isRdwBigEndian && recordLengthLong > 0xFFFFL) { + throw new IllegalArgumentException( + s"RDW length $recordLengthLong exceeds 65535 and cannot be encoded in big-endian mode." + ) + } + if (!isRdwBigEndian && recordLengthLong > Int.MaxValue.toLong) { + throw new IllegalArgumentException( + s"RDW length $recordLengthLong exceeds ${Int.MaxValue} and cannot be encoded safely." + ) + } + val recordLength = recordLengthLong.toInt + + processRDD(df.rdd, cobolSchema.copybook, df.schema, size, recordLength, startOffset, hasRdw, isRdwBigEndian) + } +} + +object NestedRecordCombiner { + private val log = LoggerFactory.getLogger(this.getClass) + + def getFieldDefinition(field: Primitive): String = { + val pic = field.dataType.originalPic.getOrElse(field.dataType.pic) + + val usage = field.dataType match { + case dt: Integral => dt.compact.map(_.toString).getOrElse("USAGE IS DISPLAY") + case dt: Decimal => dt.compact.map(_.toString).getOrElse("USAGE IS DISPLAY") + case _ => "" + } + + s"$pic $usage".trim + } + + def constructWriterAst(copybook: Copybook, schema: StructType): GroupField = { + buildGroupField(getAst(copybook), schema, row => row) + } + + def processRDD(rdd: RDD[Row], copybook: Copybook, schema: StructType, recordSize: Int, recordLengthHeader: Int, startOffset: Int, hasRdw: Boolean, isRdwBigEndian: Boolean): RDD[Array[Byte]] = { + val writerAst = constructWriterAst(copybook, schema) + + rdd.mapPartitions { rows => + rows.map { row => + val ar = new Array[Byte](recordSize) + + if (hasRdw) { + if (isRdwBigEndian) { + ar(0) = ((recordLengthHeader >> 8) & 0xFF).toByte + ar(1) = (recordLengthHeader & 0xFF).toByte + // The last two bytes are reserved and defined by IBM as binary zeros on all platforms. + ar(2) = 0 + ar(3) = 0 + } else { + ar(0) = (recordLengthHeader & 0xFF).toByte + ar(1) = ((recordLengthHeader >> 8) & 0xFF).toByte + // This is non-standard. But so are little-endian RDW headers. + // As an advantage, it has no effect for small records but adds support for big records (> 64KB). + ar(2) = ((recordLengthHeader >> 16) & 0xFF).toByte + ar(3) = ((recordLengthHeader >> 24) & 0xFF).toByte + } + } + + writeToBytes(writerAst, row, ar, startOffset) + + ar + } + } + } + + def getAst(copybook: Copybook): Group = { + val rootAst = copybook.ast + + if (rootAst.children.length == 1 && rootAst.children.head.isInstanceOf[Group]) { + rootAst.children.head.asInstanceOf[Group] + } else { + rootAst + } + } + + /** + * Recursively walks the copybook group and the Spark StructType in lockstep, producing + * [[WriterAst]] nodes whose getters extract the correct value from a [[org.apache.spark.sql.Row]]. + * + * @param group A copybook Group node whose children will be processed. + * @param schema The Spark StructType that corresponds to `group`. + * @param getter A function that, given the "outer" Row, returns the Row that belongs to this group. + * @param path The path to the field + * @return A [[GroupField]] covering all non-filler, non-redefines children found in both + * the copybook and the Spark schema. + */ + private def buildGroupField(group: Group, schema: StructType, getter: GroupGetter, path: String = ""): GroupField = { + val children = group.children.withFilter { stmt => + stmt.redefines.isEmpty + }.map { + case s if s.isFiller => Filler(s.binaryProperties.actualSize) + case p: Primitive => buildPrimitiveNode(p, schema, path) + case g: Group => buildGroupNode(g, schema, path) + } + GroupField(children.toSeq, group, getter) + } + + /** + * Builds a [[WriterAst]] node for a primitive copybook field, using the field's index in the + * supplied Spark schema to create a getter function. + * + * Returns a filler when the field is absent from the schema (e.g. filtered out during reading). + */ + private def buildPrimitiveNode(p: Primitive, schema: StructType, path: String = ""): WriterAst = { + val fieldName = p.name + val fieldIndexOpt = schema.fields.zipWithIndex.find { case (field, _) => + field.name.equalsIgnoreCase(fieldName) + }.map(_._2) + + fieldIndexOpt.map { idx => + if (p.encode.isEmpty) { + val fieldDefinition = getFieldDefinition(p) + throw new IllegalArgumentException(s"Field '${p.name}' does not have an encoding defined in the copybook. " + + s"'PIC $fieldDefinition' is not yet supported.") + } + if (p.occurs.isDefined) { + // Array of primitives + PrimitiveArray(p, row => row.getAs[mutable.WrappedArray[AnyRef]](idx)) + } else { + PrimitiveField(p, row => row.get(idx)) + } + }.getOrElse { + log.error(s"Field '$path${p.name}' is not found in Spark schema. Will be replaced by filler.") + Filler(p.binaryProperties.actualSize) + } + } + + /** + * Builds a [[WriterAst]] node for a group copybook field. For groups with OCCURS the getter + * extracts an array; for plain groups it extracts the nested Row. In both cases the children + * are built by recursing into the nested Spark StructType. + * + * Returns a filler when the field is absent from the schema. + */ + private def buildGroupNode(g: Group, schema: StructType, path: String = ""): WriterAst = { + val fieldName = g.name + val fieldIndexOpt = schema.fields.zipWithIndex.find { case (field, _) => + field.name.equalsIgnoreCase(fieldName) + }.map(_._2) + + fieldIndexOpt.map { idx => + if (g.occurs.isDefined) { + // Array of structs – the element type must be a StructType + schema(idx).dataType match { + case ArrayType(elementType: StructType, _) => + val childAst = buildGroupField(g, elementType, row => row, s"$path${g.name}.") + GroupArray(childAst, g, row => row.getAs[mutable.WrappedArray[AnyRef]](idx)) + case other => + throw new IllegalArgumentException( + s"Expected ArrayType(StructType) for group field '${g.name}' with OCCURS, but got $other") + } + } else { + // Nested struct + schema(idx).dataType match { + case nestedSchema: StructType => + val childGetter: GroupGetter = row => row.getAs[Row](idx) + val childAst = buildGroupField(g, nestedSchema, childGetter, s"$path${g.name}.") + GroupField(childAst.children, g, childGetter) + case other => + throw new IllegalArgumentException( + s"Expected StructType for group field '${g.name}', but got $other") + } + } + }.getOrElse { + log.error(s"Field '$path${g.name}' is not found in Spark schema. Will be replaced by filler.") + Filler(g.binaryProperties.actualSize) + } + } + + /** + * Recursively walks `ast` and writes every primitive value from `row` into `ar`. + * + * For plain (non-array) fields the `configuredStartOffset` is forwarded directly to + * [[Copybook.setPrimitiveField]], which adds it to `field.binaryProperties.offset`. + * + * For array fields (both primitive and group-of-primitives) each element is written + * using the `fieldStartOffsetOverride` parameter so the exact byte position can be + * supplied. The row array may contain fewer elements than the copybook allows — any + * missing tail elements are silently skipped, leaving those bytes as zeroes. + * + * @param ast The [[WriterAst]] node to process. + * @param row The Spark [[Row]] from which values are read. + * @param ar The target byte array (record buffer). + * @param currentOffset RDW prefix length (0 for fixed-length records, 4 for variable). + */ + private def writeToBytes(ast: WriterAst, row: Row, ar: Array[Byte], currentOffset: Int): Int = { + ast match { + // ── Filler ────────────────────────────────────────────────────── + case Filler(size) => size + + // ── Plain primitive ────────────────────────────────────────────────────── + case PrimitiveField(cobolField, getter) => + val value = getter(row) + if (value != null) { + Copybook.setPrimitiveField(cobolField, ar, value, 0, currentOffset) + } + cobolField.binaryProperties.actualSize + + // ── Plain nested group ─────────────────────────────────────────────────── + case GroupField(children, cobolField, getter) => + val nestedRow = getter(row) + if (nestedRow != null) { + var writtenBytes = 0 + children.foreach(child => + writtenBytes += writeToBytes(child, nestedRow, ar, currentOffset + writtenBytes) + ) + } + cobolField.binaryProperties.actualSize + + // ── Array of primitives (OCCURS on a primitive field) ─────────────────── + case PrimitiveArray(cobolField, arrayGetter) => + val arr = arrayGetter(row) + if (arr != null) { + val maxElements = cobolField.arrayMaxSize // copybook upper bound + val elementSize = cobolField.binaryProperties.dataSize + val baseOffset = currentOffset + val elementsToWrite = math.min(arr.length, maxElements) + + var i = 0 + while (i < elementsToWrite) { + val value = arr(i) + if (value != null) { + val elementOffset = baseOffset + i * elementSize + // fieldStartOffsetOverride is the absolute position; pass it so + // setPrimitiveField does not add binaryProperties.offset on top again. + Copybook.setPrimitiveField(cobolField, ar, value, fieldStartOffsetOverride = elementOffset) + } + i += 1 + } + } + cobolField.binaryProperties.actualSize + + // ── Array of groups (OCCURS on a group field) ─────────────────────────── + case GroupArray(groupField: GroupField, cobolField, arrayGetter) => + val arr = arrayGetter(row) + if (arr != null) { + val maxElements = cobolField.arrayMaxSize // copybook upper bound + val elementSize = cobolField.binaryProperties.dataSize + val baseOffset = currentOffset + val elementsToWrite = math.min(arr.length, maxElements) + + var i = 0 + while (i < elementsToWrite) { + val elementRow = arr(i).asInstanceOf[Row] + if (elementRow != null) { + // Build an adjusted element offset so that each child's base offset + // (which is relative to the group's base) lands at the correct position in ar. + val elementStartOffset = baseOffset + i * elementSize + writeToBytes(groupField, elementRow, ar, elementStartOffset) + } + i += 1 + } + } + cobolField.binaryProperties.actualSize + } + } +} diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/RecordCombinerSelector.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/RecordCombinerSelector.scala index a016ac9c..77703964 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/RecordCombinerSelector.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/RecordCombinerSelector.scala @@ -32,7 +32,7 @@ object RecordCombinerSelector { * @return A `RecordCombiner` implementation suitable for combining records based on the given schema and parameters. */ def selectCombiner(cobolSchema: CobolSchema, readerParameters: ReaderParameters): RecordCombiner = { - new BasicRecordCombiner + new NestedRecordCombiner } } diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/WriterAst.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/WriterAst.scala new file mode 100644 index 00000000..8e280a09 --- /dev/null +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/writer/WriterAst.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.spark.cobol.writer + +import org.apache.spark.sql.Row +import za.co.absa.cobrix.cobol.parser.ast.{Group, Primitive} + +import scala.collection.mutable + +sealed trait WriterAst + +object WriterAst { + type PrimitiveGetter = Row => Any + type GroupGetter = Row => Row + type ArrayGetter = Row => mutable.WrappedArray[AnyRef] + + case class Filler(fillerSize: Int) extends WriterAst + case class PrimitiveField(cobolField: Primitive, getter: PrimitiveGetter) extends WriterAst + case class GroupField(children: Seq[WriterAst], cobolField: Group, getter: GroupGetter) extends WriterAst + case class PrimitiveArray(cobolField: Primitive, arrayGetter: ArrayGetter) extends WriterAst + case class GroupArray(groupField: GroupField, cobolField: Group, arrayGetter: ArrayGetter) extends WriterAst +} diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala index 51c8041b..676c9b09 100644 --- a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtilsSuite.scala @@ -16,6 +16,7 @@ package za.co.absa.cobrix.spark.cobol.utils +import org.apache.spark.sql.functions.{col, struct} import org.apache.spark.sql.types._ import org.scalatest.funsuite.AnyFunSuite import org.slf4j.LoggerFactory @@ -867,6 +868,23 @@ class SparkUtilsSuite extends AnyFunSuite with SparkTestBase with BinaryFileFixt } } + test("printRowUdf dumps a Spark record as a single field") { + val expectedData = + """[ { + | "record" : "id=1, value=a" + |}, { + | "record" : "id=2, value=b" + |}, { + | "record" : "id=3, value=c" + |} ]""".stripMargin.replace("\r\n", "\n") + + val df = List((1, "a"), (2, "b"), (3, "c")).toDF("id", "value") + val convertedDf = df.select(SparkUtils.printRowUdf(struct(df.columns.map(col): _*)).as("record")) + val actualData = SparkUtils.convertDataFrameToPrettyJSON(convertedDf) + + assertResults(actualData, expectedData) + } + private def assertSchema(actualSchema: String, expectedSchema: String): Unit = { if (actualSchema != expectedSchema) { logger.error(s"EXPECTED:\n$expectedSchema") diff --git a/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/writer/NestedWriterSuite.scala b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/writer/NestedWriterSuite.scala new file mode 100644 index 00000000..92ea8107 --- /dev/null +++ b/spark-cobol/src/test/scala/za/co/absa/cobrix/spark/cobol/writer/NestedWriterSuite.scala @@ -0,0 +1,133 @@ +/* + * Copyright 2018 ABSA Group Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package za.co.absa.cobrix.spark.cobol.writer + +import org.apache.hadoop.fs.Path +import org.apache.spark.sql.SaveMode +import org.scalatest.wordspec.AnyWordSpec +import za.co.absa.cobrix.spark.cobol.source.base.SparkTestBase +import za.co.absa.cobrix.spark.cobol.source.fixtures.{BinaryFileFixture, TextComparisonFixture} + +class NestedWriterSuite extends AnyWordSpec with SparkTestBase with BinaryFileFixture with TextComparisonFixture { + private val copybook = + """ 01 RECORD. + | 05 ID PIC 9(2). + | 05 FILLER PIC 9(1). + | 05 CNT1 PIC 9(1). + | 05 NUMBERS PIC 9(2) + | OCCURS 0 TO 5 DEPENDING ON CNT1. + | 05 PLACE. + | 10 COUNTRY-CODE PIC X(2). + | 10 CITY PIC X(10). + | 05 CNT2 PIC 9(1). + | 05 PEOPLE + | OCCURS 0 TO 3 DEPENDING ON CNT2. + | 10 NAME PIC X(14). + | 10 FILLER PIC X(1). + | 10 PHONE-NUMBER PIC X(12). + |""".stripMargin + + "writer" should { + "write the dataframe according to the copybook" in { + //val parsedCopybook = CopybookParser.parse(copybook) + //println(parsedCopybook.generateRecordLayoutPositions()) + + val exampleJsons = Seq( + """{"ID":1,"cnt1":3,"NUMBERS":[10,20,30],"PLACE":{"COUNTRY_CODE":"US","CITY":"New York"},"CNT2":2,"PEOPLE":[{"NAME":"John Doe","PHONE_NUMBER":"555-1234"},{"NAME": "Jane Smith","PHONE_NUMBER":"555-5678"}]}""", + """{"ID":2,"cnt1":0,"NUMBERS":[],"PLACE":{"COUNTRY_CODE":"ZA","CITY":"Cape Town"},"CNT2":1,"PEOPLE":[{"NAME":"Test User","PHONE_NUMBER":"555-1235"}]}""" + ) + + import spark.implicits._ + + val df = spark.read.json(exampleJsons.toDS()) + .select("ID", "cnt1", "NUMBERS", "PLACE", "CNT2", "PEOPLE") + + // df.printSchema() + // df.show() + // val ast = NestedRecordCombiner.constructWriterAst(parsedCopybook, df.schema) + // println(ast) + // Apply the UDF to the full record by packing all columns into a struct + //val dfWithDump = df.withColumn( + // "record_dump", + // printRowUdf(struct(df.columns.map(col): _*)) + //) + //dfWithDump.select("record_dump").show(truncate = false) + + withTempDirectory("cobol_writer1") { tempDir => + val path = new Path(tempDir, "writer1") + + df.coalesce(1) + .orderBy("id") + .write + .format("cobol") + .mode(SaveMode.Overwrite) + .option("copybook_contents", copybook) + .option("record_format", "V") + .option("is_rdw_big_endian", "false") + .option("is_rdw_part_of_record_length", "true") + .save(path.toString) + + //val df2 = spark.read.format("cobol") + // .option("copybook_contents", copybook) + // .option("record_format", "V") + // .option("is_rdw_big_endian", "false") + // .option("is_rdw_part_of_record_length", "true") + // .load(path.toString) + //println(SparkUtils.convertDataFrameToPrettyJSON(df2)) + + val fs = path.getFileSystem(spark.sparkContext.hadoopConfiguration) + + assert(fs.exists(path), "Output directory should exist") + val files = fs.listStatus(path) + .filter(_.getPath.getName.startsWith("part-")) + assert(files.nonEmpty, "Output directory should contain part files") + + val partFile = files.head.getPath + val data = fs.open(partFile) + val bytes = new Array[Byte](files.head.getLen.toInt) + data.readFully(bytes) + data.close() + + // Expected EBCDIC data for sample test data + val expected = Array( + 0x70, 0x00, 0x00, 0x00, // RDW record 0 + 0xF0, 0xF1, 0x00, 0xF3, 0xF1, 0xF0, 0xF2, 0xF0, 0xF3, 0xF0, 0x00, 0x00, 0x00, 0x00, 0xE4, 0xE2, 0xD5, 0x85, + 0xA6, 0x40, 0xE8, 0x96, 0x99, 0x92, 0x40, 0x40, 0xF2, 0xD1, 0x96, 0x88, 0x95, 0x40, 0xC4, 0x96, 0x85, 0x40, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x00, 0xF5, 0xF5, 0xF5, 0xCA, 0xF1, 0xF2, 0xF3, 0xF4, 0x40, 0x40, 0x40, 0x40, + 0xD1, 0x81, 0x95, 0x85, 0x40, 0xE2, 0x94, 0x89, 0xA3, 0x88, 0x40, 0x40, 0x40, 0x40, 0x00, 0xF5, 0xF5, 0xF5, + 0xCA, 0xF5, 0xF6, 0xF7, 0xF8, 0x40, 0x40, 0x40, 0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x70, 0x00, 0x00, 0x00, // RDW record 1 + 0xF0, 0xF2, 0x00, 0xF0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xE9, 0xC1, 0xC3, 0x81, + 0x97, 0x85, 0x40, 0xE3, 0x96, 0xA6, 0x95, 0x40, 0xF1, 0xE3, 0x85, 0xA2, 0xA3, 0x40, 0xE4, 0xA2, 0x85, 0x99, + 0x40, 0x40, 0x40, 0x40, 0x40, 0x00, 0xF5, 0xF5, 0xF5, 0xCA, 0xF1, 0xF2, 0xF3, 0xF5, 0x40, 0x40, 0x40, 0x40, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + ).map(_.toByte) + + if (!bytes.sameElements(expected)) { + println(s"Expected bytes: ${expected.map("%02X" format _).mkString(" ")}") + println(s"Actual bytes: ${bytes.map("%02X" format _).mkString(" ")}") + //println(s"Actual bytes: ${bytes.map("0x%02X" format _).mkString(", ")}") + + assert(bytes.sameElements(expected), "Written data should match expected EBCDIC encoding") + } + } + } + } +}