From 7d1f1d69588b771c5ec393c86976008a352ddcc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E0=A4=95=E0=A4=BE=E0=A4=B0=E0=A4=A4=E0=A5=8B=E0=A4=AB?= =?UTF-8?q?=E0=A5=8D=E0=A4=AB=E0=A5=87=E0=A4=B2=E0=A4=B8=E0=A5=8D=E0=A4=95?= =?UTF-8?q?=E0=A5=8D=E0=A4=B0=E0=A4=BF=E0=A4=AA=E0=A5=8D=E0=A4=9F=E2=84=A2?= Date: Tue, 7 Feb 2023 15:01:04 +0100 Subject: [PATCH] feat: support for SQL aggregate functions SUM, AVG, MIN, and MAX to the Repository API (#9737) * feat: Add support for SQL aggregate functions SUM, AVG, MIN, and MAX to the Repository API * rename field name to make tests work in oracle * fix the comments * update the docs * escape column name * address PR comment * format the code --- docs/repository-api.md | 24 +++++ src/common/PickKeysByType.ts | 7 ++ src/entity-manager/EntityManager.ts | 64 ++++++++++++++ src/repository/BaseEntity.ts | 45 ++++++++++ src/repository/Repository.ts | 41 +++++++++ .../aggregate-methods/entity/Post.ts | 12 +++ .../repository-aggregate-methods.ts | 87 +++++++++++++++++++ 7 files changed, 280 insertions(+) create mode 100644 src/common/PickKeysByType.ts create mode 100644 test/functional/repository/aggregate-methods/entity/Post.ts create mode 100644 test/functional/repository/aggregate-methods/repository-aggregate-methods.ts diff --git a/docs/repository-api.md b/docs/repository-api.md index f2f2b6aac0..c9d2ca899e 100644 --- a/docs/repository-api.md +++ b/docs/repository-api.md @@ -273,6 +273,30 @@ const count = await repository.count({ const count = await repository.countBy({ firstName: "Timber" }) ``` +- `sum` - Returns the sum of a numeric field for all entities that match `FindOptionsWhere`. + +```typescript +const sum = await repository.sum("age", { firstName: "Timber" }) +``` + +- `average` - Returns the average of a numeric field for all entities that match `FindOptionsWhere`. + +```typescript +const average = await repository.average("age", { firstName: "Timber" }) +``` + +- `minimum` - Returns the minimum of a numeric field for all entities that match `FindOptionsWhere`. + +```typescript +const minimum = await repository.minimum("age", { firstName: "Timber" }) +``` + +- `maximum` - Returns the maximum of a numeric field for all entities that match `FindOptionsWhere`. + +```typescript +const maximum = await repository.maximum("age", { firstName: "Timber" }) +``` + - `find` - Finds entities that match given `FindOptions`. ```typescript diff --git a/src/common/PickKeysByType.ts b/src/common/PickKeysByType.ts new file mode 100644 index 0000000000..fd4f9d24c9 --- /dev/null +++ b/src/common/PickKeysByType.ts @@ -0,0 +1,7 @@ +/** + * Pick only the keys that match the Type `U` + */ +export type PickKeysByType = string & + keyof { + [P in keyof T as T[P] extends U ? P : never]: T[P] + } diff --git a/src/entity-manager/EntityManager.ts b/src/entity-manager/EntityManager.ts index 76c315075f..6ace02a6ce 100644 --- a/src/entity-manager/EntityManager.ts +++ b/src/entity-manager/EntityManager.ts @@ -37,6 +37,7 @@ import { getMetadataArgsStorage } from "../globals" import { UpsertOptions } from "../repository/UpsertOptions" import { InstanceChecker } from "../util/InstanceChecker" import { ObjectLiteral } from "../common/ObjectLiteral" +import { PickKeysByType } from "../common/PickKeysByType" /** * Entity manager supposed to work with any entity, automatically find its repository and call its methods, @@ -1001,6 +1002,69 @@ export class EntityManager { .getCount() } + /** + * Return the SUM of a column + */ + sum( + entityClass: EntityTarget, + columnName: PickKeysByType, + where?: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.callAggregateFun(entityClass, "SUM", columnName, where) + } + + /** + * Return the AVG of a column + */ + average( + entityClass: EntityTarget, + columnName: PickKeysByType, + where?: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.callAggregateFun(entityClass, "AVG", columnName, where) + } + + /** + * Return the MIN of a column + */ + minimum( + entityClass: EntityTarget, + columnName: PickKeysByType, + where?: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.callAggregateFun(entityClass, "MIN", columnName, where) + } + + /** + * Return the MAX of a column + */ + maximum( + entityClass: EntityTarget, + columnName: PickKeysByType, + where?: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.callAggregateFun(entityClass, "MAX", columnName, where) + } + + private async callAggregateFun( + entityClass: EntityTarget, + fnName: "SUM" | "AVG" | "MIN" | "MAX", + columnName: PickKeysByType, + where: FindOptionsWhere | FindOptionsWhere[] = {}, + ): Promise { + const metadata = this.connection.getMetadata(entityClass) + const result = await this.createQueryBuilder(entityClass, metadata.name) + .setFindOptions({ where }) + .select( + `${fnName}(${this.connection.driver.escape( + String(columnName), + )})`, + fnName, + ) + .getRawOne() + return result[fnName] === null ? null : parseFloat(result[fnName]) + } + /** * Finds entities that match given find options. */ diff --git a/src/repository/BaseEntity.ts b/src/repository/BaseEntity.ts index 89e13c1a80..dc3ee4a4dd 100644 --- a/src/repository/BaseEntity.ts +++ b/src/repository/BaseEntity.ts @@ -15,6 +15,7 @@ import { ObjectUtils } from "../util/ObjectUtils" import { QueryDeepPartialEntity } from "../query-builder/QueryPartialEntity" import { UpsertOptions } from "./UpsertOptions" import { EntityTarget } from "../common/EntityTarget" +import { PickKeysByType } from "../common/PickKeysByType" /** * Base abstract entity for all entities, used in ActiveRecord patterns. @@ -408,6 +409,50 @@ export class BaseEntity { return this.getRepository().countBy(where) } + /** + * Return the SUM of a column + */ + static sum( + this: { new (): T } & typeof BaseEntity, + columnName: PickKeysByType, + where: FindOptionsWhere, + ): Promise { + return this.getRepository().sum(columnName, where) + } + + /** + * Return the AVG of a column + */ + static average( + this: { new (): T } & typeof BaseEntity, + columnName: PickKeysByType, + where: FindOptionsWhere, + ): Promise { + return this.getRepository().average(columnName, where) + } + + /** + * Return the MIN of a column + */ + static minimum( + this: { new (): T } & typeof BaseEntity, + columnName: PickKeysByType, + where: FindOptionsWhere, + ): Promise { + return this.getRepository().minimum(columnName, where) + } + + /** + * Return the MAX of a column + */ + static maximum( + this: { new (): T } & typeof BaseEntity, + columnName: PickKeysByType, + where: FindOptionsWhere, + ): Promise { + return this.getRepository().maximum(columnName, where) + } + /** * Finds entities that match given options. */ diff --git a/src/repository/Repository.ts b/src/repository/Repository.ts index dc3bb3b62a..4d78c30ef9 100644 --- a/src/repository/Repository.ts +++ b/src/repository/Repository.ts @@ -15,6 +15,7 @@ import { ObjectID } from "../driver/mongodb/typings" import { FindOptionsWhere } from "../find-options/FindOptionsWhere" import { UpsertOptions } from "./UpsertOptions" import { EntityTarget } from "../common/EntityTarget" +import { PickKeysByType } from "../common/PickKeysByType" /** * Repository is supposed to work with your entity objects. Find entities, insert, update, delete, etc. @@ -476,6 +477,46 @@ export class Repository { return this.manager.countBy(this.metadata.target, where) } + /** + * Return the SUM of a column + */ + sum( + columnName: PickKeysByType, + where?: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.manager.sum(this.metadata.target, columnName, where) + } + + /** + * Return the AVG of a column + */ + average( + columnName: PickKeysByType, + where?: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.manager.average(this.metadata.target, columnName, where) + } + + /** + * Return the MIN of a column + */ + minimum( + columnName: PickKeysByType, + where?: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.manager.minimum(this.metadata.target, columnName, where) + } + + /** + * Return the MAX of a column + */ + maximum( + columnName: PickKeysByType, + where?: FindOptionsWhere | FindOptionsWhere[], + ): Promise { + return this.manager.maximum(this.metadata.target, columnName, where) + } + /** * Finds entities that match given find options. */ diff --git a/test/functional/repository/aggregate-methods/entity/Post.ts b/test/functional/repository/aggregate-methods/entity/Post.ts new file mode 100644 index 0000000000..d922cacccb --- /dev/null +++ b/test/functional/repository/aggregate-methods/entity/Post.ts @@ -0,0 +1,12 @@ +import { Entity } from "../../../../../src/decorator/entity/Entity" +import { Column } from "../../../../../src/decorator/columns/Column" +import { PrimaryColumn } from "../../../../../src/decorator/columns/PrimaryColumn" + +@Entity() +export class Post { + @PrimaryColumn() + id: number + + @Column() + counter: number +} diff --git a/test/functional/repository/aggregate-methods/repository-aggregate-methods.ts b/test/functional/repository/aggregate-methods/repository-aggregate-methods.ts new file mode 100644 index 0000000000..89049475cc --- /dev/null +++ b/test/functional/repository/aggregate-methods/repository-aggregate-methods.ts @@ -0,0 +1,87 @@ +import "reflect-metadata" +import { + closeTestingConnections, + createTestingConnections, +} from "../../../utils/test-utils" +import { Repository } from "../../../../src/repository/Repository" +import { DataSource } from "../../../../src/data-source/DataSource" +import { Post } from "./entity/Post" +import { LessThan } from "../../../../src" +import { expect } from "chai" + +describe("repository > aggregate methods", () => { + debugger + let connections: DataSource[] + let repository: Repository + + before(async () => { + connections = await createTestingConnections({ + entities: [Post], + schemaCreate: true, + dropSchema: true, + }) + repository = connections[0].getRepository(Post) + for (let i = 0; i < 100; i++) { + const post = new Post() + post.id = i + post.counter = i + 1 + await repository.save(post) + } + }) + + after(() => closeTestingConnections(connections)) + + describe("sum", () => { + it("should return the aggregate sum", async () => { + const sum = await repository.sum("counter") + expect(sum).to.equal(5050) + }) + + it("should return null when 0 rows match the query", async () => { + const sum = await repository.sum("counter", { id: LessThan(0) }) + expect(sum).to.be.null + }) + }) + + describe("average", () => { + it("should return the aggregate average", async () => { + const average = await repository.average("counter") + expect(average).to.equal(50.5) + }) + + it("should return null when 0 rows match the query", async () => { + const average = await repository.average("counter", { + id: LessThan(0), + }) + expect(average).to.be.null + }) + }) + + describe("minimum", () => { + it("should return the aggregate minimum", async () => { + const minimum = await repository.minimum("counter") + expect(minimum).to.equal(1) + }) + + it("should return null when 0 rows match the query", async () => { + const minimum = await repository.minimum("counter", { + id: LessThan(0), + }) + expect(minimum).to.be.null + }) + }) + + describe("maximum", () => { + it("should return the aggregate maximum", async () => { + const maximum = await repository.maximum("counter") + expect(maximum).to.equal(100) + }) + + it("should return null when 0 rows match the query", async () => { + const maximum = await repository.maximum("counter", { + id: LessThan(0), + }) + expect(maximum).to.be.null + }) + }) +})