FPGA開発日記

カテゴリ別記事インデックス https://msyksphinz.github.io/github_pages , English Version https://fpgadevdiary.hatenadiary.com/

分岐予測の評価キット Branch Prediction Championship Kit を試す (11. TAGEのモデルを読む)

TAGEの続き。

いろいろさがして、以下のC++コードを見つけた。

github.com

このソースコードを理解してみよう。

class Tage
 {
 private:
   /* data */
   int num_branches;                                               // Stores the number of branch instructions since the last useful reset
   uint8_t bimodal_table[TAGE_BIMODAL_TABLE_SIZE];                 // Array represent the counters of the bimodal table
   struct tage_predictor_table_entry predictor_table[TAGE_NUM_COMPONENTS][(1 << TAGE_MAX_INDEX_BITS)];
   uint8_t global_history[TAGE_GLOBAL_HISTORY_BUFFER_LENGTH];      // Stores the global branch history
   uint8_t path_history[TAGE_PATH_HISTORY_BUFFER_LENGTH];          // Stores the last bits of the last N branch PCs
   uint8_t use_alt_on_na;                                          // 4 bit counter to decide between alternate and provider component prediction
   int component_history_lengths[TAGE_NUM_COMPONENTS];             // History lengths used to compute hashes for different components
   uint8_t tage_pred, pred, alt_pred;                              // Final prediction , provider prediction, and alternate prediction
   int pred_comp, alt_comp;                                        // Provider and alternate component of last branch PC
   int STRONG;                                                     //Strength of provider prediction counter of last branch PC

 public:
   void init();                                                    // initialise the member variables
   uint8_t predict(uint64_t ip);                                   // return the prediction from tage
   void update(uint64_t ip, uint8_t taken);                        // updates the state of tage

   Index get_bimodal_index(uint64_t ip);                           // helper hash function to index into the bimodal table
   Index get_predictor_index(uint64_t ip, int component);          // helper hash function to index into the predictor table using histories
   Tag get_tag(uint64_t ip, int component);                        // helper hash function to get the tag of particular ip and component
   int get_match_below_n(uint64_t ip, int component);              // helper function to find the hit component strictly before the component argument
   void ctr_update(uint8_t &ctr, int cond, int low, int high);     // counter update helper function (including clipping)
   uint8_t get_prediction(uint64_t ip, int comp);                  // helper function for prediction
   Path get_path_history_hash(int component);                      // hepoer hash function to compress the path history
   History get_compressed_global_history(int inSize, int outSize); // Compress global history of last 'inSize' branches into 'outSize' by wrapping the history

   Tage();
   ~Tage();
 };
  • num_branches

    • 分岐が実行されるたびにカウントアップされる。一定回数実行されると、カウンタがリセットされる。
         // graceful resetting of useful counter
         num_branches++;
         if (num_branches % TAGE_RESET_USEFUL_INTERVAL == 0)
         {
           num_branches = 0;
           for (int i = 0; i < TAGE_NUM_COMPONENTS; i++)
           {
             for (int j = 0; j < (1 << TAGE_INDEX_BITS[i]); j++)
               predictor_table[i][j].useful >>= 1;
           }
         }
    
  • bimodal_table

    • ベースとなる分岐予測器 (base predictor)。TAGE_BIMODAL_TABLE_SIZEの大きさを持っている。これは単純な2ビットのカウンタになっている。
    • 予測条件。単純なBimodalカウンタ
      uint8_t Tage::get_prediction(uint64_t ip, int comp)
       {
         /*
           Get the prediction according to a specific component
         */
         if(comp == 0) // Check if component is the bimodal table
         {
           Index index = get_bimodal_index(ip); // Get bimodal index
           return bimodal_table[index] >= (1 << (TAGE_BASE_COUNTER_BITS - 1));
         }
    
  • global_history

    • グローバル履歴を保持している。TAGE_PATH_HISTORY_BUFFER_LENGTHの長さだけ保持している。
    • update()の時に更新する。
         // update global history
         for (int i = TAGE_GLOBAL_HISTORY_BUFFER_LENGTH - 1; i > 0; i--)
           global_history[i] = global_history[i - 1];
         global_history[0] = taken;
    
  • tage_predictor_table_entry predictor_table

    • predictor_tableはTAGE_NUM_COMPONENTSで示すコンポーネントの数だけ予測器を持っている。
      struct tage_predictor_table_entry
       {
         uint8_t ctr; // The counter on which prediction is based Range - 0-7
         Tag tag; // Stores the tag
         uint8_t useful; // Variable to store the usefulness of the entry Range - 0-3
       };
    
      struct tage_predictor_table_entry predictor_table[TAGE_NUM_COMPONENTS][(1 << TAGE_MAX_INDEX_BITS)];
    
    • まず、それぞれのpredictorは以下のパラメータで初期化される。
         for (int i = 0; i < TAGE_NUM_COMPONENTS; i++)
         {
           for (int j = 0; j < (1 << TAGE_INDEX_BITS[i]); j++)
           {
             predictor_table[i][j].ctr = (1 << (TAGE_COUNTER_BITS - 1)); // weakly taken
             predictor_table[i][j].useful = 0;                           // not useful
             predictor_table[i][j].tag = 0;
           }
         }
    

まずは予測だが、最も長い履歴を使うコンポーネントからタグがマッチするものを見つけていく。戻り値はコンポーネントのID。

 int Tage::get_match_below_n(uint64_t ip, int component)
 {
   /*
     Get component number of first predictor which has an entry for the IP below a specfic component number
   */
   for (int i = component - 1; i >= 1; i--)
   {
     Index index = get_predictor_index(ip, i);
     Tag tag = get_tag(ip, i);

     if (predictor_table[i - 1][index].tag == tag) // Compare tags at a specific index
     {
       return i;
     }
   }

   return 0; // Default to bimodal in case no match found
 }

これがpred_compとして最終的な予測になり、一方でそれよりも小さくてマッチするIDがalt_compとして登録される。

uint8_t Tage::predict(uint64_t ip)
 {
   pred_comp = get_match_below_n(ip, TAGE_NUM_COMPONENTS + 1); // Get the first predictor from the end which matches the PC
   alt_comp = get_match_below_n(ip, pred_comp); // Get the first predictor below the provider which matches the PC

   // Store predictions for both components for use in the update step
   pred = get_prediction(ip, pred_comp);
   alt_pred = get_prediction(ip, alt_comp);

その次のpredの選択条件は正直良く分からなくて、predではなくalt_predを使う条件も存在しているらしい。

   if(pred_comp == 0)
     tage_pred = pred;
   else
   {
     Index index = get_predictor_index(ip, pred_comp);
     STRONG = abs(2 * predictor_table[pred_comp - 1][index].ctr + 1 - (1 << TAGE_COUNTER_BITS)) > 1;
     if (use_alt_on_na < 8 || STRONG) // Use provider component only if USE_ALT_ON_NAs < 8 or the provider counter is strong
       tage_pred = pred;
     else
       tage_pred = alt_pred;
   }
   return tage_pred;
 }

アップデートの方式はこちら。そのままif分を日本語に変えていくと

  • 最終予測がBimodalのbase predictorではない場合
    • Weakly予測の場合 (STRONG予測ではない場合)
      • predとalt_predが異なる場合
        • use_alt_on_naをアップデートする。predの予測が正しくなかった場合、use_alt_on_naをインクリメントする。
        • このuse_alt_on_naは、状況によってalt_predを予測として使用できるようにする機能だと思う。
    • alt_compがBimodalのbase predictorではない場合
      • predのエントリがuseful = 0の場合(つまりpredのエントリが役に立たない場合)
        • takenの場合はカウンタをインクリメント、not takenの場合はデクリメントする
    • alt_compがBimodalのbase predictorである場合
      • predのエントリがuseful = 0の場合(つまりpredのエントリが役に立たない場合)
        • takenの場合はカウンタをインクリメント、not takenの場合はデクリメントする
    • predとalt_predが異なる場合
      • predの予測がヒットした場合
        • predのエントリのusefulをインクリメントする
      • predの予測がヒットしなかった場合 (つまり、alt_predがヒットした場合)
        • use_alt_on_na < 8以下、つまり、altの信用度が高い場合
          • predのエントリのusefulをデクリメントする
    • predのエントリのカウンタをアップデートする
  • 最終予測がBimodalのbase predictorである場合
    • Bimodal Tableをアップデートする

ここまで来て、これはTAGE-SC-Lではないか?という気がしてきた。