diff --git a/examples/java/src/main/java/org/apache/beam/examples/snippets/Snippets.java b/examples/java/src/main/java/org/apache/beam/examples/snippets/Snippets.java index 4f24c69f74b7..49f6ba89b350 100644 --- a/examples/java/src/main/java/org/apache/beam/examples/snippets/Snippets.java +++ b/examples/java/src/main/java/org/apache/beam/examples/snippets/Snippets.java @@ -43,7 +43,6 @@ import java.util.Map; import java.util.stream.Collectors; import javax.annotation.Nullable; -import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.DefaultCoder; @@ -60,7 +59,6 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryOptions; import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations; import org.apache.beam.sdk.io.gcp.bigquery.InsertRetryPolicy; -import org.apache.beam.sdk.io.gcp.bigquery.SchemaAndRecord; import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; import org.apache.beam.sdk.io.gcp.bigquery.WriteResult; import org.apache.beam.sdk.io.range.OffsetRange; @@ -228,8 +226,7 @@ public static void modelBigQueryIO( // [START BigQueryReadFunction] PCollection maxTemperatures = p.apply( - BigQueryIO.read( - (SchemaAndRecord elem) -> (Double) elem.getRecord().get("max_temperature")) + BigQueryIO.parseGenericRecords(record -> (Double) record.get("max_temperature")) .from(tableSpec) .withCoder(DoubleCoder.of())); // [END BigQueryReadFunction] @@ -239,8 +236,7 @@ public static void modelBigQueryIO( // [START BigQueryReadQuery] PCollection maxTemperatures = p.apply( - BigQueryIO.read( - (SchemaAndRecord elem) -> (Double) elem.getRecord().get("max_temperature")) + BigQueryIO.parseGenericRecords(record -> (Double) record.get("max_temperature")) .fromQuery( "SELECT max_temperature FROM [apache-beam-testing.samples.weather_stations]") .withCoder(DoubleCoder.of())); @@ -251,8 +247,7 @@ public static void modelBigQueryIO( // [START BigQueryReadQueryStdSQL] PCollection maxTemperatures = p.apply( - BigQueryIO.read( - (SchemaAndRecord elem) -> (Double) elem.getRecord().get("max_temperature")) + BigQueryIO.parseGenericRecords(record -> (Double) record.get("max_temperature")) .fromQuery( "SELECT max_temperature FROM `clouddataflow-readonly.samples.weather_stations`") .usingStandardSql() @@ -392,15 +387,13 @@ public WeatherData(long year, long month, long day, double maxTemp) { PCollection weatherData = p.apply( - BigQueryIO.read( - (SchemaAndRecord elem) -> { - GenericRecord record = elem.getRecord(); - return new WeatherData( - (Long) record.get("year"), - (Long) record.get("month"), - (Long) record.get("day"), - (Double) record.get("max_temperature")); - }) + BigQueryIO.parseGenericRecords( + record -> + new WeatherData( + (Long) record.get("year"), + (Long) record.get("month"), + (Long) record.get("day"), + (Double) record.get("max_temperature"))) .fromQuery( "SELECT year, month, day, max_temperature " + "FROM [apache-beam-testing.samples.weather_stations] " diff --git a/examples/kotlin/src/main/java/org/apache/beam/examples/kotlin/snippets/Snippets.kt b/examples/kotlin/src/main/java/org/apache/beam/examples/kotlin/snippets/Snippets.kt index d2f58c215a56..a216f4e09748 100644 --- a/examples/kotlin/src/main/java/org/apache/beam/examples/kotlin/snippets/Snippets.kt +++ b/examples/kotlin/src/main/java/org/apache/beam/examples/kotlin/snippets/Snippets.kt @@ -121,7 +121,7 @@ object Snippets { val tableSpec = "apache-beam-testing.samples.weather_stations" // [START BigQueryReadFunction] val maxTemperatures = pipeline.apply( - BigQueryIO.read { it.record["max_temperature"] as Double? } + BigQueryIO.parseGenericRecords { it["max_temperature"] as Double? } .from(tableSpec) .withCoder(DoubleCoder.of())) // [END BigQueryReadFunction] @@ -130,7 +130,7 @@ object Snippets { run { // [START BigQueryReadQuery] val maxTemperatures = pipeline.apply( - BigQueryIO.read { it.record["max_temperature"] as Double? } + BigQueryIO.parseGenericRecords { it["max_temperature"] as Double? } .fromQuery( "SELECT max_temperature FROM [apache-beam-testing.samples.weather_stations]") .withCoder(DoubleCoder.of())) @@ -140,7 +140,7 @@ object Snippets { run { // [START BigQueryReadQueryStdSQL] val maxTemperatures = pipeline.apply( - BigQueryIO.read { it.record["max_temperature"] as Double? } + BigQueryIO.parseGenericRecords { it["max_temperature"] as Double? } .fromQuery( "SELECT max_temperature FROM `clouddataflow-readonly.samples.weather_stations`") .usingStandardSql() @@ -249,13 +249,12 @@ object Snippets { ) */ val weatherData = pipeline.apply( - BigQueryIO.read { - val record = it.record + BigQueryIO.parseGenericRecords { WeatherData( - record.get("year") as Long, - record.get("month") as Long, - record.get("day") as Long, - record.get("max_temperature") as Double) + it.get("year") as Long, + it.get("month") as Long, + it.get("day") as Long, + it.get("max_temperature") as Double) } .fromQuery(""" SELECT year, month, day, max_temperature diff --git a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java index 78ba610ad4d1..2bd2ca45f2f7 100644 --- a/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java +++ b/sdks/java/extensions/arrow/src/main/java/org/apache/beam/sdk/extensions/arrow/ArrowConversion.java @@ -252,6 +252,16 @@ public static RecordBatchRowIterator rowsFromSerializedRecordBatch( InputStream inputStream, RootAllocator allocator) throws IOException { + return rowsFromSerializedRecordBatch( + arrowSchema, ArrowSchemaTranslator.toBeamSchema(arrowSchema), inputStream, allocator); + } + + public static RecordBatchRowIterator rowsFromSerializedRecordBatch( + org.apache.arrow.vector.types.pojo.Schema arrowSchema, + Schema schema, + InputStream inputStream, + RootAllocator allocator) + throws IOException { VectorSchemaRoot vectorRoot = VectorSchemaRoot.create(arrowSchema, allocator); VectorLoader vectorLoader = new VectorLoader(vectorRoot); vectorRoot.clear(); @@ -261,7 +271,7 @@ public static RecordBatchRowIterator rowsFromSerializedRecordBatch( vectorLoader.load(arrowMessage); } } - return rowsFromRecordBatch(ArrowSchemaTranslator.toBeamSchema(arrowSchema), vectorRoot); + return rowsFromRecordBatch(schema, vectorRoot); } public static org.apache.arrow.vector.types.pojo.Schema arrowSchemaFromInput(InputStream input) diff --git a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSource.java b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSource.java index c6c7fa426dbf..f94e02e8f891 100644 --- a/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSource.java +++ b/sdks/java/extensions/avro/src/main/java/org/apache/beam/sdk/extensions/avro/io/AvroSource.java @@ -142,11 +142,11 @@ public interface DatumReaderFactory extends Serializable { // Use cases of AvroSource are: // 1) AvroSource Reading GenericRecord records with a specified schema. // 2) AvroSource Reading records of a generated Avro class Foo. - // 3) AvroSource Reading GenericRecord records with an unspecified schema + // 3) AvroSource Reading GenericRecord records with an (un)specified schema // and converting them to type T. // | Case 1 | Case 2 | Case 3 | // type | GenericRecord | Foo | GenericRecord | - // readerSchemaString | non-null | non-null | null | + // readerSchemaString | non-null | non-null | either | // parseFn | null | null | non-null | // outputCoder | either | either | non-null | // readerFactory | either | either | either | diff --git a/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java b/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java index 2bd18ce32244..4c46c2f41543 100644 --- a/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java +++ b/sdks/java/extensions/zetasketch/src/test/java/org/apache/beam/sdk/extensions/zetasketch/BigQueryHllSketchCompatibilityIT.java @@ -26,7 +26,6 @@ import com.google.api.services.bigquery.model.TableReference; import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; -import com.google.cloud.bigquery.storage.v1.DataFormat; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; @@ -34,12 +33,12 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method; -import org.apache.beam.sdk.io.gcp.bigquery.SchemaAndRecord; import org.apache.beam.sdk.io.gcp.testing.BigqueryClient; import org.apache.beam.sdk.options.ApplicationNameOptions; import org.apache.beam.sdk.testing.PAssert; @@ -179,11 +178,10 @@ private void readSketchFromBigQuery(String tableId, Long expectedCount) { "SELECT HLL_COUNT.INIT(%s) AS %s FROM %s", DATA_FIELD_NAME, QUERY_RESULT_FIELD_NAME, tableSpec); - SerializableFunction parseQueryResultToByteArray = - input -> + SerializableFunction parseQueryResultToByteArray = + record -> // BigQuery BYTES type corresponds to Java java.nio.ByteBuffer type - HllCount.getSketchFromByteBuffer( - (ByteBuffer) input.getRecord().get(QUERY_RESULT_FIELD_NAME)); + HllCount.getSketchFromByteBuffer((ByteBuffer) record.get(QUERY_RESULT_FIELD_NAME)); TestPipelineOptions options = TestPipeline.testingPipelineOptions().as(TestPipelineOptions.class); @@ -191,8 +189,7 @@ private void readSketchFromBigQuery(String tableId, Long expectedCount) { Pipeline p = Pipeline.create(options); PCollection result = p.apply( - BigQueryIO.read(parseQueryResultToByteArray) - .withFormat(DataFormat.AVRO) + BigQueryIO.parseGenericRecords(parseQueryResultToByteArray) .fromQuery(query) .usingStandardSql() .withMethod(Method.DIRECT_READ) diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroRowWriter.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroRowWriter.java index 1f45371b19ff..7c8383e5a218 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroRowWriter.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/AvroRowWriter.java @@ -20,7 +20,7 @@ import java.io.IOException; import org.apache.avro.Schema; import org.apache.avro.file.DataFileWriter; -import org.apache.avro.io.DatumWriter; +import org.apache.beam.sdk.extensions.avro.io.AvroSink; import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.MimeTypes; @@ -36,7 +36,7 @@ class AvroRowWriter extends BigQueryRowWriter { String basename, Schema schema, SerializableFunction, AvroT> toAvroRecord, - SerializableFunction> writerFactory) + AvroSink.DatumWriterFactory writerFactory) throws Exception { super(basename, MimeTypes.BINARY); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java index ca9dfdb65caf..371b24b1eb8c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIO.java @@ -52,7 +52,6 @@ import com.google.protobuf.DynamicMessage; import com.google.protobuf.Message; import java.io.IOException; -import java.io.Serializable; import java.lang.reflect.InvocationTargetException; import java.util.Arrays; import java.util.Collections; @@ -64,21 +63,23 @@ import java.util.function.Predicate; import java.util.regex.Pattern; import java.util.stream.Collectors; -import org.apache.avro.generic.GenericDatumReader; import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.generic.GenericRecord; -import org.apache.avro.io.DatumReader; import org.apache.avro.io.DatumWriter; -import org.apache.avro.io.Decoder; +import org.apache.avro.reflect.ReflectData; +import org.apache.avro.specific.SpecificRecord; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineRunner; import org.apache.beam.sdk.coders.CannotProvideCoderException; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.extensions.avro.io.AvroDatumFactory; import org.apache.beam.sdk.extensions.avro.io.AvroSource; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.extensions.gcp.util.Transport; import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; @@ -142,13 +143,10 @@ import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.sdk.values.ValueInSingleWindow; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Function; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Predicates; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Suppliers; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -282,13 +280,13 @@ * BigQueryIO.readTableRows().from("apache-beam-testing.samples.weather_stations")); * } * - * Example: Reading rows of a table and parsing them into a custom type. + * Example: Reading rows of a table and parsing them into a custom type from avro. * *
{@code
  * PCollection weatherData = pipeline.apply(
  *    BigQueryIO
- *      .read(new SerializableFunction() {
- *        public WeatherRecord apply(SchemaAndRecord schemaAndRecord) {
+ *      .readAvro(new SerializableFunction() {
+ *        public WeatherRecord apply(GenericRecord record) {
  *          return new WeatherRecord(...);
  *        }
  *      })
@@ -296,7 +294,7 @@
  *      .withCoder(SerializableCoder.of(WeatherRecord.class));
  * }
* - *

Note: When using {@link #read(SerializableFunction)}, you may sometimes need to use {@link + *

Note: When using read API with a parse function, you may sometimes need to use {@link * TypedRead#withCoder(Coder)} to specify a {@link Coder} for the result type, if Beam fails to * infer it automatically. * @@ -617,13 +615,6 @@ public class BigQueryIO { static final SerializableFunction TABLE_ROW_IDENTITY_FORMATTER = SerializableFunctions.identity(); - /** - * A formatting function that maps a GenericRecord to itself. This allows sending a {@code - * PCollection} directly to BigQueryIO.Write. - */ - static final SerializableFunction, GenericRecord> - GENERIC_RECORD_IDENTITY_FORMATTER = AvroWriteRequest::getElement; - static final SerializableFunction> GENERIC_DATUM_WRITER_FACTORY = schema -> new GenericDatumWriter<>(); @@ -634,9 +625,9 @@ public class BigQueryIO { static final String STORAGE_URI = "storageUri"; /** - * @deprecated Use {@link #read(SerializableFunction)} or {@link #readTableRows} instead. {@link - * #readTableRows()} does exactly the same as {@link #read}, however {@link - * #read(SerializableFunction)} performs better. + * @deprecated Use {@link #parseGenericRecords(SerializableFunction)}, {@link + * #parseArrowRows(SerializableFunction)} or {@link #readTableRows} instead. {@link + * #readTableRows()} does exactly the same as {@link #read}. */ @Deprecated public static Read read() { @@ -644,67 +635,81 @@ public static Read read() { } /** - * Like {@link #read(SerializableFunction)} but represents each row as a {@link TableRow}. + * Reads from a BigQuery table or query and returns a {@link PCollection} with one element per + * each row of the table or query result, parsed from the BigQuery AVRO format, converted to a + * {@link TableRow}. * *

This method is more convenient to use in some cases, but usually has significantly lower - * performance than using {@link #read(SerializableFunction)} directly to parse data into a - * domain-specific type, due to the overhead of converting the rows to {@link TableRow}. + * performance than using {@link #parseGenericRecords(SerializableFunction)} or {@link + * #parseArrowRows(SerializableFunction)} directly to parse data into a domain-specific type, due + * to the overhead of converting the rows to {@link TableRow}. */ public static TypedRead readTableRows() { - return read(new TableRowParser()).withCoder(TableRowJsonCoder.of()); + return readTableRows(DataFormat.AVRO); + } + + /** + * Like {@link #readTableRows()} but with possibility to choose between BigQuery AVRO or ARROW + * format. + */ + public static TypedRead readTableRows(DataFormat dataFormat) { + if (dataFormat == DataFormat.AVRO) { + return readAvroImpl( + null, + true, + AvroDatumFactory.generic(), + input -> BigQueryAvroUtils.convertGenericRecordToTableRow(input.getElement()), + TableRowJsonCoder.of(), + TypeDescriptor.of(TableRow.class)); + } else if (dataFormat == DataFormat.ARROW) { + return readArrowImpl( + null, + input -> BigQueryUtils.toTableRow(input.getElement()), + TableRowJsonCoder.of(), + TypeDescriptor.of(TableRow.class)); + } else { + throw new IllegalArgumentException("Unsupported data format: " + dataFormat); + } } /** Like {@link #readTableRows()} but with {@link Schema} support. */ public static TypedRead readTableRowsWithSchema() { - return read(new TableRowParser()) - .withCoder(TableRowJsonCoder.of()) + return readTableRowsWithSchema(DataFormat.AVRO); + } + + /** Like {@link #readTableRows(DataFormat)} but with {@link Schema} support. */ + public static TypedRead readTableRowsWithSchema(DataFormat dataFormat) { + return readTableRows(dataFormat) .withBeamRowConverters( TypeDescriptor.of(TableRow.class), BigQueryUtils.tableRowToBeamRow(), BigQueryUtils.tableRowFromBeamRow()); } - private static class TableSchemaFunction - implements Serializable, Function<@Nullable String, @Nullable TableSchema> { - @Override - public @Nullable TableSchema apply(@Nullable String input) { - return BigQueryHelpers.fromJsonString(input, TableSchema.class); - } + /** + * Reads from a BigQuery table or query and returns a {@link PCollection} with one element per + * each row of the table or query result, parsed from the BigQuery AVRO format, converted to a + * {@link Row}. + * + *

This method is more convenient to use in some cases, but usually has significantly lower + * performance than using {@link #parseGenericRecords(SerializableFunction)} or {@link + * #parseArrowRows(SerializableFunction)} directly to parse data into a domain-specific type, due + * to the overhead of converting the rows to {@link Row}. + */ + public static TypedRead readRows() { + return readRows(DataFormat.AVRO); } - @VisibleForTesting - static class GenericDatumTransformer implements DatumReader { - private final SerializableFunction parseFn; - private final Supplier tableSchema; - private GenericDatumReader reader; - private org.apache.avro.Schema writerSchema; - - public GenericDatumTransformer( - SerializableFunction parseFn, - String tableSchema, - org.apache.avro.Schema writer) { - this.parseFn = parseFn; - this.tableSchema = - Suppliers.memoize( - Suppliers.compose(new TableSchemaFunction(), Suppliers.ofInstance(tableSchema))); - this.writerSchema = writer; - this.reader = new GenericDatumReader<>(this.writerSchema); - } - - @Override - public void setSchema(org.apache.avro.Schema schema) { - if (this.writerSchema.equals(schema)) { - return; - } - - this.writerSchema = schema; - this.reader = new GenericDatumReader<>(this.writerSchema); - } - - @Override - public T read(T reuse, Decoder in) throws IOException { - GenericRecord record = (GenericRecord) this.reader.read(reuse, in); - return parseFn.apply(new SchemaAndRecord(record, this.tableSchema.get())); + /** + * Like {@link #readRows()} but with possibility to choose between BigQuery AVRO or ARROW format. + */ + public static TypedRead readRows(DataFormat dataFormat) { + if (dataFormat == DataFormat.AVRO) { + return parseGenericRecords(new RowAvroParser()); + } else if (dataFormat == DataFormat.ARROW) { + return readArrowRows(); + } else { + throw new IllegalArgumentException("Unsupported data format: " + dataFormat); } } @@ -727,35 +732,13 @@ public T read(T reuse, Decoder in) throws IOException { * } * }).from("..."); * } + * + * @deprecated Use {@link #parseGenericRecords(SerializableFunction)} instead. */ + @Deprecated public static TypedRead read(SerializableFunction parseFn) { - return new AutoValue_BigQueryIO_TypedRead.Builder() - .setValidate(true) - .setWithTemplateCompatibility(false) - .setBigQueryServices(new BigQueryServicesImpl()) - .setDatumReaderFactory( - (SerializableFunction>) - input -> { - try { - String jsonTableSchema = BigQueryIO.JSON_FACTORY.toString(input); - return (AvroSource.DatumReaderFactory) - (writer, reader) -> - new GenericDatumTransformer<>(parseFn, jsonTableSchema, writer); - } catch (IOException e) { - LOG.warn( - String.format("Error while converting table schema %s to JSON!", input), e); - return null; - } - }) - // TODO: Remove setParseFn once https://github.com/apache/beam/issues/21076 is fixed. - .setParseFn(parseFn) - .setMethod(TypedRead.Method.DEFAULT) - .setUseAvroLogicalTypes(false) - .setFormat(DataFormat.AVRO) - .setProjectionPushdownApplied(false) - .setBadRecordErrorHandler(new DefaultErrorHandler<>()) - .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER) - .build(); + return readAvroImpl( + null, false, AvroDatumFactory.generic(), parseFn, null, TypeDescriptors.outputOf(parseFn)); } /** @@ -766,36 +749,200 @@ public static TypedRead read(SerializableFunction par *

{@code
    * class ClickEvent { long userId; String url; ... }
    *
-   * p.apply(BigQueryIO.read(ClickEvent.class)).from("...")
-   * .read((AvroSource.DatumReaderFactory) (writer, reader) -> new ReflectDatumReader<>(ReflectData.get().getSchema(ClickEvent.class)));
+   * p.apply(BigQueryIO.readWithDatumReader(AvroDatumFactory.reflect(ClickEvent.class)).from("...")
    * }
*/ public static TypedRead readWithDatumReader( AvroSource.DatumReaderFactory readerFactory) { + TypeDescriptor td = null; + if (readerFactory instanceof AvroDatumFactory) { + td = TypeDescriptor.of(((AvroDatumFactory) readerFactory).getType()); + } + return readAvroImpl(null, false, readerFactory, SchemaAndElement::getElement, null, td); + } + + /** + * Reads from a BigQuery table or query and returns a {@link PCollection} with one element per + * each row of the table or query result as {@link GenericRecord}. Logical type in Extract jobs + * will be enabled. + */ + public static TypedRead readGenericRecords() { + return readAvroImpl( + null, + true, + AvroDatumFactory.generic(), + SchemaAndElement::getRecord, + null, + TypeDescriptor.of(GenericRecord.class)); + } + + /** + * Reads from a BigQuery table or query and returns a {@link PCollection} with one element per + * each row of the table or query result as {@link GenericRecord} with the desired schema. Logical + * type in Extract jobs will be enabled. + */ + public static TypedRead readGenericRecords(org.apache.avro.Schema schema) { + return readAvroImpl( + schema, + true, + AvroDatumFactory.generic(), + SchemaAndElement::getRecord, + AvroCoder.generic(schema), + TypeDescriptor.of(GenericRecord.class)); + } + + /** + * Reads from a BigQuery table or query and returns a {@link PCollection} with one element per + * each row of the table or query result as input avro class. Logical type in Extract jobs will be + * enabled. + */ + public static TypedRead readSpecificRecords(Class recordClass) { + org.apache.avro.Schema schema = ReflectData.get().getSchema(recordClass); + AvroDatumFactory factory; + if (GenericRecord.class.equals(recordClass)) { + throw new IllegalArgumentException("TypedRead for GenericRecord requires a schema"); + } else if (SpecificRecord.class.isAssignableFrom(recordClass)) { + factory = AvroDatumFactory.specific(recordClass); + } else { + factory = AvroDatumFactory.reflect(recordClass); + } + AvroCoder coder = AvroCoder.of(factory, schema); + TypeDescriptor td = TypeDescriptor.of(recordClass); + return readAvroImpl(schema, true, factory, SchemaAndElement::getElement, coder, td); + } + + /** + * Reads from a BigQuery table or query and returns a {@link PCollection} with one element per + * each row of the table or query result. This API directly deserializes BigQuery AVRO data to the + * input class, based on the appropriate {@link org.apache.avro.io.DatumReader} and schema. + * Logical type in Extract jobs will be enabled. + */ + public static TypedRead readRecords( + org.apache.avro.Schema schema, AvroSource.DatumReaderFactory readerFactory) { + TypeDescriptor td = null; + Coder coder = null; + if (readerFactory instanceof AvroDatumFactory) { + coder = AvroCoder.of((AvroDatumFactory) readerFactory, schema); + td = TypeDescriptor.of(((AvroDatumFactory) readerFactory).getType()); + } + return readAvroImpl(schema, true, readerFactory, SchemaAndElement::getElement, coder, td); + } + + /** + * Reads from a BigQuery table or query and returns a {@link PCollection} with one element per + * each row of the table or query result, parsed from the BigQuery AVRO format using the specified + * function. Logical type in Extract jobs will be enabled. + */ + public static TypedRead parseGenericRecords( + SerializableFunction avroFormatFunction) { + return readAvroImpl( + null, + true, + AvroDatumFactory.generic(), + input -> avroFormatFunction.apply(input.getElement()), + null, + TypeDescriptors.outputOf(avroFormatFunction)); + } + + @SuppressWarnings("unchecked") + private static TypedRead readAvroImpl( + org.apache.avro.@Nullable Schema schema, // when null infer from TableSchema at runtime + Boolean useAvroLogicalTypes, + AvroSource.DatumReaderFactory readerFactory, + SerializableFunction, T> parseFn, + @Nullable Coder coder, + @Nullable TypeDescriptor typeDescriptor) { + + if (typeDescriptor != null && typeDescriptor.hasUnresolvedParameters()) { + // type extraction failed and will not be serializable + typeDescriptor = null; + } + return new AutoValue_BigQueryIO_TypedRead.Builder() .setValidate(true) .setWithTemplateCompatibility(false) .setBigQueryServices(new BigQueryServicesImpl()) - .setDatumReaderFactory( - (SerializableFunction>) - input -> readerFactory) .setMethod(TypedRead.Method.DEFAULT) - .setUseAvroLogicalTypes(false) .setFormat(DataFormat.AVRO) + .setAvroSchema(schema) + .setDatumReaderFactory(readerFactory) + .setParseFn(parseFn) + .setUseAvroLogicalTypes(useAvroLogicalTypes) .setProjectionPushdownApplied(false) .setBadRecordErrorHandler(new DefaultErrorHandler<>()) .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER) + .setCoder(coder) + .setTypeDescriptor(typeDescriptor) .build(); } - @VisibleForTesting - static class TableRowParser implements SerializableFunction { + /** + * Reads from a BigQuery table or query and returns a {@link PCollection} with one element per + * each row of the table or query result as {@link Row}. + */ + public static TypedRead readArrowRows() { + return readArrowImpl(null, SchemaAndRow::getElement, null, TypeDescriptor.of(Row.class)); + } - public static final TableRowParser INSTANCE = new TableRowParser(); + /** + * Reads from a BigQuery table or query and returns a {@link PCollection} with one element per + * each row of the table or query result as {@link Row} with the desired schema. + */ + public static TypedRead readArrowRows(Schema schema) { + return readArrowImpl( + schema, SchemaAndRow::getElement, RowCoder.of(schema), TypeDescriptor.of(Row.class)); + } + + /** + * Reads from a BigQuery table or query and returns a {@link PCollection} with one element per + * each row of the table or query result, parsed from the BigQuery ARROW format using the + * specified function. + */ + public static TypedRead parseArrowRows(SerializableFunction arrowFormatFunction) { + return readArrowImpl( + null, + input -> arrowFormatFunction.apply(input.getElement()), + null, + TypeDescriptors.outputOf(arrowFormatFunction)); + } + + private static TypedRead readArrowImpl( + @Nullable Schema schema, // when null infer from TableSchema at runtime + SerializableFunction parseFn, + @Nullable Coder coder, + TypeDescriptor typeDescriptor) { + + if (typeDescriptor != null && typeDescriptor.hasUnresolvedParameters()) { + // type extraction failed and will not be serializable + typeDescriptor = null; + } + return new AutoValue_BigQueryIO_TypedRead.Builder() + .setValidate(true) + .setWithTemplateCompatibility(false) + .setBigQueryServices(new BigQueryServicesImpl()) + .setMethod(TypedRead.Method.DIRECT_READ) // arrow is only available in direct read + .setFormat(DataFormat.ARROW) + .setArrowSchema(schema) + .setArrowParseFn(parseFn) + .setUseAvroLogicalTypes(false) + .setProjectionPushdownApplied(false) + .setBadRecordErrorHandler(new DefaultErrorHandler<>()) + .setBadRecordRouter(BadRecordRouter.THROWING_ROUTER) + .setCoder(coder) + .setTypeDescriptor(typeDescriptor) + .build(); + } + + static class RowAvroParser implements SerializableFunction { + + private transient Schema schema; @Override - public TableRow apply(SchemaAndRecord schemaAndRecord) { - return BigQueryAvroUtils.convertGenericRecordToTableRow(schemaAndRecord.getRecord()); + public Row apply(GenericRecord record) { + if (schema == null) { + schema = AvroUtils.toBeamSchema(record.getSchema()); + } + return AvroUtils.toBeamRowStrict(record, schema); } } @@ -804,7 +951,7 @@ public static class Read extends PTransform> { private final TypedRead inner; Read() { - this(BigQueryIO.read(TableRowParser.INSTANCE).withCoder(TableRowJsonCoder.of())); + this(BigQueryIO.readTableRows()); } Read(TypedRead inner) { @@ -984,10 +1131,16 @@ abstract static class Builder { abstract TypedRead build(); - abstract Builder setParseFn(SerializableFunction parseFn); + abstract Builder setAvroSchema(org.apache.avro.Schema avroSchema); + + abstract Builder setDatumReaderFactory(AvroSource.DatumReaderFactory readerFactory); + + abstract Builder setParseFn( + SerializableFunction, T> parseFn); - abstract Builder setDatumReaderFactory( - SerializableFunction> factoryFn); + abstract Builder setArrowSchema(Schema arrowSchema); + + abstract Builder setArrowParseFn(SerializableFunction parseFn); abstract Builder setCoder(Coder coder); @@ -1023,10 +1176,15 @@ abstract Builder setBadRecordErrorHandler( abstract BigQueryServices getBigQueryServices(); - abstract @Nullable SerializableFunction getParseFn(); + abstract org.apache.avro.@Nullable Schema getAvroSchema(); + + abstract AvroSource.@Nullable DatumReaderFactory getDatumReaderFactory(); + + abstract @Nullable SerializableFunction, T> getParseFn(); - abstract @Nullable SerializableFunction> - getDatumReaderFactory(); + abstract @Nullable Schema getArrowSchema(); + + abstract @Nullable SerializableFunction getArrowParseFn(); abstract @Nullable QueryPriority getQueryPriority(); @@ -1094,11 +1252,16 @@ Coder inferCoder(CoderRegistry coderRegistry) { } try { - return coderRegistry.getCoder(TypeDescriptors.outputOf(getParseFn())); + TypeDescriptor td = getTypeDescriptor(); + if (td != null) { + return coderRegistry.getCoder(td); + } else { + throw new IllegalArgumentException( + "Unable to infer coder for output. Specify it explicitly using withCoder()."); + } } catch (CannotProvideCoderException e) { throw new IllegalArgumentException( - "Unable to infer coder for output of parseFn. Specify it explicitly using withCoder().", - e); + "Unable to infer coder for output. Specify it explicitly using withCoder().", e); } } @@ -1123,7 +1286,7 @@ private BigQuerySourceDef createSourceDef() { } private BigQueryStorageQuerySource createStorageQuerySource( - String stepUuid, Coder outputCoder) { + String stepUuid, BigQueryReaderFactory bqReaderFactory, Coder outputCoder) { return BigQueryStorageQuerySource.create( stepUuid, getQuery(), @@ -1135,7 +1298,7 @@ private BigQueryStorageQuerySource createStorageQuerySource( getQueryTempProject(), getKmsKey(), getFormat(), - getParseFn(), + bqReaderFactory, outputCoder, getBigQueryServices()); } @@ -1267,7 +1430,70 @@ public PCollection expand(PBegin input) { checkArgument(getUseLegacySql() != null, "useLegacySql should not be null if query is set"); } - checkArgument(getDatumReaderFactory() != null, "A readerDatumFactory is required"); + if (getMethod() != TypedRead.Method.DIRECT_READ) { + checkArgument( + getSelectedFields() == null, + "Invalid BigQueryIO.Read: Specifies selected fields, " + + "which only applies when using Method.DIRECT_READ"); + + checkArgument( + getRowRestriction() == null, + "Invalid BigQueryIO.Read: Specifies row restriction, " + + "which only applies when using Method.DIRECT_READ"); + } else if (getTableProvider() == null) { + checkArgument( + getSelectedFields() == null, + "Invalid BigQueryIO.Read: Specifies selected fields, " + + "which only applies when reading from a table"); + + checkArgument( + getRowRestriction() == null, + "Invalid BigQueryIO.Read: Specifies row restriction, " + + "which only applies when reading from a table"); + } + + BigQueryReaderFactory bqReaderFactory; + switch (getFormat()) { + case ARROW: + checkArgument(getArrowParseFn() != null, "Arrow parseFn is required"); + + @Nullable Schema arrowSchema = getArrowSchema(); + SerializableFunction arrowParseFn = getArrowParseFn(); + + if (arrowParseFn == null) { + checkArgument(getParseFn() != null, "Arrow or Avro parseFn is required"); + LOG.warn( + "Reading ARROW from AVRO. Consider using readArrow() instead of withFormat(DataFormat.ARROW)"); + // withFormat() was probably used + SerializableFunction parseFn = + (SerializableFunction) getParseFn(); + arrowParseFn = + arrowInput -> { + GenericRecord record = AvroUtils.toGenericRecord(arrowInput.getElement()); + return parseFn.apply(new SchemaAndRecord(record, arrowInput.getTableSchema())); + }; + } + + bqReaderFactory = BigQueryReaderFactory.arrow(arrowSchema, arrowParseFn); + break; + case AVRO: + checkArgument(getDatumReaderFactory() != null, "Avro datumReaderFactory is required"); + checkArgument(getParseFn() != null, "Avro parseFn is required"); + + org.apache.avro.@Nullable Schema avroSchema = getAvroSchema(); + AvroSource.DatumReaderFactory datumFactory = getDatumReaderFactory(); + SerializableFunction, T> avroParseFn = getParseFn(); + boolean useAvroLogicalTypes = getUseAvroLogicalTypes(); + bqReaderFactory = + BigQueryReaderFactory.avro( + avroSchema, + useAvroLogicalTypes, + (AvroSource.DatumReaderFactory) datumFactory, + (SerializableFunction) avroParseFn); + break; + default: + throw new IllegalArgumentException("Unsupported format: " + getFormat()); + } // if both toRowFn and fromRowFn values are set, enable Beam schema support Pipeline p = input.getPipeline(); @@ -1287,19 +1513,9 @@ public PCollection expand(PBegin input) { final Coder coder = inferCoder(p.getCoderRegistry()); if (getMethod() == TypedRead.Method.DIRECT_READ) { - return expandForDirectRead(input, coder, beamSchema, bqOptions); + return expandForDirectRead(input, coder, beamSchema, bqReaderFactory, bqOptions); } - checkArgument( - getSelectedFields() == null, - "Invalid BigQueryIO.Read: Specifies selected fields, " - + "which only applies when using Method.DIRECT_READ"); - - checkArgument( - getRowRestriction() == null, - "Invalid BigQueryIO.Read: Specifies row restriction, " - + "which only applies when using Method.DIRECT_READ"); - final PCollectionView jobIdTokenView; PCollection jobIdTokenCollection; PCollection rows; @@ -1314,7 +1530,7 @@ public PCollection expand(PBegin input) { p.apply( org.apache.beam.sdk.io.Read.from( sourceDef.toSource( - staticJobUuid, coder, getDatumReaderFactory(), getUseAvroLogicalTypes()))); + staticJobUuid, coder, bqReaderFactory, getUseAvroLogicalTypes()))); } else { // Create a singleton job ID token at execution time. jobIdTokenCollection = @@ -1342,10 +1558,7 @@ public void processElement(ProcessContext c) throws Exception { String jobUuid = c.element(); BigQuerySourceBase source = sourceDef.toSource( - jobUuid, - coder, - getDatumReaderFactory(), - getUseAvroLogicalTypes()); + jobUuid, coder, bqReaderFactory, getUseAvroLogicalTypes()); BigQueryOptions options = c.getPipelineOptions().as(BigQueryOptions.class); ExtractResult res = source.extractFiles(options); @@ -1378,10 +1591,7 @@ public void processElement(ProcessContext c) throws Exception { String jobUuid = c.sideInput(jobIdTokenView); BigQuerySourceBase source = sourceDef.toSource( - jobUuid, - coder, - getDatumReaderFactory(), - getUseAvroLogicalTypes()); + jobUuid, coder, bqReaderFactory, getUseAvroLogicalTypes()); List> sources = source.createSources( ImmutableList.of( @@ -1391,6 +1601,7 @@ public void processElement(ProcessContext c) throws Exception { null); checkArgument(sources.size() == 1, "Expected exactly one source."); BoundedSource avroSource = sources.get(0); + BoundedSource.BoundedReader reader = avroSource.createReader(c.getPipelineOptions()); for (boolean more = reader.start(); more; more = reader.advance()) { @@ -1447,7 +1658,11 @@ void cleanup(PassThroughThenCleanup.ContextContainer c) throws Exception { } private PCollection expandForDirectRead( - PBegin input, Coder outputCoder, Schema beamSchema, BigQueryOptions bqOptions) { + PBegin input, + Coder outputCoder, + Schema beamSchema, + BigQueryReaderFactory bqReaderFactory, + BigQueryOptions bqOptions) { ValueProvider tableProvider = getTableProvider(); Pipeline p = input.getPipeline(); if (tableProvider != null) { @@ -1463,7 +1678,7 @@ private PCollection expandForDirectRead( getFormat(), getSelectedFields(), getRowRestriction(), - getParseFn(), + bqReaderFactory, outputCoder, getBigQueryServices(), getProjectionPushdownApplied()))); @@ -1484,11 +1699,11 @@ private PCollection expandForDirectRead( getFormat(), getSelectedFields(), getRowRestriction(), - getParseFn(), + bqReaderFactory, outputCoder, getBigQueryServices(), getProjectionPushdownApplied()); - List> sources; + List> sources; try { // This splitting logic taken from the SDF implementation of Read long estimatedSize = source.getEstimatedSizeBytes(bqOptions); @@ -1505,13 +1720,28 @@ private PCollection expandForDirectRead( } catch (Exception e) { throw new RuntimeException("Unable to split TableSource", e); } + TupleTag rowTag = new TupleTag<>(); PCollectionTuple resultTuple = p.apply(Create.of(sources)) .apply( - "Read Storage Table Source", - ParDo.of(new ReadTableSource(rowTag, getParseFn(), getBadRecordRouter())) + ParDo.of( + new DoFn, T>() { + @ProcessElement + public void processElement( + ProcessContext c, MultiOutputReceiver outputReceiver) + throws Exception { + BigQueryStorageStreamSource streamSource = c.element(); + readSource( + c.getPipelineOptions(), + rowTag, + outputReceiver, + streamSource, + getBadRecordRouter()); + } + }) .withOutputTags(rowTag, TupleTagList.of(BAD_RECORD_TAG))); + getBadRecordErrorHandler() .addErrorCollection( resultTuple @@ -1522,16 +1752,6 @@ private PCollection expandForDirectRead( } } - checkArgument( - getSelectedFields() == null, - "Invalid BigQueryIO.Read: Specifies selected fields, " - + "which only applies when reading from a table"); - - checkArgument( - getRowRestriction() == null, - "Invalid BigQueryIO.Read: Specifies row restriction, " - + "which only applies when reading from a table"); - // // N.B. All of the code below exists because the BigQuery storage API can't (yet) read from // all anonymous tables, so we need the job ID to reason about the name of the destination @@ -1555,7 +1775,7 @@ && getBadRecordRouter() instanceof ThrowingBadRecordRouter) { rows = p.apply( org.apache.beam.sdk.io.Read.from( - createStorageQuerySource(staticJobUuid, outputCoder))); + createStorageQuerySource(staticJobUuid, bqReaderFactory, outputCoder))); } else { // Create a singleton job ID token at pipeline execution time. PCollection jobIdTokenCollection = @@ -1578,7 +1798,12 @@ public String apply(String input) { PCollectionTuple tuple = createTupleForDirectRead( - jobIdTokenCollection, outputCoder, readStreamsTag, readSessionTag, tableSchemaTag); + jobIdTokenCollection, + bqReaderFactory, + outputCoder, + readStreamsTag, + readSessionTag, + tableSchemaTag); tuple.get(readStreamsTag).setCoder(ProtoCoder.of(ReadStream.class)); tuple.get(readSessionTag).setCoder(ProtoCoder.of(ReadSession.class)); tuple.get(tableSchemaTag).setCoder(StringUtf8Coder.of()); @@ -1590,7 +1815,12 @@ public String apply(String input) { rows = createPCollectionForDirectRead( - tuple, outputCoder, readStreamsTag, readSessionView, tableSchemaView); + tuple, + bqReaderFactory, + outputCoder, + readStreamsTag, + readSessionView, + tableSchemaView); } PassThroughThenCleanup.CleanupOperation cleanupOperation = @@ -1640,50 +1870,9 @@ void cleanup(ContextContainer c) throws Exception { return rows.apply(new PassThroughThenCleanup<>(cleanupOperation, jobIdTokenView)); } - private static class ReadTableSource extends DoFn, T> { - - private final TupleTag rowTag; - - private final SerializableFunction parseFn; - - private final BadRecordRouter badRecordRouter; - - public ReadTableSource( - TupleTag rowTag, - SerializableFunction parseFn, - BadRecordRouter badRecordRouter) { - this.rowTag = rowTag; - this.parseFn = parseFn; - this.badRecordRouter = badRecordRouter; - } - - @ProcessElement - public void processElement( - @Element BoundedSource boundedSource, - MultiOutputReceiver outputReceiver, - PipelineOptions options) - throws Exception { - ErrorHandlingParseFn errorHandlingParseFn = new ErrorHandlingParseFn(parseFn); - BoundedSource sourceWithErrorHandlingParseFn; - if (boundedSource instanceof BigQueryStorageStreamSource) { - sourceWithErrorHandlingParseFn = - ((BigQueryStorageStreamSource) boundedSource).fromExisting(errorHandlingParseFn); - } else { - throw new RuntimeException( - "Bounded Source is not BigQueryStorageStreamSource, unable to read"); - } - readSource( - options, - rowTag, - outputReceiver, - sourceWithErrorHandlingParseFn, - errorHandlingParseFn, - badRecordRouter); - } - } - private PCollectionTuple createTupleForDirectRead( PCollection jobIdTokenCollection, + BigQueryReaderFactory bqReaderFactory, Coder outputCoder, TupleTag readStreamsTag, TupleTag readSessionTag, @@ -1702,7 +1891,7 @@ public void processElement(ProcessContext c) throws Exception { // The getTargetTable call runs a new instance of the query and returns // the destination table created to hold the results. BigQueryStorageQuerySource querySource = - createStorageQuerySource(jobUuid, outputCoder); + createStorageQuerySource(jobUuid, bqReaderFactory, outputCoder); Table queryResultTable = querySource.getTargetTable(options); // Create a read session without specifying a desired stream count and @@ -1719,7 +1908,7 @@ public void processElement(ProcessContext c) throws Exception { .setTable( BigQueryHelpers.toTableResourceName( queryResultTable.getTableReference())) - .setDataFormat(DataFormat.AVRO)) + .setDataFormat(getFormat())) .setMaxStreamCount(0) .build(); @@ -1745,39 +1934,9 @@ public void processElement(ProcessContext c) throws Exception { return tuple; } - private static class ErrorHandlingParseFn - implements SerializableFunction { - private final SerializableFunction parseFn; - - private transient SchemaAndRecord schemaAndRecord = null; - - private ErrorHandlingParseFn(SerializableFunction parseFn) { - this.parseFn = parseFn; - } - - @Override - public T apply(SchemaAndRecord input) { - schemaAndRecord = input; - try { - return parseFn.apply(input); - } catch (Exception e) { - throw new ParseException(e); - } - } - - public SchemaAndRecord getSchemaAndRecord() { - return schemaAndRecord; - } - } - - private static class ParseException extends RuntimeException { - public ParseException(Exception e) { - super(e); - } - } - private PCollection createPCollectionForDirectRead( PCollectionTuple tuple, + BigQueryReaderFactory bqReaderFactory, Coder outputCoder, TupleTag readStreamsTag, PCollectionView readSessionView, @@ -1800,15 +1959,12 @@ public void processElement( c.sideInput(tableSchemaView), TableSchema.class); ReadStream readStream = c.element(); - ErrorHandlingParseFn errorHandlingParseFn = - new ErrorHandlingParseFn(getParseFn()); - BigQueryStorageStreamSource streamSource = BigQueryStorageStreamSource.create( readSession, readStream, tableSchema, - errorHandlingParseFn, + bqReaderFactory, outputCoder, getBigQueryServices()); @@ -1817,7 +1973,6 @@ public void processElement( rowTag, outputReceiver, streamSource, - errorHandlingParseFn, getBadRecordRouter()); } }) @@ -1835,14 +1990,14 @@ public static void readSource( PipelineOptions options, TupleTag rowTag, MultiOutputReceiver outputReceiver, - BoundedSource streamSource, - ErrorHandlingParseFn errorHandlingParseFn, + BigQueryStorageStreamSource streamSource, BadRecordRouter badRecordRouter) throws Exception { // Read all the data from the stream. In the event that this work // item fails and is rescheduled, the same rows will be returned in // the same order. - BoundedSource.BoundedReader reader = streamSource.createReader(options); + BigQueryStorageStreamSource.BigQueryStorageStreamReader reader = + streamSource.createReader(options); try { if (reader.start()) { @@ -1850,12 +2005,12 @@ public static void readSource( } else { return; } - } catch (ParseException e) { - GenericRecord record = errorHandlingParseFn.getSchemaAndRecord().getRecord(); + } catch (BigQueryStorageReader.ReadException e) { + BigQueryStorageReader storageReader = reader.getStorageReader(); badRecordRouter.route( outputReceiver, - record, - AvroCoder.of(record.getSchema()), + storageReader.getLastBadRecord(), + (Coder) storageReader.getBadRecordCoder(), (Exception) e.getCause(), "Unable to parse record reading from BigQuery"); } @@ -1867,12 +2022,12 @@ public static void readSource( } else { return; } - } catch (ParseException e) { - GenericRecord record = errorHandlingParseFn.getSchemaAndRecord().getRecord(); + } catch (BigQueryStorageReader.ReadException e) { + BigQueryStorageReader storageReader = reader.getStorageReader(); badRecordRouter.route( outputReceiver, - record, - AvroCoder.of(record.getSchema()), + storageReader.getLastBadRecord(), + (Coder) storageReader.getBadRecordCoder(), (Exception) e.getCause(), "Unable to parse record reading from BigQuery"); } @@ -2049,7 +2204,12 @@ public TypedRead withMethod(TypedRead.Method method) { return toBuilder().setMethod(method).build(); } - /** See {@link DataFormat}. */ + /** + * See {@link DataFormat}. + * + * @deprecated User {@link #readGenericRecords()} or {@link #readArrowRows()} instead + */ + @Deprecated public TypedRead withFormat(DataFormat format) { return toBuilder().setFormat(format).build(); } @@ -2101,6 +2261,12 @@ public TypedRead withTestServices(BigQueryServices testServices) { return toBuilder().setBigQueryServices(testServices).build(); } + /** + * Enable the logical type in Extract jobs. + * + * @see BQ + * avro export + */ public TypedRead useAvroLogicalTypes() { return toBuilder().setUseAvroLogicalTypes(true).build(); } @@ -2269,8 +2435,15 @@ public static Write applyRowMutations() { * GenericRecords} to a BigQuery table. */ public static Write writeGenericRecords() { - return BigQueryIO.write() - .withAvroFormatFunction(GENERIC_RECORD_IDENTITY_FORMATTER); + return BigQueryIO.write().withAvroWriter(GENERIC_DATUM_WRITER_FACTORY); + } + + /** + * A {@link PTransform} that writes a {@link PCollection} containing {@link SpecificRecord + * SpecificRecord} to a BigQuery table. + */ + public static Write writeSpecificRecords(Class type) { + return BigQueryIO.write().withAvroWriter(AvroDatumFactory.specific(type)::apply); } /** diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java index 1da47156dda7..d8c80f141196 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslation.java @@ -21,11 +21,13 @@ import static org.apache.beam.sdk.util.construction.TransformUpgrader.toByteArray; import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; import com.google.auto.service.AutoService; import com.google.cloud.bigquery.storage.v1.AppendRowsRequest.MissingValueInterpretation; import com.google.cloud.bigquery.storage.v1.DataFormat; import java.io.IOException; import java.io.InvalidClassException; +import java.io.Serializable; import java.time.Duration; import java.util.Collection; import java.util.Collections; @@ -38,6 +40,9 @@ import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.extensions.avro.io.AvroDatumFactory; +import org.apache.beam.sdk.extensions.avro.io.AvroSource; +import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.FromBeamRowFunction; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.QueryPriority; @@ -91,8 +96,11 @@ static class BigQueryIOReadTranslator implements TransformPayloadTranslator transform) { if (transform.getBigQueryServices() != null) { fieldValues.put("bigquery_services", toByteArray(transform.getBigQueryServices())); } - if (transform.getParseFn() != null) { - fieldValues.put("parse_fn", toByteArray(transform.getParseFn())); + if (transform.getAvroSchema() != null) { + org.apache.avro.Schema avroSchema = transform.getAvroSchema(); + // avro 1.8 Schema is not serializable + if (avroSchema instanceof Serializable) { + fieldValues.put("avro_schema", toByteArray(transform.getAvroSchema())); + } else { + String avroSchemaStr = avroSchema.toString(); + fieldValues.put("avro_schema", toByteArray(avroSchemaStr)); + } } if (transform.getDatumReaderFactory() != null) { fieldValues.put("datum_reader_factory", toByteArray(transform.getDatumReaderFactory())); } + if (transform.getParseFn() != null) { + fieldValues.put("parse_fn", toByteArray(transform.getParseFn())); + } + if (transform.getArrowSchema() != null) { + fieldValues.put("arrow_schema", toByteArray(transform.getArrowSchema())); + } + if (transform.getArrowParseFn() != null) { + fieldValues.put("arrow_parse_fn", toByteArray(transform.getArrowParseFn())); + } if (transform.getQueryPriority() != null) { fieldValues.put("query_priority", toByteArray(transform.getQueryPriority())); } @@ -254,15 +278,68 @@ public TypedRead fromConfigRow(Row configRow, PipelineOptions options) { builder.setBigQueryServices(new BigQueryServicesImpl()); } } + byte[] formatBytes = configRow.getBytes("format"); + DataFormat format = null; + if (formatBytes != null) { + format = (DataFormat) fromByteArray(formatBytes); + builder = builder.setFormat(format); + } + byte[] avroSchemaBytes = configRow.getBytes("avro_schema"); + if (avroSchemaBytes != null) { + Object avroSchemaObj = fromByteArray(avroSchemaBytes); + if (avroSchemaObj instanceof org.apache.avro.Schema) { + builder = builder.setAvroSchema((org.apache.avro.Schema) avroSchemaObj); + } else { + String avroSchemaStr = (String) avroSchemaObj; + org.apache.avro.Schema avroSchema = + new org.apache.avro.Schema.Parser().parse(avroSchemaStr); + builder = builder.setAvroSchema(avroSchema); + } + } byte[] parseFnBytes = configRow.getBytes("parse_fn"); if (parseFnBytes != null) { builder = builder.setParseFn((SerializableFunction) fromByteArray(parseFnBytes)); } byte[] datumReaderFactoryBytes = configRow.getBytes("datum_reader_factory"); if (datumReaderFactoryBytes != null) { - builder = - builder.setDatumReaderFactory( - (SerializableFunction) fromByteArray(datumReaderFactoryBytes)); + if (TransformUpgrader.compareVersions(updateCompatibilityBeamVersion, "2.62.0") < 0) { + // on old version, readWithDatumReader sets a SerializableFunction with unused parameter + // when parseFnBytes was set, just read as GenericRecord + if (parseFnBytes == null) { + SerializableFunction> + datumReaderFactoryFn = + (SerializableFunction>) + fromByteArray(datumReaderFactoryBytes); + builder = builder.setDatumReaderFactory(datumReaderFactoryFn.apply(null)); + } else { + builder = builder.setDatumReaderFactory(AvroDatumFactory.generic()); + } + } else { + builder = + builder.setDatumReaderFactory( + (AvroSource.DatumReaderFactory) fromByteArray(datumReaderFactoryBytes)); + } + } + byte[] arrowSchemaBytes = configRow.getBytes("arrow_schema"); + if (arrowSchemaBytes != null) { + builder = builder.setArrowSchema((Schema) fromByteArray(avroSchemaBytes)); + } + byte[] arrowParseFnBytes = configRow.getBytes("arrow_parse_fn"); + if (arrowParseFnBytes != null) { + builder = builder.setParseFn((SerializableFunction) fromByteArray(parseFnBytes)); + } else if (format == DataFormat.ARROW + && TransformUpgrader.compareVersions(updateCompatibilityBeamVersion, "2.62.0") < 0) { + if (parseFnBytes != null) { + // on old version, arrow was read from avro record + SerializableFunction avroFn = + (SerializableFunction) fromByteArray(parseFnBytes); + SerializableFunction arrowFn = + input -> + avroFn.apply( + new SchemaAndRecord( + AvroUtils.toGenericRecord(input.getElement()), input.getTableSchema())); + builder = builder.setArrowParseFn(arrowFn); + } } byte[] queryPriorityBytes = configRow.getBytes("query_priority"); if (queryPriorityBytes != null) { @@ -290,10 +367,6 @@ public TypedRead fromConfigRow(Row configRow, PipelineOptions options) { if (methodBytes != null) { builder = builder.setMethod((TypedRead.Method) fromByteArray(methodBytes)); } - byte[] formatBytes = configRow.getBytes("format"); - if (formatBytes != null) { - builder = builder.setFormat((DataFormat) fromByteArray(formatBytes)); - } Collection selectedFields = configRow.getArray("selected_fields"); if (selectedFields != null && !selectedFields.isEmpty()) { builder.setSelectedFields(StaticValueProvider.of(ImmutableList.of(selectedFields))); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java index fc882b1c2a4f..518a6eca56c4 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySource.java @@ -18,12 +18,9 @@ package org.apache.beam.sdk.io.gcp.bigquery; import com.google.api.services.bigquery.model.TableReference; -import com.google.api.services.bigquery.model.TableSchema; import java.io.IOException; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.extensions.avro.io.AvroSource; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; @@ -36,7 +33,7 @@ static BigQueryQuerySource create( BigQueryQuerySourceDef queryDef, BigQueryServices bqServices, Coder coder, - SerializableFunction> readerFactory, + BigQueryReaderFactory readerFactory, boolean useAvroLogicalTypes) { return new BigQueryQuerySource<>( stepUuid, queryDef, bqServices, coder, readerFactory, useAvroLogicalTypes); @@ -49,7 +46,7 @@ private BigQueryQuerySource( BigQueryQuerySourceDef queryDef, BigQueryServices bqServices, Coder coder, - SerializableFunction> readerFactory, + BigQueryReaderFactory readerFactory, boolean useAvroLogicalTypes) { super(stepUuid, bqServices, coder, readerFactory, useAvroLogicalTypes); this.queryDef = queryDef; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySourceDef.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySourceDef.java index 25f274d708b5..ef7879f4cc36 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySourceDef.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryQuerySourceDef.java @@ -28,10 +28,8 @@ import java.util.Optional; import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.extensions.avro.io.AvroSource; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryResourceNaming.JobType; import org.apache.beam.sdk.options.ValueProvider; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -169,7 +167,7 @@ void cleanupTempResource(BigQueryOptions bqOptions, String stepUuid) throws Exce public BigQuerySourceBase toSource( String stepUuid, Coder coder, - SerializableFunction> readerFactory, + BigQueryReaderFactory readerFactory, boolean useAvroLogicalTypes) { return BigQueryQuerySource.create( stepUuid, this, bqServices, coder, readerFactory, useAvroLogicalTypes); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryReaderFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryReaderFactory.java new file mode 100644 index 000000000000..b9c4296b13ac --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryReaderFactory.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.bigquery.storage.v1.ReadSession; +import java.io.IOException; +import java.io.InputStream; +import java.io.Serializable; +import org.apache.avro.generic.GenericRecord; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.extensions.arrow.ArrowConversion; +import org.apache.beam.sdk.extensions.avro.io.AvroSource; +import org.apache.beam.sdk.io.BoundedSource; +import org.apache.beam.sdk.io.fs.MatchResult; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.SerializableSupplier; +import org.apache.beam.sdk.values.Row; +import org.checkerframework.checker.nullness.qual.Nullable; + +abstract class BigQueryReaderFactory implements BigQueryStorageReaderFactory, Serializable { + + abstract BoundedSource getSource( + MatchResult.Metadata metadata, TableSchema tableSchema, Coder coder); + + abstract BoundedSource getSource( + String fileNameOrPattern, TableSchema tableSchema, Coder coder); + + static BigQueryReaderFactory avro( + org.apache.avro.@Nullable Schema schema, + boolean extractWithLogicalTypes, + AvroSource.DatumReaderFactory readerFactory, + SerializableFunction, T> fromAvro) { + return new BigQueryAvroReaderFactory<>( + schema, extractWithLogicalTypes, readerFactory, fromAvro); + } + + static BigQueryReaderFactory arrow( + @Nullable Schema schema, SerializableFunction fromArrow) { + return new BigQueryArrowReaderFactory<>(schema, fromArrow); + } + + ///////////////////////////////////////////////////////////////////////////// + // Avro + ///////////////////////////////////////////////////////////////////////////// + private static class SerializableSchemaSupplier + implements SerializableSupplier { + private transient org.apache.avro.Schema schema; + private final String jsonSchema; + + SerializableSchemaSupplier(org.apache.avro.Schema schema) { + this.schema = schema; + this.jsonSchema = schema.toString(); + } + + @Override + public org.apache.avro.Schema get() { + if (schema == null) { + schema = new org.apache.avro.Schema.Parser().parse(jsonSchema); + } + return schema; + } + } + + static class BigQueryAvroReaderFactory extends BigQueryReaderFactory { + private final @Nullable SerializableSchemaSupplier + schemaSupplier; // avro 1.8 schema is not serializable + private final boolean extractWithLogicalTypes; + private final AvroSource.DatumReaderFactory readerFactory; + private final SerializableFunction, T> fromAvro; + + BigQueryAvroReaderFactory( + org.apache.avro.@Nullable Schema schema, + boolean extractWithLogicalTypes, + AvroSource.DatumReaderFactory readerFactory, + SerializableFunction, T> fromAvro) { + + this.schemaSupplier = schema == null ? null : new SerializableSchemaSupplier(schema); + this.extractWithLogicalTypes = extractWithLogicalTypes; + this.readerFactory = readerFactory; + this.fromAvro = fromAvro; + } + + @Override + public AvroSource getSource( + MatchResult.Metadata metadata, TableSchema tableSchema, Coder coder) { + return getSource(AvroSource.from(metadata), tableSchema, coder); + } + + @Override + public AvroSource getSource( + String fileNameOrPattern, TableSchema tableSchema, Coder coder) { + return getSource(AvroSource.from(fileNameOrPattern), tableSchema, coder); + } + + private AvroSource getSource( + AvroSource source, TableSchema tableSchema, Coder coder) { + org.apache.avro.Schema readerSchema; + if (schemaSupplier != null) { + readerSchema = schemaSupplier.get(); + } else { + readerSchema = BigQueryUtils.toGenericAvroSchema(tableSchema, extractWithLogicalTypes); + } + SerializableFunction parseFn = + (r) -> fromAvro.apply(new SchemaAndElement<>((AvroT) r, tableSchema)); + return source + .withSchema(readerSchema) + .withDatumReaderFactory(readerFactory) + .withParseFn(parseFn, coder); + } + + @Override + public BigQueryStorageAvroReader getReader( + TableSchema tableSchema, ReadSession readSession) throws IOException { + org.apache.avro.Schema writerSchema = + new org.apache.avro.Schema.Parser().parse(readSession.getAvroSchema().getSchema()); + org.apache.avro.Schema readerSchema; + if (schemaSupplier != null) { + readerSchema = schemaSupplier.get(); + } else { + // BQ storage always uses logical-types + readerSchema = BigQueryUtils.toGenericAvroSchema(tableSchema, true); + } + SerializableFunction fromAvroRecord = + (r) -> fromAvro.apply(new SchemaAndElement<>(r, tableSchema)); + return new BigQueryStorageAvroReader<>( + writerSchema, readerSchema, readerFactory, fromAvroRecord); + } + } + + ///////////////////////////////////////////////////////////////////////////// + // Arrow + ///////////////////////////////////////////////////////////////////////////// + static class BigQueryArrowReaderFactory extends BigQueryReaderFactory { + private final SerializableFunction schemaFactory; + private final SerializableFunction parseFn; + + BigQueryArrowReaderFactory( + @Nullable Schema schema, SerializableFunction parseFn) { + this.parseFn = parseFn; + if (schema == null) { + this.schemaFactory = BigQueryUtils::fromTableSchema; + } else { + this.schemaFactory = tableSchema -> schema; + } + } + + @Override + BoundedSource getSource( + MatchResult.Metadata metadata, TableSchema tableSchema, Coder coder) { + throw new UnsupportedOperationException("Arrow file source not supported"); + } + + @Override + BoundedSource getSource(String fileNameOrPattern, TableSchema tableSchema, Coder coder) { + throw new UnsupportedOperationException("Arrow file source not supported"); + } + + @Override + public BigQueryStorageArrowReader getReader(TableSchema tableSchema, ReadSession readSession) + throws IOException { + try (InputStream input = readSession.getArrowSchema().getSerializedSchema().newInput()) { + org.apache.arrow.vector.types.pojo.Schema writerSchema = + ArrowConversion.arrowSchemaFromInput(input); + Schema readerSchema = schemaFactory.apply(tableSchema); + SerializableFunction fromRow = + (r) -> parseFn.apply(new SchemaAndRow(r, tableSchema)); + return new BigQueryStorageArrowReader<>(writerSchema, readerSchema, fromRow); + } + } + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java index b7b83dccaece..48d58a247656 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceBase.java @@ -32,9 +32,7 @@ import java.io.IOException; import java.util.List; import java.util.stream.Stream; -import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.extensions.avro.io.AvroSource; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.fs.MatchResult; import org.apache.beam.sdk.io.fs.ResourceId; @@ -43,7 +41,6 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.JobService; import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; import org.slf4j.Logger; @@ -73,7 +70,7 @@ abstract class BigQuerySourceBase extends BoundedSource { protected final BigQueryServices bqServices; private transient @Nullable List> cachedSplitResult = null; - private SerializableFunction> readerFactory; + private BigQueryReaderFactory readerFactory; private Coder coder; private final boolean useAvroLogicalTypes; @@ -81,7 +78,7 @@ abstract class BigQuerySourceBase extends BoundedSource { String stepUuid, BigQueryServices bqServices, Coder coder, - SerializableFunction> readerFactory, + BigQueryReaderFactory readerFactory, boolean useAvroLogicalTypes) { this.stepUuid = checkArgumentNotNull(stepUuid, "stepUuid"); this.bqServices = checkArgumentNotNull(bqServices, "bqServices"); @@ -243,23 +240,14 @@ private List executeExtract( List> createSources( List files, TableSchema schema, @Nullable List metadata) throws IOException, InterruptedException { - String avroSchema = BigQueryAvroUtils.toGenericAvroSchema(schema).toString(); - - AvroSource.DatumReaderFactory factory = readerFactory.apply(schema); - - Stream> avroSources; - // If metadata is available, create AvroSources with said metadata in SINGLE_FILE_OR_SUBRANGE - // mode. + Stream> sources; + // If metadata is available, create source with said metadata if (metadata != null) { - avroSources = metadata.stream().map(AvroSource::from); + sources = metadata.stream().map(m -> readerFactory.getSource(m, schema, coder)); } else { - avroSources = files.stream().map(ResourceId::toString).map(AvroSource::from); + sources = files.stream().map(f -> readerFactory.getSource(f.toString(), schema, coder)); } - return avroSources - .map(s -> s.withSchema(avroSchema)) - .map(s -> (AvroSource) s.withDatumReaderFactory(factory)) - .map(s -> s.withCoder(coder)) - .collect(collectingAndThen(toList(), ImmutableList::copyOf)); + return sources.collect(collectingAndThen(toList(), ImmutableList::copyOf)); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceDef.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceDef.java index a9c4c5af283c..18690e2c222c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceDef.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQuerySourceDef.java @@ -20,7 +20,6 @@ import com.google.api.services.bigquery.model.TableSchema; import java.io.Serializable; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.extensions.avro.io.AvroSource; import org.apache.beam.sdk.transforms.SerializableFunction; /** @@ -41,7 +40,7 @@ interface BigQuerySourceDef extends Serializable { BigQuerySourceBase toSource( String stepUuid, Coder coder, - SerializableFunction> readerFactory, + BigQueryReaderFactory readerFactory, boolean useAvroLogicalTypes); /** diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageArrowReader.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageArrowReader.java index 70703cf0082e..332e1650ee41 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageArrowReader.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageArrowReader.java @@ -17,29 +17,38 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; -import com.google.cloud.bigquery.storage.v1.ArrowSchema; import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; -import com.google.cloud.bigquery.storage.v1.ReadSession; import java.io.IOException; -import java.io.InputStream; -import javax.annotation.Nullable; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.avro.generic.GenericRecord; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.extensions.arrow.ArrowConversion; import org.apache.beam.sdk.extensions.arrow.ArrowConversion.RecordBatchRowIterator; -import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.Row; +import org.checkerframework.checker.nullness.qual.Nullable; -class BigQueryStorageArrowReader implements BigQueryStorageReader { +class BigQueryStorageArrowReader implements BigQueryStorageReader { + private final org.apache.arrow.vector.types.pojo.Schema arrowSchema; + private final Schema schema; + private final SerializableFunction fromRow; + private final Coder badRecordCoder; private @Nullable RecordBatchRowIterator recordBatchIterator; private long rowCount; - private ArrowSchema protoSchema; private @Nullable RootAllocator alloc; - BigQueryStorageArrowReader(ReadSession readSession) throws IOException { - protoSchema = readSession.getArrowSchema(); + private transient @Nullable Row badRecord = null; + + BigQueryStorageArrowReader( + org.apache.arrow.vector.types.pojo.Schema writerSchema, + Schema readerSchema, + SerializableFunction fromRow) { + this.arrowSchema = writerSchema; + this.schema = readerSchema; + this.fromRow = fromRow; + this.badRecordCoder = RowCoder.of(readerSchema); this.rowCount = 0; this.alloc = null; } @@ -49,13 +58,11 @@ public void processReadRowsResponse(ReadRowsResponse readRowsResponse) throws IO com.google.cloud.bigquery.storage.v1.ArrowRecordBatch recordBatch = readRowsResponse.getArrowRecordBatch(); rowCount = recordBatch.getRowCount(); - InputStream input = protoSchema.getSerializedSchema().newInput(); - Schema arrowSchema = ArrowConversion.arrowSchemaFromInput(input); RootAllocator alloc = new RootAllocator(Long.MAX_VALUE); this.alloc = alloc; this.recordBatchIterator = ArrowConversion.rowsFromSerializedRecordBatch( - arrowSchema, recordBatch.getSerializedRecordBatch().newInput(), alloc); + arrowSchema, schema, recordBatch.getSerializedRecordBatch().newInput(), alloc); } @Override @@ -64,15 +71,27 @@ public long getRowCount() { } @Override - public GenericRecord readSingleRecord() throws IOException { + public T readSingleRecord() throws IOException { if (recordBatchIterator == null) { throw new IOException("Not Initialized"); } Row row = recordBatchIterator.next(); - // TODO(https://github.com/apache/beam/issues/21076): Update this interface to expect a Row, and - // avoid converting Arrow data to - // GenericRecord. - return AvroUtils.toGenericRecord(row, null); + try { + return fromRow.apply(row); + } catch (Exception e) { + badRecord = row; + throw new ReadException(e); + } + } + + @Override + public @Nullable Row getLastBadRecord() { + return badRecord; + } + + @Override + public Coder getBadRecordCoder() { + return badRecordCoder; } @Override diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageAvroReader.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageAvroReader.java index 50ce6a89f7a9..26d6cc61da6b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageAvroReader.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageAvroReader.java @@ -19,29 +19,43 @@ import com.google.cloud.bigquery.storage.v1.AvroRows; import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; -import com.google.cloud.bigquery.storage.v1.ReadSession; import java.io.IOException; import org.apache.avro.Schema; -import org.apache.avro.generic.GenericDatumReader; -import org.apache.avro.generic.GenericRecord; import org.apache.avro.io.BinaryDecoder; import org.apache.avro.io.DatumReader; import org.apache.avro.io.DecoderFactory; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.extensions.avro.io.AvroDatumFactory; +import org.apache.beam.sdk.extensions.avro.io.AvroSource; +import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.Preconditions; import org.checkerframework.checker.nullness.qual.Nullable; -class BigQueryStorageAvroReader implements BigQueryStorageReader { +class BigQueryStorageAvroReader implements BigQueryStorageReader { - private final Schema avroSchema; - private final DatumReader datumReader; + private final DatumReader datumReader; + private final SerializableFunction fromAvroRecord; + private final @Nullable AvroCoder badRecordCoder; private @Nullable BinaryDecoder decoder; private long rowCount; - BigQueryStorageAvroReader(ReadSession readSession) { - this.avroSchema = new Schema.Parser().parse(readSession.getAvroSchema().getSchema()); - this.datumReader = new GenericDatumReader<>(avroSchema); + private transient @Nullable AvroT badRecord = null; + + BigQueryStorageAvroReader( + Schema writerSchema, + Schema readerSchema, + AvroSource.DatumReaderFactory readerFactory, + SerializableFunction fromAvroRecord) { + this.datumReader = readerFactory.apply(writerSchema, readerSchema); + this.fromAvroRecord = fromAvroRecord; this.rowCount = 0; - decoder = null; + this.decoder = null; + if (readerFactory instanceof AvroDatumFactory) { + this.badRecordCoder = AvroCoder.of((AvroDatumFactory) readerFactory, readerSchema); + } else { + this.badRecordCoder = null; + } } @Override @@ -63,14 +77,29 @@ public long getRowCount() { } @Override - public GenericRecord readSingleRecord() throws IOException { + public T readSingleRecord() throws IOException { Preconditions.checkStateNotNull(decoder); @SuppressWarnings({ "nullness" // reused record is null but avro not annotated }) // record should not be reused, mutating outputted values is unsafe - GenericRecord newRecord = datumReader.read(/*reuse=*/ null, decoder); - return newRecord; + AvroT avroRecord = datumReader.read(/*reuse=*/ null, decoder); + try { + return fromAvroRecord.apply(avroRecord); + } catch (Exception e) { + badRecord = avroRecord; + throw new ReadException(e); + } + } + + @Override + public @Nullable Object getLastBadRecord() { + return badRecord; + } + + @Override + public @Nullable Coder getBadRecordCoder() { + return badRecordCoder; } @Override diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageQuerySource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageQuerySource.java index a2350ef19a74..97a80856d0cd 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageQuerySource.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageQuerySource.java @@ -31,7 +31,6 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.display.DisplayData; import org.checkerframework.checker.nullness.qual.Nullable; @@ -49,7 +48,7 @@ public static BigQueryStorageQuerySource create( @Nullable String queryTempProject, @Nullable String kmsKey, @Nullable DataFormat format, - SerializableFunction parseFn, + BigQueryStorageReaderFactory readerFactory, Coder outputCoder, BigQueryServices bqServices) { return new BigQueryStorageQuerySource<>( @@ -63,7 +62,7 @@ public static BigQueryStorageQuerySource create( queryTempProject, kmsKey, format, - parseFn, + readerFactory, outputCoder, bqServices); } @@ -76,7 +75,7 @@ public static BigQueryStorageQuerySource create( QueryPriority priority, @Nullable String location, @Nullable String kmsKey, - SerializableFunction parseFn, + BigQueryStorageReaderFactory readerFactory, Coder outputCoder, BigQueryServices bqServices) { return new BigQueryStorageQuerySource<>( @@ -90,7 +89,7 @@ public static BigQueryStorageQuerySource create( null, kmsKey, null, - parseFn, + readerFactory, outputCoder, bqServices); } @@ -119,10 +118,10 @@ private BigQueryStorageQuerySource( @Nullable String queryTempProject, @Nullable String kmsKey, @Nullable DataFormat format, - SerializableFunction parseFn, + BigQueryStorageReaderFactory readerFactory, Coder outputCoder, BigQueryServices bqServices) { - super(format, null, null, parseFn, outputCoder, bqServices); + super(format, null, null, readerFactory, outputCoder, bqServices); this.stepUuid = checkNotNull(stepUuid, "stepUuid"); this.queryProvider = checkNotNull(queryProvider, "queryProvider"); this.flattenResults = checkNotNull(flattenResults, "flattenResults"); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReader.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReader.java index e13a0bdd9d65..c00a75e37c46 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReader.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReader.java @@ -19,17 +19,22 @@ import com.google.cloud.bigquery.storage.v1.ReadRowsResponse; import java.io.IOException; -import org.apache.avro.generic.GenericRecord; +import org.apache.beam.sdk.coders.Coder; +import org.checkerframework.checker.nullness.qual.Nullable; -interface BigQueryStorageReader extends AutoCloseable { +interface BigQueryStorageReader extends AutoCloseable { void processReadRowsResponse(ReadRowsResponse readRowsResponse) throws IOException; long getRowCount(); - // TODO(https://github.com/apache/beam/issues/21076): BigQueryStorageReader should produce Rows, - // rather than GenericRecords - GenericRecord readSingleRecord() throws IOException; + T readSingleRecord() throws IOException; + + @Nullable + Object getLastBadRecord(); + + @Nullable + Coder getBadRecordCoder(); boolean readyForNextReadResponse() throws IOException; @@ -37,4 +42,10 @@ interface BigQueryStorageReader extends AutoCloseable { @Override void close(); + + class ReadException extends RuntimeException { + public ReadException(Throwable cause) { + super(cause); + } + } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReaderFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReaderFactory.java index fba06d020699..63c54c32564b 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReaderFactory.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReaderFactory.java @@ -17,19 +17,10 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; +import com.google.api.services.bigquery.model.TableSchema; import com.google.cloud.bigquery.storage.v1.ReadSession; import java.io.IOException; -class BigQueryStorageReaderFactory { - - private BigQueryStorageReaderFactory() {} - - public static BigQueryStorageReader getReader(ReadSession readSession) throws IOException { - if (readSession.hasAvroSchema()) { - return new BigQueryStorageAvroReader(readSession); - } else if (readSession.hasArrowSchema()) { - return new BigQueryStorageArrowReader(readSession); - } - throw new IllegalStateException("Read session does not have Avro/Arrow schema set."); - } +interface BigQueryStorageReaderFactory { + BigQueryStorageReader getReader(TableSchema table, ReadSession readSession) throws IOException; } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java index d0bc655b311a..cb0881fc8d08 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageSourceBase.java @@ -34,7 +34,6 @@ import org.apache.beam.sdk.metrics.Lineage; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; @@ -66,7 +65,7 @@ abstract class BigQueryStorageSourceBase extends BoundedSource { protected final @Nullable DataFormat format; protected final @Nullable ValueProvider> selectedFieldsProvider; protected final @Nullable ValueProvider rowRestrictionProvider; - protected final SerializableFunction parseFn; + protected final BigQueryStorageReaderFactory readerFactory; protected final Coder outputCoder; protected final BigQueryServices bqServices; @@ -74,13 +73,13 @@ abstract class BigQueryStorageSourceBase extends BoundedSource { @Nullable DataFormat format, @Nullable ValueProvider> selectedFieldsProvider, @Nullable ValueProvider rowRestrictionProvider, - SerializableFunction parseFn, + BigQueryStorageReaderFactory readerFactory, Coder outputCoder, BigQueryServices bqServices) { this.format = format; this.selectedFieldsProvider = selectedFieldsProvider; this.rowRestrictionProvider = rowRestrictionProvider; - this.parseFn = checkNotNull(parseFn, "parseFn"); + this.readerFactory = readerFactory; this.outputCoder = checkNotNull(outputCoder, "outputCoder"); this.bqServices = checkNotNull(bqServices, "bqServices"); } @@ -180,8 +179,9 @@ public List> split( // TODO: this is inconsistent with method above, where it can be null Preconditions.checkStateNotNull(targetTable); + TableSchema tableSchema = targetTable.getSchema(); - if (selectedFieldsProvider != null && selectedFieldsProvider.isAccessible()) { + if (selectedFieldsProvider != null) { tableSchema = BigQueryUtils.trimSchema(tableSchema, selectedFieldsProvider.get()); } @@ -189,7 +189,7 @@ public List> split( for (ReadStream readStream : readSession.getStreamsList()) { sources.add( BigQueryStorageStreamSource.create( - readSession, readStream, tableSchema, parseFn, outputCoder, bqServices)); + readSession, readStream, tableSchema, readerFactory, outputCoder, bqServices)); } return ImmutableList.copyOf(sources); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java index adc0933defed..152e504cb3eb 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageStreamSource.java @@ -17,8 +17,6 @@ */ package org.apache.beam.sdk.io.gcp.bigquery; -import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.fromJsonString; -import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryHelpers.toJsonString; import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; import com.google.api.gax.rpc.ApiException; @@ -49,9 +47,9 @@ import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.util.Preconditions; +import org.apache.beam.sdk.util.SerializableSupplier; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.checkerframework.checker.nullness.qual.Nullable; import org.checkerframework.checker.nullness.qual.RequiresNonNull; @@ -67,51 +65,65 @@ public static BigQueryStorageStreamSource create( ReadSession readSession, ReadStream readStream, TableSchema tableSchema, - SerializableFunction parseFn, + BigQueryStorageReaderFactory readerFactory, Coder outputCoder, BigQueryServices bqServices) { return new BigQueryStorageStreamSource<>( readSession, readStream, - toJsonString(Preconditions.checkArgumentNotNull(tableSchema, "tableSchema")), - parseFn, + new SerializableTableSchemaSupplier(tableSchema), + readerFactory, outputCoder, bqServices); } + private static class SerializableTableSchemaSupplier + implements SerializableSupplier { + private transient TableSchema tableSchema; + private final String jsonSchema; + + SerializableTableSchemaSupplier(TableSchema tableSchema) { + this.tableSchema = tableSchema; + this.jsonSchema = BigQueryHelpers.toJsonString(tableSchema); + } + + @Override + public TableSchema get() { + if (tableSchema == null) { + tableSchema = BigQueryHelpers.fromJsonString(jsonSchema, TableSchema.class); + } + return tableSchema; + } + } + /** * Creates a new source with the same properties as this one, except with a different {@link * ReadStream}. */ public BigQueryStorageStreamSource fromExisting(ReadStream newReadStream) { return new BigQueryStorageStreamSource<>( - readSession, newReadStream, jsonTableSchema, parseFn, outputCoder, bqServices); - } - - public BigQueryStorageStreamSource fromExisting( - SerializableFunction parseFn) { - return new BigQueryStorageStreamSource<>( - readSession, readStream, jsonTableSchema, parseFn, outputCoder, bqServices); + readSession, newReadStream, tableSchemaSupplier, readerFactory, outputCoder, bqServices); } private final ReadSession readSession; private final ReadStream readStream; - private final String jsonTableSchema; - private final SerializableFunction parseFn; + private final SerializableTableSchemaSupplier tableSchemaSupplier; + private final BigQueryStorageReaderFactory readerFactory; private final Coder outputCoder; private final BigQueryServices bqServices; private BigQueryStorageStreamSource( ReadSession readSession, ReadStream readStream, - String jsonTableSchema, - SerializableFunction parseFn, + SerializableTableSchemaSupplier tableSchemaSupplier, + BigQueryStorageReaderFactory readerFactory, Coder outputCoder, BigQueryServices bqServices) { this.readSession = Preconditions.checkArgumentNotNull(readSession, "readSession"); this.readStream = Preconditions.checkArgumentNotNull(readStream, "stream"); - this.jsonTableSchema = Preconditions.checkArgumentNotNull(jsonTableSchema, "jsonTableSchema"); - this.parseFn = Preconditions.checkArgumentNotNull(parseFn, "parseFn"); + this.tableSchemaSupplier = + Preconditions.checkArgumentNotNull(tableSchemaSupplier, "tableSchemaSupplier"); + this.readerFactory = Preconditions.checkArgumentNotNull(readerFactory, "readerFactory"); this.outputCoder = Preconditions.checkArgumentNotNull(outputCoder, "outputCoder"); this.bqServices = Preconditions.checkArgumentNotNull(bqServices, "bqServices"); } @@ -158,10 +170,8 @@ public String toString() { /** A {@link org.apache.beam.sdk.io.Source.Reader} which reads records from a stream. */ public static class BigQueryStorageStreamReader extends BoundedSource.BoundedReader { - private final BigQueryStorageReader reader; - private final SerializableFunction parseFn; + private final BigQueryStorageReader reader; private final StorageClient storageClient; - private final TableSchema tableSchema; private BigQueryStorageStreamSource source; private @Nullable BigQueryServerStream responseStream = null; @@ -203,10 +213,9 @@ public static class BigQueryStorageStreamReader extends BoundedSource.Bounded private BigQueryStorageStreamReader( BigQueryStorageStreamSource source, BigQueryOptions options) throws IOException { this.source = source; - this.reader = BigQueryStorageReaderFactory.getReader(source.readSession); - this.parseFn = source.parseFn; + this.reader = + source.readerFactory.getReader(source.tableSchemaSupplier.get(), source.readSession); this.storageClient = source.bqServices.getStorageClient(options); - this.tableSchema = fromJsonString(source.jsonTableSchema, TableSchema.class); // number of stream determined from server side for storage read api v2 this.splitAllowed = !options.getEnableStorageReadApiV2(); this.fractionConsumed = 0d; @@ -311,9 +320,7 @@ private synchronized boolean readNextRecord() throws IOException { * 1.0 / totalRowsInCurrentResponse; - SchemaAndRecord schemaAndRecord = new SchemaAndRecord(reader.readSingleRecord(), tableSchema); - - current = parseFn.apply(schemaAndRecord); + current = reader.readSingleRecord(); return true; } @@ -451,5 +458,9 @@ public synchronized BigQueryStorageStreamSource getCurrentSource() { public synchronized Double getFractionConsumed() { return fractionConsumed; } + + BigQueryStorageReader getStorageReader() { + return reader; + } } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageTableSource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageTableSource.java index 909a2551b299..62893447c5b7 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageTableSource.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageTableSource.java @@ -31,7 +31,6 @@ import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.ValueProvider; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.checkerframework.checker.nullness.qual.Nullable; @@ -53,7 +52,7 @@ public static BigQueryStorageTableSource create( DataFormat format, @Nullable ValueProvider> selectedFields, @Nullable ValueProvider rowRestriction, - SerializableFunction parseFn, + BigQueryStorageReaderFactory readerFactory, Coder outputCoder, BigQueryServices bqServices, boolean projectionPushdownApplied) { @@ -62,7 +61,7 @@ public static BigQueryStorageTableSource create( format, selectedFields, rowRestriction, - parseFn, + readerFactory, outputCoder, bqServices, projectionPushdownApplied); @@ -72,7 +71,7 @@ public static BigQueryStorageTableSource create( ValueProvider tableRefProvider, @Nullable ValueProvider> selectedFields, @Nullable ValueProvider rowRestriction, - SerializableFunction parseFn, + BigQueryStorageReaderFactory readerFactory, Coder outputCoder, BigQueryServices bqServices) { return new BigQueryStorageTableSource<>( @@ -80,7 +79,7 @@ public static BigQueryStorageTableSource create( null, selectedFields, rowRestriction, - parseFn, + readerFactory, outputCoder, bqServices, false); @@ -91,11 +90,11 @@ private BigQueryStorageTableSource( @Nullable DataFormat format, @Nullable ValueProvider> selectedFields, @Nullable ValueProvider rowRestriction, - SerializableFunction parseFn, + BigQueryStorageReaderFactory readerFactory, Coder outputCoder, BigQueryServices bqServices, boolean projectionPushdownApplied) { - super(format, selectedFields, rowRestriction, parseFn, outputCoder, bqServices); + super(format, selectedFields, rowRestriction, readerFactory, outputCoder, bqServices); this.tableReferenceProvider = checkNotNull(tableRefProvider, "tableRefProvider"); this.projectionPushdownApplied = projectionPushdownApplied; cachedTable = new AtomicReference<>(); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTableSource.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTableSource.java index 1b6aedf8cb17..7d89d6994922 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTableSource.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTableSource.java @@ -19,14 +19,11 @@ import com.google.api.services.bigquery.model.Table; import com.google.api.services.bigquery.model.TableReference; -import com.google.api.services.bigquery.model.TableSchema; import java.io.IOException; import java.util.concurrent.atomic.AtomicReference; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.extensions.avro.io.AvroSource; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.checkerframework.checker.nullness.qual.Nullable; @@ -40,7 +37,7 @@ static BigQueryTableSource create( BigQueryTableSourceDef tableDef, BigQueryServices bqServices, Coder coder, - SerializableFunction> readerFactory, + BigQueryReaderFactory readerFactory, boolean useAvroLogicalTypes) { return new BigQueryTableSource<>( stepUuid, tableDef, bqServices, coder, readerFactory, useAvroLogicalTypes); @@ -54,7 +51,7 @@ private BigQueryTableSource( BigQueryTableSourceDef tableDef, BigQueryServices bqServices, Coder coder, - SerializableFunction> readerFactory, + BigQueryReaderFactory readerFactory, boolean useAvroLogicalTypes) { super(stepUuid, bqServices, coder, readerFactory, useAvroLogicalTypes); this.tableDef = tableDef; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTableSourceDef.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTableSourceDef.java index a7299c6992fe..e79f2558ee2c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTableSourceDef.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryTableSourceDef.java @@ -25,10 +25,8 @@ import com.google.api.services.bigquery.model.TableSchema; import java.io.IOException; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.extensions.avro.io.AvroSource; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; import org.apache.beam.sdk.options.ValueProvider; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Strings; import org.slf4j.Logger; @@ -93,7 +91,7 @@ ValueProvider getJsonTable() { public BigQuerySourceBase toSource( String stepUuid, Coder coder, - SerializableFunction> readerFactory, + BigQueryReaderFactory readerFactory, boolean useAvroLogicalTypes) { return BigQueryTableSource.create( stepUuid, this, bqServices, coder, readerFactory, useAvroLogicalTypes); diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/RowWriterFactory.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/RowWriterFactory.java index 21bf9ae74adf..fa27ff074461 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/RowWriterFactory.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/RowWriterFactory.java @@ -22,6 +22,7 @@ import java.io.Serializable; import org.apache.avro.Schema; import org.apache.avro.io.DatumWriter; +import org.apache.beam.sdk.extensions.avro.io.AvroSink; import org.apache.beam.sdk.transforms.SerializableFunction; import org.checkerframework.checker.nullness.qual.Nullable; @@ -43,7 +44,7 @@ abstract BigQueryRowWriter createRowWriter( static RowWriterFactory tableRows( SerializableFunction toRow, SerializableFunction toFailsafeRow) { - return new TableRowWriterFactory(toRow, toFailsafeRow); + return new TableRowWriterFactory<>(toRow, toFailsafeRow); } static final class TableRowWriterFactory @@ -91,20 +92,20 @@ String getSourceFormat() { AvroRowWriterFactory avroRecords( SerializableFunction, AvroT> toAvro, SerializableFunction> writerFactory) { - return new AvroRowWriterFactory<>(toAvro, writerFactory, null, null); + return new AvroRowWriterFactory<>(toAvro, writerFactory::apply, null, null); } static final class AvroRowWriterFactory extends RowWriterFactory { private final SerializableFunction, AvroT> toAvro; - private final SerializableFunction> writerFactory; + private final AvroSink.DatumWriterFactory writerFactory; private final @Nullable SerializableFunction<@Nullable TableSchema, Schema> schemaFactory; private final @Nullable DynamicDestinations dynamicDestinations; private AvroRowWriterFactory( SerializableFunction, AvroT> toAvro, - SerializableFunction> writerFactory, + AvroSink.DatumWriterFactory writerFactory, @Nullable SerializableFunction<@Nullable TableSchema, Schema> schemaFactory, @Nullable DynamicDestinations dynamicDestinations) { this.toAvro = toAvro; diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SchemaAndElement.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SchemaAndElement.java new file mode 100644 index 000000000000..ebae513ee8db --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SchemaAndElement.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import com.google.api.services.bigquery.model.TableSchema; + +/** + * A wrapper for a record and the {@link TableSchema} representing the schema of the table (or + * query) it was generated from. + */ +public class SchemaAndElement { + private final T element; + private final TableSchema tableSchema; + + public SchemaAndElement(T record, TableSchema tableSchema) { + this.element = record; + this.tableSchema = tableSchema; + } + + public T getElement() { + return element; + } + + // getRecord is defined here so method is present when cast to SchemaAndRecord + public T getRecord() { + return element; + } + + // getRow is defined here so method is present when cast to SchemaAndRow + protected T getRow() { + return element; + } + + public TableSchema getTableSchema() { + return tableSchema; + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SchemaAndRecord.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SchemaAndRecord.java index e6811efd3d82..2716b5e34e1e 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SchemaAndRecord.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SchemaAndRecord.java @@ -24,20 +24,8 @@ * A wrapper for a {@link GenericRecord} and the {@link TableSchema} representing the schema of the * table (or query) it was generated from. */ -public class SchemaAndRecord { - private final GenericRecord record; - private final TableSchema tableSchema; - +public class SchemaAndRecord extends SchemaAndElement { public SchemaAndRecord(GenericRecord record, TableSchema tableSchema) { - this.record = record; - this.tableSchema = tableSchema; - } - - public GenericRecord getRecord() { - return record; - } - - public TableSchema getTableSchema() { - return tableSchema; + super(record, tableSchema); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SchemaAndRow.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SchemaAndRow.java new file mode 100644 index 000000000000..a79952d708d5 --- /dev/null +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/SchemaAndRow.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.gcp.bigquery; + +import com.google.api.services.bigquery.model.TableSchema; +import org.apache.beam.sdk.values.Row; + +/** + * A wrapper for an arrow {@link Row} and the {@link TableSchema} representing the schema of the + * table (or query) it was generated from. + */ +public class SchemaAndRow extends SchemaAndElement { + public SchemaAndRow(Row row, TableSchema tableSchema) { + super(row, tableSchema); + } +} diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsGenericRecord.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsGenericRecord.java index a387495863a2..db2f77981623 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsGenericRecord.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/bigquery/StorageApiDynamicDestinationsGenericRecord.java @@ -24,7 +24,6 @@ import com.google.protobuf.Message; import org.apache.avro.Schema; import org.apache.avro.generic.GenericRecord; -import org.apache.beam.sdk.extensions.avro.schemas.utils.AvroUtils; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.DatasetService; import org.apache.beam.sdk.transforms.SerializableFunction; import org.checkerframework.checker.nullness.qual.NonNull; @@ -64,13 +63,11 @@ class GenericRecordConverter implements MessageConverter { final com.google.cloud.bigquery.storage.v1.TableSchema protoTableSchema; final Schema avroSchema; - final TableSchema bqTableSchema; final Descriptor descriptor; final @javax.annotation.Nullable Descriptor cdcDescriptor; GenericRecordConverter(DestinationT destination) throws Exception { avroSchema = schemaFactory.apply(getSchema(destination)); - bqTableSchema = BigQueryUtils.toTableSchema(AvroUtils.toBeamSchema(avroSchema)); protoTableSchema = AvroGenericRecordToStorageApiProto.protoTableSchemaFromAvroSchema(avroSchema); descriptor = @@ -113,7 +110,7 @@ public TableRow toFailsafeTableRow(T element) { return formatRecordOnFailureFunction.apply(element); } else { return BigQueryUtils.convertGenericRecordToTableRow( - toGenericRecord.apply(new AvroWriteRequest<>(element, avroSchema)), bqTableSchema); + toGenericRecord.apply(new AvroWriteRequest<>(element, avroSchema))); } } diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java index ac09b11638de..705b964bb69c 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/testing/FakeJobService.java @@ -66,6 +66,7 @@ import org.apache.avro.generic.GenericDatumWriter; import org.apache.avro.generic.GenericRecord; import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.avro.io.DatumReader; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; @@ -510,14 +511,16 @@ private List readJsonTableRows(String filename) throws IOException { private List readAvroTableRows(String filename, TableSchema tableSchema) throws IOException { List tableRows = Lists.newArrayList(); - FileReader dfr = - DataFileReader.openReader(new File(filename), new GenericDatumReader<>()); - - while (dfr.hasNext()) { - GenericRecord record = dfr.next(null); - tableRows.add(BigQueryUtils.convertGenericRecordToTableRow(record, tableSchema)); + Schema readerSchema = BigQueryUtils.toGenericAvroSchema(tableSchema, true); + DatumReader reader = new GenericDatumReader<>(); + reader.setSchema(readerSchema); + try (FileReader dfr = DataFileReader.openReader(new File(filename), reader)) { + while (dfr.hasNext()) { + GenericRecord record = dfr.next(null); + tableRows.add(BigQueryUtils.convertGenericRecordToTableRow(record)); + } + return tableRows; } - return tableRows; } private long writeRows( diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java index a8aca7570b33..fa0a709f2fbc 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOReadTest.java @@ -45,14 +45,14 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutionException; -import org.apache.avro.specific.SpecificDatumReader; -import org.apache.avro.specific.SpecificRecordBase; +import org.apache.avro.reflect.Nullable; import org.apache.beam.sdk.PipelineResult; import org.apache.beam.sdk.coders.CoderRegistry; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.VarIntCoder; -import org.apache.beam.sdk.extensions.avro.io.AvroSource; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.extensions.avro.io.AvroDatumFactory; import org.apache.beam.sdk.extensions.protobuf.ByteStringCoder; import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; import org.apache.beam.sdk.io.BoundedSource; @@ -145,20 +145,12 @@ public void evaluate() throws Throwable { .withDatasetService(fakeDatasetService) .withJobService(fakeJobService); - private SerializableFunction> - datumReaderFactoryFn = - (SerializableFunction>) - input -> { - try { - String jsonSchema = BigQueryIO.JSON_FACTORY.toString(input); - return (AvroSource.DatumReaderFactory) - (writer, reader) -> - new BigQueryIO.GenericDatumTransformer<>( - BigQueryIO.TableRowParser.INSTANCE, jsonSchema, writer); - } catch (IOException e) { - return null; - } - }; + private BigQueryReaderFactory readerFactory = + BigQueryReaderFactory.avro( + null, + false, + AvroDatumFactory.generic(), + (input) -> BigQueryAvroUtils.convertGenericRecordToTableRow(input.getElement())); private static class MyData implements Serializable { private String name; @@ -650,48 +642,35 @@ public void testReadTableWithSchema() throws IOException, InterruptedException { p.run(); } - static class User extends SpecificRecordBase { - private static final org.apache.avro.Schema schema = - org.apache.avro.SchemaBuilder.record("User") - .namespace("org.apache.beam.sdk.io.gcp.bigquery.BigQueryIOReadTest$") - .fields() - .optionalString("name") - .endRecord(); + static class User { + @Nullable String name; - private String name; - - public String getName() { - return this.name; - } + User() {} - public void setName(String name) { + User(String name) { this.name = name; } - public User() {} - @Override - public void put(int i, Object v) { - if (i == 0) { - setName(((org.apache.avro.util.Utf8) v).toString()); - } + public String toString() { + return "User{" + "name='" + name + "'" + "}"; } @Override - public Object get(int i) { - if (i == 0) { - return getName(); + public boolean equals(Object o) { + if (this == o) { + return true; } - return null; + if (!(o instanceof User)) { + return false; + } + User user = (User) o; + return Objects.equals(name, user.name); } @Override - public org.apache.avro.Schema getSchema() { - return schema; - } - - public static org.apache.avro.Schema getAvroSchema() { - return schema; + public int hashCode() { + return Objects.hashCode(name); } } @@ -705,11 +684,11 @@ public void testReadTableWithReaderDatumFactory() throws IOException, Interrupte someTable.setTableReference( new TableReference() .setProjectId("non-executing-project") - .setDatasetId("schema_dataset") - .setTableId("schema_table")); + .setDatasetId("user_dataset") + .setTableId("user_table")); someTable.setNumBytes(1024L * 1024L); FakeDatasetService fakeDatasetService = new FakeDatasetService(); - fakeDatasetService.createDataset("non-executing-project", "schema_dataset", "", "", null); + fakeDatasetService.createDataset("non-executing-project", "user_dataset", "", "", null); fakeDatasetService.createTable(someTable); List records = @@ -727,24 +706,18 @@ public void testReadTableWithReaderDatumFactory() throws IOException, Interrupte .withDatasetService(fakeDatasetService); BigQueryIO.TypedRead read = - BigQueryIO.readWithDatumReader( - (AvroSource.DatumReaderFactory) - (writer, reader) -> new SpecificDatumReader<>(User.getAvroSchema())) - .from("non-executing-project:schema_dataset.schema_table") + BigQueryIO.readWithDatumReader(AvroDatumFactory.reflect(User.class)) + .from("non-executing-project:user_dataset.user_table") .withTestServices(fakeBqServices) .withoutValidation() - .withCoder(SerializableCoder.of(User.class)); + .withCoder(AvroCoder.reflect(User.class)); PCollection bqRows = p.apply(read); - User a = new User(); - a.setName("a"); - User b = new User(); - b.setName("b"); - User c = new User(); - c.setName("c"); - User d = new User(); - d.setName("d"); + User a = new User("a"); + User b = new User("b"); + User c = new User("c"); + User d = new User("d"); PAssert.that(bqRows).containsInAnyOrder(ImmutableList.of(a, b, c, d)); @@ -819,7 +792,7 @@ public void testBigQueryTableSourceInitSplit() throws Exception { String stepUuid = "testStepUuid"; BoundedSource bqSource = BigQueryTableSourceDef.create(fakeBqServices, ValueProvider.StaticValueProvider.of(table)) - .toSource(stepUuid, TableRowJsonCoder.of(), datumReaderFactoryFn, false); + .toSource(stepUuid, TableRowJsonCoder.of(), readerFactory, false); PipelineOptions options = PipelineOptionsFactory.create(); options.setTempLocation(testFolder.getRoot().getAbsolutePath()); @@ -874,7 +847,7 @@ public void testEstimatedSizeWithoutStreamingBuffer() throws Exception { String stepUuid = "testStepUuid"; BoundedSource bqSource = BigQueryTableSourceDef.create(fakeBqServices, ValueProvider.StaticValueProvider.of(table)) - .toSource(stepUuid, TableRowJsonCoder.of(), datumReaderFactoryFn, false); + .toSource(stepUuid, TableRowJsonCoder.of(), readerFactory, false); PipelineOptions options = PipelineOptionsFactory.create(); @@ -912,7 +885,7 @@ public void testEstimatedSizeWithStreamingBuffer() throws Exception { String stepUuid = "testStepUuid"; BoundedSource bqSource = BigQueryTableSourceDef.create(fakeBqServices, ValueProvider.StaticValueProvider.of(table)) - .toSource(stepUuid, TableRowJsonCoder.of(), datumReaderFactoryFn, false); + .toSource(stepUuid, TableRowJsonCoder.of(), readerFactory, false); PipelineOptions options = PipelineOptionsFactory.create(); @@ -944,7 +917,7 @@ public void testBigQueryQuerySourceEstimatedSize() throws Exception { null, null, null) - .toSource(stepUuid, TableRowJsonCoder.of(), datumReaderFactoryFn, false); + .toSource(stepUuid, TableRowJsonCoder.of(), readerFactory, false); fakeJobService.expectDryRunQuery( bqOptions.getProject(), @@ -1021,7 +994,7 @@ public void testBigQueryQuerySourceInitSplit() throws Exception { null, null, null) - .toSource(stepUuid, TableRowJsonCoder.of(), datumReaderFactoryFn, false); + .toSource(stepUuid, TableRowJsonCoder.of(), readerFactory, false); options.setTempLocation(testFolder.getRoot().getAbsolutePath()); @@ -1094,7 +1067,7 @@ public void testBigQueryQuerySourceInitSplit_NoReferencedTables() throws Excepti null, null, null) - .toSource(stepUuid, TableRowJsonCoder.of(), datumReaderFactoryFn, false); + .toSource(stepUuid, TableRowJsonCoder.of(), readerFactory, false); options.setTempLocation(testFolder.getRoot().getAbsolutePath()); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java index 2b1c111269df..0f8df034da08 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryIT.java @@ -21,9 +21,9 @@ import com.google.api.services.bigquery.model.TableRow; import java.util.Map; +import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TableRowParser; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.ExperimentalOptions; @@ -97,7 +97,7 @@ private void runBigQueryIOStorageQueryPipeline() { PCollection count = p.apply( "Query", - BigQueryIO.read(TableRowParser.INSTANCE) + BigQueryIO.readTableRows() .fromQuery("SELECT * FROM `" + options.getInputTable() + "`") .usingStandardSql() .withMethod(Method.DIRECT_READ)) @@ -112,7 +112,7 @@ public void testBigQueryStorageQuery1G() throws Exception { runBigQueryIOStorageQueryPipeline(); } - static class FailingTableRowParser implements SerializableFunction { + static class FailingTableRowParser implements SerializableFunction { public static final BigQueryIOStorageReadIT.FailingTableRowParser INSTANCE = new BigQueryIOStorageReadIT.FailingTableRowParser(); @@ -120,12 +120,12 @@ static class FailingTableRowParser implements SerializableFunction count = p.apply( "Read", - BigQueryIO.read(FailingTableRowParser.INSTANCE) + BigQueryIO.parseGenericRecords(FailingTableRowParser.INSTANCE) .fromQuery("SELECT * FROM `" + options.getInputTable() + "`") .usingStandardSql() .withMethod(Method.DIRECT_READ) diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryTest.java index 4298c367936c..c78f7c0af4a4 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageQueryTest.java @@ -58,12 +58,8 @@ import org.apache.avro.generic.GenericRecord; import org.apache.avro.io.Encoder; import org.apache.avro.io.EncoderFactory; -import org.apache.beam.sdk.coders.CoderRegistry; -import org.apache.beam.sdk.coders.KvCoder; -import org.apache.beam.sdk.extensions.protobuf.ByteStringCoder; -import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; +import org.apache.beam.sdk.extensions.avro.io.AvroDatumFactory; import org.apache.beam.sdk.io.BoundedSource; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TableRowParser; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.QueryPriority; @@ -104,6 +100,13 @@ @RunWith(JUnit4.class) public class BigQueryIOStorageQueryTest { + private static final BigQueryReaderFactory TABLE_ROW_AVRO_READER_FACTORY = + BigQueryReaderFactory.avro( + null, + false, + AvroDatumFactory.generic(), + input -> BigQueryAvroUtils.convertGenericRecordToTableRow(input.getRecord())); + private transient BigQueryOptions options; private transient TemporaryFolder testFolder = new TemporaryFolder(); private transient TestPipeline p; @@ -170,7 +173,7 @@ public void testDefaultQueryBasedSource() throws Exception { @Test public void testQueryBasedSourceWithCustomQuery() throws Exception { TypedRead typedRead = - BigQueryIO.read(new TableRowParser()) + BigQueryIO.readTableRows() .fromQuery("SELECT * FROM `google.com:project.dataset.table`") .withCoder(TableRowJsonCoder.of()); checkTypedReadQueryObject(typedRead, "SELECT * FROM `google.com:project.dataset.table`"); @@ -227,7 +230,7 @@ public void testQueryBasedSourceWithTemplateCompatibility() throws Exception { } private TypedRead getDefaultTypedRead() { - return BigQueryIO.read(new TableRowParser()) + return BigQueryIO.readTableRows() .fromQuery(DEFAULT_QUERY) .withCoder(TableRowJsonCoder.of()) .withMethod(Method.DIRECT_READ); @@ -273,21 +276,6 @@ public void testName() { assertEquals("BigQueryIO.TypedRead", getDefaultTypedRead().getName()); } - @Test - public void testCoderInference() { - SerializableFunction> parseFn = - new SerializableFunction>() { - @Override - public KV apply(SchemaAndRecord input) { - return null; - } - }; - - assertEquals( - KvCoder.of(ByteStringCoder.of(), ProtoCoder.of(ReadSession.class)), - BigQueryIO.read(parseFn).inferCoder(CoderRegistry.createDefault())); - } - @Test public void testQuerySourceEstimatedSize() throws Exception { @@ -310,7 +298,7 @@ public void testQuerySourceEstimatedSize() throws Exception { /* queryTempProject = */ null, /* kmsKey = */ null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), fakeBigQueryServices); @@ -424,7 +412,7 @@ private void doQuerySourceInitialSplit( /* queryTempProject = */ null, /* kmsKey = */ null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -526,7 +514,7 @@ public void testQuerySourceInitialSplit_NoReferencedTables() throws Exception { /* queryTempProject = */ null, /* kmsKey = */ null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -602,11 +590,10 @@ private static ReadRowsResponse createResponse( } private static final class ParseKeyValue - implements SerializableFunction> { + implements SerializableFunction> { @Override - public KV apply(SchemaAndRecord input) { - return KV.of( - input.getRecord().get("name").toString(), (Long) input.getRecord().get("number")); + public KV apply(GenericRecord record) { + return KV.of(record.get("name").toString(), (Long) record.get("number")); } } @@ -675,7 +662,7 @@ public void testQuerySourceInitialSplitWithBigQueryProject_EmptyResult() throws /* queryTempProject = */ null, /* kmsKey = */ null, DataFormat.AVRO, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -748,7 +735,7 @@ public void testQuerySourceInitialSplit_EmptyResult() throws Exception { /* queryTempProject = */ null, /* kmsKey = */ null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -773,7 +760,7 @@ public void testQuerySourceCreateReader() throws Exception { /* queryTempProject = */ null, /* kmsKey = */ null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), fakeBigQueryServices); @@ -783,7 +770,7 @@ public void testQuerySourceCreateReader() throws Exception { } public TypedRead> configureTypedRead( - SerializableFunction> parseFn) throws Exception { + SerializableFunction> parseFn) throws Exception { TableReference sourceTableRef = BigQueryHelpers.parseTableSpec("project:dataset.table"); fakeDatasetService.createDataset( @@ -843,7 +830,7 @@ public TypedRead> configureTypedRead( when(fakeStorageClient.readRows(expectedReadRowsRequest, "")) .thenReturn(new FakeBigQueryServerStream<>(readRowsResponses)); - return BigQueryIO.read(parseFn) + return BigQueryIO.parseGenericRecords(parseFn) .fromQuery(encodedQuery) .withMethod(Method.DIRECT_READ) .withTestServices( @@ -881,14 +868,13 @@ private void doReadFromBigQueryIO(boolean templateCompatibility) throws Exceptio } private static final class FailingParseKeyValue - implements SerializableFunction> { + implements SerializableFunction> { @Override - public KV apply(SchemaAndRecord input) { - if (input.getRecord().get("name").toString().equals("B")) { + public KV apply(GenericRecord record) { + if (record.get("name").toString().equals("B")) { throw new RuntimeException("ExpectedException"); } - return KV.of( - input.getRecord().get("name").toString(), (Long) input.getRecord().get("number")); + return KV.of(record.get("name").toString(), (Long) record.get("number")); } } diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java index 4e20d3634800..cd1a8629b617 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadIT.java @@ -23,9 +23,9 @@ import com.google.api.services.bigquery.model.TableRow; import com.google.cloud.bigquery.storage.v1.DataFormat; import java.util.Map; +import org.apache.avro.generic.GenericRecord; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TableRowParser; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryUtils.ConversionOptions; import org.apache.beam.sdk.options.Description; @@ -116,28 +116,27 @@ private void runBigQueryIOStorageReadPipeline() { PCollection count = p.apply( "Read", - BigQueryIO.read(TableRowParser.INSTANCE) + BigQueryIO.readTableRows(options.getDataFormat()) .from(options.getInputTable()) - .withMethod(Method.DIRECT_READ) - .withFormat(options.getDataFormat())) + .withMethod(Method.DIRECT_READ)) .apply("Count", Count.globally()); PAssert.thatSingleton(count).isEqualTo(options.getNumRecords()); p.run().waitUntilFinish(); } - static class FailingTableRowParser implements SerializableFunction { + static class FailingTableRowParser implements SerializableFunction { public static final FailingTableRowParser INSTANCE = new FailingTableRowParser(); private int parseCount = 0; @Override - public TableRow apply(SchemaAndRecord schemaAndRecord) { + public TableRow apply(GenericRecord record) { parseCount++; if (parseCount % 50 == 0) { throw new RuntimeException("ExpectedException"); } - return TableRowParser.INSTANCE.apply(schemaAndRecord); + return BigQueryAvroUtils.convertGenericRecordToTableRow(record); } } @@ -148,10 +147,9 @@ private void runBigQueryIOStorageReadPipelineErrorHandling() throws Exception { PCollection count = p.apply( "Read", - BigQueryIO.read(FailingTableRowParser.INSTANCE) + BigQueryIO.parseGenericRecords(FailingTableRowParser.INSTANCE) .from(options.getInputTable()) .withMethod(Method.DIRECT_READ) - .withFormat(options.getDataFormat()) .withErrorHandler(errorHandler)) .apply("Count", Count.globally()); @@ -211,10 +209,9 @@ private void storageReadWithSchema(DataFormat format) { PCollection tableContents = p.apply( "Read", - BigQueryIO.readTableRowsWithSchema() + BigQueryIO.readTableRowsWithSchema(options.getDataFormat()) .from(options.getInputTable()) - .withMethod(Method.DIRECT_READ) - .withFormat(options.getDataFormat())) + .withMethod(Method.DIRECT_READ)) .apply(Convert.toRows()); PAssert.thatSingleton(tableContents.apply(Count.globally())).isEqualTo(options.getNumRecords()); assertEquals(tableContents.getSchema(), multiFieldSchema); @@ -240,15 +237,12 @@ public void testBigQueryStorageReadProjectionPushdown() throws Exception { PCollection count = p.apply( "Read", - BigQueryIO.read( + BigQueryIO.parseGenericRecords( record -> BigQueryUtils.toBeamRow( - record.getRecord(), - multiFieldSchema, - ConversionOptions.builder().build())) + record, multiFieldSchema, ConversionOptions.builder().build())) .from(options.getInputTable()) .withMethod(Method.DIRECT_READ) - .withFormat(options.getDataFormat()) .withCoder(SchemaCoder.of(multiFieldSchema))) .apply(ParDo.of(new GetIntField())) .apply("Count", Count.globally()); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java index 7998bac65055..fca97fc783c5 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTableRowIT.java @@ -25,7 +25,6 @@ import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; import org.apache.beam.sdk.io.FileSystems; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TableRowParser; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method; import org.apache.beam.sdk.options.Description; import org.apache.beam.sdk.options.ExperimentalOptions; @@ -48,10 +47,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; -/** - * Integration tests for {@link BigQueryIO#readTableRows()} using {@link Method#DIRECT_READ} in - * combination with {@link TableRowParser} to generate output in {@link TableRow} form. - */ +/** Integration tests for {@link BigQueryIO#readTableRows()} using {@link Method#DIRECT_READ}. */ @RunWith(JUnit4.class) public class BigQueryIOStorageReadTableRowIT { diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java index 5b9e15f22b90..ecaad482d6a9 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOStorageReadTest.java @@ -86,14 +86,10 @@ import org.apache.avro.io.Encoder; import org.apache.avro.io.EncoderFactory; import org.apache.avro.util.Utf8; -import org.apache.beam.sdk.coders.CoderRegistry; -import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; -import org.apache.beam.sdk.extensions.protobuf.ByteStringCoder; -import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; +import org.apache.beam.sdk.extensions.avro.io.AvroDatumFactory; import org.apache.beam.sdk.io.BoundedSource; import org.apache.beam.sdk.io.BoundedSource.BoundedReader; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TableRowParser; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead.Method; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryServices.StorageClient; @@ -140,6 +136,52 @@ @RunWith(JUnit4.class) public class BigQueryIOStorageReadTest { + private static final String AVRO_SCHEMA_STRING = + "{\"namespace\": \"example.avro\",\n" + + " \"type\": \"record\",\n" + + " \"name\": \"RowRecord\",\n" + + " \"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"},\n" + + " {\"name\": \"number\", \"type\": \"long\"}\n" + + " ]\n" + + "}"; + + private static final Schema AVRO_SCHEMA = new Schema.Parser().parse(AVRO_SCHEMA_STRING); + + private static final String TRIMMED_AVRO_SCHEMA_STRING = + "{\"namespace\": \"example.avro\",\n" + + "\"type\": \"record\",\n" + + "\"name\": \"RowRecord\",\n" + + "\"fields\": [\n" + + " {\"name\": \"name\", \"type\": \"string\"}\n" + + " ]\n" + + "}"; + + private static final Schema TRIMMED_AVRO_SCHEMA = + new Schema.Parser().parse(TRIMMED_AVRO_SCHEMA_STRING); + + private static final TableSchema TABLE_SCHEMA = + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING").setMode("REQUIRED"), + new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED"))); + + private static final org.apache.arrow.vector.types.pojo.Schema ARROW_SCHEMA = + new org.apache.arrow.vector.types.pojo.Schema( + asList( + field("name", new ArrowType.Utf8()), field("number", new ArrowType.Int(64, true)))); + + private static final BigQueryStorageReaderFactory TABLE_ROW_AVRO_READER_FACTORY = + BigQueryReaderFactory.avro( + null, + false, + AvroDatumFactory.generic(), + input -> BigQueryAvroUtils.convertGenericRecordToTableRow(input.getRecord())); + + private static final BigQueryStorageReaderFactory TABLE_ROW_ARROW_READER_FACTORY = + BigQueryReaderFactory.arrow(null, input -> BigQueryUtils.toTableRow(input.getRow())); + private transient PipelineOptions options; private final transient TemporaryFolder testFolder = new TemporaryFolder(); private transient TestPipeline p; @@ -193,8 +235,7 @@ public void teardown() { @Test public void testBuildTableBasedSource() { BigQueryIO.TypedRead typedRead = - BigQueryIO.read(new TableRowParser()) - .withCoder(TableRowJsonCoder.of()) + BigQueryIO.readTableRows() .withMethod(Method.DIRECT_READ) .from("foo.com:project:dataset.table"); checkTypedReadTableObject(typedRead, "foo.com:project", "dataset", "table"); @@ -204,8 +245,7 @@ public void testBuildTableBasedSource() { @Test public void testBuildTableBasedSourceWithoutValidation() { BigQueryIO.TypedRead typedRead = - BigQueryIO.read(new TableRowParser()) - .withCoder(TableRowJsonCoder.of()) + BigQueryIO.readTableRows() .withMethod(Method.DIRECT_READ) .from("foo.com:project:dataset.table") .withoutValidation(); @@ -216,10 +256,7 @@ public void testBuildTableBasedSourceWithoutValidation() { @Test public void testBuildTableBasedSourceWithDefaultProject() { BigQueryIO.TypedRead typedRead = - BigQueryIO.read(new TableRowParser()) - .withCoder(TableRowJsonCoder.of()) - .withMethod(Method.DIRECT_READ) - .from("myDataset.myTable"); + BigQueryIO.readTableRows().withMethod(Method.DIRECT_READ).from("myDataset.myTable"); checkTypedReadTableObject(typedRead, null, "myDataset", "myTable"); } @@ -231,10 +268,7 @@ public void testBuildTableBasedSourceWithTableReference() { .setDatasetId("dataset") .setTableId("table"); BigQueryIO.TypedRead typedRead = - BigQueryIO.read(new TableRowParser()) - .withCoder(TableRowJsonCoder.of()) - .withMethod(Method.DIRECT_READ) - .from(tableReference); + BigQueryIO.readTableRows().withMethod(Method.DIRECT_READ).from(tableReference); checkTypedReadTableObject(typedRead, "foo.com:project", "dataset", "table"); } @@ -255,8 +289,7 @@ public void testBuildSourceWithTableAndFlatten() { + " which only applies to queries"); p.apply( "ReadMyTable", - BigQueryIO.read(new TableRowParser()) - .withCoder(TableRowJsonCoder.of()) + BigQueryIO.readTableRows() .withMethod(Method.DIRECT_READ) .from("foo.com:project:dataset.table") .withoutResultFlattening()); @@ -271,8 +304,7 @@ public void testBuildSourceWithTableAndSqlDialect() { + " which only applies to queries"); p.apply( "ReadMyTable", - BigQueryIO.read(new TableRowParser()) - .withCoder(TableRowJsonCoder.of()) + BigQueryIO.readTableRows() .withMethod(Method.DIRECT_READ) .from("foo.com:project:dataset.table") .usingStandardSql()); @@ -283,8 +315,7 @@ public void testBuildSourceWithTableAndSqlDialect() { public void testDisplayData() { String tableSpec = "foo.com:project:dataset.table"; BigQueryIO.TypedRead typedRead = - BigQueryIO.read(new TableRowParser()) - .withCoder(TableRowJsonCoder.of()) + BigQueryIO.readTableRows() .withMethod(Method.DIRECT_READ) .withSelectedFields(ImmutableList.of("foo", "bar")) .withProjectionPushdownApplied() @@ -299,29 +330,12 @@ public void testDisplayData() { public void testName() { assertEquals( "BigQueryIO.TypedRead", - BigQueryIO.read(new TableRowParser()) - .withCoder(TableRowJsonCoder.of()) + BigQueryIO.readTableRows() .withMethod(Method.DIRECT_READ) .from("foo.com:project:dataset.table") .getName()); } - @Test - public void testCoderInference() { - // Lambdas erase too much type information -- use an anonymous class here. - SerializableFunction> parseFn = - new SerializableFunction>() { - @Override - public KV apply(SchemaAndRecord input) { - return null; - } - }; - - assertEquals( - KvCoder.of(ByteStringCoder.of(), ProtoCoder.of(ReadSession.class)), - BigQueryIO.read(parseFn).inferCoder(CoderRegistry.createDefault())); - } - @Test public void testTableSourceEstimatedSize() throws Exception { doTableSourceEstimatedSizeTest(false); @@ -347,7 +361,7 @@ private void doTableSourceEstimatedSizeTest(boolean useStreamingBuffer) throws E ValueProvider.StaticValueProvider.of(tableRef), null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withDatasetService(fakeDatasetService)); @@ -367,7 +381,7 @@ public void testTableSourceEstimatedSize_WithBigQueryProject() throws Exception ValueProvider.StaticValueProvider.of(BigQueryHelpers.parseTableSpec("dataset.table")), null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withDatasetService(fakeDatasetService)); @@ -386,7 +400,7 @@ public void testTableSourceEstimatedSize_WithDefaultProject() throws Exception { ValueProvider.StaticValueProvider.of(BigQueryHelpers.parseTableSpec("dataset.table")), null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withDatasetService(fakeDatasetService)); @@ -408,42 +422,6 @@ public void testTableSourceInitialSplit_MaxSplitCount() throws Exception { doTableSourceInitialSplitTest(10L, 10_000); } - private static final String AVRO_SCHEMA_STRING = - "{\"namespace\": \"example.avro\",\n" - + " \"type\": \"record\",\n" - + " \"name\": \"RowRecord\",\n" - + " \"fields\": [\n" - + " {\"name\": \"name\", \"type\": \"string\"},\n" - + " {\"name\": \"number\", \"type\": \"long\"}\n" - + " ]\n" - + "}"; - - private static final Schema AVRO_SCHEMA = new Schema.Parser().parse(AVRO_SCHEMA_STRING); - - private static final String TRIMMED_AVRO_SCHEMA_STRING = - "{\"namespace\": \"example.avro\",\n" - + "\"type\": \"record\",\n" - + "\"name\": \"RowRecord\",\n" - + "\"fields\": [\n" - + " {\"name\": \"name\", \"type\": \"string\"}\n" - + " ]\n" - + "}"; - - private static final Schema TRIMMED_AVRO_SCHEMA = - new Schema.Parser().parse(TRIMMED_AVRO_SCHEMA_STRING); - - private static final TableSchema TABLE_SCHEMA = - new TableSchema() - .setFields( - ImmutableList.of( - new TableFieldSchema().setName("name").setType("STRING").setMode("REQUIRED"), - new TableFieldSchema().setName("number").setType("INTEGER").setMode("REQUIRED"))); - - private static final org.apache.arrow.vector.types.pojo.Schema ARROW_SCHEMA = - new org.apache.arrow.vector.types.pojo.Schema( - asList( - field("name", new ArrowType.Utf8()), field("number", new ArrowType.Int(64, true)))); - private void doTableSourceInitialSplitTest(long bundleSize, int streamCount) throws Exception { fakeDatasetService.createDataset("foo.com:project", "dataset", "", "", null); TableReference tableRef = BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table"); @@ -479,7 +457,7 @@ private void doTableSourceInitialSplitTest(long bundleSize, int streamCount) thr ValueProvider.StaticValueProvider.of(tableRef), null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -527,7 +505,7 @@ public void testTableSourceInitialSplit_WithSelectedFieldsAndRowRestriction() th ValueProvider.StaticValueProvider.of(tableRef), StaticValueProvider.of(Lists.newArrayList("name")), StaticValueProvider.of("number > 5"), - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -573,7 +551,7 @@ public void testTableSourceInitialSplit_WithDefaultProject() throws Exception { ValueProvider.StaticValueProvider.of(BigQueryHelpers.parseTableSpec("dataset.table")), null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -615,7 +593,7 @@ public void testTableSourceInitialSplit_EmptyTable() throws Exception { ValueProvider.StaticValueProvider.of(tableRef), null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -633,7 +611,7 @@ public void testTableSourceCreateReader() throws Exception { BigQueryHelpers.parseTableSpec("foo.com:project:dataset.table")), null, null, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withDatasetService(fakeDatasetService)); @@ -746,13 +724,12 @@ private ReadRowsResponse createResponseArrow( @Test public void testStreamSourceEstimatedSizeBytes() throws Exception { - BigQueryStorageStreamSource streamSource = BigQueryStorageStreamSource.create( ReadSession.getDefaultInstance(), ReadStream.getDefaultInstance(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices()); @@ -767,7 +744,7 @@ public void testStreamSourceSplit() throws Exception { ReadSession.getDefaultInstance(), ReadStream.getDefaultInstance(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices()); @@ -796,15 +773,16 @@ public void testSplitReadStreamAtFraction() throws IOException { readSession, ReadStream.newBuilder().setName("readStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); PipelineOptions options = PipelineOptionsFactory.fromArgs("--enableStorageReadApiV2").create(); - BigQueryStorageStreamReader reader = streamSource.createReader(options); - reader.start(); - // Beam does not split storage read api v2 stream - assertNull(reader.splitAtFraction(0.5)); + try (BigQueryStorageStreamReader reader = streamSource.createReader(options)) { + reader.start(); + // Beam does not split storage read api v2 stream + assertNull(reader.splitAtFraction(0.5)); + } } @Test @@ -839,14 +817,15 @@ public void testReadFromStreamSource() throws Exception { readSession, ReadStream.newBuilder().setName("readStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); List rows = new ArrayList<>(); - BoundedReader reader = streamSource.createReader(options); - for (boolean hasNext = reader.start(); hasNext; hasNext = reader.advance()) { - rows.add(reader.getCurrent()); + try (BoundedReader reader = streamSource.createReader(options)) { + for (boolean hasNext = reader.start(); hasNext; hasNext = reader.advance()) { + rows.add(reader.getCurrent()); + } } System.out.println("Rows: " + rows); @@ -895,7 +874,7 @@ public void testFractionConsumed() throws Exception { readSession, ReadStream.newBuilder().setName("readStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -980,7 +959,7 @@ public void testFractionConsumedWithSplit() throws Exception { readSession, ReadStream.newBuilder().setName("parentStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -1065,7 +1044,7 @@ public void testStreamSourceSplitAtFractionSucceeds() throws Exception { .build(), ReadStream.newBuilder().setName("parentStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -1202,7 +1181,7 @@ public void testStreamSourceSplitAtFractionRepeated() throws Exception { .build(), readStreams.get(0), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -1266,7 +1245,7 @@ public void testStreamSourceSplitAtFractionFailsWhenSplitIsNotPossible() throws .build(), ReadStream.newBuilder().setName("parentStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -1355,7 +1334,7 @@ public void testStreamSourceSplitAtFractionFailsWhenParentIsPastSplitPoint() thr .build(), ReadStream.newBuilder().setName("parentStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -1379,13 +1358,21 @@ public void testStreamSourceSplitAtFractionFailsWhenParentIsPastSplitPoint() thr assertFalse(parent.advance()); } - private static final class ParseKeyValue - implements SerializableFunction> { + private static final class ParseAvroKeyValue + implements SerializableFunction> { @Override - public KV apply(SchemaAndRecord input) { - return KV.of( - input.getRecord().get("name").toString(), (Long) input.getRecord().get("number")); + public KV apply(GenericRecord record) { + return KV.of(record.get("name").toString(), (Long) record.get("number")); + } + } + + private static final class ParseArrowKeyValue + implements SerializableFunction> { + + @Override + public KV apply(Row row) { + return KV.of(row.getString("name"), row.getInt64("number")); } } @@ -1447,7 +1434,7 @@ public void testStreamSourceSplitAtFractionFailsWhenReaderRunning() throws Excep readSession, ReadStream.newBuilder().setName("readStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_AVRO_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -1521,10 +1508,9 @@ public void testReadFromBigQueryIO() throws Exception { PCollection> output = p.apply( - BigQueryIO.read(new ParseKeyValue()) + BigQueryIO.parseGenericRecords(new ParseAvroKeyValue()) .from("foo.com:project:dataset.table") .withMethod(Method.DIRECT_READ) - .withFormat(DataFormat.AVRO) .withTestServices( new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -1591,7 +1577,6 @@ public void testReadFromBigQueryIOWithTrimmedSchema() throws Exception { .from("foo.com:project:dataset.table") .withMethod(Method.DIRECT_READ) .withSelectedFields(Lists.newArrayList("name")) - .withFormat(DataFormat.AVRO) .withTestServices( new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -1662,7 +1647,6 @@ public void testReadFromBigQueryIOWithBeamSchema() throws Exception { .from("foo.com:project:dataset.table") .withMethod(Method.DIRECT_READ) .withSelectedFields(Lists.newArrayList("name")) - .withFormat(DataFormat.AVRO) .withTestServices( new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -1732,10 +1716,9 @@ public void testReadFromBigQueryIOArrow() throws Exception { PCollection> output = p.apply( - BigQueryIO.read(new ParseKeyValue()) + BigQueryIO.parseArrowRows(new ParseArrowKeyValue()) .from("foo.com:project:dataset.table") .withMethod(Method.DIRECT_READ) - .withFormat(DataFormat.ARROW) .withTestServices( new FakeBigQueryServices() .withDatasetService(fakeDatasetService) @@ -1750,7 +1733,6 @@ public void testReadFromBigQueryIOArrow() throws Exception { @Test public void testReadFromStreamSourceArrow() throws Exception { - ReadSession readSession = ReadSession.newBuilder() .setName("readSession") @@ -1781,7 +1763,7 @@ public void testReadFromStreamSourceArrow() throws Exception { readSession, ReadStream.newBuilder().setName("readStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_ARROW_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -1830,7 +1812,7 @@ public void testFractionConsumedArrow() throws Exception { readSession, ReadStream.newBuilder().setName("readStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_ARROW_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -1912,7 +1894,7 @@ public void testFractionConsumedWithSplitArrow() throws Exception { readSession, ReadStream.newBuilder().setName("parentStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_ARROW_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -1993,7 +1975,7 @@ public void testStreamSourceSplitAtFractionSucceedsArrow() throws Exception { .build(), ReadStream.newBuilder().setName("parentStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_ARROW_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -2112,7 +2094,7 @@ public void testStreamSourceSplitAtFractionRepeatedArrow() throws Exception { .build(), readStreams.get(0), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_ARROW_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -2172,7 +2154,7 @@ public void testStreamSourceSplitAtFractionFailsWhenSplitIsNotPossibleArrow() th .build(), ReadStream.newBuilder().setName("parentStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_ARROW_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -2258,7 +2240,7 @@ public void testStreamSourceSplitAtFractionFailsWhenParentIsPastSplitPointArrow( .build(), ReadStream.newBuilder().setName("parentStream").build(), TABLE_SCHEMA, - new TableRowParser(), + TABLE_ROW_ARROW_READER_FACTORY, TableRowJsonCoder.of(), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); @@ -2293,7 +2275,7 @@ public void testActuateProjectionPushdown() { BigQueryIO.read( record -> BigQueryUtils.toBeamRow( - record.getRecord(), schema, ConversionOptions.builder().build())) + record.getElement(), schema, ConversionOptions.builder().build())) .withMethod(Method.DIRECT_READ) .withCoder(SchemaCoder.of(schema)); @@ -2319,7 +2301,7 @@ public void testReadFromQueryDoesNotSupportProjectionPushdown() { BigQueryIO.read( record -> BigQueryUtils.toBeamRow( - record.getRecord(), schema, ConversionOptions.builder().build())) + record.getElement(), schema, ConversionOptions.builder().build())) .fromQuery("SELECT bar FROM `dataset.table`") .withMethod(Method.DIRECT_READ) .withCoder(SchemaCoder.of(schema)); @@ -2356,13 +2338,16 @@ public void testReadFromBigQueryAvroObjectsMutation() throws Exception { when(fakeStorageClient.readRows(expectedRequest, "")) .thenReturn(new FakeBigQueryServerStream<>(responses)); + BigQueryStorageReaderFactory readerFactory = + BigQueryReaderFactory.avro( + null, false, AvroDatumFactory.generic(), SchemaAndElement::getRecord); BigQueryStorageStreamSource streamSource = BigQueryStorageStreamSource.create( readSession, ReadStream.newBuilder().setName("readStream").build(), TABLE_SCHEMA, - SchemaAndRecord::getRecord, - AvroCoder.of(AVRO_SCHEMA), + readerFactory, + AvroCoder.generic(AVRO_SCHEMA), new FakeBigQueryServices().withStorageClient(fakeStorageClient)); BoundedReader reader = streamSource.createReader(options); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java index 5b7b5d473190..8c15b0d6d401 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryIOTranslationTest.java @@ -24,11 +24,15 @@ import com.google.api.services.bigquery.model.Clustering; import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.cloud.bigquery.storage.v1.DataFormat; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import org.apache.beam.sdk.extensions.avro.io.AvroDatumFactory; +import org.apache.beam.sdk.extensions.avro.io.AvroSource; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.TypedRead; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; @@ -36,10 +40,14 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.options.StreamingOptions; import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.construction.TransformUpgrader; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +@RunWith(JUnit4.class) public class BigQueryIOTranslationTest { // A mapping from Read transform builder methods to the corresponding schema fields in @@ -55,8 +63,11 @@ public class BigQueryIOTranslationTest { READ_TRANSFORM_SCHEMA_MAPPING.put( "getWithTemplateCompatibility", "with_template_compatibility"); READ_TRANSFORM_SCHEMA_MAPPING.put("getBigQueryServices", "bigquery_services"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getAvroSchema", "avro_schema"); READ_TRANSFORM_SCHEMA_MAPPING.put("getParseFn", "parse_fn"); READ_TRANSFORM_SCHEMA_MAPPING.put("getDatumReaderFactory", "datum_reader_factory"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getArrowSchema", "arrow_schema"); + READ_TRANSFORM_SCHEMA_MAPPING.put("getArrowParseFn", "arrow_parse_fn"); READ_TRANSFORM_SCHEMA_MAPPING.put("getQueryPriority", "query_priority"); READ_TRANSFORM_SCHEMA_MAPPING.put("getQueryLocation", "query_location"); READ_TRANSFORM_SCHEMA_MAPPING.put("getQueryTempDataset", "query_temp_dataset"); @@ -140,9 +151,82 @@ public class BigQueryIOTranslationTest { WRITE_TRANSFORM_SCHEMA_MAPPING.put("getBadRecordErrorHandler", "bad_record_error_handler"); } + static class DummyParseFn implements SerializableFunction { + @Override + public Object apply(SchemaAndRecord input) { + return null; + } + } + + @Test + public void testReCreateReadTransformFromDeprecatedArrow() { + BigQueryIO.TypedRead readTransform = + BigQueryIO.read(new DummyParseFn()) + .withFormat(DataFormat.ARROW) + .from("dummyproject:dummydataset.dummytable") + .withMethod(TypedRead.Method.DIRECT_READ); + + BigQueryIOTranslation.BigQueryIOReadTranslator translator = + new BigQueryIOTranslation.BigQueryIOReadTranslator(); + + // old versions do not set arrow_parse_fn + Row row = + Row.fromRow(translator.toConfigRow(readTransform)) + .withFieldValue("arrow_parse_fn", null) + .build(); + + PipelineOptions options = PipelineOptionsFactory.create(); + options.as(StreamingOptions.class).setUpdateCompatibilityVersion("2.60.0"); + BigQueryIO.TypedRead readTransformFromRow = + (BigQueryIO.TypedRead) translator.fromConfigRow(row, options); + assertNotNull(readTransformFromRow.getTable()); + assertEquals("dummyproject", readTransformFromRow.getTable().getProjectId()); + assertEquals("dummydataset", readTransformFromRow.getTable().getDatasetId()); + assertEquals("dummytable", readTransformFromRow.getTable().getTableId()); + assertNotNull(readTransformFromRow.getArrowParseFn()); + assertEquals(TypedRead.Method.DIRECT_READ, readTransformFromRow.getMethod()); + } + + public static class DummyClass { + + public String name; + + @org.apache.avro.reflect.Nullable public Integer age; + } + + @Test + public void testReCreateReadTransformFromDatumReader() { + AvroSource.DatumReaderFactory readerFactory = + AvroDatumFactory.reflect(DummyClass.class); + BigQueryIO.TypedRead readTransform = + BigQueryIO.readWithDatumReader(readerFactory).from("dummyproject:dummydataset.dummytable"); + + BigQueryIOTranslation.BigQueryIOReadTranslator translator = + new BigQueryIOTranslation.BigQueryIOReadTranslator(); + + // old versions set a SerializableFunction with unused input and do not set parseFn + SerializableFunction> oldDatumFactory = + (schema) -> readerFactory; + Row row = + Row.fromRow(translator.toConfigRow(readTransform)) + .withFieldValue("datum_reader_factory", TransformUpgrader.toByteArray(oldDatumFactory)) + .withFieldValue("parse_fn", null) + .build(); + + PipelineOptions options = PipelineOptionsFactory.create(); + options.as(StreamingOptions.class).setUpdateCompatibilityVersion("2.60.0"); + BigQueryIO.TypedRead readTransformFromRow = + (BigQueryIO.TypedRead) translator.fromConfigRow(row, options); + assertNotNull(readTransformFromRow.getTable()); + assertEquals("dummyproject", readTransformFromRow.getTable().getProjectId()); + assertEquals("dummydataset", readTransformFromRow.getTable().getDatasetId()); + assertEquals("dummytable", readTransformFromRow.getTable().getTableId()); + assertTrue( + readTransformFromRow.getDatumReaderFactory() instanceof AvroSource.DatumReaderFactory); + } + @Test public void testReCreateReadTransformFromRowTable() { - // setting a subset of fields here. BigQueryIO.TypedRead readTransform = BigQueryIO.readTableRows() .from("dummyproject:dummydataset.dummytable") @@ -154,9 +238,8 @@ public void testReCreateReadTransformFromRowTable() { new BigQueryIOTranslation.BigQueryIOReadTranslator(); Row row = translator.toConfigRow(readTransform); - BigQueryIO.TypedRead readTransformFromRow = - (BigQueryIO.TypedRead) - translator.fromConfigRow(row, PipelineOptionsFactory.create()); + BigQueryIO.TypedRead readTransformFromRow = + translator.fromConfigRow(row, PipelineOptionsFactory.create()); assertNotNull(readTransformFromRow.getTable()); assertEquals("dummyproject", readTransformFromRow.getTable().getProjectId()); assertEquals("dummydataset", readTransformFromRow.getTable().getDatasetId()); @@ -166,16 +249,8 @@ public void testReCreateReadTransformFromRowTable() { assertTrue(readTransformFromRow.getWithTemplateCompatibility()); } - static class DummyParseFn implements SerializableFunction { - @Override - public Object apply(SchemaAndRecord input) { - return null; - } - } - @Test public void testReCreateReadTransformFromRowQuery() { - // setting a subset of fields here. BigQueryIO.TypedRead readTransform = BigQueryIO.read(new DummyParseFn()) .fromQuery("dummyquery") diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReaderTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReaderTest.java index d98114071698..2cef22d7d64b 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReaderTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/bigquery/BigQueryStorageReaderTest.java @@ -21,7 +21,10 @@ import static org.apache.beam.sdk.io.gcp.bigquery.BigQueryIOStorageReadTest.field; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableSchema; import com.google.cloud.bigquery.storage.v1.ArrowSchema; import com.google.cloud.bigquery.storage.v1.AvroSchema; import com.google.cloud.bigquery.storage.v1.ReadSession; @@ -32,6 +35,13 @@ import org.apache.arrow.vector.ipc.WriteChannel; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.avro.generic.GenericRecord; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.extensions.avro.coders.AvroCoder; +import org.apache.beam.sdk.extensions.avro.io.AvroDatumFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -39,10 +49,21 @@ @RunWith(JUnit4.class) public class BigQueryStorageReaderTest { + private static final Schema BEAM_SCHEMA = + Schema.builder().addStringField("name").addInt64Field("number").build(); + + private static final TableSchema TABLE_SCHEMA = + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("name").setType("STRING"), + new TableFieldSchema().setName("number").setType("INT64"))); + private static final org.apache.arrow.vector.types.pojo.Schema ARROW_SCHEMA = new org.apache.arrow.vector.types.pojo.Schema( asList( field("name", new ArrowType.Utf8()), field("number", new ArrowType.Int(64, true)))); + private static final ReadSession ARROW_READ_SESSION = ReadSession.newBuilder() .setName("readSession") @@ -51,31 +72,47 @@ public class BigQueryStorageReaderTest { .setSerializedSchema(serializeArrowSchema(ARROW_SCHEMA)) .build()) .build(); - private static final String AVRO_SCHEMA_STRING = - "{\"namespace\": \"example.avro\",\n" - + " \"type\": \"record\",\n" - + " \"name\": \"RowRecord\",\n" - + " \"fields\": [\n" - + " {\"name\": \"name\", \"type\": \"string\"},\n" - + " {\"name\": \"number\", \"type\": \"long\"}\n" - + " ]\n" - + "}"; + + private static final org.apache.avro.Schema AVRO_SCHEMA = + org.apache.avro.SchemaBuilder.builder() + .record("RowRecord") + .fields() + .name("name") + .type() + .stringType() + .noDefault() + .name("number") + .type() + .longType() + .noDefault() + .endRecord(); + private static final ReadSession AVRO_READ_SESSION = ReadSession.newBuilder() .setName("readSession") - .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA_STRING)) + .setAvroSchema(AvroSchema.newBuilder().setSchema(AVRO_SCHEMA.toString())) .build(); @Test public void bigQueryStorageReaderFactory_arrowReader() throws Exception { - BigQueryStorageReader reader = BigQueryStorageReaderFactory.getReader(ARROW_READ_SESSION); + BigQueryReaderFactory factory = + BigQueryReaderFactory.arrow(BEAM_SCHEMA, SchemaAndRow::getRow); + + BigQueryStorageReader reader = factory.getReader(TABLE_SCHEMA, ARROW_READ_SESSION); assertThat(reader, instanceOf(BigQueryStorageArrowReader.class)); + assertEquals(RowCoder.of(BEAM_SCHEMA), reader.getBadRecordCoder()); } @Test public void bigQueryStorageReaderFactory_avroReader() throws Exception { - BigQueryStorageReader reader = BigQueryStorageReaderFactory.getReader(AVRO_READ_SESSION); + AvroDatumFactory datumFactory = AvroDatumFactory.generic(); + BigQueryReaderFactory factory = + BigQueryReaderFactory.avro(AVRO_SCHEMA, false, datumFactory, SchemaAndElement::getRecord); + + BigQueryStorageReader reader = + factory.getReader(TABLE_SCHEMA, AVRO_READ_SESSION); assertThat(reader, instanceOf(BigQueryStorageAvroReader.class)); + assertEquals(AvroCoder.of(datumFactory, AVRO_SCHEMA), reader.getBadRecordCoder()); } private static ByteString serializeArrowSchema(