Skip to content

Commit 4452365

Browse files
committed
Compare based on keys on hash diff
1 parent f05b791 commit 4452365

File tree

4 files changed

+30
-18
lines changed

4 files changed

+30
-18
lines changed

src/lib/src/api/local/compare.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ pub fn compare_files(
4646
targets: Vec<String>,
4747
output: Option<PathBuf>,
4848
) -> Result<CompareResult, OxenError> {
49-
// Assert that the files exist in their respective commits and are tabular.
49+
// Assert that the files exist in their respective commits.
5050
let file_1 = get_version_file(repo, &compare_entry_1)?;
5151
let file_2 = get_version_file(repo, &compare_entry_2)?;
5252

@@ -414,7 +414,9 @@ fn compute_row_comparison(
414414
let mut dupes = CompareDupes { left: 0, right: 0 };
415415

416416
let dataframes = match strategy {
417-
CompareStrategy::Hash => hash_compare::compare(df_1, df_2, &schema_1, &schema_2)?,
417+
CompareStrategy::Hash => {
418+
hash_compare::compare(df_1, df_2, &schema_1, &schema_2, keys.to_owned())?
419+
}
418420
CompareStrategy::Join => {
419421
// TODO: unsure if hash comparison or join is faster here - would guess join, could use some testing
420422
let (df_1, df_2) = hash_dfs(

src/lib/src/api/local/compare/hash_compare.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub fn compare(
1313
head_df: &DataFrame,
1414
schema_1: &Schema,
1515
schema_2: &Schema,
16+
keys: Vec<&str>,
1617
) -> Result<(DataFrame, DataFrame, DataFrame, DataFrame), OxenError> {
1718
if schema_1.hash != schema_2.hash {
1819
return Err(OxenError::invalid_file_type(
@@ -21,7 +22,7 @@ pub fn compare(
2122
}
2223

2324
// Compute row indices
24-
let (added_indices, removed_indices) = compute_new_row_indices(base_df, head_df)?;
25+
let (added_indices, removed_indices) = compute_new_row_indices(base_df, head_df, keys)?;
2526

2627
// Take added from the current df
2728
let added_rows = if !added_indices.is_empty() {
@@ -52,10 +53,11 @@ pub fn compare(
5253
fn compute_new_row_indices(
5354
base_df: &DataFrame,
5455
head_df: &DataFrame,
56+
keys: Vec<&str>,
5557
) -> Result<(Vec<u32>, Vec<u32>), OxenError> {
5658
// Hash the rows
57-
let base_df = tabular::df_hash_rows(base_df.clone())?;
58-
let head_df = tabular::df_hash_rows(head_df.clone())?;
59+
let base_df = tabular::df_hash_rows(base_df.clone(), Some(keys.clone()))?;
60+
let head_df = tabular::df_hash_rows(head_df.clone(), Some(keys.clone()))?;
5961

6062
log::debug!("diff_current got current hashes base_df {:?}", base_df);
6163
log::debug!("diff_current got current hashes head_df {:?}", head_df);

src/lib/src/api/local/diff.rs

+8-10
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,19 @@ use serde::{Deserialize, Serialize};
33
use crate::core::df::tabular;
44
use crate::core::index::object_db_reader::ObjectDBReader;
55
use crate::core::index::CommitDirEntryReader;
6+
use crate::core::index::CommitEntryReader;
67
use crate::error::OxenError;
78
use crate::model::diff::diff_entry_status::DiffEntryStatus;
8-
use crate::model::diff::generic_diff::GenericDiff;
99
use crate::model::{Commit, CommitEntry, DataFrameDiff, DiffEntry, LocalRepository, Schema};
1010
use crate::opts::DFOpts;
1111
use crate::view::compare::AddRemoveModifyCounts;
1212
use crate::view::Pagination;
1313
use crate::{constants, util};
1414

15-
use crate::core::index::CommitEntryReader;
16-
use colored::Colorize;
17-
use difference::{Changeset, Difference};
1815
use polars::export::ahash::HashMap;
1916
use polars::prelude::DataFrame;
2017
use polars::prelude::IntoLazy;
18+
2119
use std::collections::HashSet;
2220
use std::path::{Path, PathBuf};
2321
use std::str::FromStr;
@@ -79,8 +77,8 @@ pub fn get_version_file_from_commit_id(
7977

8078
pub fn count_added_rows(base_df: DataFrame, head_df: DataFrame) -> Result<usize, OxenError> {
8179
// Hash the rows
82-
let base_df = tabular::df_hash_rows(base_df)?;
83-
let head_df = tabular::df_hash_rows(head_df)?;
80+
let base_df = tabular::df_hash_rows(base_df, None)?;
81+
let head_df = tabular::df_hash_rows(head_df, None)?;
8482

8583
// log::debug!("count_added_rows got base_df {}", base_df);
8684
// log::debug!("count_added_rows got head_df {}", head_df);
@@ -111,8 +109,8 @@ pub fn count_added_rows(base_df: DataFrame, head_df: DataFrame) -> Result<usize,
111109

112110
pub fn count_removed_rows(base_df: DataFrame, head_df: DataFrame) -> Result<usize, OxenError> {
113111
// Hash the rows
114-
let base_df = tabular::df_hash_rows(base_df)?;
115-
let head_df = tabular::df_hash_rows(head_df)?;
112+
let base_df = tabular::df_hash_rows(base_df, None)?;
113+
let head_df = tabular::df_hash_rows(head_df, None)?;
116114

117115
// log::debug!("count_removed_rows got base_df {}", base_df);
118116
// log::debug!("count_removed_rows got head_df {}", head_df);
@@ -149,8 +147,8 @@ pub fn compute_new_row_indices(
149147
head_df: &DataFrame,
150148
) -> Result<(Vec<u32>, Vec<u32>), OxenError> {
151149
// Hash the rows
152-
let base_df = tabular::df_hash_rows(base_df.clone())?;
153-
let head_df = tabular::df_hash_rows(head_df.clone())?;
150+
let base_df = tabular::df_hash_rows(base_df.clone(), None)?;
151+
let head_df = tabular::df_hash_rows(head_df.clone(), None)?;
154152

155153
log::debug!("diff_current got current hashes base_df {:?}", base_df);
156154
log::debug!("diff_current got current hashes head_df {:?}", head_df);

src/lib/src/core/df/tabular.rs

+13-3
Original file line numberDiff line numberDiff line change
@@ -700,13 +700,23 @@ pub fn any_val_to_bytes(value: &AnyValue) -> Vec<u8> {
700700
}
701701
}
702702

703-
pub fn df_hash_rows(df: DataFrame) -> Result<DataFrame, OxenError> {
703+
pub fn df_hash_rows(df: DataFrame, keys: Option<Vec<&str>>) -> Result<DataFrame, OxenError> {
704704
let num_rows = df.height() as i64;
705705

706706
let mut col_names = vec![];
707707
let schema = df.schema();
708-
for field in schema.iter_fields() {
709-
col_names.push(col(field.name()));
708+
709+
match keys {
710+
Some(keys) => {
711+
for key in keys {
712+
col_names.push(col(key));
713+
}
714+
}
715+
None => {
716+
for field in schema.iter_fields() {
717+
col_names.push(col(field.name()));
718+
}
719+
}
710720
}
711721
// println!("Hashing: {:?}", col_names);
712722
// println!("{:?}", df);

0 commit comments

Comments
 (0)