diff --git a/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormatV1.java b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormatV1.java index 98ef69e8..1371665b 100644 --- a/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormatV1.java +++ b/hadoop/src/main/java/org/tensorflow/hadoop/io/TFRecordFileOutputFormatV1.java @@ -19,15 +19,18 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.BytesWritable; -import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.compress.CompressionCodec; +import org.apache.hadoop.io.compress.GzipCodec; import org.apache.hadoop.mapred.FileOutputFormat; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapred.RecordWriter; import org.apache.hadoop.mapred.Reporter; import org.apache.hadoop.util.Progressable; +import org.apache.hadoop.util.ReflectionUtils; import org.tensorflow.hadoop.util.TFRecordWriter; +import java.io.DataOutputStream; import java.io.IOException; public class TFRecordFileOutputFormatV1 extends FileOutputFormat { @@ -35,11 +38,20 @@ public class TFRecordFileOutputFormatV1 extends FileOutputFormat getRecordWriter(FileSystem ignored, JobConf job, String name, Progressable progress) throws IOException { - Path file = FileOutputFormat.getTaskOutputPath(job, name); + boolean isCompressed = getCompressOutput(job); + CompressionCodec codec = null; + String extension = ""; + if (isCompressed) { + Class codecClass = getOutputCompressorClass(job, GzipCodec.class); + codec = ReflectionUtils.newInstance(codecClass, job); + extension = codec.getDefaultExtension(); + } + Path file = FileOutputFormat.getTaskOutputPath(job, name + extension); FileSystem fs = file.getFileSystem(job); - int bufferSize = TFRecordIOConf.getBufferSize(job); - final FSDataOutputStream fsdos = fs.create(file, true, bufferSize); + FSDataOutputStream fsDataOutputStream = fs.create(file, true, bufferSize); + final DataOutputStream fsdos = isCompressed ? + new DataOutputStream(codec.createOutputStream(fsDataOutputStream)) : fsDataOutputStream; final TFRecordWriter writer = new TFRecordWriter(fsdos); return new RecordWriter() { @Override @@ -55,5 +67,4 @@ public void close(Reporter reporter) } }; } - }