1 : // Copyright 2013 Google Inc. All Rights Reserved.
2 : //
3 : // Licensed under the Apache License, Version 2.0 (the "License");
4 : // you may not use this file except in compliance with the License.
5 : // You may obtain a copy of the License at
6 : //
7 : // http://www.apache.org/licenses/LICENSE-2.0
8 : //
9 : // Unless required by applicable law or agreed to in writing, software
10 : // distributed under the License is distributed on an "AS IS" BASIS,
11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : // See the License for the specific language governing permissions and
13 : // limitations under the License.
14 :
15 : #include "syzygy/pe/transforms/coff_add_imports_transform.h"
16 :
17 : #include "gtest/gtest.h"
18 : #include "syzygy/block_graph/typed_block.h"
19 : #include "syzygy/block_graph/unittest_util.h"
20 : #include "syzygy/block_graph/orderers/original_orderer.h"
21 : #include "syzygy/core/unittest_util.h"
22 : #include "syzygy/pe/coff_decomposer.h"
23 : #include "syzygy/pe/coff_file.h"
24 : #include "syzygy/pe/coff_file_writer.h"
25 : #include "syzygy/pe/coff_image_layout_builder.h"
26 : #include "syzygy/pe/pe_utils.h"
27 : #include "syzygy/pe/unittest_util.h"
28 :
29 : namespace pe {
30 : namespace transforms {
31 :
32 : using block_graph::BlockGraph;
33 : using block_graph::ConstTypedBlock;
34 : using block_graph::OrderedBlockGraph;
35 : using core::RelativeAddress;
36 :
37 : namespace {
38 :
39 : class CoffAddImportsTransformTest : public testing::PELibUnitTest {
40 : public:
41 E : CoffAddImportsTransformTest() : image_layout_(&block_graph_) {
42 E : }
43 :
44 E : virtual void SetUp() OVERRIDE {
45 E : testing::PELibUnitTest::SetUp();
46 :
47 : test_dll_obj_path_ =
48 E : testing::GetExeTestDataRelativePath(testing::kTestDllCoffObjName);
49 E : ASSERT_NO_FATAL_FAILURE(CreateTemporaryDir(&temp_dir_path_));
50 E : new_test_dll_obj_path_ = temp_dir_path_.Append(L"test_dll.obj");
51 :
52 E : ASSERT_NO_FATAL_FAILURE(DecomposeOriginal());
53 E : }
54 :
55 : protected:
56 : // Decompose test_dll.coff_obj.
57 E : void DecomposeOriginal() {
58 E : ASSERT_TRUE(image_file_.Init(test_dll_obj_path_));
59 E : CoffDecomposer decomposer(image_file_);
60 E : ASSERT_TRUE(decomposer.Decompose(&image_layout_));
61 :
62 E : headers_block_ = image_layout_.blocks.GetBlockByAddress(RelativeAddress(0));
63 E : ASSERT_TRUE(headers_block_ != NULL);
64 E : }
65 :
66 : // Reorder and lay out test_dll.coff_obj into a new object file, located
67 : // at new_test_dll_obj_path_.
68 E : void LayoutAndWriteNew(block_graph::BlockGraphOrdererInterface* orderer) {
69 E : DCHECK(orderer != NULL);
70 :
71 : // Cast headers block.
72 E : ConstTypedBlock<IMAGE_FILE_HEADER> file_header;
73 E : ASSERT_TRUE(file_header.Init(0, headers_block_));
74 :
75 : // Reorder using the specified ordering.
76 E : OrderedBlockGraph ordered_graph(&block_graph_);
77 E : ASSERT_TRUE(orderer->OrderBlockGraph(&ordered_graph, headers_block_));
78 :
79 : // Wipe references from headers, so we can remove relocation blocks
80 : // during laying out.
81 E : ASSERT_TRUE(headers_block_->RemoveAllReferences());
82 :
83 : // Lay out new image.
84 E : ImageLayout new_image_layout(&block_graph_);
85 E : CoffImageLayoutBuilder layout_builder(&new_image_layout);
86 E : ASSERT_TRUE(layout_builder.LayoutImage(ordered_graph));
87 :
88 : // Write temporary image file.
89 E : CoffFileWriter writer(&new_image_layout);
90 E : ASSERT_TRUE(writer.WriteImage(new_test_dll_obj_path_));
91 E : }
92 :
93 : // Check that symbols in @p module have been assigned a reference, and
94 : // that writing and parsing the file again yields a symbol table that
95 : // contains them.
96 E : void TestSymbols(const ImportedModule& module) {
97 : // Check resulting references.
98 E : for (size_t i = 0; i < module.size(); ++i) {
99 E : BlockGraph::Reference ref;
100 E : EXPECT_TRUE(module.GetSymbolReference(i, &ref));
101 E : EXPECT_TRUE(ref.referenced() != NULL);
102 E : EXPECT_GE(ref.offset(), 0);
103 E : EXPECT_LT(ref.offset(),
104 : static_cast<BlockGraph::Offset>(ref.referenced()->size()));
105 E : }
106 :
107 : // Rewrite file and parse new symbol table.
108 E : block_graph::orderers::OriginalOrderer orig_orderer;
109 E : ASSERT_NO_FATAL_FAILURE(LayoutAndWriteNew(&orig_orderer));
110 E : CoffFile image_file;
111 E : ASSERT_TRUE(image_file.Init(new_test_dll_obj_path_));
112 :
113 E : size_t num_found = 0;
114 E : size_t num_symbols = image_file.file_header()->NumberOfSymbols;
115 E : const IMAGE_SYMBOL* symbol = NULL;
116 E : for (size_t i = 0; i < num_symbols; i += 1 + symbol->NumberOfAuxSymbols) {
117 E : symbol = image_file.symbol(i);
118 E : const char* name = image_file.GetSymbolName(i);
119 E : for (size_t j = 0; j < module.size(); ++j) {
120 E : if (module.GetSymbolName(j) == name)
121 E : ++num_found;
122 E : }
123 E : }
124 E : EXPECT_EQ(module.size(), num_found);
125 E : }
126 :
127 : base::FilePath test_dll_obj_path_;
128 : base::FilePath new_test_dll_obj_path_;
129 : base::FilePath temp_dir_path_;
130 :
131 : // Original image details.
132 : testing::DummyTransformPolicy policy_;
133 : CoffFile image_file_;
134 : BlockGraph block_graph_;
135 : ImageLayout image_layout_;
136 : BlockGraph::Block* headers_block_;
137 : };
138 :
139 : const char kFunction1Name[] = "__imp_?function1@@YAHXZ";
140 : const char kFunction3Name[] = "?function3@@YAHXZ";
141 : const char kFunction4Name[] = "?function4@@YAHXZ";
142 :
143 : } // namespace
144 :
145 E : TEST_F(CoffAddImportsTransformTest, AddImportsExisting) {
146 E : ImportedModule module("export_dll.dll");
147 : size_t function1 = module.AddSymbol(kFunction1Name,
148 E : ImportedModule::kAlwaysImport);
149 : size_t function3 = module.AddSymbol(kFunction3Name,
150 E : ImportedModule::kAlwaysImport);
151 :
152 E : CoffAddImportsTransform transform;
153 E : transform.AddModule(&module);
154 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
155 E : &transform, &policy_, &block_graph_, headers_block_));
156 E : EXPECT_EQ(0u, transform.modules_added());
157 E : EXPECT_EQ(0u, transform.symbols_added());
158 :
159 E : EXPECT_TRUE(module.ModuleIsImported());
160 E : EXPECT_TRUE(module.SymbolIsImported(function1));
161 E : EXPECT_TRUE(module.SymbolIsImported(function3));
162 :
163 E : EXPECT_FALSE(module.ModuleWasAdded());
164 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
165 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
166 :
167 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
168 E : module.GetSymbolImportIndex(function1));
169 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
170 E : module.GetSymbolImportIndex(function3));
171 :
172 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
173 E : }
174 :
175 E : TEST_F(CoffAddImportsTransformTest, AddImportsNewSymbol) {
176 E : ImportedModule module("export_dll.dll");
177 : size_t function1 = module.AddSymbol(kFunction1Name,
178 E : ImportedModule::kAlwaysImport);
179 : size_t function3 = module.AddSymbol(kFunction3Name,
180 E : ImportedModule::kAlwaysImport);
181 : size_t function4 = module.AddSymbol(kFunction4Name,
182 E : ImportedModule::kAlwaysImport);
183 :
184 E : CoffAddImportsTransform transform;
185 E : transform.AddModule(&module);
186 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
187 E : &transform, &policy_, &block_graph_, headers_block_));
188 E : EXPECT_EQ(0u, transform.modules_added());
189 E : EXPECT_EQ(1u, transform.symbols_added());
190 :
191 E : EXPECT_TRUE(module.ModuleIsImported());
192 E : EXPECT_TRUE(module.SymbolIsImported(function1));
193 E : EXPECT_TRUE(module.SymbolIsImported(function3));
194 E : EXPECT_TRUE(module.SymbolIsImported(function4));
195 :
196 E : EXPECT_FALSE(module.ModuleWasAdded());
197 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
198 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
199 E : EXPECT_TRUE(module.SymbolWasAdded(function4));
200 :
201 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
202 E : module.GetSymbolImportIndex(function1));
203 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
204 E : module.GetSymbolImportIndex(function3));
205 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
206 E : module.GetSymbolImportIndex(function4));
207 :
208 E : EXPECT_NO_FATAL_FAILURE(TestSymbols(module));
209 E : }
210 :
211 E : TEST_F(CoffAddImportsTransformTest, FindImportsExisting) {
212 E : ImportedModule module("export_dll.dll");
213 : size_t function1 = module.AddSymbol(kFunction1Name,
214 E : ImportedModule::kFindOnly);
215 : size_t function3 = module.AddSymbol(kFunction3Name,
216 E : ImportedModule::kFindOnly);
217 :
218 E : CoffAddImportsTransform transform;
219 E : transform.AddModule(&module);
220 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
221 E : &transform, &policy_, &block_graph_, headers_block_));
222 E : EXPECT_EQ(0u, transform.modules_added());
223 E : EXPECT_EQ(0u, transform.symbols_added());
224 :
225 E : EXPECT_TRUE(module.ModuleIsImported());
226 E : EXPECT_TRUE(module.SymbolIsImported(function1));
227 E : EXPECT_TRUE(module.SymbolIsImported(function3));
228 :
229 E : EXPECT_FALSE(module.ModuleWasAdded());
230 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
231 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
232 :
233 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
234 E : module.GetSymbolImportIndex(function1));
235 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
236 E : module.GetSymbolImportIndex(function3));
237 E : }
238 :
239 E : TEST_F(CoffAddImportsTransformTest, FindImportsNewSymbol) {
240 E : ImportedModule module("export_dll.dll");
241 : size_t function1 = module.AddSymbol(kFunction1Name,
242 E : ImportedModule::kFindOnly);
243 : size_t function3 = module.AddSymbol(kFunction3Name,
244 E : ImportedModule::kFindOnly);
245 : size_t function4 = module.AddSymbol(kFunction4Name,
246 E : ImportedModule::kFindOnly);
247 :
248 E : CoffAddImportsTransform transform;
249 E : transform.AddModule(&module);
250 : EXPECT_TRUE(block_graph::ApplyBlockGraphTransform(
251 E : &transform, &policy_, &block_graph_, headers_block_));
252 E : EXPECT_EQ(0u, transform.modules_added());
253 E : EXPECT_EQ(0u, transform.symbols_added());
254 :
255 E : EXPECT_TRUE(module.ModuleIsImported());
256 E : EXPECT_TRUE(module.SymbolIsImported(function1));
257 E : EXPECT_TRUE(module.SymbolIsImported(function3));
258 E : EXPECT_FALSE(module.SymbolIsImported(function4));
259 :
260 E : EXPECT_FALSE(module.ModuleWasAdded());
261 E : EXPECT_FALSE(module.SymbolWasAdded(function1));
262 E : EXPECT_FALSE(module.SymbolWasAdded(function3));
263 E : EXPECT_FALSE(module.SymbolWasAdded(function4));
264 :
265 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
266 E : module.GetSymbolImportIndex(function1));
267 : EXPECT_NE(ImportedModule::kInvalidImportIndex,
268 E : module.GetSymbolImportIndex(function3));
269 : EXPECT_EQ(ImportedModule::kInvalidImportIndex,
270 E : module.GetSymbolImportIndex(function4));
271 E : }
272 :
273 : } // namespace transforms
274 : } // namespace pe
|