CMU15445 之 Project#0 - C++ Primer 詳解

之一Yo發表於2022-06-26

前言

這個實驗主要用來測試大家對現代 C++ 的掌握程度,實驗要求如下:

實驗要求

簡單翻譯一下上述要求,就是我們需要實現定義在 src/include/primer/p0_starter.h 中的三個類 MatrixRowMatrixRowMatrixOperations,其中 MatrixRowMatrix 的父類,RowMatrixOperations 定義了三個用於陣列運算的成員函式:AddMultiplyGEMM(就是 \(\boldsymbol{A}*\boldsymbol{B} + \boldsymbol{C}\))。

程式碼實現

Matrix 類

抽象基類 Matrix 需要我們編寫的程式碼很少,只要完成建構函式和解構函式即可,下面省略了一些不需要我們寫的程式碼:

template <typename T>
class Matrix {
 protected:
  /**
   *
   * Construct a new Matrix instance.
   * @param rows The number of rows
   * @param cols The number of columns
   *
   */
  Matrix(int rows, int cols) : rows_(rows), cols_(cols), linear_(new T[rows * cols]) {}
  int rows_;
  int cols_;
  T *linear_;

 public:
  /**
   * Destroy a matrix instance.
   * TODO(P0): Add implementation
   */
  virtual ~Matrix() { delete[] linear_; }
};

linear_ 指向一個由二維矩陣展平而得的一維陣列,裡面共有 rows * cols 個型別為 T 的元素。由於我們在堆上分配陣列的空間使用的是 new T[],所以刪除的時候也得用 delete[]

RowMatrix 類

這個類用於表示二維矩陣,需要實現父類 Matrix 中的所有純虛擬函式,為了方便訪問資料元素,RowMatrix 多定義了一個指標陣列 data_,裡面的每個元素分別指向了二維矩陣每行首元素的地址:

template <typename T>
class RowMatrix : public Matrix<T> {
 public:
  /**
   * Construct a new RowMatrix instance.
   * @param rows The number of rows
   * @param cols The number of columns
   */
  RowMatrix(int rows, int cols) : Matrix<T>(rows, cols) {
    data_ = new T *[rows];
    for (int i = 0; i < rows; ++i) {
      data_[i] = &this->linear_[i * cols];
    }
  }

  /**
   * @return The number of rows in the matrix
   */
  auto GetRowCount() const -> int override { return this->rows_; }

  /**
   * @return The number of columns in the matrix
   */
  auto GetColumnCount() const -> int override { return this->cols_; }

  /**
   * Get the (i,j)th matrix element.
   *
   * Throw OUT_OF_RANGE if either index is out of range.
   *
   * @param i The row index
   * @param j The column index
   * @return The (i,j)th matrix element
   * @throws OUT_OF_RANGE if either index is out of range
   */
  auto GetElement(int i, int j) const -> T override {
    if (i < 0 || i >= GetRowCount() || j < 0 || j >= GetColumnCount()) {
      throw Exception(ExceptionType::OUT_OF_RANGE, "The index out of range");
    }

    return data_[i][j];
  }

  /**
   * Set the (i,j)th matrix element.
   *
   * Throw OUT_OF_RANGE if either index is out of range.
   *
   * @param i The row index
   * @param j The column index
   * @param val The value to insert
   * @throws OUT_OF_RANGE if either index is out of range
   */
  void SetElement(int i, int j, T val) override {
    if (i < 0 || i >= GetRowCount() || j < 0 || j >= GetColumnCount()) {
      throw Exception(ExceptionType::OUT_OF_RANGE, "The index out of range");
    }

    data_[i][j] = val;
  }

  /**
   * Fill the elements of the matrix from `source`.
   *
   * Throw OUT_OF_RANGE in the event that `source`
   * does not contain the required number of elements.
   *
   * @param source The source container
   * @throws OUT_OF_RANGE if `source` is incorrect size
   */
  void FillFrom(const std::vector<T> &source) override {
    if (static_cast<int>(source.size()) != GetRowCount() * GetColumnCount()) {
      throw Exception(ExceptionType::OUT_OF_RANGE, "The number of elements of `source` is different from matrix");
    }

    for (int i = 0; i < GetRowCount(); ++i) {
      for (int j = 0; j < GetColumnCount(); ++j) {
        data_[i][j] = source[i * GetColumnCount() + j];
      }
    }
  }

  /**
   * Destroy a RowMatrix instance.
   */
  ~RowMatrix() override { delete[] data_; }

 private:
  T **data_;
};

需要注意的是,在 RowMatrix 中訪問基類部分的成員(非虛擬函式)時需要加上 this 指標,不然編譯時會報錯說找不到指定的成員。

RowMatrixOperations 類

實現該類的三個成員函式之前應該檢查資料維度是否匹配,不匹配就返回空指標,否則開個迴圈遍歷二維矩陣完成相關操作即可:

template <typename T>
class RowMatrixOperations {
 public:
  /**
   * Compute (`matrixA` + `matrixB`) and return the result.
   * Return `nullptr` if dimensions mismatch for input matrices.
   * @param matrixA Input matrix
   * @param matrixB Input matrix
   * @return The result of matrix addition
   */
  static auto Add(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB) -> std::unique_ptr<RowMatrix<T>> {
    if (matrixA->GetRowCount() != matrixB->GetRowCount() || matrixA->GetColumnCount() != matrixB->GetColumnCount()) {
      return std::unique_ptr<RowMatrix<T>>(nullptr);
    }

    auto rows = matrixA->GetRowCount();
    auto cols = matrixA->GetColumnCount();
    auto matrix = std::make_unique<RowMatrix<T>>(rows, cols);
    for (int i = 0; i < rows; ++i) {
      for (int j = 0; j < cols; ++j) {
        matrix->SetElement(i, j, matrixA->GetElement(i, j) + matrixB->GetElement(i, j));
      }
    }

    return matrix;
  }

  /**
   * Compute the matrix multiplication (`matrixA` * `matrixB` and return the result.
   * Return `nullptr` if dimensions mismatch for input matrices.
   * @param matrixA Input matrix
   * @param matrixB Input matrix
   * @return The result of matrix multiplication
   */
  static auto Multiply(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB) -> std::unique_ptr<RowMatrix<T>> {
    if (matrixA->GetColumnCount() != matrixB->GetRowCount()) {
      return std::unique_ptr<RowMatrix<T>>(nullptr);
    }

    auto rows = matrixA->GetRowCount();
    auto cols = matrixB->GetColumnCount();
    auto matrix = std::make_unique<RowMatrix<T>>(rows, cols);
    for (int i = 0; i < rows; ++i) {
      for (int j = 0; j < cols; ++j) {
        T sum = 0;
        for (int k = 0; k < matrixA->GetColumnCount(); ++k) {
          sum += matrixA->GetElement(i, k) * matrixB->GetElement(k, j);
        }
        matrix->SetElement(i, j, sum);
      }
    }

    return matrix;
  }

  /**
   * Simplified General Matrix Multiply operation. Compute (`matrixA` * `matrixB` + `matrixC`).
   * Return `nullptr` if dimensions mismatch for input matrices.
   * @param matrixA Input matrix
   * @param matrixB Input matrix
   * @param matrixC Input matrix
   * @return The result of general matrix multiply
   */
  static auto GEMM(const RowMatrix<T> *matrixA, const RowMatrix<T> *matrixB, const RowMatrix<T> *matrixC)
      -> std::unique_ptr<RowMatrix<T>> {
    if (matrixA->GetColumnCount() != matrixB->GetRowCount()) {
      return std::unique_ptr<RowMatrix<T>>(nullptr);
    }
    if (matrixA->GetRowCount() != matrixC->GetRowCount() || matrixB->GetColumnCount() != matrixC->GetColumnCount()) {
      return std::unique_ptr<RowMatrix<T>>(nullptr);
    }

    return Add(Multiply(matrixA, matrixB).get(), matrixC);
  }
};

測試

開啟 test/primer/starter_test.cpp,將各個測試用例裡面的 DISABLED_ 字首移除,比如 TEST(StarterTest, DISABLED_SampleTest) 改為 TEST(StarterTest, SampleTest),之後執行下述命令:

mkdir build
cd build
cmake ..
make starter_test
./test/starter_test

測試結果如下圖所示:

測試結果

總結

這次實驗感覺比較簡單,主要考察虛擬函式、模板和動態記憶體(包括智慧指標)的知識,就是沒搞明白為什麼函式都用尾置返回型別,而且 Google 風格也讓人很不習慣,縮排居然只有兩格,函式居然開頭大寫。以上~~

相關文章