-- |
-- Module      : Network.TLS.Receiving
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- the Receiving module contains calls related to unmarshalling packets according
-- to the TLS state
--
{-# LANGUAGE FlexibleContexts #-}

module Network.TLS.Receiving
    ( processPacket
    , decodeRecordM
    ) where

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.ErrT
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Record
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Util
import Network.TLS.Wire

import Control.Concurrent.MVar
import Control.Monad.State.Strict

processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)

processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)
processPacket _ (Record ProtocolType_AppData _ fragment :: Fragment Plaintext
fragment) = Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right (Packet -> Either TLSError Packet)
-> Packet -> Either TLSError Packet
forall a b. (a -> b) -> a -> b
$ ByteString -> Packet
AppData (ByteString -> Packet) -> ByteString -> Packet
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment

processPacket _ (Record ProtocolType_Alert _ fragment :: Fragment Plaintext
fragment) = Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return ([(AlertLevel, AlertDescription)] -> Packet
Alert ([(AlertLevel, AlertDescription)] -> Packet)
-> Either TLSError [(AlertLevel, AlertDescription)]
-> Either TLSError Packet
forall a b l. (a -> b) -> Either l a -> Either l b
`fmapEither` ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment))

processPacket ctx :: Context
ctx (Record ProtocolType_ChangeCipherSpec _ fragment :: Fragment Plaintext
fragment) =
    case ByteString -> Either TLSError ()
decodeChangeCipherSpec (ByteString -> Either TLSError ())
-> ByteString -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment of
        Left err :: TLSError
err -> Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet
forall a b. a -> Either a b
Left TLSError
err
        Right _  -> do Context -> IO ()
switchRxEncryption Context
ctx
                       Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right Packet
ChangeCipherSpec

processPacket ctx :: Context
ctx (Record ProtocolType_Handshake ver :: Version
ver fragment :: Fragment Plaintext
fragment) = do
    Maybe CipherKeyExchangeType
keyxchg <- Context -> IO (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx IO (Maybe HandshakeState)
-> (Maybe HandshakeState -> IO (Maybe CipherKeyExchangeType))
-> IO (Maybe CipherKeyExchangeType)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \hs :: Maybe HandshakeState
hs -> Maybe CipherKeyExchangeType -> IO (Maybe CipherKeyExchangeType)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe HandshakeState
hs Maybe HandshakeState
-> (HandshakeState -> Maybe Cipher) -> Maybe Cipher
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HandshakeState -> Maybe Cipher
hstPendingCipher Maybe Cipher
-> (Cipher -> Maybe CipherKeyExchangeType)
-> Maybe CipherKeyExchangeType
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CipherKeyExchangeType -> Maybe CipherKeyExchangeType
forall a. a -> Maybe a
Just (CipherKeyExchangeType -> Maybe CipherKeyExchangeType)
-> (Cipher -> CipherKeyExchangeType)
-> Cipher
-> Maybe CipherKeyExchangeType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cipher -> CipherKeyExchangeType
cipherKeyExchange)
    Context -> TLSSt Packet -> IO (Either TLSError Packet)
forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx (TLSSt Packet -> IO (Either TLSError Packet))
-> TLSSt Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ do
        let currentParams :: CurrentParams
currentParams = CurrentParams :: Version -> Maybe CipherKeyExchangeType -> CurrentParams
CurrentParams
                            { cParamsVersion :: Version
cParamsVersion     = Version
ver
                            , cParamsKeyXchgType :: Maybe CipherKeyExchangeType
cParamsKeyXchgType = Maybe CipherKeyExchangeType
keyxchg
                            }
        -- get back the optional continuation, and parse as many handshake record as possible.
        Maybe (GetContinuation (HandshakeType, ByteString))
mCont <- (TLSState -> Maybe (GetContinuation (HandshakeType, ByteString)))
-> TLSSt (Maybe (GetContinuation (HandshakeType, ByteString)))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont
        (TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\st :: TLSState
st -> TLSState
st { stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont = Maybe (GetContinuation (HandshakeType, ByteString))
forall a. Maybe a
Nothing })
        [Handshake]
hss   <- CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> TLSSt [Handshake]
forall (m :: * -> *).
(MonadError TLSError m, MonadState TLSState m) =>
CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
mCont (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment)
        Packet -> TLSSt Packet
forall (m :: * -> *) a. Monad m => a -> m a
return (Packet -> TLSSt Packet) -> Packet -> TLSSt Packet
forall a b. (a -> b) -> a -> b
$ [Handshake] -> Packet
Handshake [Handshake]
hss
  where parseMany :: CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany currentParams :: CurrentParams
currentParams mCont :: Maybe (GetContinuation (HandshakeType, ByteString))
mCont bs :: ByteString
bs =
            case GetContinuation (HandshakeType, ByteString)
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> GetContinuation (HandshakeType, ByteString)
forall a. a -> Maybe a -> a
fromMaybe GetContinuation (HandshakeType, ByteString)
decodeHandshakeRecord Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs of
                GotError err :: TLSError
err                -> TLSError -> m [Handshake]
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                GotPartial cont :: GetContinuation (HandshakeType, ByteString)
cont             -> (TLSState -> TLSState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\st :: TLSState
st -> TLSState
st { stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont = GetContinuation (HandshakeType, ByteString)
-> Maybe (GetContinuation (HandshakeType, ByteString))
forall a. a -> Maybe a
Just GetContinuation (HandshakeType, ByteString)
cont }) m () -> m [Handshake] -> m [Handshake]
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Handshake] -> m [Handshake]
forall (m :: * -> *) a. Monad m => a -> m a
return []
                GotSuccess (ty :: HandshakeType
ty,content :: ByteString
content)     ->
                    (TLSError -> m [Handshake])
-> (Handshake -> m [Handshake])
-> Either TLSError Handshake
-> m [Handshake]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either TLSError -> m [Handshake]
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Handshake] -> m [Handshake]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Handshake] -> m [Handshake])
-> (Handshake -> [Handshake]) -> Handshake -> m [Handshake]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Handshake -> [Handshake] -> [Handshake]
forall a. a -> [a] -> [a]
:[])) (Either TLSError Handshake -> m [Handshake])
-> Either TLSError Handshake -> m [Handshake]
forall a b. (a -> b) -> a -> b
$ CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content
                GotSuccessRemaining (ty :: HandshakeType
ty,content :: ByteString
content) left :: ByteString
left ->
                    case CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content of
                        Left err :: TLSError
err -> TLSError -> m [Handshake]
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                        Right hh :: Handshake
hh -> (Handshake
hhHandshake -> [Handshake] -> [Handshake]
forall a. a -> [a] -> [a]
:) ([Handshake] -> [Handshake]) -> m [Handshake] -> m [Handshake]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
forall a. Maybe a
Nothing ByteString
left

processPacket _ (Record ProtocolType_DeprecatedHandshake _ fragment :: Fragment Plaintext
fragment) =
    case ByteString -> Either TLSError Handshake
decodeDeprecatedHandshake (ByteString -> Either TLSError Handshake)
-> ByteString -> Either TLSError Handshake
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment of
        Left err :: TLSError
err -> Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet
forall a b. a -> Either a b
Left TLSError
err
        Right hs :: Handshake
hs -> Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right (Packet -> Either TLSError Packet)
-> Packet -> Either TLSError Packet
forall a b. (a -> b) -> a -> b
$ [Handshake] -> Packet
Handshake [Handshake
hs]

switchRxEncryption :: Context -> IO ()
switchRxEncryption :: Context -> IO ()
switchRxEncryption ctx :: Context
ctx =
    Context -> HandshakeM (Maybe RecordState) -> IO (Maybe RecordState)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx ((HandshakeState -> Maybe RecordState)
-> HandshakeM (Maybe RecordState)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe RecordState
hstPendingRxState) IO (Maybe RecordState) -> (Maybe RecordState -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \rx :: Maybe RecordState
rx ->
    IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
ctxRxState Context
ctx) (\_ -> RecordState -> IO RecordState
forall (m :: * -> *) a. Monad m => a -> m a
return (RecordState -> IO RecordState) -> RecordState -> IO RecordState
forall a b. (a -> b) -> a -> b
$ String -> Maybe RecordState -> RecordState
forall a. String -> Maybe a -> a
fromJust "rx-state" Maybe RecordState
rx)

decodeRecordM :: Header -> ByteString -> RecordM (Record Plaintext)
decodeRecordM :: Header -> ByteString -> RecordM (Record Plaintext)
decodeRecordM header :: Header
header content :: ByteString
content = Record Ciphertext -> RecordM (Record Plaintext)
disengageRecord Record Ciphertext
erecord
   where
     erecord :: Record Ciphertext
erecord = Header -> Fragment Ciphertext -> Record Ciphertext
forall a. Header -> Fragment a -> Record a
rawToRecord Header
header (ByteString -> Fragment Ciphertext
fragmentCiphertext ByteString
content)