1 : // Copyright 2015 Google Inc. All Rights Reserved.
2 : //
3 : // Licensed under the Apache License, Version 2.0 (the "License");
4 : // you may not use this file except in compliance with the License.
5 : // You may obtain a copy of the License at
6 : //
7 : // http://www.apache.org/licenses/LICENSE-2.0
8 : //
9 : // Unless required by applicable law or agreed to in writing, software
10 : // distributed under the License is distributed on an "AS IS" BASIS,
11 : // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 : // See the License for the specific language governing permissions and
13 : // limitations under the License.
14 :
15 : #include "syzygy/refinery/validators/vftable_ptr_validator.h"
16 :
17 : #include <string>
18 :
19 : #include "base/containers/hash_tables.h"
20 : #include "base/strings/string_util.h"
21 : #include "base/win/scoped_comptr.h"
22 : #include "gmock/gmock.h"
23 : #include "gtest/gtest.h"
24 : #include "syzygy/pe/pe_file.h"
25 : #include "syzygy/refinery/process_state/process_state.h"
26 : #include "syzygy/refinery/process_state/process_state_util.h"
27 :
28 m : namespace refinery {
29 :
30 m : using testing::_;
31 m : using testing::DoAll;
32 m : using testing::Return;
33 m : using testing::SetArgPointee;
34 :
35 m : namespace {
36 :
37 m : const Address kAddress = 1000ULL; // Fits 32-bit.
38 m : const Address kAddressOther = 2000ULL; // Fits 32-bit.
39 m : const Address kUdtAddress = 9000ULL; // Fits 32-bit.
40 m : const Size kSize = 42U;
41 m : const Size kSizeOther = 43U;
42 m : const uint32_t kChecksum = 11U;
43 m : const uint32_t kChecksumOther = 12U;
44 m : const uint32_t kTimestamp = 22U;
45 m : const wchar_t kPath[] = L"c:\\path\\ModuleName";
46 m : const wchar_t kPathOther[] = L"c:\\path\\ModuleNameOther";
47 :
48 m : class MockSymbolProvider : public SymbolProvider {
49 m : public:
50 m : MOCK_METHOD2(FindOrCreateTypeRepository,
51 m : bool(const pe::PEFile::Signature& signature,
52 m : scoped_refptr<TypeRepository>* type_repo));
53 m : MOCK_METHOD2(GetVFTableRVAs,
54 m : bool(const pe::PEFile::Signature& signature,
55 m : base::hash_set<RelativeAddress>* vftable_rvas));
56 m : };
57 :
58 m : class TestVftablePtrValidator : public VftablePtrValidator {
59 m : public:
60 m : using VftablePtrValidator::GetVFTableVAs;
61 m : };
62 :
63 m : void AddBytesRecord(ProcessState* state, Address address, uintptr_t value) {
64 m : DCHECK(state);
65 :
66 m : BytesLayerPtr bytes_layer;
67 m : state->FindOrCreateLayer(&bytes_layer);
68 m : BytesRecordPtr bytes_record;
69 m : bytes_layer->CreateRecord(AddressRange(address, sizeof(value)),
70 m : &bytes_record);
71 m : Bytes* bytes_proto = bytes_record->mutable_data();
72 m : std::string* buffer = bytes_proto->mutable_data();
73 m : memcpy(base::WriteInto(buffer, sizeof(value) + 1), &value, sizeof(value));
74 m : }
75 :
76 m : } // namespace
77 :
78 : // Sets up a process state with a single typed block. The bytes layer is empty
79 : // and up to the specific tests.
80 m : class VftablePtrValidatorSyntheticTest : public testing::Test {
81 m : protected:
82 m : void SetUp() override {
83 m : testing::Test::SetUp();
84 :
85 : // Add a module to the process state.
86 m : ModuleLayerAccessor accessor(&state_);
87 m : accessor.AddModuleRecord(AddressRange(kAddress, kSize), kChecksum,
88 m : kTimestamp, kPath);
89 m : ModuleId module_id = accessor.GetModuleId(kAddress);
90 m : pe::PEFile::Signature module_signature;
91 m : ASSERT_TRUE(accessor.GetModuleSignature(module_id, &module_signature));
92 :
93 : // Create a type repository for the module, then a UDT.
94 m : repository_ = new TypeRepository(module_signature);
95 m : UserDefinedTypePtr udt = AddSimpleUDTWithVfptr(repository_.get());
96 :
97 : // Add a typed block.
98 m : udt_range_ = AddressRange(kUdtAddress, udt->size());
99 m : AddTypedBlockRecord(udt_range_, L"udt", module_id, udt->type_id(), &state_);
100 :
101 : // Build the allowed set of vfptr rvas and set the expected vfptr value.
102 m : const Address kVfptrRva = 10U;
103 m : ASSERT_LT(kVfptrRva, kSize);
104 m : base::hash_set<RelativeAddress> rvas;
105 m : rvas.insert(kVfptrRva);
106 m : expected_vfptr_ = kAddress + kVfptrRva;
107 :
108 : // Ensure the bytes layer exists.
109 m : BytesLayerPtr bytes_layer;
110 m : state_.FindOrCreateLayer(&bytes_layer);
111 :
112 : // Create the symbol provider and set expectations.
113 m : mock_provider_ = new MockSymbolProvider();
114 m : EXPECT_CALL(*mock_provider_, GetVFTableRVAs(module_signature, testing::_))
115 m : .WillOnce(DoAll(SetArgPointee<1>(rvas), Return(true)));
116 m : EXPECT_CALL(*mock_provider_,
117 m : FindOrCreateTypeRepository(module_signature, _))
118 m : .Times(1)
119 m : .WillOnce(DoAll(SetArgPointee<1>(repository_), Return(true)));
120 m : }
121 :
122 m : UserDefinedTypePtr AddSimpleUDTWithVfptr(TypeRepository* repo) {
123 m : DCHECK(repo);
124 :
125 : // Create a vfptr type: a pointer to a vtshape type.
126 : // TODO(manzagop): update this to a vtshape type once it exists.
127 m : TypePtr vtshape_type = new WildcardType(L"vtshape", 4U);
128 m : repo->AddType(vtshape_type);
129 :
130 m : PointerTypePtr vfptr_type =
131 m : new PointerType(sizeof(uintptr_t), PointerType::PTR_MODE_PTR);
132 m : vfptr_type->Finalize(kNoTypeFlags, vtshape_type->type_id());
133 m : repo->AddType(vfptr_type);
134 :
135 : // Create a UDT. It (artificially) only has a vftptr.
136 m : UserDefinedTypePtr other_udt =
137 m : new UserDefinedType(L"other", L"decorated_other", vfptr_type->size(),
138 m : UserDefinedType::UDT_CLASS);
139 m : repo->AddType(other_udt);
140 m : {
141 m : UserDefinedType::Fields fields;
142 m : fields.push_back(
143 m : new UserDefinedType::VfptrField(0, vfptr_type->type_id(), repo));
144 m : UserDefinedType::Functions functions;
145 m : other_udt->Finalize(&fields, &functions);
146 m : }
147 :
148 : // Create another UDT. This also is an artificial type: it has the other UDT
149 : // as both base class and member, as well as a vfptr (yet not virtual
150 : // function).
151 m : UserDefinedTypePtr udt;
152 m : {
153 m : UserDefinedType::Fields fields;
154 m : UserDefinedType::Functions functions;
155 m : ptrdiff_t size = 0;
156 :
157 m : base_field_ =
158 m : new UserDefinedType::BaseClassField(size, other_udt->type_id(), repo);
159 m : fields.push_back(base_field_);
160 m : size += other_udt->size();
161 :
162 m : member_field_ = new UserDefinedType::MemberField(
163 m : L"member", size, kNoTypeFlags, 0, 0, other_udt->type_id(), repo);
164 m : fields.push_back(member_field_);
165 m : size += other_udt->size();
166 :
167 m : vfptr_field_ =
168 m : new UserDefinedType::VfptrField(size, vfptr_type->type_id(), repo);
169 m : fields.push_back(vfptr_field_);
170 m : size += vfptr_type->size();
171 :
172 m : udt = new UserDefinedType(L"foo", L"decorated_foo", size,
173 m : UserDefinedType::UDT_CLASS);
174 m : repo->AddType(udt);
175 m : udt->Finalize(&fields, &functions);
176 m : }
177 :
178 m : return udt;
179 m : }
180 :
181 m : void Validate(bool expect_error) {
182 m : VftablePtrValidator validator(mock_provider_);
183 m : ValidationReport report;
184 m : ASSERT_EQ(Validator::VALIDATION_COMPLETE,
185 m : validator.Validate(&state_, &report));
186 :
187 m : if (expect_error) {
188 m : ASSERT_EQ(1, report.error_size());
189 m : ASSERT_EQ(VIOLATION_VFPTR, report.error(0).type());
190 m : } else {
191 m : ASSERT_EQ(0, report.error_size());
192 m : }
193 m : }
194 :
195 m : ProcessState state_;
196 m : scoped_refptr<TypeRepository> repository_;
197 m : scoped_refptr<MockSymbolProvider> mock_provider_;
198 :
199 m : BaseClassFieldPtr base_field_;
200 m : MemberFieldPtr member_field_;
201 m : VfptrFieldPtr vfptr_field_;
202 :
203 m : AddressRange udt_range_;
204 m : Address expected_vfptr_;
205 m : };
206 :
207 m : TEST_F(VftablePtrValidatorSyntheticTest, NoBytesCase) {
208 : // No bytes to validate against. Expect no error.
209 m : ASSERT_NO_FATAL_FAILURE(Validate(false));
210 m : }
211 :
212 m : TEST_F(VftablePtrValidatorSyntheticTest, ValidBytesCase) {
213 : // Valid bytes. Expect no error.
214 m : AddBytesRecord(&state_, kUdtAddress + vfptr_field_->offset(),
215 m : expected_vfptr_);
216 m : ASSERT_NO_FATAL_FAILURE(Validate(false));
217 m : }
218 :
219 m : TEST_F(VftablePtrValidatorSyntheticTest, InvalidBytesCase) {
220 : // Invalid bytes. Expect an error.
221 m : AddBytesRecord(&state_, kUdtAddress + vfptr_field_->offset(),
222 m : expected_vfptr_ + 1);
223 m : ASSERT_NO_FATAL_FAILURE(Validate(true));
224 m : }
225 :
226 m : TEST_F(VftablePtrValidatorSyntheticTest, BaseClassValidBytesCase) {
227 : // Valid bytes. Expect no error.
228 m : AddBytesRecord(&state_, kUdtAddress + base_field_->offset(),
229 m : expected_vfptr_);
230 m : ASSERT_NO_FATAL_FAILURE(Validate(false));
231 m : }
232 :
233 m : TEST_F(VftablePtrValidatorSyntheticTest, BaseClassInvalidBytesCase) {
234 : // Invalid bytes. Expect an error.
235 m : AddBytesRecord(&state_, kUdtAddress + base_field_->offset(),
236 m : expected_vfptr_ + 1);
237 m : ASSERT_NO_FATAL_FAILURE(Validate(true));
238 m : }
239 :
240 m : TEST_F(VftablePtrValidatorSyntheticTest, MemberValidBytesCase) {
241 : // Valid bytes. Expect no error.
242 m : AddBytesRecord(&state_, kUdtAddress + member_field_->offset(),
243 m : expected_vfptr_);
244 m : ASSERT_NO_FATAL_FAILURE(Validate(false));
245 m : }
246 :
247 m : TEST_F(VftablePtrValidatorSyntheticTest, MemberInvalidBytesCase) {
248 : // Invalid bytes. Expect an error.
249 m : AddBytesRecord(&state_, kUdtAddress + member_field_->offset(),
250 m : expected_vfptr_ + 1);
251 m : ASSERT_NO_FATAL_FAILURE(Validate(true));
252 m : }
253 :
254 m : TEST(VftablePtrValidatorTest, GetVFTableVAs) {
255 : // Create a process state with 2 modules.
256 m : ProcessState state;
257 m : ModuleLayerAccessor accessor(&state);
258 m : accessor.AddModuleRecord(AddressRange(kAddress, kSize), kChecksum, kTimestamp,
259 m : kPath);
260 m : accessor.AddModuleRecord(AddressRange(kAddressOther, kSizeOther),
261 m : kChecksumOther, kTimestamp, kPathOther);
262 :
263 : // Set up the symbol provider.
264 m : scoped_refptr<MockSymbolProvider> provider = new MockSymbolProvider();
265 :
266 m : pe::PEFile::Signature signature;
267 m : ASSERT_TRUE(accessor.GetModuleSignature(kAddress, &signature));
268 m : signature.base_address = core::AbsoluteAddress(0U);
269 m : base::hash_set<RelativeAddress> rvas;
270 m : rvas.insert(1ULL);
271 m : rvas.insert(2ULL);
272 m : EXPECT_CALL(*provider, GetVFTableRVAs(signature, testing::_))
273 m : .WillOnce(DoAll(SetArgPointee<1>(rvas), Return(true)));
274 :
275 m : pe::PEFile::Signature signature_other;
276 m : ASSERT_TRUE(accessor.GetModuleSignature(kAddressOther, &signature_other));
277 m : signature_other.base_address = core::AbsoluteAddress(0U);
278 m : base::hash_set<RelativeAddress> rvas_other;
279 m : rvas_other.insert(3ULL);
280 m : rvas_other.insert(4ULL);
281 m : EXPECT_CALL(*provider, GetVFTableRVAs(signature_other, testing::_))
282 m : .WillOnce(DoAll(SetArgPointee<1>(rvas_other), Return(true)));
283 :
284 : // Retrieve VAs and validate.
285 m : base::hash_set<RelativeAddress> vftable_vas;
286 m : ASSERT_TRUE(TestVftablePtrValidator::GetVFTableVAs(&state, provider.get(),
287 m : &vftable_vas));
288 :
289 m : base::hash_set<RelativeAddress> expected_vftable_vas;
290 m : expected_vftable_vas.insert(kAddress + 1ULL);
291 m : expected_vftable_vas.insert(kAddress + 2ULL);
292 m : expected_vftable_vas.insert(kAddressOther + 3ULL);
293 m : expected_vftable_vas.insert(kAddressOther + 4ULL);
294 :
295 m : ASSERT_EQ(expected_vftable_vas, vftable_vas);
296 m : }
297 :
298 m : } // namespace refinery
|