26 #ifndef DCCLFIELDCODECARITHMETIC20120726H
27 #define DCCLFIELDCODECARITHMETIC20120726H
33 #include "../field_codec_typed.h"
35 #include "dccl/arithmetic/protobuf/arithmetic.pb.h"
36 #include "dccl/arithmetic/protobuf/arithmetic_extensions.pb.h"
38 #include "../logger.h"
40 #include "../binary.h"
41 #include "../thread_safety.h"
64 using symbol_type = int;
65 using value_type = double;
67 static constexpr symbol_type OUT_OF_RANGE_SYMBOL = -1;
68 static constexpr symbol_type EOF_SYMBOL = -2;
69 static constexpr symbol_type MIN_SYMBOL = EOF_SYMBOL;
71 static constexpr
int CODE_VALUE_BITS = 32;
72 static constexpr
int FREQUENCY_BITS = CODE_VALUE_BITS - 2;
74 static constexpr freq_type MAX_FREQUENCY = (1 << FREQUENCY_BITS) - 1;
76 #if DCCL_THREAD_SUPPORT
77 static std::recursive_mutex last_bits_map_mutex;
78 #define LOCK_LAST_BITS_MAP_MUTEX \
79 std::lock_guard<std::recursive_mutex> l(dccl::arith::Model::last_bits_map_mutex);
81 #define LOCK_LAST_BITS_MAP_MUTEX
84 static std::map<std::string, std::map<std::string, Bitset>> last_bits_map;
94 symbol_type value_to_symbol(value_type value)
const;
95 value_type symbol_to_value(symbol_type symbol)
const;
96 symbol_type total_symbols()
98 return encoder_cumulative_freqs_.size();
103 symbol_type max_symbol()
const {
return user_model_.frequency_size() - 1; }
105 freq_type total_freq(ModelState state)
const
107 const auto& c_freqs =
108 (state == ENCODER) ? encoder_cumulative_freqs_ : decoder_cumulative_freqs_;
110 return c_freqs.at(max_symbol());
113 void update_model(symbol_type symbol, ModelState state);
115 std::pair<freq_type, freq_type> symbol_to_cumulative_freq(symbol_type symbol,
116 ModelState state)
const;
117 std::pair<symbol_type, symbol_type>
118 cumulative_freq_to_symbol(std::pair<freq_type, freq_type> c_freq_pair, ModelState state)
const;
124 std::map<symbol_type, freq_type> encoder_cumulative_freqs_;
125 std::map<symbol_type, freq_type> decoder_cumulative_freqs_;
133 Model& find(
const std::string& name)
135 auto it = arithmetic_models_.find(name);
136 if (it == arithmetic_models_.end())
137 throw(
Exception(
"Cannot find model called: " + name));
145 Model new_model(model);
146 _create_and_validate_model(&new_model);
147 if (arithmetic_models_.count(model.name()))
148 arithmetic_models_.erase(model.name());
149 arithmetic_models_.insert(std::make_pair(model.name(), new_model));
152 void _create_and_validate_model(
Model* model)
154 if (!model->user_model_.IsInitialized())
156 throw(
Exception(
"Invalid model: " + model->user_model_.DebugString() +
157 "Missing fields: " + model->user_model_.InitializationErrorString()));
160 Model::freq_type cumulative_freq = 0;
161 for (Model::symbol_type symbol = Model::MIN_SYMBOL, n = model->user_model_.frequency_size();
162 symbol < n; ++symbol)
164 Model::freq_type freq;
165 if (symbol == Model::EOF_SYMBOL)
166 freq = model->user_model_.eof_frequency();
167 else if (symbol == Model::OUT_OF_RANGE_SYMBOL)
168 freq = model->user_model_.out_of_range_frequency();
170 freq = model->user_model_.frequency(symbol);
172 if (freq == 0 && symbol != Model::OUT_OF_RANGE_SYMBOL && symbol != Model::EOF_SYMBOL)
174 throw(
Exception(
"Invalid model: " + model->user_model_.DebugString() +
175 "All frequencies must be nonzero."));
177 cumulative_freq += freq;
178 model->encoder_cumulative_freqs_.insert(std::make_pair(symbol, cumulative_freq));
182 model->decoder_cumulative_freqs_ = model->encoder_cumulative_freqs_;
184 if (model->total_freq(Model::ENCODER) > Model::MAX_FREQUENCY)
186 throw(
Exception(
"Invalid model: " + model->user_model_.DebugString() +
187 "Sum of all frequencies must be less than " +
188 std::to_string(Model::MAX_FREQUENCY) +
189 " in order to use 64 bit arithmetic"));
192 if (model->user_model_.value_bound_size() != model->user_model_.frequency_size() + 1)
194 throw(
Exception(
"Invalid model: " + model->user_model_.DebugString() +
195 "`value_bound` size must be exactly 1 more than number of symbols (= "
196 "size of `frequency`)."));
200 if (std::adjacent_find(
201 model->user_model_.value_bound().begin(), model->user_model_.value_bound().end(),
202 std::greater_equal<Model::value_type>()) != model->user_model_.value_bound().end())
204 throw(
Exception(
"Invalid model: " + model->user_model_.DebugString() +
205 "`value_bound` must be monotonically increasing."));
210 std::map<std::string, Model> arithmetic_models_;
213 template <
typename FieldType = Model::value_type>
217 static constexpr
uint64 TOP_VALUE =
218 (
static_cast<uint64>(1) << Model::CODE_VALUE_BITS) - 1;
219 static constexpr
uint64 HALF =
220 (
static_cast<uint64>(1) << (Model::CODE_VALUE_BITS - 1));
221 static constexpr
uint64 FIRST_QTR = HALF >> 1;
222 static constexpr
uint64 THIRD_QTR = HALF + FIRST_QTR;
224 Bitset encode_repeated(
const std::vector<Model::value_type>& wire_value)
override
226 return encode_repeated(wire_value,
true);
229 Bitset encode_repeated(
const std::vector<Model::value_type>& wire_value,
bool update_model)
232 using namespace dccl::logger;
233 Model& model = current_model();
237 int bits_to_follow = 0;
240 for (
unsigned value_index = 0, n = max_repeat(); value_index < n; ++value_index)
242 Model::symbol_type symbol = Model::EOF_SYMBOL;
244 if (wire_value.size() > value_index)
246 Model::value_type value = wire_value[value_index];
247 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) value is : " << value
250 symbol = model.value_to_symbol(value);
254 if (symbol == Model::OUT_OF_RANGE_SYMBOL &&
255 model.user_model().out_of_range_frequency() == 0)
257 dlog.
is(DEBUG2) && dlog <<
"(ArithmeticFieldCodec) out of range symbol, but no "
258 "frequency given; ending encoding"
261 symbol = Model::EOF_SYMBOL;
265 if (symbol == Model::EOF_SYMBOL && model.user_model().eof_frequency() == 0)
267 dlog.
is(DEBUG2) && dlog <<
"(ArithmeticFieldCodec) end of file, but no frequency "
268 "given; filling with most probable symbol"
270 symbol = *std::max_element(model.user_model().frequency().begin(),
271 model.user_model().frequency().end());
274 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) symbol is : " << symbol << std::endl;
276 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) current interval: ["
277 << (double)low / TOP_VALUE <<
"," << (
double)high / TOP_VALUE
280 uint64 range = (high - low) + 1;
282 std::pair<Model::freq_type, Model::freq_type> c_freq_range =
283 model.symbol_to_cumulative_freq(symbol, Model::ENCODER);
285 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) input symbol (" << symbol
286 <<
") cumulative freq: [" << c_freq_range.first <<
","
287 << c_freq_range.second <<
")" << std::endl;
289 high = low + (range * c_freq_range.second) / model.total_freq(Model::ENCODER) - 1;
290 low += (range * c_freq_range.first) / model.total_freq(Model::ENCODER);
292 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) input symbol (" << symbol
293 <<
") interval: [" << (double)low / TOP_VALUE <<
","
294 << (
double)high / TOP_VALUE <<
")" << std::endl;
297 dlog <<
"(ArithmeticFieldCodec) Q1: " <<
Bitset(Model::CODE_VALUE_BITS, FIRST_QTR)
298 <<
", Q2: " <<
Bitset(Model::CODE_VALUE_BITS, HALF)
299 <<
", Q3 : " <<
Bitset(Model::CODE_VALUE_BITS, THIRD_QTR)
300 <<
", top: " <<
Bitset(Model::CODE_VALUE_BITS, TOP_VALUE) << std::endl;
302 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) low: "
304 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) high: "
309 model.update_model(symbol, Model::ENCODER);
315 bit_plus_follow(&bits, &bits_to_follow, 0);
317 dlog <<
"(ArithmeticFieldCodec): completely in [0, 0.5): EXPAND"
320 else if (low >= HALF)
322 bit_plus_follow(&bits, &bits_to_follow, 1);
326 dlog <<
"(ArithmeticFieldCodec): completely in [0.5, 1): EXPAND"
329 else if (low >= FIRST_QTR && high < THIRD_QTR)
332 dlog <<
"(ArithmeticFieldCodec): straddle middle [0.25, 0.75): EXPAND"
346 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) low: "
349 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) high: "
353 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) current interval: ["
354 << (double)low / TOP_VALUE <<
","
355 << (
double)high / TOP_VALUE <<
")" << std::endl;
359 if (value_index == wire_value.size())
370 if (high != TOP_VALUE || bits_to_follow > 0)
371 bit_plus_follow(&bits, &bits_to_follow, 0);
375 else if (high == TOP_VALUE)
377 bit_plus_follow(&bits, &bits_to_follow, 1);
385 bit_plus_follow(&bits, &bits_to_follow, (low < FIRST_QTR) ? 0 : 1);
390 LOCK_LAST_BITS_MAP_MUTEX
399 void bit_plus_follow(
Bitset* bits,
int* bits_to_follow,
bool bit)
401 bits->push_back(bit);
402 dccl::dlog.
is(dccl::logger::DEBUG3) &&
403 dccl::dlog <<
"(ArithmeticFieldCodec): emitted bit: " << bit << std::endl;
405 while (*bits_to_follow)
407 dccl::dlog.
is(dccl::logger::DEBUG3) &&
408 dccl::dlog <<
"(ArithmeticFieldCodec): emitted bit (from follow): " << !bit
411 bits->push_back(!bit);
412 (*bits_to_follow) -= 1;
419 using namespace dccl::logger;
421 std::vector<Model::value_type> values;
422 Model& model = current_model();
430 int bit_stream_offset = Model::CODE_VALUE_BITS - bits->size();
432 for (
int i = 0, n = Model::CODE_VALUE_BITS; i < n; ++i)
434 if (i >= bit_stream_offset)
436 (
static_cast<uint64>((*bits)[bits->size() - (i - bit_stream_offset) - 1]) << i);
439 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec): starting value: "
442 for (
unsigned value_index = 0, n = max_repeat(); value_index < n; ++value_index)
444 uint64 range = (high - low) + 1;
446 Model::symbol_type symbol = bits_to_symbol(bits, value, bit_stream_offset, low, range);
448 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) symbol is: " << symbol << std::endl;
450 std::pair<Model::freq_type, Model::freq_type> c_freq_range =
451 model.symbol_to_cumulative_freq(symbol, Model::DECODER);
453 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) input symbol (" << symbol
454 <<
") cumulative freq: [" << c_freq_range.first <<
","
455 << c_freq_range.second <<
")" << std::endl;
457 high = low + (range * c_freq_range.second) / model.total_freq(Model::DECODER) - 1;
458 low += (range * c_freq_range.first) / model.total_freq(Model::DECODER);
460 model.update_model(symbol, Model::DECODER);
462 if (symbol == Model::EOF_SYMBOL)
465 values.push_back(model.symbol_to_value(symbol));
467 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) value is: " << values.back()
476 else if (low >= HALF)
482 else if (low >= FIRST_QTR && high < THIRD_QTR)
495 bit_stream_offset += 1;
502 LOCK_LAST_BITS_MAP_MUTEX
507 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) bits used is (" << bits->size()
508 <<
"): " << *bits << std::endl;
509 dlog.
is(DEBUG3) && dlog <<
"(ArithmeticFieldCodec) bits original is (" << in.size()
510 <<
"): " << in << std::endl;
518 unsigned size_repeated(
const std::vector<Model::value_type>& wire_values)
override
521 return encode_repeated(wire_values,
false).size();
529 Model& model = current_model();
533 Model::freq_type out_of_range_freq = model.user_model().out_of_range_frequency();
534 if (out_of_range_freq == 0)
535 out_of_range_freq = Model::MAX_FREQUENCY;
537 Model::value_type lowest_frequency =
538 std::min(out_of_range_freq, *std::min_element(model.user_model().frequency().begin(),
539 model.user_model().frequency().end()));
542 auto size_least_probable = (unsigned)(std::ceil(
543 max_repeat() * (log2(model.total_freq(Model::ENCODER)) - log2(lowest_frequency))));
545 dccl::dlog.
is(dccl::logger::DEBUG3) &&
546 dccl::dlog <<
"(ArithmeticFieldCodec) size_least_probable: " << size_least_probable
549 Model::freq_type eof_freq = model.user_model().eof_frequency();
551 auto size_least_probable_plus_eof =
552 (unsigned)((eof_freq != 0)
553 ? std::ceil(max_repeat() * log2(model.total_freq(Model::ENCODER)) -
554 (max_repeat() - 1) * log2(lowest_frequency) - log2(eof_freq))
557 dccl::dlog.
is(dccl::logger::DEBUG3) &&
558 dccl::dlog <<
"(ArithmeticFieldCodec) size_least_probable_plus_eof: "
559 << size_least_probable_plus_eof << std::endl;
561 return std::max(size_least_probable_plus_eof, size_least_probable) + 1;
567 const Model& model = current_model();
569 if (model.user_model().is_adaptive())
574 Model::freq_type out_of_range_freq = model.user_model().out_of_range_frequency();
575 if (out_of_range_freq == 0)
576 out_of_range_freq = 1;
578 Model::freq_type eof_freq = model.user_model().eof_frequency();
581 (unsigned)((eof_freq != 0)
582 ? std::ceil(log2(model.total_freq(Model::ENCODER)) - log2(eof_freq))
583 : std::numeric_limits<unsigned>::max());
585 dccl::dlog.
is(dccl::logger::DEBUG3) &&
586 dccl::dlog <<
"(ArithmeticFieldCodec) size_empty: " << size_empty << std::endl;
589 Model::value_type highest_frequency =
590 std::max(out_of_range_freq, *std::max_element(model.user_model().frequency().begin(),
591 model.user_model().frequency().end()));
593 auto size_most_probable = (unsigned)(std::ceil(
594 max_repeat() * (log2(model.total_freq(Model::ENCODER)) - log2(highest_frequency))));
596 dccl::dlog.
is(dccl::logger::DEBUG3) &&
597 dccl::dlog <<
"(ArithmeticFieldCodec) size_most_probable: " << size_most_probable
600 return std::min(size_empty, size_most_probable);
606 "missing (dccl.field).arithmetic");
608 std::string model_name =
612 model_manager().find(model_name);
617 model_name +
"\" loaded.");
623 Model::symbol_type bits_to_symbol(
Bitset* bits,
uint64& value,
int& bit_stream_offset,
626 Model& model = current_model();
630 uint64 value_high = (bit_stream_offset > 0)
631 ? value + ((
static_cast<uint64>(1) << bit_stream_offset) - 1)
634 dccl::dlog.
is(dccl::logger::DEBUG3) &&
635 dccl::dlog <<
"(ArithmeticFieldCodec): value range: ["
636 <<
Bitset(Model::CODE_VALUE_BITS, value) <<
","
637 <<
Bitset(Model::CODE_VALUE_BITS, value_high) <<
")" << std::endl;
639 Model::freq_type cumulative_freq =
640 ((value - low + 1) * model.total_freq(Model::DECODER) - 1) / range;
641 Model::freq_type cumulative_freq_high =
642 ((value_high - low + 1) * model.total_freq(Model::DECODER) - 1) / range;
644 dccl::dlog.
is(dccl::logger::DEBUG3) &&
645 dccl::dlog <<
"(ArithmeticFieldCodec): c_freq: " << cumulative_freq
646 <<
", c_freq_high: " << cumulative_freq_high << std::endl;
648 std::pair<Model::symbol_type, Model::symbol_type> symbol_pair =
649 model.cumulative_freq_to_symbol(
650 std::make_pair(cumulative_freq, cumulative_freq_high), Model::DECODER);
652 dccl::dlog.
is(dccl::logger::DEBUG3) &&
653 dccl::dlog <<
"(ArithmeticFieldCodec): symbol: " << symbol_pair.first <<
", "
654 << symbol_pair.second << std::endl;
656 if (symbol_pair.first == symbol_pair.second)
657 return symbol_pair.first;
662 dccl::dlog.
is(dccl::logger::DEBUG3) &&
663 dccl::dlog <<
"(ArithmeticFieldCodec): bits: " << *bits << std::endl;
666 value |=
static_cast<uint64>(bits->back()) << bit_stream_offset;
668 dccl::dlog.
is(dccl::logger::DEBUG3) &&
669 dccl::dlog <<
"(ArithmeticFieldCodec): ambiguous (symbol could be "
670 << symbol_pair.first <<
" or " << symbol_pair.second <<
")" << std::endl;
683 Model& current_model()
686 return model_manager().find(
name);
689 ModelManager& model_manager() {
return dccl::arith::model_manager(this->manager()); }
693 template <
typename FieldType>
const uint64 ArithmeticFieldCodecBase<FieldType>::TOP_VALUE;
694 template <
typename FieldType>
const uint64 ArithmeticFieldCodecBase<FieldType>::FIRST_QTR;
695 template <
typename FieldType>
const uint64 ArithmeticFieldCodecBase<FieldType>::HALF;
696 template <
typename FieldType>
const uint64 ArithmeticFieldCodecBase<FieldType>::THIRD_QTR;
698 template <
typename FieldType>
701 Model::value_type
pre_encode(
const FieldType& field_value)
override
703 return static_cast<Model::value_type
>(field_value);
706 FieldType
post_decode(
const Model::value_type& wire_value)
override
708 return static_cast<FieldType
>(wire_value);
718 pre_encode(
const google::protobuf::EnumValueDescriptor*
const& field_value)
override
720 return field_value->number();
723 const google::protobuf::EnumValueDescriptor*
724 post_decode(
const Model::value_type& wire_value)
override
727 const google::protobuf::EnumValueDescriptor* return_value =
728 e->FindValueByNumber((
int)wire_value);