ProtoInputStream.java
package de.schegge.rosinante.io;
import java.io.EOFException;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
public class ProtoInputStream extends FilterInputStream {
private int lastFieldNumber;
public ProtoInputStream(InputStream in) {
super(in);
}
public WireType readType() {
try {
int value = readVariableByteInt();
lastFieldNumber = value >> 3;
return WireType.values()[value & 0b00000111];
} catch (IOException e) {
return null;
}
}
protected int readVariableByteInt() throws IOException {
byte digit;
int msgLength = 0;
int multiplier = 1;
do {
int value = read();
if (value == -1) {
throw new EOFException();
}
digit = (byte) value;
msgLength += ((digit & 0x7F) * multiplier);
multiplier *= 128;
} while ((digit & 0x80) != 0);
return msgLength;
}
protected Long readVariableByteLong() throws IOException {
byte digit;
long msgLength = 0;
long multiplier = 1;
do {
int value = read();
if (value == -1) {
throw new EOFException();
}
digit = (byte) value;
msgLength += ((digit & 0x7F) * multiplier);
multiplier *= 128;
} while ((digit & 0x80) != 0);
return msgLength;
}
protected int readLittleEndian32() throws IOException {
int b0 = read();
int b1 = read();
int b2 = read();
int b3 = read();
if (b0 == -1 || b1 == -1 || b2 == -1 || b3 == -1) {
throw new EOFException();
}
return b0 | (b1 << 8) | (b2 << 16) | (b3 << 24);
}
protected long readLittleEndian64() throws IOException {
long b0 = read();
long b1 = read();
long b2 = read();
long b3 = read();
long b4 = read();
long b5 = read();
long b6 = read();
long b7 = read();
if (b0 == -1 || b1 == -1 || b2 == -1 || b3 == -1 || b4 == -1 || b5 == -1 || b6 == -1 || b7 == -1) {
throw new EOFException();
}
return b0 | (b1 << 8) | (b2 << 16) | (b3 << 24) | (b4 << 32) | (b5 << 40) | (b6 << 48) | (b7 << 56);
}
private <T> List<T> readIntegers(WireType wireType, Function<Integer, T> converter) throws IOException {
if (wireType == WireType.VARINT) {
return List.of(converter.apply(readVariableByteInt()));
}
if (wireType == WireType.LEN) {
int length = readVariableByteInt();
try (LimitedProtoInputStream limitedProtoInputStream = new LimitedProtoInputStream(this, length)) {
List<T> result = new ArrayList<>();
do {
result.add(converter.apply(limitedProtoInputStream.readVariableByteInt()));
} while (limitedProtoInputStream.available() > 0);
return result;
}
}
throw new IOException("wrong type: " + wireType);
}
private List<Long> readLongs(WireType wireType, Function<Long, Long> converter) throws IOException {
if (wireType == WireType.VARINT) {
return List.of(converter.apply(readVariableByteLong()));
}
if (wireType == WireType.LEN) {
int length = readVariableByteInt();
try (LimitedProtoInputStream limitedProtoInputStream = new LimitedProtoInputStream(this, length)) {
List<Long> result = new ArrayList<>();
do {
result.add(converter.apply(limitedProtoInputStream.readVariableByteLong()));
} while (limitedProtoInputStream.available() > 0);
return result;
}
}
throw new IOException("wrong type: " + wireType);
}
private <T> List<T> readFixed32Elements(WireType wireType, Function<Integer, T> converter) throws IOException {
if (wireType == WireType.I32) {
return List.of(converter.apply(readLittleEndian32()));
}
if (wireType == WireType.LEN) {
int length = readVariableByteInt();
try (LimitedProtoInputStream limited = new LimitedProtoInputStream(this, length)) {
List<T> result = new ArrayList<>();
do {
result.add(converter.apply(limited.readLittleEndian32()));
} while (limited.available() > 0);
return result;
}
}
throw new IOException("wrong type: " + wireType);
}
private <T> List<T> readFixed64Elements(WireType wireType, Function<Long, T> converter) throws IOException {
if (wireType == WireType.I64) {
return List.of(converter.apply(readLittleEndian64()));
}
if (wireType == WireType.LEN) {
int length = readVariableByteInt();
try (LimitedProtoInputStream limited = new LimitedProtoInputStream(this, length)) {
List<T> result = new ArrayList<>();
do {
result.add(converter.apply(limited.readLittleEndian64()));
} while (limited.available() > 0);
return result;
}
}
throw new IOException("wrong type: " + wireType);
}
public int readInteger(WireType wireType) throws IOException {
validate(wireType, WireType.VARINT);
return readVariableByteInt();
}
public long readLong(WireType wireType) throws IOException {
validate(wireType, WireType.VARINT);
return readVariableByteLong();
}
public int readZigZagInteger(WireType wireType) throws IOException {
validate(wireType, WireType.VARINT);
int value = readVariableByteInt();
return (value >>> 1) ^ -(value & 1);
}
public long readZigZagLong(WireType wireType) throws IOException {
validate(wireType, WireType.VARINT);
Long value = readVariableByteLong();
return (value >>> 1) ^ -(value & 1);
}
public boolean readBoolean(WireType wireType) throws IOException {
validate(wireType, WireType.VARINT);
return readVariableByteInt() == 1;
}
public String readString(WireType wireType) throws IOException {
validate(wireType, WireType.LEN);
int length = readVariableByteInt();
return new String(readNBytes(length), StandardCharsets.UTF_8);
}
public byte[] readBytes(WireType wireType) throws IOException {
validate(wireType, WireType.LEN);
int length = readVariableByteInt();
return readNBytes(length);
}
public float readFloat(WireType wireType) throws IOException {
validate(wireType, WireType.I32);
return Float.intBitsToFloat(readLittleEndian32());
}
public double readDouble(WireType wireType) throws IOException {
validate(wireType, WireType.I64);
return Double.longBitsToDouble(readLittleEndian64());
}
public int readFixed32(WireType wireType) throws IOException {
validate(wireType, WireType.I32);
return readLittleEndian32();
}
public long readFixed64(WireType wireType) throws IOException {
validate(wireType, WireType.I64);
return readLittleEndian64();
}
public int readSFixed32(WireType wireType) throws IOException {
return readFixed32(wireType);
}
public long readSFixed64(WireType wireType) throws IOException {
return readFixed64(wireType);
}
public int getLastFieldNumber() {
return lastFieldNumber;
}
private void validate(WireType actual, WireType expected) throws IOException {
if (actual != expected) {
throw new IOException("wrong type: " + actual);
}
}
public List<Integer> readIntegers(WireType wireType) throws IOException {
return readIntegers(wireType, Function.identity());
}
public List<Integer> readZigZagIntegers(WireType wireType) throws IOException {
return readIntegers(wireType, value -> (value >>> 1) ^ -(value & 1));
}
public List<Long> readLongs(WireType wireType) throws IOException {
return readLongs(wireType, Function.identity());
}
public List<Long> readZigZagLongs(WireType wireType) throws IOException {
return readLongs(wireType, value -> (value >>> 1) ^ -(value & 1));
}
public List<Boolean> readBooleans(WireType wireType) throws IOException {
return readIntegers(wireType, value -> value == 1);
}
public List<Float> readFloats(WireType wireType) throws IOException {
return readFixed32Elements(wireType, Float::intBitsToFloat);
}
public List<Double> readDoubles(WireType wireType) throws IOException {
return readFixed64Elements(wireType, Double::longBitsToDouble);
}
public List<Integer> readFixed32s(WireType wireType) throws IOException {
return readFixed32Elements(wireType, Function.identity());
}
public List<Long> readFixed64s(WireType wireType) throws IOException {
return readFixed64Elements(wireType, Function.identity());
}
public List<Integer> readSFixed32s(WireType wireType) throws IOException {
return readFixed32s(wireType);
}
public List<Long> readSFixed64s(WireType wireType) throws IOException {
return readFixed64s(wireType);
}
}