Skip to content

Commit 611dfe7

Browse files
committed
Main process manages models, child process loads models.
1 parent 6f1cad6 commit 611dfe7

File tree

11 files changed

+139
-156
lines changed

11 files changed

+139
-156
lines changed

iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
import torch
2222

2323
from iotdb.ainode.core.exception import InferenceModelInternalError
24-
from iotdb.ainode.core.manager.model_manager import ModelManager
24+
from iotdb.ainode.core.model.model_loader import load_model
2525

2626

2727
class BasicPipeline(ABC):
28-
def __init__(self, model_id, **infer_kwargs):
29-
self.model_id = model_id
28+
def __init__(self, model_info, **infer_kwargs):
29+
self.model_info = model_info
3030
self.device = infer_kwargs.get("device", "cpu")
31-
self.model = ModelManager().load_model(model_id, device_map=self.device)
31+
self.model = load_model(model_info, device_map=self.device)
3232

3333
def _preprocess(self, inputs):
3434
"""
@@ -45,8 +45,8 @@ def _postprocess(self, output: torch.Tensor):
4545

4646

4747
class ForecastPipeline(BasicPipeline):
48-
def __init__(self, model_id, **infer_kwargs):
49-
super().__init__(model_id, infer_kwargs=infer_kwargs)
48+
def __init__(self, model_info, **infer_kwargs):
49+
super().__init__(model_info, infer_kwargs=infer_kwargs)
5050

5151
def _preprocess(self, inputs):
5252
if len(inputs.shape) != 2:
@@ -63,8 +63,8 @@ def _postprocess(self, output: torch.Tensor):
6363

6464

6565
class ClassificationPipeline(BasicPipeline):
66-
def __init__(self, model_id, **infer_kwargs):
67-
super().__init__(model_id, infer_kwargs=infer_kwargs)
66+
def __init__(self, model_info, **infer_kwargs):
67+
super().__init__(model_info, infer_kwargs=infer_kwargs)
6868

6969
def _preprocess(self, inputs):
7070
pass
@@ -80,8 +80,8 @@ def _postprocess(self, output: torch.Tensor):
8080

8181

8282
class ChatPipeline(BasicPipeline):
83-
def __init__(self, model_id, **infer_kwargs):
84-
super().__init__(model_id, infer_kwargs=infer_kwargs)
83+
def __init__(self, model_info, **infer_kwargs):
84+
super().__init__(model_info, infer_kwargs=infer_kwargs)
8585

8686
def _preprocess(self, inputs):
8787
pass

iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ def load_pipeline(model_info: ModelInfo, device: str, **kwargs):
4949
model_info.model_id, model_info.pipeline_cls
5050
)
5151

52-
return pipeline_cls(model_info.model_id, device=device)
52+
return pipeline_cls(model_info, device=device)

iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from iotdb.ainode.core.constant import TSStatusCode
2222
from iotdb.ainode.core.exception import BuiltInModelDeletionError
2323
from iotdb.ainode.core.log import Logger
24-
from iotdb.ainode.core.model.model_loader import ModelLoader
24+
from iotdb.ainode.core.model.model_loader import load_model
2525
from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo, ModelStorage
2626
from iotdb.ainode.core.rpc.status import get_status
2727
from iotdb.ainode.core.util.decorator import singleton
@@ -41,7 +41,6 @@
4141
class ModelManager:
4242
def __init__(self):
4343
self._model_storage = ModelStorage()
44-
self._model_loader = ModelLoader(storage=self._model_storage)
4544

4645
def register_model(
4746
self,
@@ -75,7 +74,8 @@ def delete_model(self, req: TDeleteModelReq) -> TSStatus:
7574
return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
7675

7776
def load_model(self, model_id: str, **kwargs) -> Any:
78-
return self._model_loader.load_model(model_id=model_id, **kwargs)
77+
model_info = self.get_model_info(model_id)
78+
return load_model(model_info=model_info, **kwargs)
7979

8080
def get_model_info(
8181
self,

iotdb-core/ainode/iotdb/ainode/core/manager/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from iotdb.ainode.core.exception import ModelNotExistError
2626
from iotdb.ainode.core.log import Logger
2727
from iotdb.ainode.core.manager.model_manager import ModelManager
28+
from iotdb.ainode.core.model.model_loader import load_model
2829

2930
logger = Logger()
3031

@@ -46,7 +47,8 @@ def measure_model_memory(device: torch.device, model_id: str) -> int:
4647
torch.cuda.synchronize(device)
4748
start = torch.cuda.memory_reserved(device)
4849

49-
model = ModelManager().load_model(model_id).to(device)
50+
model_info = ModelManager().get_model_info(model_id)
51+
model = load_model(model_info).to(device)
5052
torch.cuda.synchronize(device)
5153
end = torch.cuda.memory_reserved(device)
5254
usage = end - start

iotdb-core/ainode/iotdb/ainode/core/model/model_info.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
category: ModelCategory,
2929
state: ModelStates,
3030
model_type: str = "",
31+
config_cls: str = "",
3132
model_cls: str = "",
3233
pipeline_cls: str = "",
3334
repo_id: str = "",
@@ -38,6 +39,7 @@ def __init__(
3839
self.model_type = model_type
3940
self.category = category
4041
self.state = state
42+
self.config_cls = config_cls
4143
self.model_cls = model_cls
4244
self.pipeline_cls = pipeline_cls
4345
self.repo_id = repo_id
@@ -112,6 +114,7 @@ def __repr__(self):
112114
category=ModelCategory.BUILTIN,
113115
state=ModelStates.INACTIVE,
114116
model_type="timer",
117+
config_cls="configuration_timer.TimerConfig",
115118
model_cls="modeling_timer.TimerForPrediction",
116119
pipeline_cls="pipeline_timer.TimerPipeline",
117120
repo_id="thuml/timer-base-84m",
@@ -121,6 +124,7 @@ def __repr__(self):
121124
category=ModelCategory.BUILTIN,
122125
state=ModelStates.INACTIVE,
123126
model_type="sundial",
127+
config_cls="configuration_sundial.SundialConfig",
124128
model_cls="modeling_sundial.SundialForPrediction",
125129
pipeline_cls="pipeline_sundial.SundialPipeline",
126130
repo_id="thuml/sundial-base-128m",

0 commit comments

Comments
 (0)