Skip to content

Commit

Permalink
predict: formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
koush committed Nov 21, 2024
1 parent c6f4c1a commit cd0ab10
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 14 deletions.
41 changes: 28 additions & 13 deletions plugins/onnx/src/ort/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,20 @@
"scrypted_yolov8n_320",
]


def parse_labels(names):
j = ast.literal_eval(names)
ret = {}
for k, v in j.items():
ret[int(k)] = v
return ret


class ONNXPlugin(
PredictPlugin, scrypted_sdk.BufferConverter, scrypted_sdk.Settings, scrypted_sdk.DeviceProvider
PredictPlugin,
scrypted_sdk.BufferConverter,
scrypted_sdk.Settings,
scrypted_sdk.DeviceProvider,
):
def __init__(self, nativeId: str | None = None, forked: bool = False):
super().__init__(nativeId=nativeId, forked=forked)
Expand All @@ -67,7 +72,11 @@ def __init__(self, nativeId: str | None = None, forked: bool = False):

print(f"model {model}")

onnxmodel = model if self.scrypted_yolo_nas else "best" if self.scrypted_model else model
onnxmodel = (
model
if self.scrypted_yolo_nas
else "best" if self.scrypted_model else model
)

model_version = "v3"
onnxfile = self.downloadFile(
Expand All @@ -92,22 +101,28 @@ def __init__(self, nativeId: str | None = None, forked: bool = False):
sess_options = onnxruntime.SessionOptions()

providers: list[str] = []
if sys.platform == 'darwin':
if sys.platform == "darwin":
providers.append("CoreMLExecutionProvider")

if ('linux' in sys.platform or 'win' in sys.platform) and (platform.machine() == 'x86_64' or platform.machine() == 'AMD64'):
if ("linux" in sys.platform or "win" in sys.platform) and (
platform.machine() == "x86_64" or platform.machine() == "AMD64"
):
deviceId = int(deviceId)
providers.append(("CUDAExecutionProvider", { "device_id": deviceId }))
providers.append(("CUDAExecutionProvider", {"device_id": deviceId}))

providers.append('CPUExecutionProvider')
providers.append("CPUExecutionProvider")

compiled_model = onnxruntime.InferenceSession(onnxfile, sess_options=sess_options, providers=providers)
compiled_model = onnxruntime.InferenceSession(
onnxfile, sess_options=sess_options, providers=providers
)
compiled_models.append(compiled_model)

input = compiled_model.get_inputs()[0]
self.model_dim = input.shape[2]
self.input_name = input.name
self.labels = parse_labels(compiled_model.get_modelmeta().custom_metadata_map['names'])
self.labels = parse_labels(
compiled_model.get_modelmeta().custom_metadata_map["names"]
)

except:
import traceback
Expand All @@ -130,7 +145,7 @@ def executor_initializer():
providers.remove("CPUExecutionProvider")
# join the remaining providers string
self.provider = ", ".join(providers)
print('Runtime initialized on thread {}'.format(thread_name))
print("Runtime initialized on thread {}".format(thread_name))

self.executor = concurrent.futures.ThreadPoolExecutor(
initializer=executor_initializer,
Expand Down Expand Up @@ -222,11 +237,11 @@ async def getSettings(self) -> list[Setting]:
"title": "Execution Device",
"readonly": True,
"value": self.provider,
}
},
]

async def putSetting(self, key: str, value: SettingValue):
if (key == 'deviceIds'):
if key == "deviceIds":
value = json.dumps(value)
self.storage.setItem(key, value)
await self.onDeviceEvent(scrypted_sdk.ScryptedInterface.Settings.value, None)
Expand All @@ -240,7 +255,7 @@ def get_input_size(self) -> Tuple[int, int]:
return [self.model_dim, self.model_dim]

async def detect_once(self, input: Image.Image, settings: Any, src_size, cvss):
def prepare():
def prepare():
im = np.array(input)
im = np.expand_dims(input, axis=0)
im = im.transpose((0, 3, 1, 2)) # BHWC to BCHW, (n, 3, h, w)
Expand All @@ -250,7 +265,7 @@ def prepare():

def predict(input_tensor):
compiled_model = self.compiled_models[threading.current_thread().name]
output_tensors = compiled_model.run(None, { self.input_name: input_tensor })
output_tensors = compiled_model.run(None, {self.input_name: input_tensor})
if self.scrypted_yolov10:
return yolo.parse_yolov10(output_tensors[0][0])
if self.scrypted_yolo_nas:
Expand Down
6 changes: 5 additions & 1 deletion plugins/openvino/src/ov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,11 @@ def __init__(self, nativeId: str | None = None, forked: bool = False):
self.sigmoid = model == "yolo-v4-tiny-tf"
self.modelName = model

ovmodel = "best-converted" if self.scrypted_yolov9 else "best" if self.scrypted_model else model
ovmodel = (
"best-converted"
if self.scrypted_yolov9
else "best" if self.scrypted_model else model
)

model_version = "v7"
xmlFile = self.downloadFile(
Expand Down

0 comments on commit cd0ab10

Please sign in to comment.