聊聊flink的ParallelIteratorInputFormat

go4it發表於2019-03-01

本文主要研究一下flink的ParallelIteratorInputFormat

例項

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        DataSet<Long> dataSet = env.generateSequence(15,106)
                .setParallelism(3);
        dataSet.print();
複製程式碼
  • 這裡使用ExecutionEnvironment的generateSequence方法建立了帶NumberSequenceIterator的ParallelIteratorInputFormat

ParallelIteratorInputFormat

flink-java-1.6.2-sources.jar!/org/apache/flink/api/java/io/ParallelIteratorInputFormat.java

/**
 * An input format that generates data in parallel through a {@link SplittableIterator}.
 */
@PublicEvolving
public class ParallelIteratorInputFormat<T> extends GenericInputFormat<T> {

	private static final long serialVersionUID = 1L;

	private final SplittableIterator<T> source;

	private transient Iterator<T> splitIterator;

	public ParallelIteratorInputFormat(SplittableIterator<T> iterator) {
		this.source = iterator;
	}

	@Override
	public void open(GenericInputSplit split) throws IOException {
		super.open(split);

		this.splitIterator = this.source.getSplit(split.getSplitNumber(), split.getTotalNumberOfSplits());
	}

	@Override
	public boolean reachedEnd() {
		return !this.splitIterator.hasNext();
	}

	@Override
	public T nextRecord(T reuse) {
		return this.splitIterator.next();
	}
}
複製程式碼
  • ParallelIteratorInputFormat繼承了GenericInputFormat類,而GenericInputFormat類底下還有其他四個子類,分別是CRowValuesInputFormat、CollectionInputFormat、IteratorInputFormat、ValuesInputFormat,它們有一個共同的特點就是都實現了NonParallelInput介面

NonParallelInput

flink-core-1.6.2-sources.jar!/org/apache/flink/api/common/io/NonParallelInput.java

/**
 * This interface acts as a marker for input formats for inputs which cannot be split.
 * Data sources with a non-parallel input formats are always executed with a parallelism
 * of one.
 * 
 * @see InputFormat
 */
@Public
public interface NonParallelInput {
}
複製程式碼
  • 這個介面沒有定義任何方法,僅僅是一個標識,表示該InputFormat是否支援split

GenericInputFormat.createInputSplits

flink-core-1.6.2-sources.jar!/org/apache/flink/api/common/io/GenericInputFormat.java

	@Override
	public GenericInputSplit[] createInputSplits(int numSplits) throws IOException {
		if (numSplits < 1) {
			throw new IllegalArgumentException("Number of input splits has to be at least 1.");
		}

		numSplits = (this instanceof NonParallelInput) ? 1 : numSplits;
		GenericInputSplit[] splits = new GenericInputSplit[numSplits];
		for (int i = 0; i < splits.length; i++) {
			splits[i] = new GenericInputSplit(i, numSplits);
		}
		return splits;
	}
複製程式碼
  • GenericInputFormat的createInputSplits方法對輸入的numSplits進行了限制,如果小於1則丟擲IllegalArgumentException異常,如果當前InputFormat有實現NonParallelInput介面,則將numSplits重置為1

ExecutionEnvironment.fromParallelCollection

flink-java-1.6.2-sources.jar!/org/apache/flink/api/java/ExecutionEnvironment.java

	/**
	 * Creates a new data set that contains elements in the iterator. The iterator is splittable, allowing the
	 * framework to create a parallel data source that returns the elements in the iterator.
	 *
	 * <p>Because the iterator will remain unmodified until the actual execution happens, the type of data
	 * returned by the iterator must be given explicitly in the form of the type class (this is due to the
	 * fact that the Java compiler erases the generic type information).
	 *
	 * @param iterator The iterator that produces the elements of the data set.
	 * @param type The class of the data produced by the iterator. Must not be a generic class.
	 * @return A DataSet representing the elements in the iterator.
	 *
	 * @see #fromParallelCollection(SplittableIterator, TypeInformation)
	 */
	public <X> DataSource<X> fromParallelCollection(SplittableIterator<X> iterator, Class<X> type) {
		return fromParallelCollection(iterator, TypeExtractor.getForClass(type));
	}

	/**
	 * Creates a new data set that contains elements in the iterator. The iterator is splittable, allowing the
	 * framework to create a parallel data source that returns the elements in the iterator.
	 *
	 * <p>Because the iterator will remain unmodified until the actual execution happens, the type of data
	 * returned by the iterator must be given explicitly in the form of the type information.
	 * This method is useful for cases where the type is generic. In that case, the type class
	 * (as given in {@link #fromParallelCollection(SplittableIterator, Class)} does not supply all type information.
	 *
	 * @param iterator The iterator that produces the elements of the data set.
	 * @param type The TypeInformation for the produced data set.
	 * @return A DataSet representing the elements in the iterator.
	 *
	 * @see #fromParallelCollection(SplittableIterator, Class)
	 */
	public <X> DataSource<X> fromParallelCollection(SplittableIterator<X> iterator, TypeInformation<X> type) {
		return fromParallelCollection(iterator, type, Utils.getCallLocationName());
	}

	// private helper for passing different call location names
	private <X> DataSource<X> fromParallelCollection(SplittableIterator<X> iterator, TypeInformation<X> type, String callLocationName) {
		return new DataSource<>(this, new ParallelIteratorInputFormat<>(iterator), type, callLocationName);
	}

	/**
	 * Creates a new data set that contains a sequence of numbers. The data set will be created in parallel,
	 * so there is no guarantee about the order of the elements.
	 *
	 * @param from The number to start at (inclusive).
	 * @param to The number to stop at (inclusive).
	 * @return A DataSet, containing all number in the {@code [from, to]} interval.
	 */
	public DataSource<Long> generateSequence(long from, long to) {
		return fromParallelCollection(new NumberSequenceIterator(from, to), BasicTypeInfo.LONG_TYPE_INFO, Utils.getCallLocationName());
	}
複製程式碼
  • ExecutionEnvironment的fromParallelCollection方法,針對SplittableIterator型別的iterator,會建立ParallelIteratorInputFormat;generateSequence方法也呼叫了fromParallelCollection方法,它建立的是NumberSequenceIterator(是SplittableIterator的子類)

SplittableIterator

flink-core-1.6.2-sources.jar!/org/apache/flink/util/SplittableIterator.java

/**
 * Abstract base class for iterators that can split themselves into multiple disjoint
 * iterators. The union of these iterators returns the original iterator values.
 *
 * @param <T> The type of elements returned by the iterator.
 */
@Public
public abstract class SplittableIterator<T> implements Iterator<T>, Serializable {

	private static final long serialVersionUID = 200377674313072307L;

	/**
	 * Splits this iterator into a number disjoint iterators.
	 * The union of these iterators returns the original iterator values.
	 *
	 * @param numPartitions The number of iterators to split into.
	 * @return An array with the split iterators.
	 */
	public abstract Iterator<T>[] split(int numPartitions);

	/**
	 * Splits this iterator into <i>n</i> partitions and returns the <i>i-th</i> partition
	 * out of those.
	 *
	 * @param num The partition to return (<i>i</i>).
	 * @param numPartitions The number of partitions to split into (<i>n</i>).
	 * @return The iterator for the partition.
	 */
	public Iterator<T> getSplit(int num, int numPartitions) {
		if (numPartitions < 1 || num < 0 || num >= numPartitions) {
			throw new IllegalArgumentException();
		}

		return split(numPartitions)[num];
	}

	/**
	 * The maximum number of splits into which this iterator can be split up.
	 *
	 * @return The maximum number of splits into which this iterator can be split up.
	 */
	public abstract int getMaximumNumberOfSplits();
}
複製程式碼
  • SplittableIterator是個抽象類,它定義了抽象方法split以及getMaximumNumberOfSplits;它有兩個實現類,分別是LongValueSequenceIterator以及NumberSequenceIterator,這裡我們看下NumberSequenceIterator

NumberSequenceIterator

flink-core-1.6.2-sources.jar!/org/apache/flink/util/NumberSequenceIterator.java

/**
 * The {@code NumberSequenceIterator} is an iterator that returns a sequence of numbers (as {@code Long})s.
 * The iterator is splittable (as defined by {@link SplittableIterator}, i.e., it can be divided into multiple
 * iterators that each return a subsequence of the number sequence.
 */
@Public
public class NumberSequenceIterator extends SplittableIterator<Long> {

	private static final long serialVersionUID = 1L;

	/** The last number returned by the iterator. */
	private final long to;

	/** The next number to be returned. */
	private long current;


	/**
	 * Creates a new splittable iterator, returning the range [from, to].
	 * Both boundaries of the interval are inclusive.
	 *
	 * @param from The first number returned by the iterator.
	 * @param to The last number returned by the iterator.
	 */
	public NumberSequenceIterator(long from, long to) {
		if (from > to) {
			throw new IllegalArgumentException("The `to` value must not be smaller than the `from` value.");
		}

		this.current = from;
		this.to = to;
	}


	@Override
	public boolean hasNext() {
		return current <= to;
	}

	@Override
	public Long next() {
		if (current <= to) {
			return current++;
		} else {
			throw new NoSuchElementException();
		}
	}

	@Override
	public NumberSequenceIterator[] split(int numPartitions) {
		if (numPartitions < 1) {
			throw new IllegalArgumentException("The number of partitions must be at least 1.");
		}

		if (numPartitions == 1) {
			return new NumberSequenceIterator[] { new NumberSequenceIterator(current, to) };
		}

		// here, numPartitions >= 2 !!!

		long elementsPerSplit;

		if (to - current + 1 >= 0) {
			elementsPerSplit = (to - current + 1) / numPartitions;
		}
		else {
			// long overflow of the range.
			// we compute based on half the distance, to prevent the overflow.
			// in most cases it holds that: current < 0 and to > 0, except for: to == 0 and current == Long.MIN_VALUE
			// the later needs a special case
			final long halfDiff; // must be positive

			if (current == Long.MIN_VALUE) {
				// this means to >= 0
				halfDiff = (Long.MAX_VALUE / 2 + 1) + to / 2;
			} else {
				long posFrom = -current;
				if (posFrom > to) {
					halfDiff = to + ((posFrom - to) / 2);
				} else {
					halfDiff = posFrom + ((to - posFrom) / 2);
				}
			}
			elementsPerSplit = halfDiff / numPartitions * 2;
		}

		if (elementsPerSplit < Long.MAX_VALUE) {
			// figure out how many get one in addition
			long numWithExtra = -(elementsPerSplit * numPartitions) + to - current + 1;

			// based on rounding errors, we may have lost one)
			if (numWithExtra > numPartitions) {
				elementsPerSplit++;
				numWithExtra -= numPartitions;

				if (numWithExtra > numPartitions) {
					throw new RuntimeException("Bug in splitting logic. To much rounding loss.");
				}
			}

			NumberSequenceIterator[] iters = new NumberSequenceIterator[numPartitions];
			long curr = current;
			int i = 0;
			for (; i < numWithExtra; i++) {
				long next = curr + elementsPerSplit + 1;
				iters[i] = new NumberSequenceIterator(curr, next - 1);
				curr = next;
			}
			for (; i < numPartitions; i++) {
				long next = curr + elementsPerSplit;
				iters[i] = new NumberSequenceIterator(curr, next - 1, true);
				curr = next;
			}

			return iters;
		}
		else {
			// this can only be the case when there are two partitions
			if (numPartitions != 2) {
				throw new RuntimeException("Bug in splitting logic.");
			}

			return new NumberSequenceIterator[] {
				new NumberSequenceIterator(current, current + elementsPerSplit),
				new NumberSequenceIterator(current + elementsPerSplit, to)
			};
		}
	}

	@Override
	public int getMaximumNumberOfSplits() {
		if (to >= Integer.MAX_VALUE || current <= Integer.MIN_VALUE || to - current + 1 >= Integer.MAX_VALUE) {
			return Integer.MAX_VALUE;
		}
		else {
			return (int) (to - current + 1);
		}
	}

	//......
}
複製程式碼
  • NumberSequenceIterator的構造器提供了from及to兩個引數,它內部有一個current值,初始的時候等於from
  • split方法首先根據numPartitions,來計算elementsPerSplit,當to – current + 1 >= 0時,計算公式為(to – current + 1) / numPartitions
  • 之後根據計算出來的elementsPerSplit來計算numWithExtra,這是因為計算elementsPerSplit的時候用的是取整操作,如果每一批都按elementsPerSplit,可能存在多餘的,於是就算出這個多餘的numWithExtra,如果它大於numPartitions,則對elementsPerSplit增加1,然後對numWithExtra減去numPartitions
  • 最後就是先根據numWithExtra來迴圈分配前numWithExtra個批次,將多餘的numWithExtra平均分配給前numWithExtra個批次;numWithExtra之後到numPartitions的批次,就正常的使用from + elementsPerSplit -1來計算to
  • getMaximumNumberOfSplits則是返回可以split的最大數量,(to >= Integer.MAX_VALUE || current <= Integer.MIN_VALUE || to – current + 1 >= Integer.MAX_VALUE)的條件下返回Integer.MAX_VALUE,否則返回(int) (to – current + 1)

小結

  • GenericInputFormat類底下有五個子類,除了ParallelIteratorInputFormat外,其他的分別是CRowValuesInputFormat、CollectionInputFormat、IteratorInputFormat、ValuesInputFormat,後面這四個子類有一個共同的特點就是都實現了NonParallelInput介面
  • GenericInputFormat的createInputSplits會對輸入的numSplits進行限制,如果是NonParallelInput型別的,則強制重置為1
  • NumberSequenceIterator是SplittableIterator的一個實現類,在ExecutionEnvironment的fromParallelCollection方法,generateSequence方法(它建立的是NumberSequenceIterator),針對SplittableIterator型別的iterator,建立ParallelIteratorInputFormat;而NumberSequenceIterator的split方法,它先計算elementsPerSplit,然後計算numWithExtra,把numWithExtra均分到前面幾個批次,最後在按elementsPerSplit均分剩餘的批次

doc

相關文章