ProtoInputStream.java

package de.schegge.rosinante.io;

import java.io.EOFException;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

public class ProtoInputStream extends FilterInputStream {

    private int lastFieldNumber;

    public ProtoInputStream(InputStream in) {
        super(in);
    }

    public WireType readType() {
        try {
            int value = readVariableByteInt();
            lastFieldNumber = value >> 3;
            return WireType.values()[value & 0b00000111];
        } catch (IOException e) {
            return null;
        }
    }

    protected int readVariableByteInt() throws IOException {
        byte digit;
        int msgLength = 0;
        int multiplier = 1;

        do {
            int value = read();
            if (value == -1) {
                throw new EOFException();
            }
            digit = (byte) value;
            msgLength += ((digit & 0x7F) * multiplier);
            multiplier *= 128;
        } while ((digit & 0x80) != 0);
        return msgLength;
    }

    protected Long readVariableByteLong() throws IOException {
        byte digit;
        long msgLength = 0;
        long multiplier = 1;

        do {
            int value = read();
            if (value == -1) {
                throw new EOFException();
            }
            digit = (byte) value;
            msgLength += ((digit & 0x7F) * multiplier);
            multiplier *= 128;
        } while ((digit & 0x80) != 0);
        return msgLength;
    }

    protected int readLittleEndian32() throws IOException {
        int b0 = read();
        int b1 = read();
        int b2 = read();
        int b3 = read();
        if (b0 == -1 || b1 == -1 || b2 == -1 || b3 == -1) {
            throw new EOFException();
        }
        return b0 | (b1 << 8) | (b2 << 16) | (b3 << 24);
    }

    protected long readLittleEndian64() throws IOException {
        long b0 = read();
        long b1 = read();
        long b2 = read();
        long b3 = read();
        long b4 = read();
        long b5 = read();
        long b6 = read();
        long b7 = read();
        if (b0 == -1 || b1 == -1 || b2 == -1 || b3 == -1 || b4 == -1 || b5 == -1 || b6 == -1 || b7 == -1) {
            throw new EOFException();
        }
        return b0 | (b1 << 8) | (b2 << 16) | (b3 << 24) | (b4 << 32) | (b5 << 40) | (b6 << 48) | (b7 << 56);
    }

    private <T> List<T> readIntegers(WireType wireType, Function<Integer, T> converter) throws IOException {
        if (wireType == WireType.VARINT) {
            return List.of(converter.apply(readVariableByteInt()));
        }
        if (wireType == WireType.LEN) {
            int length = readVariableByteInt();
            try (LimitedProtoInputStream limitedProtoInputStream = new LimitedProtoInputStream(this, length)) {
                List<T> result = new ArrayList<>();
                do {
                    result.add(converter.apply(limitedProtoInputStream.readVariableByteInt()));
                } while (limitedProtoInputStream.available() > 0);
                return result;
            }
        }
        throw new IOException("wrong type: " + wireType);
    }

    private List<Long> readLongs(WireType wireType, Function<Long, Long> converter) throws IOException {
        if (wireType == WireType.VARINT) {
            return List.of(converter.apply(readVariableByteLong()));
        }
        if (wireType == WireType.LEN) {
            int length = readVariableByteInt();
            try (LimitedProtoInputStream limitedProtoInputStream = new LimitedProtoInputStream(this, length)) {
                List<Long> result = new ArrayList<>();
                do {
                    result.add(converter.apply(limitedProtoInputStream.readVariableByteLong()));
                } while (limitedProtoInputStream.available() > 0);
                return result;
            }
        }
        throw new IOException("wrong type: " + wireType);
    }

    private <T> List<T> readFixed32Elements(WireType wireType, Function<Integer, T> converter) throws IOException {
        if (wireType == WireType.I32) {
            return List.of(converter.apply(readLittleEndian32()));
        }
        if (wireType == WireType.LEN) {
            int length = readVariableByteInt();
            try (LimitedProtoInputStream limited = new LimitedProtoInputStream(this, length)) {
                List<T> result = new ArrayList<>();
                do {
                    result.add(converter.apply(limited.readLittleEndian32()));
                } while (limited.available() > 0);
                return result;
            }
        }
        throw new IOException("wrong type: " + wireType);
    }

    private <T> List<T> readFixed64Elements(WireType wireType, Function<Long, T> converter) throws IOException {
        if (wireType == WireType.I64) {
            return List.of(converter.apply(readLittleEndian64()));
        }
        if (wireType == WireType.LEN) {
            int length = readVariableByteInt();
            try (LimitedProtoInputStream limited = new LimitedProtoInputStream(this, length)) {
                List<T> result = new ArrayList<>();
                do {
                    result.add(converter.apply(limited.readLittleEndian64()));
                } while (limited.available() > 0);
                return result;
            }
        }
        throw new IOException("wrong type: " + wireType);
    }

    public int readInteger(WireType wireType) throws IOException {
        validate(wireType, WireType.VARINT);
        return readVariableByteInt();
    }

    public long readLong(WireType wireType) throws IOException {
        validate(wireType, WireType.VARINT);
        return readVariableByteLong();
    }

    public int readZigZagInteger(WireType wireType) throws IOException {
        validate(wireType, WireType.VARINT);
        int value = readVariableByteInt();
        return (value >>> 1) ^ -(value & 1);
    }

    public long readZigZagLong(WireType wireType) throws IOException {
        validate(wireType, WireType.VARINT);
        Long value = readVariableByteLong();
        return (value >>> 1) ^ -(value & 1);
    }

    public boolean readBoolean(WireType wireType) throws IOException {
        validate(wireType, WireType.VARINT);
        return readVariableByteInt() == 1;
    }

    public String readString(WireType wireType) throws IOException {
        validate(wireType, WireType.LEN);
        int length = readVariableByteInt();
        return new String(readNBytes(length), StandardCharsets.UTF_8);
    }

    public byte[] readBytes(WireType wireType) throws IOException {
        validate(wireType, WireType.LEN);
        int length = readVariableByteInt();
        return readNBytes(length);
    }

    public float readFloat(WireType wireType) throws IOException {
        validate(wireType, WireType.I32);
        return Float.intBitsToFloat(readLittleEndian32());
    }

    public double readDouble(WireType wireType) throws IOException {
        validate(wireType, WireType.I64);
        return Double.longBitsToDouble(readLittleEndian64());
    }

    public int readFixed32(WireType wireType) throws IOException {
        validate(wireType, WireType.I32);
        return readLittleEndian32();
    }

    public long readFixed64(WireType wireType) throws IOException {
        validate(wireType, WireType.I64);
        return readLittleEndian64();
    }

    public int readSFixed32(WireType wireType) throws IOException {
        return readFixed32(wireType);
    }

    public long readSFixed64(WireType wireType) throws IOException {
        return readFixed64(wireType);
    }

    public int getLastFieldNumber() {
        return lastFieldNumber;
    }

    private void validate(WireType actual, WireType expected) throws IOException {
        if (actual != expected) {
            throw new IOException("wrong type: " + actual);
        }
    }

    public List<Integer> readIntegers(WireType wireType) throws IOException {
        return readIntegers(wireType, Function.identity());
    }

    public List<Integer> readZigZagIntegers(WireType wireType) throws IOException {
        return readIntegers(wireType, value -> (value >>> 1) ^ -(value & 1));
    }

    public List<Long> readLongs(WireType wireType) throws IOException {
        return readLongs(wireType, Function.identity());
    }

    public List<Long> readZigZagLongs(WireType wireType) throws IOException {
        return readLongs(wireType, value -> (value >>> 1) ^ -(value & 1));
    }

    public List<Boolean> readBooleans(WireType wireType) throws IOException {
        return readIntegers(wireType, value -> value == 1);
    }

    public List<Float> readFloats(WireType wireType) throws IOException {
        return readFixed32Elements(wireType, Float::intBitsToFloat);
    }

    public List<Double> readDoubles(WireType wireType) throws IOException {
        return readFixed64Elements(wireType, Double::longBitsToDouble);
    }

    public List<Integer> readFixed32s(WireType wireType) throws IOException {
        return readFixed32Elements(wireType, Function.identity());
    }

    public List<Long> readFixed64s(WireType wireType) throws IOException {
        return readFixed64Elements(wireType, Function.identity());
    }

    public List<Integer> readSFixed32s(WireType wireType) throws IOException {
        return readFixed32s(wireType);
    }

    public List<Long> readSFixed64s(WireType wireType) throws IOException {
        return readFixed64s(wireType);
    }
}