Skip to content

Commit

Permalink
Temporary cache dir
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-savelyevv committed Nov 25, 2024
1 parent 2b2481c commit 3fe8ed2
Showing 1 changed file with 36 additions and 21 deletions.
57 changes: 36 additions & 21 deletions tests/openvino/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,16 +392,21 @@ def test_textual_inversion(self):
inputs = self.generate_inputs()
inputs["prompt"] = "A <cat-toy> backpack"

diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(model_id, safety_checker=None)
diffusers_pipeline.load_textual_inversion(ti_id)
with TemporaryDirectory() as temp_dir:
diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(
model_id, safety_checker=None, cache_dir=temp_dir
)
diffusers_pipeline.load_textual_inversion(ti_id)

ov_pipeline = self.OVMODEL_CLASS.from_pretrained(model_id, compile=False, safety_checker=None)
ov_pipeline.load_textual_inversion(ti_id)
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(
model_id, compile=False, safety_checker=None, cache_dir=temp_dir
)
ov_pipeline.load_textual_inversion(ti_id)

diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)
np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)


class OVPipelineForImage2ImageTest(unittest.TestCase):
Expand Down Expand Up @@ -636,16 +641,21 @@ def test_textual_inversion(self):
inputs = self.generate_inputs(model_type="stable-diffusion")
inputs["prompt"] = "A <cat-toy> backpack"

diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(model_id, safety_checker=None)
diffusers_pipeline.load_textual_inversion(ti_id)
with TemporaryDirectory() as temp_dir:
diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(
model_id, safety_checker=None, cache_dir=temp_dir
)
diffusers_pipeline.load_textual_inversion(ti_id)

ov_pipeline = self.OVMODEL_CLASS.from_pretrained(model_id, compile=False, safety_checker=None)
ov_pipeline.load_textual_inversion(ti_id)
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(
model_id, compile=False, safety_checker=None, cache_dir=temp_dir
)
ov_pipeline.load_textual_inversion(ti_id)

diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)
np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)


class OVPipelineForInpaintingTest(unittest.TestCase):
Expand Down Expand Up @@ -880,13 +890,18 @@ def test_textual_inversion(self):
inputs = self.generate_inputs()
inputs["prompt"] = "A <cat-toy> backpack"

diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(model_id, safety_checker=None)
diffusers_pipeline.load_textual_inversion(ti_id)
with TemporaryDirectory() as temp_dir:
diffusers_pipeline = self.AUTOMODEL_CLASS.from_pretrained(
model_id, safety_checker=None, cache_dir=temp_dir
)
diffusers_pipeline.load_textual_inversion(ti_id)

ov_pipeline = self.OVMODEL_CLASS.from_pretrained(model_id, compile=False, safety_checker=None)
ov_pipeline.load_textual_inversion(ti_id)
ov_pipeline = self.OVMODEL_CLASS.from_pretrained(
model_id, compile=False, safety_checker=None, cache_dir=temp_dir
)
ov_pipeline.load_textual_inversion(ti_id)

diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images
diffusers_output = diffusers_pipeline(**inputs, generator=get_generator("pt", SEED)).images
ov_output = ov_pipeline(**inputs, generator=get_generator("pt", SEED)).images

np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)
np.testing.assert_allclose(ov_output, diffusers_output, atol=1e-4, rtol=1e-2)

0 comments on commit 3fe8ed2

Please sign in to comment.