ProtoInputStream.java

package de.schegge.rosinante.io;

import java.io.EOFException;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;

public class ProtoInputStream extends FilterInputStream {

    private int lastFieldNumber;

    public ProtoInputStream(InputStream in) {
        super(in);
    }

    public WireType readType() {
        try {
            int value = readVariableByteInt();
            lastFieldNumber = value >> 3;
            return WireType.values()[value & 0b00000111];
        } catch (IOException e) {
            return null;
        }
    }

    protected int readVariableByteInt() throws IOException {
        byte digit;
        int msgLength = 0;
        int multiplier = 1;

        do {
            int value = read();
            if (value == -1) {
                throw new EOFException();
            }
            digit = (byte) value;
            msgLength += ((digit & 0x7F) * multiplier);
            multiplier *= 128;
        } while ((digit & 0x80) != 0);
        return msgLength;
    }

    protected Long readVariableByteLong() throws IOException {
        byte digit;
        long msgLength = 0;
        long multiplier = 1;

        do {
            int value = read();
            if (value == -1) {
                throw new EOFException();
            }
            digit = (byte) value;
            msgLength += ((digit & 0x7F) * multiplier);
            multiplier *= 128;
        } while ((digit & 0x80) != 0);
        return msgLength;
    }

    private <T> List<T> readIntegers(WireType wireType, Function<Integer, T> converter) throws IOException {
        if (wireType == WireType.VARINT) {
            return List.of(converter.apply(readVariableByteInt()));
        }
        if (wireType == WireType.LEN) {
            int length = readVariableByteInt();
            try (LimitedProtoInputStream limitedProtoInputStream = new LimitedProtoInputStream(this, length)) {
                List<T> result = new ArrayList<>();
                do {
                    result.add(converter.apply(limitedProtoInputStream.readVariableByteInt()));
                } while (limitedProtoInputStream.available() > 0);
                return result;
            }
        }
        throw new IOException("wrong type: " + wireType);
    }

    private List<Long> readLongs(WireType wireType, Function<Long, Long> converter) throws IOException {
        if (wireType == WireType.VARINT) {
            return List.of(converter.apply(readVariableByteLong()));
        }
        if (wireType == WireType.LEN) {
            int length = readVariableByteInt();
            try (LimitedProtoInputStream limitedProtoInputStream = new LimitedProtoInputStream(this, length)) {
                List<Long> result = new ArrayList<>();
                do {
                    result.add(converter.apply(limitedProtoInputStream.readVariableByteLong()));
                } while (limitedProtoInputStream.available() > 0);
                return result;
            }
        }
        throw new IOException("wrong type: " + wireType);
    }

    public int readInteger(WireType wireType) throws IOException {
        validate(wireType, WireType.VARINT);
        return readVariableByteInt();
    }

    public long readLong(WireType wireType) throws IOException {
        validate(wireType, WireType.VARINT);
        return readVariableByteLong();
    }

    public int readZigZagInteger(WireType wireType) throws IOException {
        validate(wireType, WireType.VARINT);
        int value = readVariableByteInt();
        return (value >>> 1) ^ -(value & 1);
    }

    public long readZigZagLong(WireType wireType) throws IOException {
        validate(wireType, WireType.VARINT);
        Long value = readVariableByteLong();
        return (value >>> 1) ^ -(value & 1);
    }

    public boolean readBoolean(WireType wireType) throws IOException {
        validate(wireType, WireType.VARINT);
        return readVariableByteInt() == 1;
    }

    public String readString(WireType wireType) throws IOException {
        validate(wireType, WireType.LEN);
        int length = readVariableByteInt();
        return new String(readNBytes(length), StandardCharsets.UTF_8);
    }

    public byte[] readBytes(WireType wireType) throws IOException {
        validate(wireType, WireType.LEN);
        int length = readVariableByteInt();
        return readNBytes(length);
    }

    public int getLastFieldNumber() {
        return lastFieldNumber;
    }

    private void validate(WireType actual, WireType expected) throws IOException {
        if (actual != expected) {
            throw new IOException("wrong type: " + actual);
        }
    }

    public List<Integer> readIntegers(WireType wireType) throws IOException {
        return readIntegers(wireType, Function.identity());
    }

    public List<Integer> readZigZagIntegers(WireType wireType) throws IOException {
        return readIntegers(wireType, value -> (value >>> 1) ^ -(value & 1));
    }

    public List<Long> readLongs(WireType wireType) throws IOException {
        return readLongs(wireType, Function.identity());
    }

    public List<Long> readZigZagLongs(WireType wireType) throws IOException {
        return readLongs(wireType, value -> (value >>> 1) ^ -(value & 1));
    }

    public List<Boolean> readBooleans(WireType wireType) throws IOException {
        return readIntegers(wireType, value -> value == 1);
    }
}