Skip to content

Commit

Permalink
Update inference stage to emit the propper output messages [no ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
dagardner-nv committed Dec 20, 2023
1 parent a52e534 commit b85776c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 24 deletions.
40 changes: 19 additions & 21 deletions examples/log_parsing/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from morpheus.config import PipelineModes
from morpheus.messages import InferenceMemory
from morpheus.messages import MultiInferenceMessage
from morpheus.messages import MultiInferenceNLPMessage
from morpheus.messages import MultiResponseMessage
from morpheus.messages import ResponseMemory
from morpheus.pipeline.stage_schema import StageSchema
from morpheus.stages.inference.triton_inference_stage import TritonInferenceStage
Expand Down Expand Up @@ -57,12 +57,12 @@ class TritonInferenceLogParsing(TritonInferenceWorker):
Determines whether a logits calculation is needed for the value returned by the Triton inference response.
"""

def build_output_message(self, x: MultiInferenceMessage) -> MultiInferenceMessage:
def build_output_message(self, x: MultiInferenceMessage) -> MultiResponseMessage:
seq_ids = cp.zeros((x.count, 3), dtype=cp.uint32)
seq_ids[:, 0] = cp.arange(x.mess_offset, x.mess_offset + x.count, dtype=cp.uint32)
seq_ids[:, 2] = x.seq_ids[:, 2]

memory = InferenceMemory(
memory = ResponseMemory(
count=x.count,
tensors={
'confidences': cp.zeros((x.count, self._inputs[list(self._inputs.keys())[0]].shape[1])),
Expand All @@ -71,12 +71,12 @@ def build_output_message(self, x: MultiInferenceMessage) -> MultiInferenceMessag
'seq_ids': seq_ids
})

return MultiInferenceMessage(meta=x.meta,
mess_offset=x.mess_offset,
mess_count=x.mess_count,
memory=memory,
offset=0,
count=x.count)
return MultiResponseMessage(meta=x.meta,
mess_offset=x.mess_offset,
mess_count=x.mess_count,
memory=memory,
offset=0,
count=x.count)

def _build_response(self, batch: MultiInferenceMessage, result: tritonclient.InferResult) -> ResponseMemory:

Expand Down Expand Up @@ -140,32 +140,30 @@ def supports_cpp_node(self) -> bool:
return False

def compute_schema(self, schema: StageSchema):
schema.output_schema.set_type(MultiInferenceMessage)
schema.output_schema.set_type(MultiResponseMessage)

@staticmethod
def _convert_one_response(output: MultiInferenceMessage, inf: MultiInferenceNLPMessage,
res: ResponseMemory) -> MultiInferenceMessage:
def _convert_one_response(output: MultiResponseMessage, inf: MultiInferenceMessage,
res: ResponseMemory) -> MultiResponseMessage:

output.get_input('input_ids')[inf.offset:inf.count + inf.offset, :] = inf.input_ids
output.get_input('seq_ids')[inf.offset:inf.count + inf.offset, :] = inf.seq_ids
output.input_ids[inf.offset:inf.count + inf.offset, :] = inf.input_ids
output.seq_ids[inf.offset:inf.count + inf.offset, :] = inf.seq_ids

# Two scenarios:
if (inf.mess_count == inf.count):
output.get_input('confidences')[inf.offset:inf.count + inf.offset, :] = res.get_output('confidences')
output.get_input('labels')[inf.offset:inf.count + inf.offset, :] = res.get_output('labels')
output.confidences[inf.offset:inf.count + inf.offset, :] = res.get_output('confidences')
output.labels[inf.offset:inf.count + inf.offset, :] = res.get_output('labels')
else:
assert inf.count == res.count

mess_ids = inf.seq_ids[:, 0].get().tolist()

# Out message has more reponses, so we have to do key based blending of probs
for i, idx in enumerate(mess_ids):
output.get_input('confidences')[idx, :] = cp.maximum(
output.get_input('confidences')[idx, :], res.get_output('confidences')[i, :])
output.get_input('labels')[idx, :] = cp.maximum(
output.get_input('labels')[idx, :], res.get_output('labels')[i, :])
output.confidences[idx, :] = cp.maximum(output.confidences[idx, :], res.get_output('confidences')[i, :])
output.labels[idx, :] = cp.maximum(output.labels[idx, :], res.get_output('labels')[i, :])

return MultiInferenceMessage.from_message(inf, memory=output.memory, offset=inf.offset, count=inf.mess_count)
return MultiResponseMessage.from_message(inf, memory=output.memory, offset=inf.offset, count=inf.mess_count)

def _get_worker_class(self) -> type[TritonInferenceWorker]:
return TritonInferenceLogParsing
6 changes: 3 additions & 3 deletions examples/log_parsing/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from morpheus.config import Config
from morpheus.config import PipelineModes
from morpheus.messages import MessageMeta
from morpheus.messages import MultiInferenceMessage
from morpheus.messages import MultiResponseMessage
from morpheus.pipeline.single_port_stage import SinglePortStage
from morpheus.pipeline.stage_schema import StageSchema

Expand Down Expand Up @@ -73,12 +73,12 @@ def supports_cpp_node(self):
return False

def accepted_types(self) -> typing.Tuple:
return (MultiInferenceMessage, )
return (MultiResponseMessage, )

def compute_schema(self, schema: StageSchema):
schema.output_schema.set_type(MessageMeta)

def _postprocess(self, x: MultiInferenceMessage):
def _postprocess(self, x: MultiResponseMessage):

infer_pdf = pd.DataFrame(x.seq_ids.get()).astype(int)
infer_pdf.columns = ["doc", "start", "stop"]
Expand Down

0 comments on commit b85776c

Please sign in to comment.