Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package za.co.absa.cobrix.cobol.parser.asttransform

import org.slf4j.LoggerFactory
import za.co.absa.cobrix.cobol.parser.CopybookParser.CopybookAST
import za.co.absa.cobrix.cobol.parser.ast.datatype.Integral
import za.co.absa.cobrix.cobol.parser.ast.datatype.{Decimal, Integral}
import za.co.absa.cobrix.cobol.parser.ast.{Group, Primitive, Statement}

import scala.collection.mutable
Expand Down Expand Up @@ -96,6 +96,7 @@ class DependencyMarker(
val newPrimitive = if (dependees contains primitive) {
primitive.dataType match {
case _: Integral => true
case d: Decimal if d.scale == 0 => true
case dt =>
for (stmt <- dependees(primitive)) {
if (stmt.dependingOnHandlers.isEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import za.co.absa.cobrix.cobol.reader.iterator.RecordLengthExpression
import za.co.absa.cobrix.cobol.reader.parameters.ReaderParameters
import za.co.absa.cobrix.cobol.reader.validator.ReaderParametersValidator

import scala.util.Try

class FixedWithRecordLengthExprRawRecordExtractor(ctx: RawRecordContext,
readerProperties: ReaderParameters) extends Serializable with RawRecordExtractor {
private val log = LoggerFactory.getLogger(this.getClass)
Expand Down Expand Up @@ -121,19 +123,21 @@ class FixedWithRecordLengthExprRawRecordExtractor(ctx: RawRecordContext,
final private def getRecordLengthFromField(lengthAST: Primitive, binaryDataStart: Array[Byte]): Int = {
val length = if (isLengthMapEmpty) {
ctx.copybook.extractPrimitiveField(lengthAST, binaryDataStart, readerProperties.startOffset) match {
case i: Int => i
case l: Long => l.toInt
case s: String => s.toInt
case null => throw new IllegalStateException(s"Null encountered as a record length field (offset: $byteIndex, raw value: ${getBytesAsHexString(binaryDataStart)}).")
case _ => throw new IllegalStateException(s"Record length value of the field ${lengthAST.name} must be an integral type.")
case i: Int => i
case l: Long => l.toInt
case s: String => Try{ s.toInt }.getOrElse(throw new IllegalStateException(s"Record length value of the field ${lengthAST.name} must be an integral type, encountered: '$s'."))
case d: BigDecimal => d.toInt
Copy link

Copilot AI May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The BigDecimal branch unconditionally truncates values, which can silently drop fractional parts. Consider validating that d.scale == 0 or using d.intValueExact() to ensure an error is thrown for non-zero scale.

Suggested change
case d: BigDecimal => d.toInt
case d: BigDecimal => Try { d.intValueExact() }.getOrElse(throw new IllegalStateException(s"Record length value of the field ${lengthAST.name} must be an integral type, encountered: '$d'."))

Copilot uses AI. Check for mistakes.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested toInt vs toIntExact. The behavior of toInt is still what I want in this context. When decimals are used as record lengths teh fractional part needs to be truncated, not rounded.

case null => throw new IllegalStateException(s"Null encountered as a record length field (offset: $byteIndex, raw value: ${getBytesAsHexString(binaryDataStart)}).")
case _ => throw new IllegalStateException(s"Record length value of the field ${lengthAST.name} must be an integral type.")
}
} else {
ctx.copybook.extractPrimitiveField(lengthAST, binaryDataStart, readerProperties.startOffset) match {
case i: Int => getRecordLengthFromMapping(i.toString)
case l: Long => getRecordLengthFromMapping(l.toString)
case s: String => getRecordLengthFromMapping(s)
case null => defaultRecordLength.getOrElse(throw new IllegalStateException(s"Null encountered as a record length field (offset: $byteIndex, raw value: ${getBytesAsHexString(binaryDataStart)})."))
case _ => throw new IllegalStateException(s"Record length value of the field ${lengthAST.name} must be an integral type.")
case i: Int => getRecordLengthFromMapping(i.toString)
case l: Long => getRecordLengthFromMapping(l.toString)
case d: BigDecimal => getRecordLengthFromMapping(d.toString())
case s: String => getRecordLengthFromMapping(s)
case null => defaultRecordLength.getOrElse(throw new IllegalStateException(s"Null encountered as a record length field (offset: $byteIndex, raw value: ${getBytesAsHexString(binaryDataStart)})."))
case _ => throw new IllegalStateException(s"Record length value of the field ${lengthAST.name} must be an integral type.")
}
}
length + recordLengthAdjustment
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package za.co.absa.cobrix.cobol.reader.validator

import org.slf4j.LoggerFactory
import za.co.absa.cobrix.cobol.parser.Copybook
import za.co.absa.cobrix.cobol.parser.ast.Primitive
import za.co.absa.cobrix.cobol.parser.expression.NumberExprEvaluator
Expand All @@ -25,6 +26,7 @@ import za.co.absa.cobrix.cobol.reader.parameters.MultisegmentParameters
import scala.util.Try

object ReaderParametersValidator {
private val log = LoggerFactory.getLogger(this.getClass)

def getEitherFieldAndExpression(fieldOrExpressionOpt: Option[String], recordLengthMap: Map[String, Int], cobolSchema: Copybook): (Option[RecordLengthField], Option[RecordLengthExpression]) = {
fieldOrExpressionOpt match {
Expand All @@ -49,7 +51,7 @@ object ReaderParametersValidator {
val astNode = field match {
case s: Primitive =>
if (!s.dataType.isInstanceOf[za.co.absa.cobrix.cobol.parser.ast.datatype.Integral] && recordLengthMap.isEmpty) {
throw new IllegalStateException(s"The record length field $recordLengthFieldName must be an integral type or a value mapping must be specified.")
log.warn(s"The record length field $recordLengthFieldName is not integral. Runtime exceptions could occur if values can't be parsed as numbers.")
}
if (s.occurs.isDefined && s.occurs.get > 1) {
throw new IllegalStateException(s"The record length field '$recordLengthFieldName' cannot be an array.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,32 +200,6 @@ class VRLRecordReaderSpec extends AnyWordSpec {
assert(record2(14) == 0xF8.toByte)
}

"throw an exception on a fraction type" in {
val copybookWithFieldLength =
""" 01 RECORD.
05 LEN PIC 9(8)V99.
05 N PIC 9(2).
05 A PIC X(2).
"""

val records = Array[Byte](0x00)
val streamH = new ByteStreamMock(records)
val streamD = new ByteStreamMock(records)
val context = RawRecordContext(0, streamH, streamD, CopybookParser.parseSimple(copybookWithFieldLength), null, null, "")

val readerParameters = ReaderParameters(lengthFieldExpression = Some("LEN"))

val ex = intercept[IllegalStateException] {
getUseCase(
copybook = copybookWithFieldLength,
records = records,
lengthFieldExpression = Some("LEN"),
recordExtractor = Some(new FixedWithRecordLengthExprRawRecordExtractor(context, readerParameters)))
}

assert(ex.getMessage == "The record length field LEN must be an integral type or a value mapping must be specified.")
}

"the length mapping with default record length" in {
val copybookWithLenbgthMap =
""" 01 RECORD.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,27 @@ class Test37RecordLengthMappingSpec extends AnyWordSpec with SparkTestBase with
}
}

"work for numeric mappings and strict integrals" in {
withTempBinFile("record_length_mapping", ".tmp", dataNumeric) { tempFile =>
val expected = """{"SEG_ID":"1","TEXT":"123"},{"SEG_ID":"2","TEXT":"123456"},{"SEG_ID":"3","TEXT":"1234567"}"""

val df = spark.read
.format("cobol")
.option("copybook_contents", copybook)
.option("record_format", "F")
.option("record_length_field", "SEG-ID")
.option("input_split_records", "2")
.option("pedantic", "true")
.option("record_length_map", """{"1":4,"2":7,"3":8}""")
.option("strict_integral_precision", "true")
.load(tempFile)

val actual = df.orderBy("SEG_ID").toJSON.collect().mkString(",")

assert(actual == expected)
}
}

"work for data with offsets" in {
withTempBinFile("record_length_mapping", ".tmp", dataWithFileOffsets) { tempFile =>
val expected = """{"SEG_ID":"A","TEXT":"123"},{"SEG_ID":"B","TEXT":"123456"},{"SEG_ID":"C","TEXT":"1234567"}"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,5 +228,40 @@ class Test21VariableOccursForTextFiles extends AnyWordSpec with SparkTestBase wi
assertEqualsMultiline(actualData, expectedData)
}
}

"correctly keep occurs for Cobrix ASCII with variable length extractor and decimal depending field" in {
val expectedSchema =
"""root
| |-- COUNT: decimal(1,0) (nullable = true)
| |-- GROUP: array (nullable = true)
| | |-- element: struct (containsNull = true)
| | | |-- INNER_COUNT: decimal(1,0) (nullable = true)
| | | |-- INNER_GROUP: array (nullable = true)
| | | | |-- element: struct (containsNull = true)
| | | | | |-- FIELD: string (nullable = true)
| |-- MARKER: string (nullable = true)
|""".stripMargin

withTempTextFile("variable_occurs_ascii", ".dat", StandardCharsets.US_ASCII, data) { tmpFileName =>
val df = spark
.read
.format("cobol")
.option("copybook_contents", copybook)
.option("record_format", "D")
.option("ascii_charset", "utf8")
.option("variable_size_occurs", "true")
.option("strict_integral_precision", "true")
.option("pedantic", "true")
.load(tmpFileName)

val actualSchema = df.schema.treeString

assertEqualsMultiline(actualSchema, expectedSchema)

val actualData = SparkUtils.prettyJSON(df.toJSON.collect().mkString("[", ",", "]"))

assertEqualsMultiline(actualData, expectedData)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package za.co.absa.cobrix.spark.cobol.source.regression

import org.apache.spark.SparkException
import org.scalatest.wordspec.AnyWordSpec
import org.slf4j.{Logger, LoggerFactory}
import za.co.absa.cobrix.spark.cobol.source.base.{SimpleComparisonBase, SparkTestBase}
Expand Down Expand Up @@ -143,7 +144,7 @@ class Test26FixLengthWithIdGeneration extends AnyWordSpec with SparkTestBase wit

"EBCDIC files" should {
"correctly work with segment id generation option with length field" in {
withTempBinFile("fix_length_reg", ".dat", binFileContentsLengthField) { tmpFileName =>
withTempBinFile("fix_length_reg1", ".dat", binFileContentsLengthField) { tmpFileName =>
val df = spark
.read
.format("cobol")
Expand All @@ -168,7 +169,7 @@ class Test26FixLengthWithIdGeneration extends AnyWordSpec with SparkTestBase wit
}

"correctly work with segment id generation option with length expression" in {
withTempBinFile("fix_length_reg", ".dat", binFileContentsLengthExpr) { tmpFileName =>
withTempBinFile("fix_length_reg2", ".dat", binFileContentsLengthExpr) { tmpFileName =>
val df = spark
.read
.format("cobol")
Expand All @@ -191,5 +192,87 @@ class Test26FixLengthWithIdGeneration extends AnyWordSpec with SparkTestBase wit
assertEqualsMultiline(actual, expected)
}
}

"correctly work with segment id generation option with length field and strict integral precision" in {
withTempBinFile("fix_length_reg3", ".dat", binFileContentsLengthField) { tmpFileName =>
val df = spark
.read
.format("cobol")
.option("copybook_contents", copybook)
.option("record_format", "F")
.option("record_length_field", "LEN")
.option("strict_integral_precision", "true")
.option("segment_field", "IND")
.option("segment_id_prefix", "ID")
.option("segment_id_level0", "A")
.option("segment_id_level1", "_")
.option("redefine-segment-id-map:0", "SEGMENT1 => A")
.option("redefine-segment-id-map:1", "SEGMENT2 => B")
.option("redefine-segment-id-map:2", "SEGMENT3 => C")
.option("input_split_records", 1)
.option("pedantic", "true")
.load(tmpFileName)

val actual = SparkUtils.convertDataFrameToPrettyJSON(df.drop("LEN").orderBy("Seg_Id0", "Seg_Id1"))

assertEqualsMultiline(actual, expected)
}
}

"correctly work when the length field has the string type" in {
val copybook =
""" 01 R.
05 LEN PIC X(1).
05 FIELD1 PIC X(1).
"""

val binFileContentsLengthField: Array[Byte] = Array[Byte](
// A1
0xF2.toByte, 0xF3.toByte, 0xF3.toByte, 0xF4.toByte
).map(_.toByte)

withTempBinFile("fix_length_str", ".dat", binFileContentsLengthField) { tmpFileName =>
val df = spark
.read
.format("cobol")
.option("copybook_contents", copybook)
.option("record_format", "F")
.option("record_length_field", "LEN")
.option("pedantic", "true")
.load(tmpFileName)

assert(df.count() == 2)
}
}

"fail when the length field has the string type and incorrect string values are encountered" in {
val copybook =
""" 01 R.
05 LEN PIC X(1).
05 FIELD1 PIC X(1).
"""

val binFileContentsLengthField: Array[Byte] = Array[Byte](
// A1
0xF2.toByte, 0xF3.toByte, 0xC3.toByte, 0xF4.toByte
).map(_.toByte)

withTempBinFile("fix_length_str", ".dat", binFileContentsLengthField) { tmpFileName =>
val df = spark
.read
.format("cobol")
.option("copybook_contents", copybook)
.option("record_format", "F")
.option("record_length_field", "LEN")
.option("pedantic", "true")
.load(tmpFileName)

val ex = intercept[SparkException] {
df.count()
}

assert(ex.getCause.getMessage.contains("Record length value of the field LEN must be an integral type, encountered: 'C'"))
}
}
}
}
Loading