# Copyright 2024 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import os
import mlrun.errors
from mlrun import get_current_project, code_to_function, mlconf
from mlrun.runtimes import ServingRuntime
from mlrun.serving import ModelRunnerStep
from mlrun.datastore.datastore_profile import (
DatastoreProfileV3io,
DatastoreProfileKafkaStream,
DatastoreProfileTDEngine,
)
from mlrun.utils import logger
class AgentDeployer:
def __init__(
self,
agent_name: str,
model_class_name: str,
function: str,
result_path: Optional[str] = None,
inputs_path: Optional[str] = None,
outputs: Optional[list[str]] = None,
requirements: Optional[list[str]] = None,
image: str = "mlrun/mlrun",
set_model_monitoring: bool = False,
**model_params,
):
"""
Class to deploy an agent as a serving function in MLRun.
:param agent_name: Name of the agent
:param model_class_name: Model class name. If LLModel is chosen
(either by name `LLModel` or by its full path, e.g. mlrun.serving.states.LLModel),
outputs will be overridden with UsageResponseKeys fields.
:param function: Path to the function file.
:param result_path: when specified selects the key/path in the output event to use as model monitoring
outputs this require that the output event body will behave like a dict,
expects scopes to be defined by dot notation (e.g "data.d").
:param inputs_path: when specified selects the key/path in the event to use as model monitoring inputs
this require that the event body will behave like a dict, expects scopes to be
defined by dot notation (e.g "data.d").
:param outputs: list of the model outputs (e.g. labels) ,if provided will override the outputs
that been configured in the model artifact, please note that those outputs need to
be equal to the model_class predict method outputs (length, and order).
:param requirements: List of additional requirements for the function
:param image: Docker image to be used for the function
:param set_model_monitoring: Whether to configure model monitoring
:param model_params: Parameters for model instantiation
"""
self._function = None
self._project = None
self._project_name = None
self.agent_name = agent_name
self.model_class_name = model_class_name
self.function_file = function
self.requirements = requirements or []
self.model_params = model_params or {}
self.result_path = result_path
self.inputs_path = inputs_path
self.output_schema = outputs
self.image = image
if set_model_monitoring:
self.configure_model_monitoring()
def configure_model_monitoring(self):
"""Configure model monitoring for the active project."""
if not self.project:
raise mlrun.errors.MLRunInvalidArgumentError(
"No active project detected, unable to set model monitoring"
)
if mlconf.is_ce_mode():
mlrun_namespace = os.environ.get("MLRUN_NAMESPACE", "mlrun")
tsdb_profile = DatastoreProfileTDEngine(
name="tdengine-tsdb-profile",
user="root",
password="taosdata",
host=f"tdengine-tsdb.{mlrun_namespace}.svc.cluster.local",
port="6041",
)
stream_profile = DatastoreProfileKafkaStream(
name="kafka-stream-profile",
brokers=f"kafka-stream.{mlrun_namespace}.svc.cluster.local:9092",
topics=[],
)
else:
tsdb_profile = DatastoreProfileV3io(
name="v3io-tsdb-profile",
v3io_access_key=mlconf.get_v3io_access_key(),
)
stream_profile = DatastoreProfileV3io(
name="v3io-stream-profile",
v3io_access_key=mlconf.get_v3io_access_key(),
)
self.project.register_datastore_profile(tsdb_profile)
self.project.register_datastore_profile(stream_profile)
self.project.set_model_monitoring_credentials(
stream_profile_name=stream_profile.name,
tsdb_profile_name=tsdb_profile.name,
replace_creds=True,
)
try:
self.project.enable_model_monitoring(
base_period=10, deploy_histogram_data_drift_app=False
)
except (mlrun.errors.MLRunConflictError, mlrun.errors.MLRunHTTPError) as e:
logger.info(
"While calling enable_model_monitoring, caught expected exception:",
error=str(e),
)
@property
def project(self):
"""Get the current MLRun project."""
if self._project:
return self._project
self._project = get_current_project(silent=True)
return self._project
@property
def project_name(self):
"""Get the name of the current MLRun project."""
if self._project_name:
return self._project_name
if self.project:
self._project_name = self.project.metadata.name
return self._project_name
raise mlrun.errors.MLRunInvalidArgumentError(
"No current project found to get project name"
)
def get_function(self) -> ServingRuntime:
"""
Get the serving function, loading it if necessary.
"""
if self._function is None:
self._load_function()
return self._function
def deploy_function(self, enable_tracking: bool) -> ServingRuntime:
"""
Deploy the agent as a serving function in MLRun.
:param enable_tracking: Whether to enable tracking for the function.
"""
function = self.get_function()
function.set_tracking(enable_tracking=enable_tracking)
function.deploy()
return function
def _load_function(
self,
) -> ServingRuntime:
self._function = code_to_function(
name=f"{self.agent_name}_serving_function",
filename=self.function_file,
project=self.project_name,
kind="serving",
image=self.image,
requirements=self.requirements,
)
graph = self._function.set_topology(topology="flow", engine="async")
model_runner_step = ModelRunnerStep()
model_runner_step.add_model(
model_class=self.model_class_name,
endpoint_name=self.agent_name,
result_path=self.result_path,
input_path=self.inputs_path,
outputs=self.output_schema,
execution_mechanism="naive",
**self.model_params,
)
graph.to(model_runner_step).respond()
return self._function