@@ -111,6 +111,7 @@ module ts {
111111 let globalTemplateStringsArrayType: ObjectType;
112112 let globalESSymbolType: ObjectType;
113113 let globalIterableType: GenericType;
114+ let globalIteratorType: GenericType;
114115 let globalIterableIteratorType: GenericType;
115116
116117 let anyArrayType: Type;
@@ -2119,7 +2120,7 @@ module ts {
21192120 // checkRightHandSideOfForOf will return undefined if the for-of expression type was
21202121 // missing properties/signatures required to get its iteratedType (like
21212122 // [Symbol.iterator] or next). This may be because we accessed properties from anyType,
2122- // or it may have led to an error inside getIteratedType .
2123+ // or it may have led to an error inside getElementTypeFromIterable .
21232124 return checkRightHandSideOfForOf((<ForOfStatement>declaration.parent.parent).expression) || anyType;
21242125 }
21252126 if (isBindingPattern(declaration.parent)) {
@@ -5854,7 +5855,7 @@ module ts {
58545855 let index = indexOf(arrayLiteral.elements, node);
58555856 return getTypeOfPropertyOfContextualType(type, "" + index)
58565857 || getIndexTypeOfContextualType(type, IndexKind.Number)
5857- || (languageVersion >= ScriptTarget.ES6 ? getIteratedType (type, /*expressionForError*/ undefined) : undefined);
5858+ || (languageVersion >= ScriptTarget.ES6 ? getElementTypeFromIterable (type, /*expressionForError*/ undefined) : undefined);
58585859 }
58595860 return undefined;
58605861 }
@@ -6041,7 +6042,7 @@ module ts {
60416042 // if there is no index type / iterated type.
60426043 let restArrayType = checkExpression((<SpreadElementExpression>e).expression, contextualMapper);
60436044 let restElementType = getIndexTypeOfType(restArrayType, IndexKind.Number) ||
6044- (languageVersion >= ScriptTarget.ES6 ? getIteratedType (restArrayType, /*expressionForError*/ undefined) : undefined);
6045+ (languageVersion >= ScriptTarget.ES6 ? getElementTypeFromIterable (restArrayType, /*expressionForError*/ undefined) : undefined);
60456046
60466047 if (restElementType) {
60476048 elementTypes.push(restElementType);
@@ -8188,6 +8189,22 @@ module ts {
81888189 break;
81898190 }
81908191 }
8192+
8193+ if (node.type) {
8194+ if (languageVersion >= ScriptTarget.ES6 && isSyntacticallyValidGenerator(node)) {
8195+ let returnType = getTypeFromTypeNode(node.type);
8196+ let generatorElementType = getElementTypeFromIterableIterator(returnType, /*errorNode*/ undefined) || anyType;
8197+ let iterableIteratorInstantiation = createIterableIteratorType(generatorElementType);
8198+
8199+ // Naively, one could check that IterableIterator<any> is assignable to the return type annotation.
8200+ // However, that would not catch the error in the following case.
8201+ //
8202+ // interface BadGenerator extends Iterable<number>, Iterator<string> { }
8203+ // function* g(): BadGenerator { } // Iterable and Iterator have different types!
8204+ //
8205+ checkTypeAssignableTo(iterableIteratorInstantiation, returnType, node.type);
8206+ }
8207+ }
81918208 }
81928209
81938210 checkSpecializedSignatureDeclaration(node);
@@ -9385,7 +9402,7 @@ module ts {
93859402 // iteratedType will be undefined if the rightType was missing properties/signatures
93869403 // required to get its iteratedType (like [Symbol.iterator] or next). This may be
93879404 // because we accessed properties from anyType, or it may have led to an error inside
9388- // getIteratedType .
9405+ // getElementTypeFromIterable .
93899406 if (iteratedType) {
93909407 checkTypeAssignableTo(iteratedType, leftType, varExpr, /*headMessage*/ undefined);
93919408 }
@@ -9483,30 +9500,24 @@ module ts {
94839500 * When errorNode is undefined, it means we should not report any errors.
94849501 */
94859502 function checkIteratedType(iterable: Type, errorNode: Node): Type {
9486- let iteratedType = getIteratedType (iterable, errorNode);
9503+ let elementType = getElementTypeFromIterable (iterable, errorNode);
94879504 // Now even though we have extracted the iteratedType, we will have to validate that the type
94889505 // passed in is actually an Iterable.
9489- if (errorNode && iteratedType ) {
9490- checkTypeAssignableTo(iterable, createIterableType(iteratedType ), errorNode);
9506+ if (errorNode && elementType ) {
9507+ checkTypeAssignableTo(iterable, createIterableType(elementType ), errorNode);
94919508 }
94929509
9493- return iteratedType ;
9510+ return elementType ;
94949511 }
94959512
9496- function getIteratedType (iterable: Type, errorNode: Node) {
9513+ function getElementTypeFromIterable (iterable: Type, errorNode: Node): Type {
94979514 Debug.assert(languageVersion >= ScriptTarget.ES6);
94989515 // We want to treat type as an iterable, and get the type it is an iterable of. The iterable
94999516 // must have the following structure (annotated with the names of the variables below):
95009517 //
95019518 // { // iterable
95029519 // [Symbol.iterator]: { // iteratorFunction
9503- // (): { // iterator
9504- // next: { // iteratorNextFunction
9505- // (): { // iteratorNextResult
9506- // value: T // iteratorNextValue
9507- // }
9508- // }
9509- // }
9520+ // (): Iterator<T>
95109521 // }
95119522 // }
95129523 //
@@ -9544,11 +9555,31 @@ module ts {
95449555 return undefined;
95459556 }
95469557
9547- let iterator = getUnionType(map(iteratorFunctionSignatures, getReturnTypeOfSignature));
9558+ return getElementTypeFromIterator(getUnionType(map(iteratorFunctionSignatures, getReturnTypeOfSignature)), errorNode);
9559+ }
9560+
9561+ function getElementTypeFromIterator(iterator: Type, errorNode: Node): Type {
9562+ // This function has very similar logic as getElementTypeFromIterable, except that it operates on
9563+ // Iterators instead of Iterables. Here is the structure:
9564+ //
9565+ // { // iterator
9566+ // next: { // iteratorNextFunction
9567+ // (): { // iteratorNextResult
9568+ // value: T // iteratorNextValue
9569+ // }
9570+ // }
9571+ // }
9572+ //
95489573 if (allConstituentTypesHaveKind(iterator, TypeFlags.Any)) {
95499574 return undefined;
95509575 }
95519576
9577+ // As an optimization, if the type is instantiated directly using the globalIteratorType (Iterator<number>),
9578+ // then just grab its type argument.
9579+ if ((iterator.flags & TypeFlags.Reference) && (<GenericType>iterator).target === globalIteratorType) {
9580+ return (<GenericType>iterator).typeArguments[0];
9581+ }
9582+
95529583 let iteratorNextFunction = getTypeOfPropertyOfType(iterator, "next");
95539584 if (iteratorNextFunction && allConstituentTypesHaveKind(iteratorNextFunction, TypeFlags.Any)) {
95549585 return undefined;
@@ -9578,6 +9609,21 @@ module ts {
95789609 return iteratorNextValue;
95799610 }
95809611
9612+ function getElementTypeFromIterableIterator(iterableIterator: Type, errorNode: Node): Type {
9613+ if (allConstituentTypesHaveKind(iterableIterator, TypeFlags.Any)) {
9614+ return undefined;
9615+ }
9616+
9617+ // As an optimization, if the type is instantiated directly using the globalIterableIteratorType (IterableIterator<number>),
9618+ // then just grab its type argument.
9619+ if ((iterableIterator.flags & TypeFlags.Reference) && (<GenericType>iterableIterator).target === globalIterableIteratorType) {
9620+ return (<GenericType>iterableIterator).typeArguments[0];
9621+ }
9622+
9623+ return getElementTypeFromIterable(iterableIterator, errorNode) ||
9624+ getElementTypeFromIterator(iterableIterator, errorNode);
9625+ }
9626+
95819627 /**
95829628 * This function does the following steps:
95839629 * 1. Break up arrayOrStringType (possibly a union) into its string constituents and array constituents.
@@ -12000,6 +12046,7 @@ module ts {
1200012046 globalESSymbolType = getGlobalType("Symbol");
1200112047 globalESSymbolConstructorSymbol = getGlobalValueSymbol("Symbol");
1200212048 globalIterableType = <GenericType>getGlobalType("Iterable", /*arity*/ 1);
12049+ globalIteratorType = <GenericType>getGlobalType("Iterator", /*arity*/ 1);
1200312050 globalIterableIteratorType = <GenericType>getGlobalType("IterableIterator", /*arity*/ 1);
1200412051 }
1200512052 else {
0 commit comments