Skip to content

Commit

Permalink
proper handle of lang cols
Browse files Browse the repository at this point in the history
  • Loading branch information
mplatzer committed Feb 19, 2025
1 parent a5a5ab4 commit 7704136
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
6 changes: 3 additions & 3 deletions mostlyai/sdk/_data/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import pandas as pd

from mostlyai.sdk.domain import ModelEncodingType
from mostlyai.sdk.domain import ModelType

from mostlyai.sdk._data.base import Schema
from mostlyai.sdk._data.util.common import TABLE_COLUMN_INFIX, TEMPORARY_PRIMARY_KEY
Expand All @@ -37,7 +37,7 @@ def split_language_model(
:return: ctx_data, tgt_data
"""
enctypes = schema.tables[tgt].encoding_types
language_cols = [col for col in enctypes if enctypes[col] == ModelEncodingType.language_text]
language_cols = [col for col in enctypes if enctypes[col].startswith(ModelType.language)]
if len(language_cols) == 0:
# if no LANGUAGE columns are present, then leave data as-is
return ctx_data, tgt_data
Expand Down Expand Up @@ -88,7 +88,7 @@ def drop_language_columns_in_target(
tgt_table = schema.tables[tgt]
drop_columns = []
for col_name, encoding_type in tgt_table.encoding_types.items():
if encoding_type == ModelEncodingType.language_text:
if encoding_type.startswith(ModelType.language):
drop_columns.append(col_name)
if drop_columns:
_LOG.info(f"drop LANGUAGE columns from target: {drop_columns}")
Expand Down
14 changes: 11 additions & 3 deletions mostlyai/sdk/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# generated by datamodel-codegen:
# timestamp: 2025-02-18T14:51:31+00:00
# timestamp: 2025-02-19T08:34:06+00:00

from __future__ import annotations

Expand Down Expand Up @@ -476,6 +476,10 @@ class RareCategoryReplacementMethod(str, Enum):


class TaskType(str, Enum):
"""
The type of the task.
"""

sync = "SYNC"
train_tabular = "TRAIN_TABULAR"
train_language = "TRAIN_LANGUAGE"
Expand All @@ -489,6 +493,10 @@ class TaskType(str, Enum):


class StepCode(str, Enum):
"""
The unique code for the step.
"""

pull_training_data = "PULL_TRAINING_DATA"
analyze_training_data = "ANALYZE_TRAINING_DATA"
encode_training_data = "ENCODE_TRAINING_DATA"
Expand Down Expand Up @@ -2163,8 +2171,8 @@ def add_model_configuration(cls, values):
keys.append(values.primary_key)
model_columns = [c for c in values.columns if c.name not in keys]
enc_types = [c.model_encoding_type or ModelEncodingType.auto for c in model_columns]
has_tabular_model = any(not enc_type.startswith("LANGUAGE_") for enc_type in enc_types)
has_language_model = any(enc_type.startswith("LANGUAGE_") for enc_type in enc_types)
has_tabular_model = any(not enc_type.startswith(ModelType.language) for enc_type in enc_types)
has_language_model = any(enc_type.startswith(ModelType.language) for enc_type in enc_types)
else:
has_tabular_model = True
has_language_model = False
Expand Down
4 changes: 2 additions & 2 deletions tools/custom_template/pydantic_v2/BaseModel.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,8 @@ class {{ class_name }}({{ base_class }}):{% if comment is defined %} # {{ comme
keys.append(values.primary_key)
model_columns = [c for c in values.columns if c.name not in keys]
enc_types = [c.model_encoding_type or ModelEncodingType.auto for c in model_columns]
has_tabular_model = any(not enc_type.startswith("LANGUAGE_") for enc_type in enc_types)
has_language_model = any(enc_type.startswith("LANGUAGE_") for enc_type in enc_types)
has_tabular_model = any(not enc_type.startswith(ModelType.language) for enc_type in enc_types)
has_language_model = any(enc_type.startswith(ModelType.language) for enc_type in enc_types)
else:
has_tabular_model = True
has_language_model = False
Expand Down
5 changes: 3 additions & 2 deletions tools/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ProgressStatus,
ModelConfiguration,
SyntheticDatasetReportType,
ModelType,
)


Expand Down Expand Up @@ -425,8 +426,8 @@ def add_model_configuration(cls, values):
keys.append(values.primary_key)
model_columns = [c for c in values.columns if c.name not in keys]
enc_types = [c.model_encoding_type or ModelEncodingType.auto for c in model_columns]
has_tabular_model = any(not enc_type.startswith("LANGUAGE_") for enc_type in enc_types)
has_language_model = any(enc_type.startswith("LANGUAGE_") for enc_type in enc_types)
has_tabular_model = any(not enc_type.startswith(ModelType.language) for enc_type in enc_types)
has_language_model = any(enc_type.startswith(ModelType.language) for enc_type in enc_types)
else:
has_tabular_model = True
has_language_model = False
Expand Down

0 comments on commit 7704136

Please sign in to comment.