Skip to content

Commit

Permalink
feat(controller): support online eval in fine-tune space (#3031)
Browse files Browse the repository at this point in the history
  • Loading branch information
goldenxinxing authored Nov 27, 2023
1 parent 94fe64d commit 908f6d9
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 5 deletions.
8 changes: 8 additions & 0 deletions client/starwhale/base/client/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class BizType(Enum):

class Type1(Enum):
evaluation = 'EVALUATION'
online_eval = 'ONLINE_EVAL'
train = 'TRAIN'
fine_tune = 'FINE_TUNE'
serving = 'SERVING'
Expand Down Expand Up @@ -834,6 +835,7 @@ class ExposedLinkVo(SwBaseModel):

class JobType(Enum):
evaluation = 'EVALUATION'
online_eval = 'ONLINE_EVAL'
train = 'TRAIN'
fine_tune = 'FINE_TUNE'
serving = 'SERVING'
Expand Down Expand Up @@ -1614,6 +1616,12 @@ class ResponseMessageGraph(SwBaseModel):
data: Graph


class ResponseMessageListJobVo(SwBaseModel):
code: str
message: str
data: List[JobVo]


class FineTuneVo(SwBaseModel):
id: int
job: JobVo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ai.starwhale.mlops.api.protocol.ft.FineTuneMigrationRequest;
import ai.starwhale.mlops.api.protocol.ft.FineTuneSpaceCreateRequest;
import ai.starwhale.mlops.api.protocol.ft.FineTuneSpaceVo;
import ai.starwhale.mlops.api.protocol.job.JobVo;
import ai.starwhale.mlops.api.protocol.model.ModelViewVo;
import ai.starwhale.mlops.common.IdConverter;
import ai.starwhale.mlops.configuration.FeaturesProperties;
Expand Down Expand Up @@ -156,6 +157,19 @@ public ResponseEntity<ResponseMessage<PageInfo<FineTuneVo>>> listFineTune(
return ResponseEntity.ok(Code.success.asResponse(pageInfo));
}

@Operation(summary = "List online eval")
@GetMapping(
value = "/project/{projectId}/ftspace/{spaceId}/online-eval",
produces = MediaType.APPLICATION_JSON_VALUE
)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER', 'GUEST')")
public ResponseEntity<ResponseMessage<List<JobVo>>> listOnlineEval(
@PathVariable("projectId") Long projectId,
@PathVariable("spaceId") Long spaceId
) {
return ResponseEntity.ok(Code.success.asResponse(fineTuneAppService.listOnlineEval(projectId, spaceId)));
}

@Operation(summary = "Get fine-tune info")
@GetMapping(value = "/project/{projectId}/ftspace/{spaceId}/ft/{ftId}", produces = MediaType.APPLICATION_JSON_VALUE)
@PreAuthorize("hasAnyRole('OWNER', 'MAINTAINER', 'GUEST')")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ public ResponseEntity<ResponseMessage<String>> createJob(
jobId = fineTuneAppService.createFineTune(projectUrl, spaceId, jobRequest);
} else if (jobRequest.getType() == JobType.EVALUATION) {
jobId = fineTuneAppService.createEvaluationJob(projectUrl, spaceId, jobRequest);
} else if (jobRequest.getType() == JobType.ONLINE_EVAL) {
jobId = jobServiceForWeb.createJob(userJobConverter.convert(projectUrl, jobRequest));
}
} else {
jobId = jobServiceForWeb.createJob(userJobConverter.convert(projectUrl, jobRequest));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import static ai.starwhale.mlops.domain.evaluation.EvaluationService.TABLE_NAME_FORMAT;

import ai.starwhale.mlops.api.protocol.job.JobRequest;
import ai.starwhale.mlops.api.protocol.job.JobVo;
import ai.starwhale.mlops.api.protocol.model.ModelViewVo;
import ai.starwhale.mlops.api.protocol.model.ModelVo;
import ai.starwhale.mlops.common.Constants;
Expand Down Expand Up @@ -431,6 +432,18 @@ public void releaseFt(
}
}

public List<JobVo> listOnlineEval(Long projectId, Long spaceId) {
var onlineEvaluations = jobMapper.listBizJobs(
projectId,
BizType.FINE_TUNE.name(),
String.valueOf(spaceId),
JobType.ONLINE_EVAL.name(),
null
);
return onlineEvaluations.stream()
.map(jobConverter::convert).collect(Collectors.toList());
}

public List<ModelViewVo> listModelVersionView(Long projectId, Long spaceId) {
return modelService.listFtSpaceModelVersionView(String.valueOf(projectId), spaceId);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
package ai.starwhale.mlops.domain.job;

public enum JobType {
EVALUATION, TRAIN, FINE_TUNE, SERVING, BUILT_IN
EVALUATION, ONLINE_EVAL, TRAIN, FINE_TUNE, SERVING, BUILT_IN
}
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ public JobConverter(
this.webServerInTask = webServerInTask;
}

private ModelVo findModelByVersionIds(Long versionId) {
private ModelVo findModelByVersionId(Long versionId) {
if (null == versionId) {
return null;
}
Expand All @@ -99,7 +99,7 @@ private ModelVo findModelByVersionIds(Long versionId) {
return modelVos.get(0);
}

private RuntimeVo findRuntimeByVersionIds(Long versionId) {
private RuntimeVo findRuntimeByVersionId(Long versionId) {
if (null == versionId) {
return null;
}
Expand Down Expand Up @@ -186,7 +186,7 @@ private List<ExposedLinkVo> generateJobExposedLinks(Long jobId) {
}

public JobVo convert(JobEntity jobEntity) throws ConvertException {
var runtime = findRuntimeByVersionIds(jobEntity.getRuntimeVersionId());
var runtime = findRuntimeByVersionId(jobEntity.getRuntimeVersionId());
var datasetList = findDatasetVersionsByJobId(jobEntity.getId());
Long pinnedTime = jobEntity.getPinnedTime() != null ? jobEntity.getPinnedTime().getTime() : null;

Expand All @@ -197,7 +197,7 @@ public JobVo convert(JobEntity jobEntity) throws ConvertException {
.owner(UserVo.fromEntity(jobEntity.getOwner(), idConvertor))
.modelName(jobEntity.getModelName())
.modelVersion(jobEntity.getModelVersion().getVersionName())
.model(findModelByVersionIds(jobEntity.getModelVersionId()))
.model(findModelByVersionId(jobEntity.getModelVersionId()))
.jobName(extractJobName(jobEntity.getStepSpec()))
.createdTime(jobEntity.getCreatedTime().getTime())
.runtime(runtime)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
import ai.starwhale.mlops.domain.ft.po.FineTuneEntity;
import ai.starwhale.mlops.domain.ft.po.FineTuneSpaceEntity;
import ai.starwhale.mlops.domain.ft.vo.FineTuneVo;
import ai.starwhale.mlops.domain.job.BizType;
import ai.starwhale.mlops.domain.job.JobCreator;
import ai.starwhale.mlops.domain.job.JobType;
import ai.starwhale.mlops.domain.job.bo.Job;
import ai.starwhale.mlops.domain.job.bo.UserJobCreateRequest;
import ai.starwhale.mlops.domain.job.converter.JobConverter;
Expand Down Expand Up @@ -174,6 +176,14 @@ void listFt() {
assertEquals(1, fineTuneAppService.list(1L, 1, 1).getSize());
}

@Test
void listFtOnlineEval() {
when(jobMapper.listBizJobs(1L, BizType.FINE_TUNE.name(), "1", JobType.ONLINE_EVAL.name(), null))
.thenReturn(List.of(JobEntity.builder().id(1L).build(), JobEntity.builder().id(2L).build()));
when(jobConverter.convert(any())).thenReturn(JobVo.builder().build());
assertEquals(2, fineTuneAppService.listOnlineEval(1L, 1L).size());
}

@Test
void ftInfo() {
when(fineTuneMapper.findById(anyLong(), anyLong())).thenReturn(FineTuneEntity.builder().jobId(1L).build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ public void testListBizJobs() {
jobs = jobMapper.listBizJobs(project.getId(), BizType.FINE_TUNE.name(), null, JobType.EVALUATION.name(), null);
Assertions.assertEquals(1, jobs.size());

jobs = jobMapper.listBizJobs(project.getId(), BizType.FINE_TUNE.name(), null, JobType.ONLINE_EVAL.name(), null);
Assertions.assertEquals(0, jobs.size());

jobs = jobMapper.listBizJobs(
project.getId(),
BizType.FINE_TUNE.name(),
Expand Down

0 comments on commit 908f6d9

Please sign in to comment.