diff --git a/chain/cosmos/poll.go b/chain/cosmos/poll.go
index f77724d82..8302f756a 100644
--- a/chain/cosmos/poll.go
+++ b/chain/cosmos/poll.go
@@ -10,7 +10,7 @@ import (
 // PollForProposalStatus attempts to find a proposal with matching ID and status.
 func PollForProposalStatus(ctx context.Context, chain *CosmosChain, startHeight, maxHeight uint64, proposalID string, status string) (ProposalResponse, error) {
 	var zero ProposalResponse
-	doPoll := func(ctx context.Context, height uint64) (any, error) {
+	doPoll := func(ctx context.Context, height uint64) (ProposalResponse, error) {
 		p, err := chain.QueryProposal(ctx, proposalID)
 		if err != nil {
 			return zero, err
@@ -20,10 +20,6 @@ func PollForProposalStatus(ctx context.Context, chain *CosmosChain, startHeight,
 		}
 		return *p, nil
 	}
-	bp := test.BlockPoller{CurrentHeight: chain.Height, PollFunc: doPoll}
-	p, err := bp.DoPoll(ctx, startHeight, maxHeight)
-	if err != nil {
-		return zero, err
-	}
-	return p.(ProposalResponse), nil
+	bp := test.BlockPoller[ProposalResponse]{CurrentHeight: chain.Height, PollFunc: doPoll}
+	return bp.DoPoll(ctx, startHeight, maxHeight)
 }
diff --git a/test/poll_for_state.go b/test/poll_for_state.go
index 5155c4af5..cf1f63325 100644
--- a/test/poll_for_state.go
+++ b/test/poll_for_state.go
@@ -12,23 +12,26 @@ import (
 
 var ErrNotFound = errors.New("not found")
 
-type BlockPoller struct {
+type BlockPoller[T any] struct {
 	CurrentHeight func(ctx context.Context) (uint64, error)
-	PollFunc      func(ctx context.Context, height uint64) (any, error)
+	PollFunc      func(ctx context.Context, height uint64) (T, error)
 }
 
-func (p BlockPoller) DoPoll(ctx context.Context, startHeight, maxHeight uint64) (any, error) {
+func (p BlockPoller[T]) DoPoll(ctx context.Context, startHeight, maxHeight uint64) (T, error) {
 	if maxHeight < startHeight {
 		panic("maxHeight must be greater than or equal to startHeight")
 	}
 
-	var pollErr error
+	var (
+		pollErr error
+		zero    T
+	)
 
 	cursor := startHeight
 	for cursor <= maxHeight {
 		curHeight, err := p.CurrentHeight(ctx)
 		if err != nil {
-			return nil, err
+			return zero, err
 		}
 		if cursor > curHeight {
 			continue
@@ -44,7 +47,7 @@ func (p BlockPoller) DoPoll(ctx context.Context, startHeight, maxHeight uint64)
 
 		return found, nil
 	}
-	return nil, pollErr
+	return zero, pollErr
 }
 
 // ChainAcker is a chain that can get its acknowledgements at a specified height
@@ -60,7 +63,7 @@ type ChainAcker interface {
 func PollForAck(ctx context.Context, chain ChainAcker, startHeight, maxHeight uint64, packet ibc.Packet) (ibc.PacketAcknowledgement, error) {
 	var zero ibc.PacketAcknowledgement
 	pollError := &packetPollError{targetPacket: packet}
-	poll := func(ctx context.Context, height uint64) (any, error) {
+	poll := func(ctx context.Context, height uint64) (ibc.PacketAcknowledgement, error) {
 		acks, err := chain.Acknowledgements(ctx, height)
 		if err != nil {
 			return zero, err
@@ -74,13 +77,13 @@ func PollForAck(ctx context.Context, chain ChainAcker, startHeight, maxHeight ui
 		return zero, ErrNotFound
 	}
 
-	poller := BlockPoller{CurrentHeight: chain.Height, PollFunc: poll}
+	poller := BlockPoller[ibc.PacketAcknowledgement]{CurrentHeight: chain.Height, PollFunc: poll}
 	found, err := poller.DoPoll(ctx, startHeight, maxHeight)
 	if err != nil {
 		pollError.SetErr(err)
 		return zero, pollError
 	}
-	return found.(ibc.PacketAcknowledgement), nil
+	return found, nil
 }
 
 // ChainTimeouter is a chain that can get its timeouts at a specified height
@@ -94,7 +97,7 @@ type ChainTimeouter interface {
 func PollForTimeout(ctx context.Context, chain ChainTimeouter, startHeight, maxHeight uint64, packet ibc.Packet) (ibc.PacketTimeout, error) {
 	pollError := &packetPollError{targetPacket: packet}
 	var zero ibc.PacketTimeout
-	poll := func(ctx context.Context, height uint64) (any, error) {
+	poll := func(ctx context.Context, height uint64) (ibc.PacketTimeout, error) {
 		timeouts, err := chain.Timeouts(ctx, height)
 		if err != nil {
 			return zero, err
@@ -108,13 +111,13 @@ func PollForTimeout(ctx context.Context, chain ChainTimeouter, startHeight, maxH
 		return zero, ErrNotFound
 	}
 
-	poller := BlockPoller{CurrentHeight: chain.Height, PollFunc: poll}
+	poller := BlockPoller[ibc.PacketTimeout]{CurrentHeight: chain.Height, PollFunc: poll}
 	found, err := poller.DoPoll(ctx, startHeight, maxHeight)
 	if err != nil {
 		pollError.SetErr(err)
 		return zero, pollError
 	}
-	return found.(ibc.PacketTimeout), nil
+	return found, nil
 }
 
 type packetPollError struct {