Skip to content

Commit

Permalink
Close a hole in type guards of property access with string literals. …
Browse files Browse the repository at this point in the history
…Add basic but sound switch statement guards over the theses same property access.
  • Loading branch information
Nevor committed Nov 5, 2014
1 parent d82f8f5 commit fd5c401
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 18 deletions.
72 changes: 60 additions & 12 deletions src/compiler/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4196,13 +4196,20 @@ module ts {
return type;
}

function keepAssignableTypes(type : Type, targetType : Type, assumeAssignable: boolean): Type {
function keepAssignablePropertyTypes(type : Type, propertyName : string, targetPropertyType : Type, assumeAssignable: boolean): Type {
if(type.flags & TypeFlags.Union) {
var types = (<UnionType>type).types;
} else {
var types = [type];
}
var remainingTypes = filter(types, t => assumeAssignable ? isTypeAssignableTo(t, targetType) : !isTypeAssignableTo(t, targetType));
var remainingTypes = filter(types, t => {
var propertyType = getTypeOfPropertyOfContextualType(t, propertyName);
if(propertyType) {
return assumeAssignable ? isTypeAssignableTo(targetPropertyType, propertyType) : !isTypeAssignableTo(propertyType, targetPropertyType);
} else {
return !assumeAssignable;
}
});
if(remainingTypes.length > 0) {
return getUnionType(remainingTypes);
}
Expand Down Expand Up @@ -4322,6 +4329,11 @@ module ts {
}
}
break;
case SyntaxKind.SwitchStatement:
if (child !== (<SwitchStatement>node).expression) {
narrowedType = narrowTypeInCaseClause(type, <SwitchStatement>node, <CaseOrDefaultClause>child);
}
break;
}
// Only use narrowed type if construct contains no assignments to variable
if (narrowedType !== type) {
Expand Down Expand Up @@ -4360,9 +4372,7 @@ module ts {
}
}

function narrowPropTypeByStringTypeEquality(type : Type, expr: BinaryExpression, assumeTrue: boolean): Type {
var left = <PropertyAccess>expr.left;
var right = expr.right;
function narrowPropTypeByStringTypeEquality(type : Type, left : PropertyAccess, right : Expression, assumeTrue: boolean): Type {
var right_t = checkExpression(right);
if (left.kind !== SyntaxKind.PropertyAccess || left.left.kind !== SyntaxKind.Identifier ||
!(right_t.flags & TypeFlags.StringLiteral) ||
Expand All @@ -4374,12 +4384,49 @@ module ts {
if (isTypeAssignableTo(right_t, t)) {
smallerType = right_t;
}
var dummyProperties: SymbolTable = {};
var dummyProperty = <TransientSymbol>createSymbol(SymbolFlags.Property | SymbolFlags.Transient, left.right.text);
dummyProperty.type = smallerType;
dummyProperties[dummyProperty.name] = dummyProperty;
var dummyType = createAnonymousType(undefined, dummyProperties, emptyArray, emptyArray, undefined, undefined);
return keepAssignableTypes(type, dummyType, assumeTrue);
var propertyName = left.right.text;
return keepAssignablePropertyTypes(type, propertyName, smallerType, assumeTrue);
}

function narrowTypeInCaseClause(type : Type, switchNode : SwitchStatement, caseClause : CaseOrDefaultClause) : Type {
var propertyAccess = <PropertyAccess>switchNode.expression;
if(switchNode.expression.kind !== SyntaxKind.PropertyAccess ||
getResolvedSymbol(<Identifier>propertyAccess.left) !== symbol) {
console.log("return from alien switch for " + symbol.name);
return type;
}
var narrowedType = type;
var remainingType = type;
var typesBeforeBreak : Type[] = [];
for (var i = 0; i < switchNode.clauses.length; i++) {
var clause = switchNode.clauses[i];
if (clause.expression) {
narrowedType = narrowPropTypeByStringTypeEquality(remainingType, <PropertyAccess>switchNode.expression, clause.expression, /* assumeTrue */ true);
typesBeforeBreak.push(narrowedType);
narrowedType = getUnionType(typesBeforeBreak);
remainingType = narrowPropTypeByStringTypeEquality(remainingType, <PropertyAccess>switchNode.expression, clause.expression, /* assumeTrue */ false);

} else {
narrowedType = remainingType;
}
console.log("clause id : " + clause.id + " while waiting for " + caseClause.id);
if (clause.id === caseClause.id) {
console.log("returning in clause : " + typeToString(narrowedType));
return narrowedType;
}
if(clause.statements && clause.statements.length > 0) {
var statements = clause.statements;
var last = statements[statements.length - 1];
if (last.kind === SyntaxKind.ReturnStatement ||
last.kind === SyntaxKind.BreakStatement) {
typesBeforeBreak = [];
}
}
}

console.log("attained default clause");

return narrowedType;
}

function narrowTypeByAnd(type: Type, expr: BinaryExpression, assumeTrue: boolean): Type {
Expand Down Expand Up @@ -4440,7 +4487,8 @@ module ts {
var operator = (<BinaryExpression>expr).operator;
if (operator === SyntaxKind.EqualsEqualsEqualsToken || operator === SyntaxKind.ExclamationEqualsEqualsToken) {
if((<BinaryExpression>expr).left.kind === SyntaxKind.PropertyAccess) {
return narrowPropTypeByStringTypeEquality(type, <BinaryExpression>expr, assumeTrue);
var binary_expr = <BinaryExpression>expr;
return narrowPropTypeByStringTypeEquality(type, <PropertyAccess>binary_expr.left, binary_expr.right, assumeTrue);
} else {
return narrowTypeByEquality(type, <BinaryExpression>expr, assumeTrue);
}
Expand Down
15 changes: 9 additions & 6 deletions src/compiler/parser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1240,15 +1240,16 @@ module ts {
}

// Parses a list of elements
function parseList<T extends Node>(kind: ParsingContext, checkForStrictMode: boolean, parseElement: () => T): NodeArray<T> {
function parseList<T extends Node>(kind: ParsingContext, checkForStrictMode: boolean, parseElement: (i:number) => T): NodeArray<T> {
var saveParsingContext = parsingContext;
parsingContext |= 1 << kind;
var result = <NodeArray<T>>[];
result.pos = getNodePos();
var saveIsInStrictMode = isInStrictMode;
var i = 0;
while (!isListTerminator(kind)) {
if (isListElement(kind, /* inErrorRecovery */ false)) {
var element = parseElement();
var element = parseElement(i++);
result.push(element);
// test elements only if we are not already in strict mode
if (!isInStrictMode && checkForStrictMode) {
Expand Down Expand Up @@ -2993,25 +2994,27 @@ module ts {
return node;
}

function parseCaseClause(): CaseOrDefaultClause {
function parseCaseClause(i: number): CaseOrDefaultClause {
var node = <CaseOrDefaultClause>createNode(SyntaxKind.CaseClause);
parseExpected(SyntaxKind.CaseKeyword);
node.expression = parseExpression();
parseExpected(SyntaxKind.ColonToken);
node.statements = parseList(ParsingContext.SwitchClauseStatements, /*checkForStrictMode*/ false, parseStatementAllowingLetDeclaration);
node.id = i;
return finishNode(node);
}

function parseDefaultClause(): CaseOrDefaultClause {
function parseDefaultClause(i: number): CaseOrDefaultClause {
var node = <CaseOrDefaultClause>createNode(SyntaxKind.DefaultClause);
parseExpected(SyntaxKind.DefaultKeyword);
parseExpected(SyntaxKind.ColonToken);
node.statements = parseList(ParsingContext.SwitchClauseStatements, /*checkForStrictMode*/ false, parseStatementAllowingLetDeclaration);
node.id = i;
return finishNode(node);
}

function parseCaseOrDefaultClause(): CaseOrDefaultClause {
return token === SyntaxKind.CaseKeyword ? parseCaseClause() : parseDefaultClause();
function parseCaseOrDefaultClause(i: number): CaseOrDefaultClause {
return token === SyntaxKind.CaseKeyword ? parseCaseClause(i) : parseDefaultClause(i);
}

function parseSwitchStatement(): SwitchStatement {
Expand Down
1 change: 1 addition & 0 deletions src/compiler/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ module ts {
export interface CaseOrDefaultClause extends Node {
expression?: Expression;
statements: NodeArray<Statement>;
id: number;
}

export interface LabeledStatement extends Statement {
Expand Down

0 comments on commit fd5c401

Please sign in to comment.