diff --git a/.gitignore b/.gitignore index 2afb532..03c9302 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.d.ts handler/node_modules handler/*.js +handler.zip handler/*.d.ts .idea node_modules \ No newline at end of file diff --git a/handler/index.ts b/handler/index.ts index e7d689e..fb82e30 100644 --- a/handler/index.ts +++ b/handler/index.ts @@ -9,6 +9,7 @@ export interface CountResolverEvent { context: any; dynamo: DynamoFilter | null; tableName: string; + indexName: string | undefined | null; } interface Context { @@ -78,14 +79,44 @@ export function makeScanInput( }; } +export function makeScanInputFromIndex( + event: CountResolverEvent, + startKey: ScanCommandInput["ExclusiveStartKey"] +): ScanCommandInput { + return { + Select: "COUNT", + TableName: event.tableName, + ExclusiveStartKey: startKey, + FilterExpression: + event.dynamo && event.dynamo.expression.length > 0 + ? event.dynamo?.expression + : undefined, + ExpressionAttributeNames: notEmptyObject(event.dynamo?.expressionNames) + ? event.dynamo?.expressionNames + : undefined, + ExpressionAttributeValues: notEmptyObject(event.dynamo?.expressionValues) + ? primitivesToString(event.dynamo?.expressionValues) + : undefined, + }; +} + export const handler = async (event: CountResolverEvent) => { debug("Incoming event data from AppSync: %o", event); const dbClient = new DynamoDB({}); + const indexName = event.indexName; + let count = 0; let startKey = undefined; while (true) { - const scanArgs = makeScanInput(event, startKey); + let scanArgs!: ScanCommandInput; + + if (!indexName) { + scanArgs = makeScanInputFromIndex(event, startKey); + } else { + scanArgs = makeScanInput(event, startKey); + } + debug("Executing the following Dynamo scan: %o", scanArgs); const res: ScanCommandOutput = await dbClient.scan(scanArgs); count += res.Count || 0; diff --git a/index.ts b/index.ts index 37f80ed..cca39fd 100644 --- a/index.ts +++ b/index.ts @@ -1,4 +1,5 @@ import { + DirectiveWrapper, TransformerNestedStack, TransformerPluginBase, } from "@aws-amplify/graphql-transformer-core"; @@ -8,34 +9,94 @@ import { TransformerSchemaVisitStepContextProvider, TransformerTransformSchemaStepContextProvider, } from "@aws-amplify/graphql-transformer-interfaces"; +import * as appsync from "@aws-cdk/aws-appsync"; +import * as iam from "@aws-cdk/aws-iam"; +import * as lambda from "@aws-cdk/aws-lambda"; +import { + DirectiveNode, + FieldDefinitionNode, + InterfaceTypeDefinitionNode, + Kind, + ListTypeNode, + NamedTypeNode, + NonNullTypeNode, + ObjectTypeDefinitionNode, + TypeNode, +} from "graphql"; import { + getBaseType, makeField, makeInputValueDefinition, makeNamedType, + ResolverResourceIDs, toCamelCase, toPascalCase, } from "graphql-transformer-common"; -import { - DirectiveNode, - FieldDefinitionNode, - ObjectTypeDefinitionNode, -} from "graphql"; -import * as lambda from "@aws-cdk/aws-lambda"; import * as path from "path"; -import * as appsync from "@aws-cdk/aws-appsync"; -import * as iam from "@aws-cdk/aws-iam"; + +export function getCountAttributeName(type: string, field: string) { + return toCamelCase([type, field, "id"]); +} + +export interface FieldCountDirectiveConfiguration { + directiveName: string; + object: ObjectTypeDefinitionNode; + field: FieldDefinitionNode; + directive: DirectiveNode; + countObject: ObjectTypeDefinitionNode; + countField: string; + countNode: FieldDefinitionNode; + resolverTypeName: string; + resolverFieldName: string; + indexName: string; +} + +const directiveName = "fieldCount"; export default class CountTransformer extends TransformerPluginBase implements TransformerPluginProvider { models: ObjectTypeDefinitionNode[]; + fields: FieldCountDirectiveConfiguration[]; constructor() { - super("count", "directive @count on OBJECT"); + super( + "count", + ` + directive @count(type: CountType) on OBJECT + directive @${directiveName}(type: CountType, countField: String, indexName: String!) on FIELD_DEFINITION + enum CountType { + scan + distinct + } +` + ); this.models = []; + this.fields = []; } + field = ( + parent: ObjectTypeDefinitionNode | InterfaceTypeDefinitionNode, + field: FieldDefinitionNode, + directive: DirectiveNode, + context: TransformerSchemaVisitStepContextProvider + ) => { + const directiveWrapped = new DirectiveWrapper(directive); + const args = directiveWrapped.getArguments({ + directiveName, + object: parent as ObjectTypeDefinitionNode, + field: field, + resolverTypeName: parent.name.value, + resolverFieldName: field.name.value, + directive, + }) as FieldCountDirectiveConfiguration; + + /// Keep track of all fields annotated with @count + validate(args, context as TransformerContextProvider); + this.fields.push(args); + }; + object = ( definition: ObjectTypeDefinitionNode, directive: DirectiveNode, @@ -46,8 +107,15 @@ export default class CountTransformer }; transformSchema = (ctx: TransformerTransformSchemaStepContextProvider) => { + const context = ctx as TransformerContextProvider; + const fields: FieldDefinitionNode[] = []; + // For each field that has been annotated with @count + for (const config of this.fields) { + ensureHasCountField(config, context); + } + // For each model that has been annotated with @count for (const model of this.models) { if (!model.directives?.find((dir) => dir.name.value === "model")) { @@ -79,6 +147,8 @@ export default class CountTransformer }; generateResolvers = (ctx: TransformerContextProvider) => { + const createdResources = new Map(); + // Path on the local filesystem to the handler zip file const HANDLER_LOCAL_PATH = path.join(__dirname, "handler.zip"); const stack: TransformerNestedStack = ctx.stackManager.createStack( @@ -118,6 +188,81 @@ export default class CountTransformer stack ); + for (const config of this.fields) { + const { + field, + countField, + object, + resolverTypeName, + resolverFieldName, + countObject, + indexName, + } = config; + + // Find the table we want to scan + const tableDataSource = ctx.dataSources.get( + countObject + ) as appsync.DynamoDbDataSource; + const table = tableDataSource.ds + .dynamoDbConfig as appsync.CfnDataSource.DynamoDBConfigProperty; + + // Allow the lambda to access this table + funcRole.addToPolicy( + new iam.PolicyStatement({ + actions: ["dynamodb:Scan"], + effect: iam.Effect.ALLOW, + resources: [ + `arn:aws:dynamodb:${table.awsRegion}:${stack.account}:table/${table.tableName}`, + ], + }) + ); + + // Create the GraphQL resolvers. + const resolverId = ResolverResourceIDs.ResolverResourceID( + config.resolverTypeName, + config.countField + ); + let resolver = createdResources.get(resolverId); + + if (resolver === undefined) { + // TODO: update function to use resolver manager + resolver = new appsync.CfnResolver( + stack, + `${field.name.value}${object.name.value}CountResolver`, + { + apiId: ctx.api.apiId, + fieldName: config.countField, + typeName: config.resolverTypeName, + kind: "UNIT", + requestMappingTemplate: ` +$util.toJson({ + "version": "2018-05-29", + "operation": "Invoke", + "payload": { + "context": $ctx, + "dynamo": $util.parseJson($util.transform.toDynamoDBFilterExpression($ctx.arguments.filter)), + "tableName": "${table.tableName}", + "index": ${indexName} + } +}) + `, + responseMappingTemplate: ` +#if( $ctx.error ) + $util.error($ctx.error.message, $ctx.error.type) +#else + $util.toJson($ctx.result) + ${config.resolverTypeName}.${config.countField}.res.vtl +#end +`, + } + ); + + createdResources.set(resolverId, resolver); + } + + resolver.pipelineConfig.functions.push(func); + } + for (const model of this.models) { // Find the table we want to scan const tableDataSource = ctx.dataSources.get( @@ -173,3 +318,167 @@ $util.toJson({ } }; } + +export function ensureHasCountField( + config: FieldCountDirectiveConfiguration, + ctx: TransformerContextProvider +) { + const { field, countNode, object } = config; + + // If fields were explicitly provided to the directive, there is nothing else to do here. + if (countNode) { + return; + } + + const countAttributeName = getCountAttributeName( + object.name.value, + field.name.value + ); + + const typeObject = ctx.output.getType( + object.name.value + ) as ObjectTypeDefinitionNode; + + if (typeObject) { + const updated = updateTypeWithCountField(typeObject, countAttributeName); + ctx.output.putType(updated); + } + + config.countField = countAttributeName; +} + +export function isNonNullType(type: TypeNode): boolean { + return type.kind === Kind.NON_NULL_TYPE; +} + +function updateTypeWithCountField( + object: ObjectTypeDefinitionNode, + countFieldName: string +): ObjectTypeDefinitionNode { + const keyFieldExists = object.fields!.some( + (f) => f.name.value === countFieldName + ); + + // If the key field already exists then do not change the input. + if (keyFieldExists) { + return object; + } + + // Create a name for the filter key + const filterInputName = toPascalCase([ + "Model", + countFieldName, + "FilterInput", + ]); + + // Add the new field to the original model + const updatedFields = [ + ...object.fields!, + makeField( + countFieldName, + [makeInputValueDefinition("filter", makeNamedType(filterInputName))], + makeNonNullType(makeNamedType("Int")), + [] + ), + ]; + + return { + ...object, + fields: updatedFields, + }; +} + +export function makeNonNullType( + type: NamedTypeNode | ListTypeNode +): NonNullTypeNode { + return { + kind: Kind.NON_NULL_TYPE, + type, + }; +} + +function validate( + config: FieldCountDirectiveConfiguration, + ctx: TransformerContextProvider +): void { + const { field } = config; + + validateIndexDirective(config); + + if (!isListType(field.type)) { + throw new Error(`@${directiveName} cannot be used on non-lists.`); + } + + config.countNode = getCountNode(config, ctx); + config.countObject = getCountTableType(config, ctx); +} + +export function validateIndexDirective( + config: FieldCountDirectiveConfiguration +) { + if ( + !config.field.directives?.find((dir) => dir.name.value === "hasMany") || + !config.field.directives?.find((dir) => dir.name.value === "manyToMany") + ) { + throw new Error( + `Any field annotated with @${config.directiveName} must also be annoted with @hasMany or @manyToMany, as it uses their connecting tables for count.` + ); + } + + if (!getModelDirective(config.object)) { + throw new Error( + `@${config.directiveName} must be on an @model object type field.` + ); + } +} +export function getModelDirective(objectType: ObjectTypeDefinitionNode) { + return objectType.directives!.find((directive) => { + return directive.name.value === "model"; + }); +} + +export function isListType(type: TypeNode): boolean { + if (type.kind === Kind.NON_NULL_TYPE) { + return isListType(type.type); + } else if (type.kind === Kind.LIST_TYPE) { + return true; + } else { + return false; + } +} + +export function getCountNode( + config: FieldCountDirectiveConfiguration, + ctx: TransformerContextProvider +) { + const { countField, object } = config; + + const fieldNode = object.fields!.find( + (objectField) => objectField.name.value === countField + ); + + if (!fieldNode) { + throw new Error(`${countField} is not a field in ${object.name.value}`); + } + + return fieldNode; +} + +export function getCountTableType( + config: FieldCountDirectiveConfiguration, + ctx: TransformerContextProvider +) { + const { field } = config; + const countFieldTypeName = getBaseType(field.type); + const countFieldType = ctx.inputDocument.definitions.find( + (d: any) => + d.kind === Kind.OBJECT_TYPE_DEFINITION && + d.name.value === countFieldTypeName + ) as ObjectTypeDefinitionNode | undefined; + + if (!countFieldType) { + throw Error(`Unknown type name on field directive`); + } + + return countFieldType; +} diff --git a/package-lock.json b/package-lock.json index 10a1f9a..eeeb4f0 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "amplify-count-directive", - "version": "1.0.0", + "version": "1.1.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "amplify-count-directive", - "version": "1.0.0", + "version": "1.1.0", "license": "GPL-3.0", "dependencies": { "@aws-amplify/graphql-model-transformer": "^0.13.1", diff --git a/test/test.ts b/test/test.ts index cd4212b..d87fc64 100644 --- a/test/test.ts +++ b/test/test.ts @@ -1,21 +1,21 @@ -import CountTransformer from "../index"; +import { ModelTransformer } from "@aws-amplify/graphql-model-transformer"; import { GraphQLTransform, validateModelSchema, } from "@aws-amplify/graphql-transformer-core"; +import Template from "@aws-amplify/graphql-transformer-core/lib/transformation/types"; +import { countResources, expect as cdkExpect } from "@aws-cdk/assert"; import * as fs from "fs"; -import * as path from "path"; import { parse } from "graphql"; -import { countResources, expect as cdkExpect } from "@aws-cdk/assert"; -import { ModelTransformer } from "@aws-amplify/graphql-model-transformer"; -import Template from "@aws-amplify/graphql-transformer-core/lib/transformation/types"; +import * as path from "path"; import { - makeScanInput, - primitivesToString, - notEmptyObject, CountResolverEvent, DynamoFilter, + makeScanInput, + notEmptyObject, + primitivesToString, } from "../handler"; +import CountTransformer from "../index"; const test_schema = fs.readFileSync( path.resolve(__dirname, "./test_schema.graphql"), @@ -52,12 +52,84 @@ type Bar @count @model { } `; +const connectionSchema = ` +type Foo @count @model { + id: ID! + string_field: String + int_field: Int + float_field: Float + bool_field: Boolean + bar_id: ID @index(name: "byBar") +} + +type Bar @count @model { + id: ID! + string_field: String + int_field: Int + float_field: Float + bool_field: Boolean + foos_count: Int + foos: [Foo] @count(field: "foos_count", indexName: "byBar") @hasMany(indexName: "byBar", fields: ["id"]) +}`; + const makeTransformer = () => new GraphQLTransform({ transformers: [new CountTransformer(), new ModelTransformer()], }); describe("cdk stack", () => { + test("field transformer fails when @model is not used on parent", () => { + const transformer = makeTransformer(); + expect(() => { + transformer.transform(` +type Foo @count @model { + id: ID! + string_field: String + int_field: Int + float_field: Float + bool_field: Boolean + bar_id: ID @index(name: "byBar") +} + +type Bar @count { + id: ID! + string_field: String + int_field: Int + float_field: Float + bool_field: Boolean + foos_count: Int + foos: [Foo] @count(field: "foos_count", indexName: "byBar") @hasMany(indexName: "byBar", fields: ["id"]) +} + `); + }).toThrow(/model/); + }); + + test("field transformer fails when @hasMany or @manyToMany is not used on field", () => { + const transformer = makeTransformer(); + expect(() => { + transformer.transform(` +type Foo @count @model { + id: ID! + string_field: String + int_field: Int + float_field: Float + bool_field: Boolean + bar_id: ID @index(name: "byBar") +} + +type Bar @count { + id: ID! + string_field: String + int_field: Int + float_field: Float + bool_field: Boolean + foos_count: Int + foos: [Foo] @count(field: "foos_count", indexName: "byBar") +} + `); + }).toThrow(/model/); + }); + test("transformer fails when @model is not used", () => { const transformer = makeTransformer(); expect(() => { @@ -120,6 +192,31 @@ function makeAppSyncEvent(dynamoFilter: DynamoFilter): CountResolverEvent { }, dynamo: dynamoFilter, tableName: "Foo-k36yt433bvewbbo5436kmt4ixa-countdev", + indexName: undefined, + }; +} + +function makeAppSyncEventWithIndex( + dynamoFilter: DynamoFilter +): CountResolverEvent { + return { + context: { + arguments: { + filter: {}, + }, + identity: null, + source: null, + result: null, + request: { headers: [], domainName: null }, + info: { fieldName: "countFoo", parentTypeName: "Query", variables: {} }, + error: null, + prev: null, + stash: {}, + outErrors: [], + }, + dynamo: dynamoFilter, + tableName: "Foo-k36yt433bvewbbo5436kmt4ixa-countdev", + indexName: undefined, }; }