Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 116 additions & 3 deletions sjsonnet/src/sjsonnet/stdlib/StringModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -498,12 +498,112 @@ object StringModule extends AbstractFunctionModule {
else null
}

private final val PresizedStringJoinMinParts = 16

private def joinPresizedStringArray(
pos: Position,
sep: Val.Str,
arr: Val.Arr,
len: Int): Val.Str = {
val sepStr = sep.str
val sepLen = sepStr.length
var totalLen = 0L
var added = false
var asciiSafe = true
var i = 0
while (i < len) {
arr.value(i) match {
case _: Val.Null =>
case x: Val.Str =>
if (added) {
totalLen += sepLen
asciiSafe &&= sep._asciiSafe
}
val str = x.str
totalLen += str.length
if (totalLen > Int.MaxValue) Error.fail("String is too large to join")
asciiSafe &&= x._asciiSafe
added = true
case x => Error.fail("Cannot join " + x.prettyName)
}
i += 1
}

if (!added) return Val.Str(pos, "")

val b = new java.lang.StringBuilder(totalLen.toInt)
i = 0
var needsSep = false
while (i < len) {
arr.value(i) match {
case _: Val.Null =>
case x: Val.Str =>
if (needsSep) b.append(sepStr)
needsSep = true
b.append(x.str)
case _ =>
}
i += 1
}
val result = b.toString
if (asciiSafe) Val.Str.asciiSafe(pos, result) else Val.Str(pos, result)
}

private def joinDirectStringArray(
pos: Position,
sep: Val.Str,
direct: Array[Eval],
len: Int): Val.Str = {
val sepStr = sep.str
val sepLen = sepStr.length
var totalLen = 0L
var added = false
var asciiSafe = true
val parts = new Array[String](len)
var i = 0
while (i < len) {
direct(i) match {
case _: Val.Null =>
case x: Val.Str =>
if (added) {
totalLen += sepLen
asciiSafe &&= sep._asciiSafe
}
val str = x.str
parts(i) = str
totalLen += str.length
if (totalLen > Int.MaxValue) Error.fail("String is too large to join")
asciiSafe &&= x._asciiSafe
added = true
case _ => return null
}
i += 1
}

if (!added) return Val.Str(pos, "")

val b = new java.lang.StringBuilder(totalLen.toInt)
i = 0
var needsSep = false
while (i < parts.length) {
val str = parts(i)
if (str != null) {
if (needsSep) b.append(sepStr)
needsSep = true
b.append(str)
}
i += 1
}
val result = b.toString
if (asciiSafe) Val.Str.asciiSafe(pos, result) else Val.Str(pos, result)
}

def evalRhs(sep: Eval, _arr: Eval, ev: EvalScope, pos: Position): Val = {
val arr = implicitly[ReadWriter[Val.Arr]].apply(_arr.value)
sep.value match {
case sepStr: Val.Str =>
val s = sepStr.str
val len = arr.length
val s = sepStr.str
val repeatedConst = joinRepeatedStringEval(pos, s, arr.constantEval, len)
if (repeatedConst != null) return repeatedConst

Expand All @@ -513,23 +613,36 @@ object StringModule extends AbstractFunctionModule {
if (direct != null) {
val repeated = joinRepeatedDirectString(pos, s, direct, len)
if (repeated != null) return repeated

val joined = joinDirectStringArray(pos, sepStr, direct, len)
if (joined != null) return joined
}

if (len >= PresizedStringJoinMinParts) {
return joinPresizedStringArray(pos, sepStr, arr, len)
}

val b = new java.lang.StringBuilder()
var i = 0
var added = false
var asciiSafe = true
while (i < len) {
arr.value(i) match {
case _: Val.Null =>
case x: Val.Str =>
if (added) b.append(s)
if (added) {
b.append(s)
asciiSafe &&= sepStr._asciiSafe
}
added = true
b.append(x.str)
asciiSafe &&= x._asciiSafe
case x => Error.fail("Cannot join " + x.prettyName)
}
i += 1
}
Val.Str(pos, b.toString)
val result = b.toString
if (asciiSafe) Val.Str.asciiSafe(pos, result) else Val.Str(pos, result)
case sep: Val.Arr =>
val len = arr.length
if (len > PresizedCopyMaxParts) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Directional coverage for std.join string paths:
// - small array fallback (inline StringBuilder, len < 16)
// - direct backing array path (joinDirectStringArray)
// - presized path (len >= 16, joinPresizedStringArray)
// - asciiSafe propagation (separator and parts both ASCII)
// - non-ASCII parts that should still join correctly
// - null skipping at all positions
// - all-null returns empty string

local small = std.join("-", ["a", "bb", null, "ccc"]);
local nonAscii = std.join("/", ["é", "λ", null, "🚀"]);
local allNull = std.join("ignored", [null, null]);

// 20 ASCII parts to force the presized path on a non-direct array.
local many = std.join(
", ",
std.makeArray(20, function(i) std.toString(i)),
);

// 18 ASCII parts on a direct (literal) array to exercise joinDirectStringArray.
local direct18 = std.join(
"|",
["a", "b", "c", "d", "e", "f", "g", "h", "i", "j",
"k", "l", "m", "n", "o", "p", "q", "r"],
);

// Mixed null + ASCII to exercise size pre-walk skipping.
local mixed = std.join(
":",
std.makeArray(20, function(i) if i % 3 == 0 then null else std.toString(i)),
);

std.assertEqual(small, "a-bb-ccc") &&
std.assertEqual(nonAscii, "é/λ/🚀") &&
std.assertEqual(allNull, "") &&
std.assertEqual(many, "0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19") &&
std.assertEqual(direct18, "a|b|c|d|e|f|g|h|i|j|k|l|m|n|o|p|q|r") &&
std.assertEqual(mixed, "1:2:4:5:7:8:10:11:13:14:16:17:19") &&
true
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
true
Loading