diff --git a/README.md b/README.md index d19a96b..11b1c9d 100644 --- a/README.md +++ b/README.md @@ -150,7 +150,7 @@ AI._SCRIPTSCAN | N/A AI.DAGRUN | N/A AI.DAGRUN_RO | N/A AI.INFO | info and infoResetStat (for resetting stats) -AI.CONFIG * | N/A +AI.CONFIG * | configLoadBackend and configBackendsPath ### Running tests diff --git a/src/client.ts b/src/client.ts index 6677104..bd0c8b7 100644 --- a/src/client.ts +++ b/src/client.ts @@ -49,22 +49,8 @@ export class Client { }); } - public modelset(keName: string, m: Model): Promise { - const args: any[] = [keName, m.backend.toString(), m.device]; - if (m.tag !== undefined) { - args.push('TAG'); - args.push(m.tag.toString()); - } - if (m.inputs.length > 0) { - args.push('INPUTS'); - m.inputs.forEach((value) => args.push(value)); - } - if (m.outputs.length > 0) { - args.push('OUTPUTS'); - m.outputs.forEach((value) => args.push(value)); - } - args.push('BLOB'); - args.push(m.blob); + public modelset(keyName: string, m: Model): Promise { + const args: any[] = m.modelSetFlatArgs(keyName); return this._sendCommand('ai.modelset', args); } diff --git a/src/model.ts b/src/model.ts index 45f4980..d1751e0 100644 --- a/src/model.ts +++ b/src/model.ts @@ -11,14 +11,32 @@ export class Model { * @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow models) * @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow models) * @param blob - the Protobuf-serialized model + * @param batchsize - when provided with an batchsize that is greater than 0, the engine will batch incoming requests from multiple clients that use the model with input tensors of the same shape. + * @param minbatchsize - when provided with an minbatchsize that is greater than 0, the engine will postpone calls to AI.MODELRUN until the batch's size had reached minbatchsize */ - constructor(backend: Backend, device: string, inputs: string[], outputs: string[], blob: Buffer | undefined) { + constructor( + backend: Backend, + device: string, + inputs: string[], + outputs: string[], + blob: Buffer | undefined, + batchsize?: number, + minbatchsize?: number, + ) { this._backend = backend; this._device = device; this._inputs = inputs; this._outputs = outputs; this._blob = blob; this._tag = undefined; + this._batchsize = batchsize || 0; + if (this._batchsize < 0) { + this._batchsize = 0; + } + this._minbatchsize = minbatchsize || 0; + if (this._minbatchsize < 0) { + this._minbatchsize = 0; + } } // tag is an optional string for tagging the model such as a version number or any arbitrary identifier @@ -86,14 +104,39 @@ export class Model { this._blob = value; } + private _batchsize: number; + + get batchsize(): number { + return this._batchsize; + } + + set batchsize(value: number) { + this._batchsize = value; + } + + private _minbatchsize: number; + + get minbatchsize(): number { + return this._minbatchsize; + } + + set minbatchsize(value: number) { + this._minbatchsize = value; + } + static NewModelFromModelGetReply(reply: any[]) { let backend = null; let device = null; let tag = null; let blob = null; + let batchsize: number = 0; + let minbatchsize: number = 0; + const inputs: string[] = []; + const outputs: string[] = []; for (let i = 0; i < reply.length; i += 2) { const key = reply[i]; const obj = reply[i + 1]; + switch (key.toString()) { case 'backend': backend = BackendMap[obj.toString()]; @@ -106,9 +149,20 @@ export class Model { tag = obj.toString(); break; case 'blob': - // blob = obj; blob = Buffer.from(obj); break; + case 'batchsize': + batchsize = parseInt(obj.toString(), 10); + break; + case 'minbatchsize': + minbatchsize = parseInt(obj.toString(), 10); + break; + case 'inputs': + obj.forEach((input) => inputs.push(input)); + break; + case 'outputs': + obj.forEach((output) => outputs.push(output)); + break; } } if (backend == null || device == null || blob == null) { @@ -126,10 +180,37 @@ export class Model { 'AI.MODELGET reply did not had the full elements to build the Model. Missing ' + missingArr.join(',') + '.', ); } - const model = new Model(backend, device, [], [], blob); + const model = new Model(backend, device, inputs, outputs, blob, batchsize, minbatchsize); if (tag !== null) { model.tag = tag; } return model; } + + modelSetFlatArgs(keyName: string) { + const args: any[] = [keyName, this.backend.toString(), this.device]; + if (this.tag !== undefined) { + args.push('TAG'); + args.push(this.tag.toString()); + } + if (this.batchsize > 0) { + args.push('BATCHSIZE'); + args.push(this.batchsize); + if (this.minbatchsize > 0) { + args.push('MINBATCHSIZE'); + args.push(this.minbatchsize); + } + } + if (this.inputs.length > 0) { + args.push('INPUTS'); + this.inputs.forEach((value) => args.push(value)); + } + if (this.outputs.length > 0) { + args.push('OUTPUTS'); + this.outputs.forEach((value) => args.push(value)); + } + args.push('BLOB'); + args.push(this.blob); + return args; + } } diff --git a/tests/test_client.ts b/tests/test_client.ts index 5f45b45..a68cb09 100644 --- a/tests/test_client.ts +++ b/tests/test_client.ts @@ -331,13 +331,77 @@ it( const aiclient = new Client(nativeClient); const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb'); - const model = new Model(Backend.TF, 'CPU', ['a', 'b'], ['c'], modelBlob); + const inputs: string[] = ['a', 'b']; + const outputs: string[] = ['c']; + const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob); model.tag = 'test_tag'; const resultModelSet = await aiclient.modelset('mymodel', model); expect(resultModelSet).to.equal('OK'); - const modelOut = await aiclient.modelget('mymodel'); + const modelOut: Model = await aiclient.modelget('mymodel'); expect(modelOut.blob.toString()).to.equal(modelBlob.toString()); + for (let index = 0; index < modelOut.outputs.length; index++) { + expect(modelOut.outputs[index]).to.equal(outputs[index]); + expect(modelOut.outputs[index]).to.equal(model.outputs[index]); + } + for (let index = 0; index < modelOut.inputs.length; index++) { + expect(modelOut.inputs[index]).to.equal(inputs[index]); + expect(modelOut.inputs[index]).to.equal(model.inputs[index]); + } + expect(modelOut.batchsize).to.equal(model.batchsize); + expect(modelOut.minbatchsize).to.equal(model.minbatchsize); + aiclient.end(true); + }), +); + +it( + 'ai.modelget batching positive testing', + mochaAsync(async () => { + const nativeClient = createClient(); + const aiclient = new Client(nativeClient); + + const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb'); + const inputs: string[] = ['a', 'b']; + const outputs: string[] = ['c']; + const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob); + model.tag = 'test_tag'; + model.batchsize = 100; + model.minbatchsize = 5; + const resultModelSet = await aiclient.modelset('mymodel-batching', model); + expect(resultModelSet).to.equal('OK'); + const modelOut: Model = await aiclient.modelget('mymodel-batching'); + const resultModelSet2 = await aiclient.modelset('mymodel-batching-loop', modelOut); + expect(resultModelSet2).to.equal('OK'); + const modelOut2: Model = await aiclient.modelget('mymodel-batching-loop'); + expect(modelOut.batchsize).to.equal(model.batchsize); + expect(modelOut.minbatchsize).to.equal(model.minbatchsize); + aiclient.end(true); + }), +); + +it( + 'ai.modelget batching via constructor positive testing', + mochaAsync(async () => { + const nativeClient = createClient(); + const aiclient = new Client(nativeClient); + + const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb'); + const inputs: string[] = ['a', 'b']; + const outputs: string[] = ['c']; + const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob, 100, 5); + model.tag = 'test_tag'; + const resultModelSet = await aiclient.modelset('mymodel-batching-t2', model); + expect(resultModelSet).to.equal('OK'); + const modelOut: Model = await aiclient.modelget('mymodel-batching-t2'); + const resultModelSet2 = await aiclient.modelset('mymodel-batching-loop-t2', modelOut); + expect(resultModelSet2).to.equal('OK'); + const modelOut2: Model = await aiclient.modelget('mymodel-batching-loop'); + expect(modelOut.batchsize).to.equal(model.batchsize); + expect(modelOut.minbatchsize).to.equal(model.minbatchsize); + + const model2 = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob, 1000); + expect(model2.batchsize).to.equal(1000); + expect(model2.minbatchsize).to.equal(0); aiclient.end(true); }), ); @@ -624,26 +688,26 @@ it( ); it( - 'ai.config positive and negative testing', - mochaAsync(async () => { - const nativeClient = createClient(); - const aiclient = new Client(nativeClient); - const result = await aiclient.configBackendsPath('/usr/lib/redis/modules/backends/'); - expect(result).to.equal('OK'); - // negative test - try { - const loadReply = await aiclient.configLoadBackend(Backend.TF, 'notexist/redisai_tensorflow.so'); - } catch (e) { - expect(e.toString()).to.equal('ReplyError: ERR error loading backend'); - } - - try { - // may throw error if backend already loaded - const loadResult = await aiclient.configLoadBackend(Backend.TF, 'redisai_tensorflow/redisai_tensorflow.so'); - expect(loadResult).to.equal('OK'); - } catch (e) { - expect(e.toString()).to.equal('ReplyError: ERR error loading backend'); - } - aiclient.end(true); - }), -); \ No newline at end of file + 'ai.config positive and negative testing', + mochaAsync(async () => { + const nativeClient = createClient(); + const aiclient = new Client(nativeClient); + const result = await aiclient.configBackendsPath('/usr/lib/redis/modules/backends/'); + expect(result).to.equal('OK'); + // negative test + try { + const loadReply = await aiclient.configLoadBackend(Backend.TF, 'notexist/redisai_tensorflow.so'); + } catch (e) { + expect(e.toString()).to.equal('ReplyError: ERR error loading backend'); + } + + try { + // may throw error if backend already loaded + const loadResult = await aiclient.configLoadBackend(Backend.TF, 'redisai_tensorflow/redisai_tensorflow.so'); + expect(loadResult).to.equal('OK'); + } catch (e) { + expect(e.toString()).to.equal('ReplyError: ERR error loading backend'); + } + aiclient.end(true); + }), +);