ProtoInputStream.java
package de.schegge.rosinante.io;
import java.io.EOFException;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
public class ProtoInputStream extends FilterInputStream {
private int lastFieldNumber;
public ProtoInputStream(InputStream in) {
super(in);
}
public WireType readType() {
try {
int value = readVariableByteInt();
lastFieldNumber = value >> 3;
return WireType.values()[value & 0b00000111];
} catch (IOException e) {
return null;
}
}
protected int readVariableByteInt() throws IOException {
byte digit;
int msgLength = 0;
int multiplier = 1;
do {
int value = read();
if (value == -1) {
throw new EOFException();
}
digit = (byte) value;
msgLength += ((digit & 0x7F) * multiplier);
multiplier *= 128;
} while ((digit & 0x80) != 0);
return msgLength;
}
protected Long readVariableByteLong() throws IOException {
byte digit;
long msgLength = 0;
long multiplier = 1;
do {
int value = read();
if (value == -1) {
throw new EOFException();
}
digit = (byte) value;
msgLength += ((digit & 0x7F) * multiplier);
multiplier *= 128;
} while ((digit & 0x80) != 0);
return msgLength;
}
private <T> List<T> readIntegers(WireType wireType, Function<Integer, T> converter) throws IOException {
if (wireType == WireType.VARINT) {
return List.of(converter.apply(readVariableByteInt()));
}
if (wireType == WireType.LEN) {
int length = readVariableByteInt();
try (LimitedProtoInputStream limitedProtoInputStream = new LimitedProtoInputStream(this, length)) {
List<T> result = new ArrayList<>();
do {
result.add(converter.apply(limitedProtoInputStream.readVariableByteInt()));
} while (limitedProtoInputStream.available() > 0);
return result;
}
}
throw new IOException("wrong type: " + wireType);
}
private List<Long> readLongs(WireType wireType, Function<Long, Long> converter) throws IOException {
if (wireType == WireType.VARINT) {
return List.of(converter.apply(readVariableByteLong()));
}
if (wireType == WireType.LEN) {
int length = readVariableByteInt();
try (LimitedProtoInputStream limitedProtoInputStream = new LimitedProtoInputStream(this, length)) {
List<Long> result = new ArrayList<>();
do {
result.add(converter.apply(limitedProtoInputStream.readVariableByteLong()));
} while (limitedProtoInputStream.available() > 0);
return result;
}
}
throw new IOException("wrong type: " + wireType);
}
public int readInteger(WireType wireType) throws IOException {
validate(wireType, WireType.VARINT);
return readVariableByteInt();
}
public long readLong(WireType wireType) throws IOException {
validate(wireType, WireType.VARINT);
return readVariableByteLong();
}
public int readZigZagInteger(WireType wireType) throws IOException {
validate(wireType, WireType.VARINT);
int value = readVariableByteInt();
return (value >>> 1) ^ -(value & 1);
}
public long readZigZagLong(WireType wireType) throws IOException {
validate(wireType, WireType.VARINT);
Long value = readVariableByteLong();
return (value >>> 1) ^ -(value & 1);
}
public boolean readBoolean(WireType wireType) throws IOException {
validate(wireType, WireType.VARINT);
return readVariableByteInt() == 1;
}
public String readString(WireType wireType) throws IOException {
validate(wireType, WireType.LEN);
int length = readVariableByteInt();
return new String(readNBytes(length), StandardCharsets.UTF_8);
}
public byte[] readBytes(WireType wireType) throws IOException {
validate(wireType, WireType.LEN);
int length = readVariableByteInt();
return readNBytes(length);
}
public int getLastFieldNumber() {
return lastFieldNumber;
}
private void validate(WireType actual, WireType expected) throws IOException {
if (actual != expected) {
throw new IOException("wrong type: " + actual);
}
}
public List<Integer> readIntegers(WireType wireType) throws IOException {
return readIntegers(wireType, Function.identity());
}
public List<Integer> readZigZagIntegers(WireType wireType) throws IOException {
return readIntegers(wireType, value -> (value >>> 1) ^ -(value & 1));
}
public List<Long> readLongs(WireType wireType) throws IOException {
return readLongs(wireType, Function.identity());
}
public List<Long> readZigZagLongs(WireType wireType) throws IOException {
return readLongs(wireType, value -> (value >>> 1) ^ -(value & 1));
}
public List<Boolean> readBooleans(WireType wireType) throws IOException {
return readIntegers(wireType, value -> value == 1);
}
}