Skip to content

Use Vector API in the Java Extension #824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
27 changes: 23 additions & 4 deletions Rakefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ end rescue nil
JAVA_DIR = "java/src/json/ext"
JAVA_RAGEL_PATH = "#{JAVA_DIR}/ParserConfig.rl"
JAVA_PARSER_SRC = "#{JAVA_DIR}/ParserConfig.java"
JAVA_SOURCES = FileList["#{JAVA_DIR}/*.java"]
JAVA_SOURCES = FileList["#{JAVA_DIR}/*.java"].exclude("#{JAVA_DIR}/Vectorized*.java")
JAVA_VEC_SOURCES = FileList["#{JAVA_DIR}/Vectorized*.java"]
JAVA_CLASSES = []
JRUBY_PARSER_JAR = File.expand_path("lib/json/ext/parser.jar")
JRUBY_GENERATOR_JAR = File.expand_path("lib/json/ext/generator.jar")
Expand Down Expand Up @@ -65,11 +66,26 @@ if defined?(RUBY_ENGINE) and RUBY_ENGINE == 'jruby'

JRUBY_JAR = File.join(CONFIG["libdir"], "jruby.jar")
if File.exist?(JRUBY_JAR)
classpath = (Dir['java/lib/*.jar'] << 'java/src' << JRUBY_JAR) * path_separator
JAVA_SOURCES.each do |src|
classpath = (Dir['java/lib/*.jar'] << 'java/src' << JRUBY_JAR) * path_separator
obj = src.sub(/\.java\Z/, '.class')
file obj => src do
sh 'javac', '-classpath', classpath, '-source', '1.8', '-target', '1.8', src
sh 'javac', '-classpath', classpath, '-source', '1.8', '-target', '1.8', src
# '--enable-preview',
end
JAVA_CLASSES << obj
end

JAVA_VEC_SOURCES.each do |src|
obj = src.sub(/\.java\Z/, '.class')
file obj => src do
sh 'javac', '--add-modules', 'jdk.incubator.vector', '-classpath', classpath, '--release', '16', src do |success, status|
if success
puts "*** 'jdk.incubator.vector' support enabled ***"
else
puts "*** 'jdk.incubator.vector' support disabled ***"
end
end
end
JAVA_CLASSES << obj
end
Expand Down Expand Up @@ -118,11 +134,14 @@ if defined?(RUBY_ENGINE) and RUBY_ENGINE == 'jruby'
generator_classes = FileList[
"json/ext/ByteList*.class",
"json/ext/OptionsReader*.class",
"json/ext/EscapeScanner*.class",
"json/ext/Generator*.class",
"json/ext/RuntimeInfo*.class",
"json/ext/StringEncoder*.class",
"json/ext/Utils*.class"
"json/ext/Utils*.class",
"json/ext/VectorizedEscapeScanner*.class"
]
puts "Creating generator jar with classes: #{generator_classes.join(', ')}"
sh 'jar', 'cf', File.basename(JRUBY_GENERATOR_JAR), *generator_classes
mv File.basename(JRUBY_GENERATOR_JAR), File.dirname(JRUBY_GENERATOR_JAR)
end
Expand Down
102 changes: 102 additions & 0 deletions java/src/json/ext/EscapeScanner.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package json.ext;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;

interface EscapeScanner {
static class State {
byte[] ptrBytes;
int ptr;
int len;
int pos;
int beg;
int ch;
}

static class VectorSupport {
private static String VECTORIZED_ESCAPE_SCANNER_CLASS = "json.ext.VectorizedEscapeScanner";
private static String VECTORIZED_SCANNER_PROP = "json.enableVectorizedEscapeScanner";
private static String VECTORIZED_SCANNER_DEFAULT = "false";
static final EscapeScanner VECTORIZED_ESCAPE_SCANNER;

static {
EscapeScanner scanner = null;
String enableVectorizedScanner = System.getProperty(VECTORIZED_SCANNER_PROP, VECTORIZED_SCANNER_DEFAULT);
if ("true".equalsIgnoreCase(enableVectorizedScanner) || "1".equalsIgnoreCase(enableVectorizedScanner)) {
try {
Class<?> vectorEscapeScannerClass = EscapeScanner.class.getClassLoader().loadClass(VECTORIZED_ESCAPE_SCANNER_CLASS);
Constructor<?> vectorizedEscapeScannerConstructor = vectorEscapeScannerClass.getDeclaredConstructor();
scanner = (EscapeScanner) vectorizedEscapeScannerConstructor.newInstance();
} catch (ClassNotFoundException | NoSuchMethodException | InstantiationException | IllegalAccessException | InvocationTargetException e) {
// Fallback to the ScalarEscapeScanner if we cannot load the VectorizedEscapeScanner.
System.err.println("Failed to load VectorizedEscapeScanner, falling back to ScalarEscapeScanner:");
e.printStackTrace();
scanner = null;
}
} else {
System.err.println("VectorizedEscapeScanner disabled.");
}
VECTORIZED_ESCAPE_SCANNER = scanner;
}
}

boolean scan(EscapeScanner.State state) throws java.io.IOException;

default State createState(byte[] ptrBytes, int ptr, int len, int beg) {
State state = new State();
state.ptrBytes = ptrBytes;
state.ptr = ptr;
state.len = len;
state.beg = beg;
state.pos = 0; // Start scanning from the beginning of the segment
return state;
}

public static EscapeScanner basicScanner() {
if (VectorSupport.VECTORIZED_ESCAPE_SCANNER != null) {
return VectorSupport.VECTORIZED_ESCAPE_SCANNER;
}

return new ScalarEscapeScanner(StringEncoder.ESCAPE_TABLE);
}

public static EscapeScanner create(byte[] escapeTable) {
return new ScalarEscapeScanner(escapeTable);
}

public static class BasicScanner implements EscapeScanner {
@Override
public boolean scan(EscapeScanner.State state) throws java.io.IOException {
while (state.pos < state.len) {
state.ch = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos]);
if (state.ch >= 0 && (state.ch < ' ' || state.ch == '\"' || state.ch == '\\')) {
return true;
}
state.pos++;
}
return false;
}
}

public static class ScalarEscapeScanner implements EscapeScanner {
private final byte[] escapeTable;

public ScalarEscapeScanner(byte[] escapeTable) {
this.escapeTable = escapeTable;
}

@Override
public boolean scan(EscapeScanner.State state) throws java.io.IOException {
while (state.pos < state.len) {
state.ch = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos]);
int ch_len = escapeTable[state.ch];
if (ch_len > 0) {
return true;
}
state.pos++;
}
return false;
}

}
}
2 changes: 1 addition & 1 deletion java/src/json/ext/Generator.java
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ public StringEncoder getStringEncoder(ThreadContext context) {
GeneratorState state = getState(context);
stringEncoder = state.asciiOnly() ?
new StringEncoderAsciiOnly(state.scriptSafe()) :
new StringEncoder(state.scriptSafe());
state.scriptSafe() ? StringEncoder.scriptSafeEncoder() : StringEncoder.basicEncoder();
}
return stringEncoder;
}
Expand Down
99 changes: 75 additions & 24 deletions java/src/json/ext/StringEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
*/
package json.ext;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;

import org.jcodings.Encoding;
import org.jcodings.specific.ASCIIEncoding;
import org.jcodings.specific.USASCIIEncoding;
Expand All @@ -17,10 +21,6 @@
import org.jruby.util.ByteList;
import org.jruby.util.StringSupport;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;

/**
* An encoder that reads from the given source and outputs its representation
* to another ByteList. The source string is fully checked for UTF-8 validity,
Expand Down Expand Up @@ -130,14 +130,22 @@ class StringEncoder extends ByteListTranscoder {
new byte[] {'0', '1', '2', '3', '4', '5', '6', '7',
'8', '9', 'a', 'b', 'c', 'd', 'e', 'f'};

StringEncoder(boolean scriptSafe) {
private StringEncoder(boolean scriptSafe) {
this(scriptSafe ? SCRIPT_SAFE_ESCAPE_TABLE : ESCAPE_TABLE);
}

StringEncoder(byte[] escapeTable) {
this.escapeTable = escapeTable;
}

public static StringEncoder scriptSafeEncoder() {
return new StringEncoder(SCRIPT_SAFE_ESCAPE_TABLE);
}

public static StringEncoder basicEncoder() {
return new StringEncoder(ESCAPE_TABLE);
}

// C: generate_json_string
void generate(ThreadContext context, RubyString object, OutputStream buffer) throws IOException {
object = ensureValidEncoding(context, object);
Expand Down Expand Up @@ -198,41 +206,83 @@ private static RubyString tryWeirdEncodings(ThreadContext context, RubyString st
return str;
}

boolean searchEscape(EscapeScanner.State state) throws IOException {
byte[] escapeTable = StringEncoder.this.escapeTable;

while (state.pos < state.len) {
state.ch = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos]);
int ch_len = escapeTable[state.ch];

if (ch_len > 0) {
return true;
}

state.pos++;
}

return false;
}

void encodeBasic(ByteList src) throws IOException {
byte[] hexdig = HEX;
byte[] scratch = aux;

EscapeScanner scanner = EscapeScanner.basicScanner();
EscapeScanner.State state = scanner.createState(src.unsafeBytes(), src.begin(), src.realSize(), 0);

while(scanner.scan(state)) {
int ch = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos]);
state.beg = state.pos = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 1);
escapeAscii(ch, scratch, hexdig);
}

if (state.beg < state.len) {
append(state.ptrBytes, state.ptr + state.beg, state.len - state.beg);
}
}

// C: convert_UTF8_to_JSON
void encode(ByteList src) throws IOException {
if (this.escapeTable == StringEncoder.ESCAPE_TABLE) {
encodeBasic(src);
return;
}

byte[] hexdig = HEX;
byte[] scratch = aux;
byte[] escapeTable = this.escapeTable;

byte[] ptrBytes = src.unsafeBytes();
int ptr = src.begin();
int len = src.realSize();

int beg = 0;
int pos = 0;

while (pos < len) {
int ch = Byte.toUnsignedInt(ptrBytes[ptr + pos]);
EscapeScanner.State state = new EscapeScanner.State();
state.ptrBytes = src.unsafeBytes();
state.ptr = src.begin();
state.len = src.realSize();
state.beg = 0;
state.pos = 0;

while(searchEscape(state)) {
// We found an escape character, so we need to flush up to this point
// and then handle the escape character.
state.beg = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 0);
int ch = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos]);
int ch_len = escapeTable[ch];
/* JSON encoding */

if (ch_len > 0) {
switch (ch_len) {
case 9: {
beg = pos = flushPos(pos, beg, ptrBytes, ptr, 1);
state.beg = state.pos = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 1);
escapeAscii(ch, scratch, hexdig);
break;
}
case 11: {
int b2 = Byte.toUnsignedInt(ptrBytes[ptr + pos + 1]);
int b2 = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos + 1]);
if (b2 == 0x80) {
int b3 = Byte.toUnsignedInt(ptrBytes[ptr + pos + 2]);
int b3 = Byte.toUnsignedInt(state.ptrBytes[state.ptr + state.pos + 2]);
if (b3 == 0xA8) {
beg = pos = flushPos(pos, beg, ptrBytes, ptr, 3);
state.beg = state.pos = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 3);
append(BACKSLASH_U2028, 0, 6);
break;
} else if (b3 == 0xA9) {
beg = pos = flushPos(pos, beg, ptrBytes, ptr, 3);
state.beg = state.pos = flushPos(state.pos, state.beg, state.ptrBytes, state.ptr, 3);
append(BACKSLASH_U2029, 0, 6);
break;
}
Expand All @@ -241,16 +291,17 @@ void encode(ByteList src) throws IOException {
// fallthrough
}
default:
pos += ch_len;
state.pos += ch_len;
break;
}
} else {
pos++;
// This should be unreachable.
state.pos++;
}
}

if (beg < len) {
append(ptrBytes, ptr + beg, len - beg);
if (state.beg < state.len) {
append(state.ptrBytes, state.ptr + state.beg, state.len - state.beg);
}
}

Expand Down
Loading
Loading