Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for BATCHSIZE, MINBATCHSIZE, INPUTS and OUTPUTS on AI.MODELGET #9

Merged
merged 4 commits into from
Jun 7, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ AI._SCRIPTSCAN | N/A
AI.DAGRUN | N/A
AI.DAGRUN_RO | N/A
AI.INFO | info and infoResetStat (for resetting stats)
AI.CONFIG * | N/A
AI.CONFIG * | configLoadBackend and configBackendsPath


### Running tests
Expand Down
18 changes: 2 additions & 16 deletions src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,8 @@ export class Client {
});
}

public modelset(keName: string, m: Model): Promise<any> {
const args: any[] = [keName, m.backend.toString(), m.device];
if (m.tag !== undefined) {
args.push('TAG');
args.push(m.tag.toString());
}
if (m.inputs.length > 0) {
args.push('INPUTS');
m.inputs.forEach((value) => args.push(value));
}
if (m.outputs.length > 0) {
args.push('OUTPUTS');
m.outputs.forEach((value) => args.push(value));
}
args.push('BLOB');
args.push(m.blob);
public modelset(keyName: string, m: Model): Promise<any> {
const args: any[] = m.modelSetFlatArgs(keyName);
return this._sendCommand('ai.modelset', args);
}

Expand Down
84 changes: 81 additions & 3 deletions src/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,26 @@ export class Model {
* @param inputs - one or more names of the model's input nodes (applicable only for TensorFlow models)
* @param outputs - one or more names of the model's output nodes (applicable only for TensorFlow models)
* @param blob - the Protobuf-serialized model
* @param batchsize - when provided with an batchsize that is greater than 0, the engine will batch incoming requests from multiple clients that use the model with input tensors of the same shape.
* @param minbatchsize - when provided with an minbatchsize that is greater than 0, the engine will postpone calls to AI.MODELRUN until the batch's size had reached minbatchsize
*/
constructor(backend: Backend, device: string, inputs: string[], outputs: string[], blob: Buffer | undefined) {
constructor(
backend: Backend,
device: string,
inputs: string[],
outputs: string[],
blob: Buffer | undefined,
batchsize?: number,
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
minbatchsize?: number,
) {
this._backend = backend;
this._device = device;
this._inputs = inputs;
this._outputs = outputs;
this._blob = blob;
this._tag = undefined;
this._batchsize = batchsize || 0;
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
this._minbatchsize = minbatchsize || 0;
}

// tag is an optional string for tagging the model such as a version number or any arbitrary identifier
Expand Down Expand Up @@ -86,14 +98,36 @@ export class Model {
this._blob = value;
}

get minbatchsize(): number {
return this._minbatchsize;
}

set minbatchsize(value: number) {
this._minbatchsize = value;
}
get batchsize(): number {
return this._batchsize;
}

set batchsize(value: number) {
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
this._batchsize = value;
}
private _batchsize: number;
private _minbatchsize: number;

static NewModelFromModelGetReply(reply: any[]) {
let backend = null;
let device = null;
let tag = null;
let blob = null;
let batchsize: number = 0;
let minbatchsize: number = 0;
const inputs: string[] = [];
const outputs: string[] = [];
for (let i = 0; i < reply.length; i += 2) {
const key = reply[i];
const obj = reply[i + 1];

switch (key.toString()) {
case 'backend':
backend = BackendMap[obj.toString()];
Expand All @@ -106,9 +140,26 @@ export class Model {
tag = obj.toString();
break;
case 'blob':
// blob = obj;
blob = Buffer.from(obj);
break;
case 'batchsize':
batchsize = parseInt(obj.toString(), 10);
break;
case 'minbatchsize':
minbatchsize = parseInt(obj.toString(), 10);
break;
case 'inputs':
// tslint:disable-next-line:prefer-for-of
for (let j = 0; j < obj.length; j++) {
inputs.push(obj[j].toString());
}
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
break;
case 'outputs':
// tslint:disable-next-line:prefer-for-of
for (let j = 0; j < obj.length; j++) {
outputs.push(obj[j].toString());
}
filipecosta90 marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}
if (backend == null || device == null || blob == null) {
Expand All @@ -126,10 +177,37 @@ export class Model {
'AI.MODELGET reply did not had the full elements to build the Model. Missing ' + missingArr.join(',') + '.',
);
}
const model = new Model(backend, device, [], [], blob);
const model = new Model(backend, device, inputs, outputs, blob, batchsize, minbatchsize);
if (tag !== null) {
model.tag = tag;
}
return model;
}

modelSetFlatArgs(keyName: string) {
const args: any[] = [keyName, this.backend.toString(), this.device];
if (this.tag !== undefined) {
args.push('TAG');
args.push(this.tag.toString());
}
if (this.batchsize > 0) {
args.push('BATCHSIZE');
args.push(this.batchsize);
if (this.minbatchsize > 0) {
args.push('MINBATCHSIZE');
args.push(this.minbatchsize);
}
}
if (this.inputs.length > 0) {
args.push('INPUTS');
this.inputs.forEach((value) => args.push(value));
}
if (this.outputs.length > 0) {
args.push('OUTPUTS');
this.outputs.forEach((value) => args.push(value));
}
args.push('BLOB');
args.push(this.blob);
return args;
}
}
114 changes: 89 additions & 25 deletions tests/test_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,77 @@ it(
const aiclient = new Client(nativeClient);

const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
const model = new Model(Backend.TF, 'CPU', ['a', 'b'], ['c'], modelBlob);
const inputs: string[] = ['a', 'b'];
const outputs: string[] = ['c'];
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob);
model.tag = 'test_tag';
const resultModelSet = await aiclient.modelset('mymodel', model);
expect(resultModelSet).to.equal('OK');

const modelOut = await aiclient.modelget('mymodel');
const modelOut: Model = await aiclient.modelget('mymodel');
expect(modelOut.blob.toString()).to.equal(modelBlob.toString());
for (let index = 0; index < modelOut.outputs.length; index++) {
expect(modelOut.outputs[index]).to.equal(outputs[index]);
expect(modelOut.outputs[index]).to.equal(model.outputs[index]);
}
for (let index = 0; index < modelOut.inputs.length; index++) {
expect(modelOut.inputs[index]).to.equal(inputs[index]);
expect(modelOut.inputs[index]).to.equal(model.inputs[index]);
}
expect(modelOut.batchsize).to.equal(model.batchsize);
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);
aiclient.end(true);
}),
);

it(
'ai.modelget batching positive testing',
mochaAsync(async () => {
const nativeClient = createClient();
const aiclient = new Client(nativeClient);

const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
const inputs: string[] = ['a', 'b'];
const outputs: string[] = ['c'];
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob);
model.tag = 'test_tag';
model.batchsize = 100;
model.minbatchsize = 5;
const resultModelSet = await aiclient.modelset('mymodel-batching', model);
expect(resultModelSet).to.equal('OK');
const modelOut: Model = await aiclient.modelget('mymodel-batching');
const resultModelSet2 = await aiclient.modelset('mymodel-batching-loop', modelOut);
expect(resultModelSet2).to.equal('OK');
const modelOut2: Model = await aiclient.modelget('mymodel-batching-loop');
expect(modelOut.batchsize).to.equal(model.batchsize);
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);
aiclient.end(true);
}),
);

it(
'ai.modelget batching via constructor positive testing',
mochaAsync(async () => {
const nativeClient = createClient();
const aiclient = new Client(nativeClient);

const modelBlob: Buffer = fs.readFileSync('./tests/test_data/graph.pb');
const inputs: string[] = ['a', 'b'];
const outputs: string[] = ['c'];
const model = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob, 100, 5);
model.tag = 'test_tag';
const resultModelSet = await aiclient.modelset('mymodel-batching-t2', model);
expect(resultModelSet).to.equal('OK');
const modelOut: Model = await aiclient.modelget('mymodel-batching-t2');
const resultModelSet2 = await aiclient.modelset('mymodel-batching-loop-t2', modelOut);
expect(resultModelSet2).to.equal('OK');
const modelOut2: Model = await aiclient.modelget('mymodel-batching-loop');
expect(modelOut.batchsize).to.equal(model.batchsize);
expect(modelOut.minbatchsize).to.equal(model.minbatchsize);

const model2 = new Model(Backend.TF, 'CPU', inputs, outputs, modelBlob, 1000);
expect(model2.batchsize).to.equal(1000);
expect(model2.minbatchsize).to.equal(0);
aiclient.end(true);
}),
);
Expand Down Expand Up @@ -624,26 +688,26 @@ it(
);

it(
'ai.config positive and negative testing',
mochaAsync(async () => {
const nativeClient = createClient();
const aiclient = new Client(nativeClient);
const result = await aiclient.configBackendsPath('/usr/lib/redis/modules/backends/');
expect(result).to.equal('OK');
// negative test
try {
const loadReply = await aiclient.configLoadBackend(Backend.TF, 'notexist/redisai_tensorflow.so');
} catch (e) {
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
}

try {
// may throw error if backend already loaded
const loadResult = await aiclient.configLoadBackend(Backend.TF, 'redisai_tensorflow/redisai_tensorflow.so');
expect(loadResult).to.equal('OK');
} catch (e) {
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
}
aiclient.end(true);
}),
);
'ai.config positive and negative testing',
mochaAsync(async () => {
const nativeClient = createClient();
const aiclient = new Client(nativeClient);
const result = await aiclient.configBackendsPath('/usr/lib/redis/modules/backends/');
expect(result).to.equal('OK');
// negative test
try {
const loadReply = await aiclient.configLoadBackend(Backend.TF, 'notexist/redisai_tensorflow.so');
} catch (e) {
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
}

try {
// may throw error if backend already loaded
const loadResult = await aiclient.configLoadBackend(Backend.TF, 'redisai_tensorflow/redisai_tensorflow.so');
expect(loadResult).to.equal('OK');
} catch (e) {
expect(e.toString()).to.equal('ReplyError: ERR error loading backend');
}
aiclient.end(true);
}),
);