8 #include <botan/internal/tls_handshake_io.h>
9 #include <botan/internal/tls_record.h>
10 #include <botan/internal/tls_seq_numbers.h>
11 #include <botan/tls_messages.h>
12 #include <botan/exceptn.h>
21 inline size_t load_be24(
const uint8_t q[3])
29 void store_be24(uint8_t out[3],
size_t val)
31 out[0] =
get_byte(1, static_cast<uint32_t>(val));
32 out[1] =
get_byte(2, static_cast<uint32_t>(val));
33 out[2] =
get_byte(3, static_cast<uint32_t>(val));
36 uint64_t steady_clock_ms()
38 return std::chrono::duration_cast<std::chrono::milliseconds>(
39 std::chrono::steady_clock::now().time_since_epoch()).count();
42 size_t split_for_mtu(
size_t mtu,
size_t msg_size)
44 const size_t DTLS_HEADERS_SIZE = 25;
46 const size_t parts = (msg_size + mtu) / mtu;
48 if(parts + DTLS_HEADERS_SIZE > mtu)
66 m_queue.insert(m_queue.end(), record.begin(), record.end());
70 if(record.size() != 1 || record[0] != 1)
75 m_queue.insert(m_queue.end(), ccs_hs, ccs_hs +
sizeof(ccs_hs));
81 std::pair<Handshake_Type, std::vector<uint8_t>>
84 if(m_queue.size() >= 4)
86 const size_t length =
make_uint32(0, m_queue[1], m_queue[2], m_queue[3]);
88 if(m_queue.size() >= length + 4)
92 std::vector<uint8_t> contents(m_queue.begin() + 4,
93 m_queue.begin() + 4 + length);
95 m_queue.erase(m_queue.begin(), m_queue.begin() + 4 + length);
97 return std::make_pair(type, contents);
108 std::vector<uint8_t> send_buf(4 + msg.size());
110 const size_t buf_size = msg.size();
114 store_be24(&send_buf[1], buf_size);
118 copy_mem(&send_buf[4], msg.data(), msg.size());
126 const std::vector<uint8_t> msg_bits = msg.
serialize();
131 return std::vector<uint8_t>();
134 const std::vector<uint8_t> buf =
format(msg_bits, msg.
type());
144 void Datagram_Handshake_IO::retransmit_last_flight()
146 const size_t flight_idx = (m_flights.size() == 1) ? 0 : (m_flights.size() - 2);
147 retransmit_flight(flight_idx);
150 void Datagram_Handshake_IO::retransmit_flight(
size_t flight_idx)
152 const std::vector<uint16_t>& flight = m_flights.at(flight_idx);
154 BOTAN_ASSERT(flight.size() > 0,
"Nonempty flight to retransmit");
156 uint16_t epoch = m_flight_data[flight[0]].epoch;
158 for(
auto msg_seq : flight)
160 auto& msg = m_flight_data[msg_seq];
162 if(msg.epoch != epoch)
165 std::vector<uint8_t> ccs(1, 1);
169 send_message(msg_seq, msg.epoch, msg.msg_type, msg.msg_bits);
176 if(m_last_write == 0 || (m_flights.size() > 1 && !m_flights.rbegin()->empty()))
185 const uint64_t ms_since_write = steady_clock_ms() - m_last_write;
187 if(ms_since_write < m_next_timeout)
190 retransmit_last_flight();
192 m_next_timeout =
std::min(2 * m_next_timeout, m_max_timeout);
198 uint64_t record_sequence)
200 const uint16_t epoch =
static_cast<uint16_t
>(record_sequence >> 48);
205 m_ccs_epochs.insert(epoch);
209 const size_t DTLS_HANDSHAKE_HEADER_LEN = 12;
211 const uint8_t* record_bits = record.data();
212 size_t record_size = record.size();
216 if(record_size < DTLS_HANDSHAKE_HEADER_LEN)
219 const uint8_t msg_type = record_bits[0];
220 const size_t msg_len = load_be24(&record_bits[1]);
222 const size_t fragment_offset = load_be24(&record_bits[6]);
223 const size_t fragment_length = load_be24(&record_bits[9]);
225 const size_t total_size = DTLS_HANDSHAKE_HEADER_LEN + fragment_length;
227 if(record_size < total_size)
230 if(message_seq >= m_in_message_seq)
232 m_messages[message_seq].add_fragment(&record_bits[DTLS_HANDSHAKE_HEADER_LEN],
244 record_bits += total_size;
245 record_size -= total_size;
249 std::pair<Handshake_Type, std::vector<uint8_t>>
253 if(!m_flights.rbegin()->empty())
254 m_flights.push_back(std::vector<uint16_t>());
258 if(!m_messages.empty())
260 const uint16_t current_epoch = m_messages.begin()->second.epoch();
262 if(m_ccs_epochs.count(current_epoch))
263 return std::make_pair(
HANDSHAKE_CCS, std::vector<uint8_t>());
268 auto i = m_messages.find(m_in_message_seq);
270 if(i == m_messages.end() || !i->second.complete())
273 m_in_message_seq += 1;
275 return i->second.message();
278 void Datagram_Handshake_IO::Handshake_Reassembly::add_fragment(
279 const uint8_t fragment[],
280 size_t fragment_length,
281 size_t fragment_offset,
292 m_msg_type = msg_type;
293 m_msg_length = msg_length;
296 if(msg_type != m_msg_type || msg_length != m_msg_length || epoch != m_epoch)
297 throw Decoding_Error(
"Inconsistent values in fragmented DTLS handshake header");
299 if(fragment_offset > m_msg_length)
302 if(fragment_offset + fragment_length > m_msg_length)
305 if(fragment_offset == 0 && fragment_length == m_msg_length)
308 m_message.assign(fragment, fragment+fragment_length);
320 for(
size_t i = 0; i != fragment_length; ++i)
321 m_fragments[fragment_offset+i] = fragment[i];
323 if(m_fragments.size() == m_msg_length)
325 m_message.resize(m_msg_length);
326 for(
size_t i = 0; i != m_msg_length; ++i)
327 m_message[i] = m_fragments[i];
333 bool Datagram_Handshake_IO::Handshake_Reassembly::complete()
const
335 return (m_msg_type !=
HANDSHAKE_NONE && m_message.size() == m_msg_length);
338 std::pair<Handshake_Type, std::vector<uint8_t>>
339 Datagram_Handshake_IO::Handshake_Reassembly::message()
const
342 throw Internal_Error(
"Datagram_Handshake_IO - message not complete");
344 return std::make_pair(static_cast<Handshake_Type>(m_msg_type), m_message);
348 Datagram_Handshake_IO::format_fragment(
const uint8_t fragment[],
350 uint16_t frag_offset,
353 uint16_t msg_sequence)
const
355 std::vector<uint8_t> send_buf(12 + frag_len);
359 store_be24(&send_buf[1], msg_len);
361 store_be(msg_sequence, &send_buf[4]);
363 store_be24(&send_buf[6], frag_offset);
364 store_be24(&send_buf[9], frag_len);
368 copy_mem(&send_buf[12], fragment, frag_len);
375 Datagram_Handshake_IO::format_w_seq(
const std::vector<uint8_t>& msg,
377 uint16_t msg_sequence)
const
379 return format_fragment(msg.data(), msg.size(), 0,
static_cast<uint16_t
>(msg.size()), type, msg_sequence);
386 return format_w_seq(msg, type, m_in_message_seq - 1);
393 const std::vector<uint8_t> msg_bits = msg.
serialize();
400 return std::vector<uint8_t>();
404 m_flights.rbegin()->push_back(m_out_message_seq);
405 m_flight_data[m_out_message_seq] = Message_Info(epoch, msg_type, msg_bits);
407 m_out_message_seq += 1;
408 m_last_write = steady_clock_ms();
409 m_next_timeout = m_initial_timeout;
411 return send_message(m_out_message_seq - 1, epoch, msg_type, msg_bits);
414 std::vector<uint8_t> Datagram_Handshake_IO::send_message(uint16_t msg_seq,
417 const std::vector<uint8_t>& msg_bits)
419 const std::vector<uint8_t> no_fragment =
420 format_w_seq(msg_bits, msg_type, msg_seq);
424 m_send_hs(epoch,
HANDSHAKE, no_fragment);
428 const size_t parts = split_for_mtu(m_mtu, msg_bits.size());
430 const size_t parts_size = (msg_bits.size() + parts) / parts;
432 size_t frag_offset = 0;
434 while(frag_offset != msg_bits.size())
436 const size_t frag_len =
437 std::min<size_t>(msg_bits.size() - frag_offset,
442 format_fragment(&msg_bits[frag_offset],
444 static_cast<uint16_t>(frag_offset),
445 static_cast<uint16_t>(msg_bits.size()),
449 frag_offset += frag_len;
bool timeout_check() override
void store_be(uint16_t in, uint8_t out[2])
Protocol_Version initial_record_version() const override
uint16_t load_be< uint16_t >(const uint8_t in[], size_t off)
Protocol_Version initial_record_version() const override
std::vector< uint8_t > send(const Handshake_Message &msg) override
std::string to_string(const BER_Object &obj)
#define BOTAN_ASSERT(expr, assertion_made)
std::vector< uint8_t > send(const Handshake_Message &msg) override
std::pair< Handshake_Type, std::vector< uint8_t > > get_next_record(bool expecting_ccs) override
std::pair< Handshake_Type, std::vector< uint8_t > > get_next_record(bool expecting_ccs) override
void add_record(const std::vector< uint8_t > &record, Record_Type type, uint64_t sequence_number) override
std::vector< uint8_t > format(const std::vector< uint8_t > &handshake_msg, Handshake_Type handshake_type) const override
void copy_mem(T *out, const T *in, size_t n)
virtual Handshake_Type type() const =0
virtual uint16_t current_write_epoch() const =0
uint8_t get_byte(size_t byte_num, T input)
std::vector< uint8_t > format(const std::vector< uint8_t > &handshake_msg, Handshake_Type handshake_type) const override
void add_record(const std::vector< uint8_t > &record, Record_Type type, uint64_t sequence_number) override
virtual std::vector< uint8_t > serialize() const =0
uint32_t make_uint32(uint8_t i0, uint8_t i1, uint8_t i2, uint8_t i3)