fix: tree truncation function

This commit is contained in:
Rene Vergara 2024-11-19 07:26:27 -06:00
parent f23c222edc
commit 42f6b6becb
No known key found for this signature in database
GPG key ID: 65122AD495A7F5B2

View file

@ -9,9 +9,11 @@
module Zenith.Tree where module Zenith.Tree where
import Codec.Borsh import Codec.Borsh
import Control.Monad.Logger (LoggingT, logDebugN)
import Data.HexString import Data.HexString
import Data.Int (Int32, Int64, Int8) import Data.Int (Int32, Int64, Int8)
import Data.Maybe (fromJust, isNothing) import Data.Maybe (fromJust, isNothing)
import qualified Data.Text as T
import qualified GHC.Generics as GHC import qualified GHC.Generics as GHC
import qualified Generics.SOP as SOP import qualified Generics.SOP as SOP
import ZcashHaskell.Orchard (combineOrchardNodes, getOrchardNodeValue) import ZcashHaskell.Orchard (combineOrchardNodes, getOrchardNodeValue)
@ -179,14 +181,49 @@ getNotePosition (Branch _ x y) i
| otherwise = Nothing | otherwise = Nothing
getNotePosition _ _ = Nothing getNotePosition _ _ = Nothing
truncateTree :: Monoid v => Node v => Tree v -> Int64 -> Tree v truncateTree :: Monoid v => Node v => Tree v -> Int64 -> LoggingT IO (Tree v)
truncateTree (Branch s x y) i truncateTree (Branch s x y) i
| getLevel s == 1 && getIndex (value x) == i = branch x EmptyLeaf | getLevel s == 1 && getIndex (value x) == i = do
| getLevel s == 1 && getIndex (value y) == i = branch x y logDebugN $ T.pack $ show (getLevel s) ++ " Trunc to left leaf"
| getIndex (value x) >= i = return $ branch x EmptyLeaf
branch (truncateTree x i) (getEmptyRoot (getLevel s)) | getLevel s == 1 && getIndex (value y) == i = do
| getIndex (value y) >= i = branch x (truncateTree y i) logDebugN $ T.pack $ show (getLevel s) ++ " Trunc to right leaf"
truncateTree x _ = x return $ branch x y
| getIndex (value x) >= i = do
logDebugN $
T.pack $
show (getLevel s) ++
": " ++ show i ++ " left i: " ++ show (getIndex (value x))
l <- truncateTree x i
return $ branch (l) (getEmptyRoot (getLevel (value x)))
| getIndex (value y) /= 0 && getIndex (value y) >= i = do
logDebugN $
T.pack $
show (getLevel s) ++
": " ++ show i ++ " right i: " ++ show (getIndex (value y))
r <- truncateTree y i
return $ branch x (r)
| otherwise = do
logDebugN $
T.pack $
show (getLevel s) ++
": " ++
show (getIndex (value x)) ++ " catchall " ++ show (getIndex (value y))
return InvalidTree
truncateTree x _ = return x
countLeaves :: Node v => Tree v -> Int64
countLeaves (Branch s x y) =
if isFull s
then 2 ^ getLevel s
else countLeaves x + countLeaves y
countLeaves (PrunedBranch x) =
if isFull x
then 2 ^ getLevel x
else 0
countLeaves (Leaf _) = 1
countLeaves EmptyLeaf = 0
countLeaves InvalidTree = 0
data SaplingNode = SaplingNode data SaplingNode = SaplingNode
{ sn_position :: !Position { sn_position :: !Position