ProtoGenerator.java

package de.schegge.rosinante.generator;

import de.schegge.rosinante.core.ProtoDestinations;
import de.schegge.rosinante.core.ProtoSources;
import de.schegge.rosinante.parser.ProtocolbuffersParser;
import org.freshmarker.Configuration;
import org.freshmarker.Template;

import java.io.IOException;
import java.io.StringWriter;
import java.util.HashMap;
import java.util.Map;

public class ProtoGenerator {

    private static final GeneratorVisitor GENERATOR_VISITOR = new GeneratorVisitor();

    private final Configuration configuration = new Configuration();

    public void generate(String source, ProtoSources protoSources, ProtoDestinations destinations) throws IOException {
        ProtocolbuffersParser parser = new ProtocolbuffersParser(source, protoSources);
        parser.Proto();
        GeneratorContext context = new GeneratorContext();
        parser.rootNode().accept(GENERATOR_VISITOR, context);
        context.getEnums().forEach(x -> createEnumClass(x, context, configuration, destinations));
        context.getMessages().forEach(m -> createMessageFile(m, context, configuration, destinations));
    }

    private void createEnumClass(EnumPattern pattern, GeneratorContext context, Configuration configuration, ProtoDestinations destinations) {
        Template template = configuration.getTemplate("message", """
                <#if package??>
                package ${package};
                
                </#if>
                import java.util.Optional;

                public enum ${pattern.name} {
                  <#list pattern.fields as field with loop1>
                  <#list field.names as name with loop2>
                  ${name}<#if loop1?has_next>,<#elseif loop2?has_next>,<#else>;</#if>
                  </#list>
                  </#list>
                
                  public static ${pattern.name} byFieldNumber(int fieldNumber) {
                    return switch (fieldNumber) {
                  <#list pattern.fields as field>
                      case ${field.fieldNumber} -> ${field.names[0]};
                  </#list>
                      default -> throw new IllegalArgumentException("invalid field number for ${pattern.name}: " + fieldNumber);
                    };
                  }
                
                  public static int getFieldNumber(${pattern.name} value) {
                    return switch (value) {
                  <#list pattern.fields as field with loop1>
                      case <#list field.names as name with loop2>${name}<#if loop2?has_next>, </#if></#list> -> ${loop1?index};
                  </#list>
                    };
                  }
                }
                """);
        Map<String, Object> model = new HashMap<>();
        model.put("package", context.getPackageName());
        model.put("pattern", pattern);
        StringWriter writer = new StringWriter();
        template.process(model, writer);
        try {
            destinations.writeFile(context.getPackageName(), pattern.getName() + ".java", writer.toString());
        } catch (IOException e) {
            throw new ProtoWriteException(e);
        }
    }

    private void createMessageFile(MessagePattern message, GeneratorContext context, Configuration configuration, ProtoDestinations destinations) {
        createMessageClass(message, context, configuration, destinations);
        createMessageProto(message, context, configuration, destinations);
    }

    private static void createMessageProto(MessagePattern message, GeneratorContext context, Configuration configuration, ProtoDestinations destinations) {
        Template template = configuration.getTemplate("message", """
                <#if package??>
                package ${package};
                
                </#if>
                import de.schegge.rosinante.io.Builder;
                import de.schegge.rosinante.io.LimitedProtoInputStream;
                import de.schegge.rosinante.io.Proto;
                import de.schegge.rosinante.io.ProtoInputStream;
                import de.schegge.rosinante.io.ProtoOutputStream;
                import de.schegge.rosinante.io.WireType;
                
                import java.io.ByteArrayOutputStream;
                import java.io.IOException;
                import java.io.InputStream;
                import java.io.OutputStream;
                import java.util.Objects;
                import java.util.Optional;
                import java.util.BitSet;
                import java.util.List;
                import java.util.ArrayList;
                
                public final class ${message.name}Proto implements Proto<${message.name}> {
                      public static class ${message.name}Builder implements Builder<${message.name}> {
                          <#list message.fields as field with loop>
                            <#if field.repeated>
                          private List<${field.type.wrapper}> ${field.name} = new ArrayList<>();
                            <#elseif field.optional>
                          private ${field.type.wrapper} ${field.name} = null;
                            <#else>
                          private ${field.type.type} ${field.name} = ${field.type.initial};
                            </#if>
                          </#list>
                          <#if (message.mandatoryFields > 0)>
                          BitSet initialized = new BitSet();
                          </#if>
                
                          <#list message.fields as field with loop>
                          <#if field.repeated>
                          public ${message.name}Builder set${field.name?capitalize}(List<${field.type.wrapper}> ${field.name}) {
                              this.${field.name} = Objects.requireNonNull(${field.name});
                              initialized.set(${loop?counter});
                              return this;
                          }
                          <#else>
                          public ${message.name}Builder set${field.name?capitalize}(<#if field.optional>${field.type.wrapper}<#else>${field.type.type}</#if> ${field.name}) {
                              <#if field.optional>
                              this.${field.name} = Objects.requireNonNull(${field.name});
                              <#else>
                                <#if (field.type.type != field.type.wrapper)>
                              this.${field.name} = ${field.name};
                                <#else>
                              this.${field.name} = Objects.requireNonNull(${field.name});
                                </#if>
                              initialized.set(${loop?counter});
                              </#if>
                              return this;
                          }
                          </#if>
                          </#list>
                
                          @Override
                          public ${message.name} build() {
                              <#if (message.mandatoryFields > 0)>
                              if (initialized.cardinality() != ${message.mandatoryFields}) {
                                  throw new IllegalArgumentException("some fields are not initialized");
                              }
                              </#if>
                              return new ${message.name}(<#list message.fields as field with loop>${field.name}<#if loop?has_next>, </#if></#list>);
                          }
                      }
                
                      @Override
                      public ${message.name}Builder builder() {
                          return new ${message.name}Builder();
                      }
                
                      @Override
                      public ${message.name} read(InputStream inputStream) throws IOException {
                          ProtoInputStream protoInputStream = new ProtoInputStream(inputStream);
                          <#list message.fields as field with loop>
                            <#if field.repeated>
                          List<${field.type.wrapper}> ${field.name} = new ArrayList<>();
                            <#elseif field.optional>
                          ${field.type.wrapper} ${field.name} = null;
                            <#else>
                          ${field.type.type} ${field.name} = ${field.type.initial};
                            </#if>
                          </#list>
                          <#if (message.mandatoryFields > 0)>
                          BitSet initialized = new BitSet();
                          </#if>
                
                          WireType wireType;
                          while ((wireType = protoInputStream.readType()) != null) {
                              int lastFieldNumber = protoInputStream.getLastFieldNumber();
                              switch (lastFieldNumber) {
                               <#list message.fields as field with loop>
                                  case ${loop?counter}:
                                      <#if (field.type.complex == 2)>
                                      int index = protoInputStream.readInteger(wireType);
                                      ${field.name} = ${field.type.type}.byFieldNumber(index);
                                      <#elseif (field.type.complex == 1)>
                                        <#if (field.repeated)>
                                      ${field.name}.add(read${field.name?capitalize}(protoInputStream));
                                        <#else>
                                      ${field.name} = read${field.name?capitalize}(protoInputStream);
                                        </#if>
                                      <#elseif (field.type.complex == 0)>
                                      <#if (field.repeated)>
                                        <#if (field.type.packable)>
                                      ${field.name}.addAll(protoInputStream.read${field.type.wrapper}s(wireType));
                                        <#else>
                                      ${field.name}.add(protoInputStream.read${field.type.io}(wireType));
                                        </#if>
                                      <#else>
                                      ${field.name} = protoInputStream.read${field.type.io}(wireType);
                                      </#if>
                                      </#if>
                                      <#if !field.optional>
                                      initialized.set(${loop?counter});
                                      </#if>
                                      break;
                               </#list>
                                  default:
                                      throw new IOException("invalid field number: " + lastFieldNumber);
                              }
                          }
                          <#if (message.mandatoryFields > 0)>
                          if (initialized.cardinality() != ${message.mandatoryFields}) {
                              throw new IOException("some fields are not initialized");
                          }
                          </#if>
                          return new ${message.name}(<#list message.fields as field with loop>${field.name}<#if loop?has_next>, </#if></#list>);
                      }

                      @Override
                      public void write(OutputStream outputStream, ${message.name} value) throws IOException {
                          ProtoOutputStream protoOutputStream = new ProtoOutputStream(outputStream);
                          <#list message.fields as field with loop>
                          <#if field.optional>
                          if (value.get${field.name?capitalize}().isPresent()) {
                            <#if (field.type.complex == 2)>
                            protoOutputStream.writeInteger(${loop?counter}, value.get${field.name?capitalize}().getFieldNumber());
                            #elseif (field.type.complex == 1)>
                            write${field.name?capitalize}(value.get${field.name?capitalize}().get(), protoOutputStream);
                            <#elseif (field.type.complex == 0)>
                            protoOutputStream.write${field.type.io}(${loop?counter}, value.get${field.name?capitalize}().get());
                            </#if>
                          }
                          <#else>
                            <#if (field.type.complex == 2)>
                              <#if (field.repeated)>
                          for (${field.type.type} ${field.name} : value.${field.name}()) {
                            protoOutputStream.writeInteger(${loop?counter}, ${field.type.type}.getFieldNumber(${field.name}));
                          }
                              <#else>
                          protoOutputStream.writeInteger(${loop?counter}, ${field.type.type}.getFieldNumber(value.${field.name}()));
                              </#if>
                            <#elseif (field.type.complex == 1)>
                              <#if (field.repeated)>
                          for (${field.type.type} ${field.name} : value.${field.name}()) {
                            ${field.name?capitalize}(${field.name}, protoOutputStream);
                          }
                              <#else>
                          write${field.name?capitalize}(value.${field.name}(), protoOutputStream);
                              </#if>
                            <#elseif (field.type.complex == 0)>
                          protoOutputStream.write${field.type.io}(${loop?counter}, value.${field.name}());
                            </#if>
                          </#if>
                          </#list>
                      }
                      <#list message.fields as field with loop>
                        <#if (field.type.complex == 1)>
              
                      private ${field.type.type} read${field.name?capitalize}(ProtoInputStream protoInputStream) throws IOException {
                        int length = protoInputStream.readInteger(WireType.VARINT);
                        try (LimitedProtoInputStream limitedProtoInputStream = new LimitedProtoInputStream(protoInputStream, length)) {
                          return new ${field.type.type}Proto().read(limitedProtoInputStream);
                        }
                      }

                      private void write${field.name?capitalize}(${field.type.type} value, ProtoOutputStream protoOutputStream) throws IOException {
                        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                        new ${field.type.type}Proto().write(new ProtoOutputStream(byteArrayOutputStream), value);
                        protoOutputStream.writeBytes(${loop?counter},  byteArrayOutputStream.toByteArray());
                      }
                        </#if>
                      </#list>
                }
                """);
        Map<String, Object> model = new HashMap<>();
        model.put("package", context.getPackageName());
        model.put("message", message);
        StringWriter writer = new StringWriter();
        template.process(model, writer);
        try {
            destinations.writeFile(context.getPackageName(), message.getName() + "Proto.java", writer.toString());
        } catch (IOException e) {
            throw new ProtoWriteException(e);
        }
    }

    private static void createMessageClass(MessagePattern message, GeneratorContext context, Configuration configuration, ProtoDestinations destinations) {
        Template template = configuration.getTemplate("message", """
                <#if package??>
                package ${package};
                
                </#if>
                import java.util.Optional;
                import java.util.List;

                public record ${message.name}(<#list message.fields as field with loop><#if field.repeated>List<${field.type.wrapper}><#elseif field.optional>${field.type.wrapper}<#else>${field.type.type}</#if> ${field.name}<#if loop?has_next>, </#if></#list>) {
                  <#list message.fields as field with loop>
                    <#if field.optional>
                  public Optional<${field.type.wrapper}> get${field.name?capitalize}() {
                    return Optional.ofNullable(${field.name}());
                  }
                    </#if>
                  </#list>
                }
                """);
        Map<String, Object> model = new HashMap<>();
        model.put("package", context.getPackageName());
        model.put("message", message);
        StringWriter writer = new StringWriter();
        template.process(model, writer);
        try {
            destinations.writeFile(context.getPackageName(), message.getName() + ".java", writer.toString());
        } catch (IOException e) {
            throw new ProtoWriteException(e);
        }
    }
}