{-# LANGUAGE FlexibleContexts #-}
module Network.TLS.Receiving
( processPacket
, decodeRecordM
) where
import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.ErrT
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Record
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Util
import Network.TLS.Wire
import Control.Concurrent.MVar
import Control.Monad.State.Strict
processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)
processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)
processPacket _ (Record ProtocolType_AppData _ fragment :: Fragment Plaintext
fragment) = Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right (Packet -> Either TLSError Packet)
-> Packet -> Either TLSError Packet
forall a b. (a -> b) -> a -> b
$ ByteString -> Packet
AppData (ByteString -> Packet) -> ByteString -> Packet
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment
processPacket _ (Record ProtocolType_Alert _ fragment :: Fragment Plaintext
fragment) = Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return ([(AlertLevel, AlertDescription)] -> Packet
Alert ([(AlertLevel, AlertDescription)] -> Packet)
-> Either TLSError [(AlertLevel, AlertDescription)]
-> Either TLSError Packet
forall a b l. (a -> b) -> Either l a -> Either l b
`fmapEither` ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment))
processPacket ctx :: Context
ctx (Record ProtocolType_ChangeCipherSpec _ fragment :: Fragment Plaintext
fragment) =
case ByteString -> Either TLSError ()
decodeChangeCipherSpec (ByteString -> Either TLSError ())
-> ByteString -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment of
Left err :: TLSError
err -> Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet
forall a b. a -> Either a b
Left TLSError
err
Right _ -> do Context -> IO ()
switchRxEncryption Context
ctx
Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right Packet
ChangeCipherSpec
processPacket ctx :: Context
ctx (Record ProtocolType_Handshake ver :: Version
ver fragment :: Fragment Plaintext
fragment) = do
Maybe CipherKeyExchangeType
keyxchg <- Context -> IO (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx IO (Maybe HandshakeState)
-> (Maybe HandshakeState -> IO (Maybe CipherKeyExchangeType))
-> IO (Maybe CipherKeyExchangeType)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \hs :: Maybe HandshakeState
hs -> Maybe CipherKeyExchangeType -> IO (Maybe CipherKeyExchangeType)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe HandshakeState
hs Maybe HandshakeState
-> (HandshakeState -> Maybe Cipher) -> Maybe Cipher
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HandshakeState -> Maybe Cipher
hstPendingCipher Maybe Cipher
-> (Cipher -> Maybe CipherKeyExchangeType)
-> Maybe CipherKeyExchangeType
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= CipherKeyExchangeType -> Maybe CipherKeyExchangeType
forall a. a -> Maybe a
Just (CipherKeyExchangeType -> Maybe CipherKeyExchangeType)
-> (Cipher -> CipherKeyExchangeType)
-> Cipher
-> Maybe CipherKeyExchangeType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cipher -> CipherKeyExchangeType
cipherKeyExchange)
Context -> TLSSt Packet -> IO (Either TLSError Packet)
forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx (TLSSt Packet -> IO (Either TLSError Packet))
-> TLSSt Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ do
let currentParams :: CurrentParams
currentParams = CurrentParams :: Version -> Maybe CipherKeyExchangeType -> CurrentParams
CurrentParams
{ cParamsVersion :: Version
cParamsVersion = Version
ver
, cParamsKeyXchgType :: Maybe CipherKeyExchangeType
cParamsKeyXchgType = Maybe CipherKeyExchangeType
keyxchg
}
Maybe (GetContinuation (HandshakeType, ByteString))
mCont <- (TLSState -> Maybe (GetContinuation (HandshakeType, ByteString)))
-> TLSSt (Maybe (GetContinuation (HandshakeType, ByteString)))
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont
(TLSState -> TLSState) -> TLSSt ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\st :: TLSState
st -> TLSState
st { stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont = Maybe (GetContinuation (HandshakeType, ByteString))
forall a. Maybe a
Nothing })
[Handshake]
hss <- CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> TLSSt [Handshake]
forall (m :: * -> *).
(MonadError TLSError m, MonadState TLSState m) =>
CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
mCont (Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment)
Packet -> TLSSt Packet
forall (m :: * -> *) a. Monad m => a -> m a
return (Packet -> TLSSt Packet) -> Packet -> TLSSt Packet
forall a b. (a -> b) -> a -> b
$ [Handshake] -> Packet
Handshake [Handshake]
hss
where parseMany :: CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany currentParams :: CurrentParams
currentParams mCont :: Maybe (GetContinuation (HandshakeType, ByteString))
mCont bs :: ByteString
bs =
case GetContinuation (HandshakeType, ByteString)
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> GetContinuation (HandshakeType, ByteString)
forall a. a -> Maybe a -> a
fromMaybe GetContinuation (HandshakeType, ByteString)
decodeHandshakeRecord Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs of
GotError err :: TLSError
err -> TLSError -> m [Handshake]
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
GotPartial cont :: GetContinuation (HandshakeType, ByteString)
cont -> (TLSState -> TLSState) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\st :: TLSState
st -> TLSState
st { stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont = GetContinuation (HandshakeType, ByteString)
-> Maybe (GetContinuation (HandshakeType, ByteString))
forall a. a -> Maybe a
Just GetContinuation (HandshakeType, ByteString)
cont }) m () -> m [Handshake] -> m [Handshake]
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Handshake] -> m [Handshake]
forall (m :: * -> *) a. Monad m => a -> m a
return []
GotSuccess (ty :: HandshakeType
ty,content :: ByteString
content) ->
(TLSError -> m [Handshake])
-> (Handshake -> m [Handshake])
-> Either TLSError Handshake
-> m [Handshake]
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either TLSError -> m [Handshake]
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ([Handshake] -> m [Handshake]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Handshake] -> m [Handshake])
-> (Handshake -> [Handshake]) -> Handshake -> m [Handshake]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Handshake -> [Handshake] -> [Handshake]
forall a. a -> [a] -> [a]
:[])) (Either TLSError Handshake -> m [Handshake])
-> Either TLSError Handshake -> m [Handshake]
forall a b. (a -> b) -> a -> b
$ CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content
GotSuccessRemaining (ty :: HandshakeType
ty,content :: ByteString
content) left :: ByteString
left ->
case CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content of
Left err :: TLSError
err -> TLSError -> m [Handshake]
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
Right hh :: Handshake
hh -> (Handshake
hhHandshake -> [Handshake] -> [Handshake]
forall a. a -> [a] -> [a]
:) ([Handshake] -> [Handshake]) -> m [Handshake] -> m [Handshake]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
forall a. Maybe a
Nothing ByteString
left
processPacket _ (Record ProtocolType_DeprecatedHandshake _ fragment :: Fragment Plaintext
fragment) =
case ByteString -> Either TLSError Handshake
decodeDeprecatedHandshake (ByteString -> Either TLSError Handshake)
-> ByteString -> Either TLSError Handshake
forall a b. (a -> b) -> a -> b
$ Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment of
Left err :: TLSError
err -> Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ TLSError -> Either TLSError Packet
forall a b. a -> Either a b
Left TLSError
err
Right hs :: Handshake
hs -> Either TLSError Packet -> IO (Either TLSError Packet)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError Packet -> IO (Either TLSError Packet))
-> Either TLSError Packet -> IO (Either TLSError Packet)
forall a b. (a -> b) -> a -> b
$ Packet -> Either TLSError Packet
forall a b. b -> Either a b
Right (Packet -> Either TLSError Packet)
-> Packet -> Either TLSError Packet
forall a b. (a -> b) -> a -> b
$ [Handshake] -> Packet
Handshake [Handshake
hs]
switchRxEncryption :: Context -> IO ()
switchRxEncryption :: Context -> IO ()
switchRxEncryption ctx :: Context
ctx =
Context -> HandshakeM (Maybe RecordState) -> IO (Maybe RecordState)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx ((HandshakeState -> Maybe RecordState)
-> HandshakeM (Maybe RecordState)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe RecordState
hstPendingRxState) IO (Maybe RecordState) -> (Maybe RecordState -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \rx :: Maybe RecordState
rx ->
IO () -> IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
ctxRxState Context
ctx) (\_ -> RecordState -> IO RecordState
forall (m :: * -> *) a. Monad m => a -> m a
return (RecordState -> IO RecordState) -> RecordState -> IO RecordState
forall a b. (a -> b) -> a -> b
$ String -> Maybe RecordState -> RecordState
forall a. String -> Maybe a -> a
fromJust "rx-state" Maybe RecordState
rx)
decodeRecordM :: Header -> ByteString -> RecordM (Record Plaintext)
decodeRecordM :: Header -> ByteString -> RecordM (Record Plaintext)
decodeRecordM header :: Header
header content :: ByteString
content = Record Ciphertext -> RecordM (Record Plaintext)
disengageRecord Record Ciphertext
erecord
where
erecord :: Record Ciphertext
erecord = Header -> Fragment Ciphertext -> Record Ciphertext
forall a. Header -> Fragment a -> Record a
rawToRecord Header
header (ByteString -> Fragment Ciphertext
fragmentCiphertext ByteString
content)